Skip to content

Commit

Permalink
Fix #961 by instantiating the batch fetcher with the step count and i…
Browse files Browse the repository at this point in the history
…ncrementing locally instead of globally
  • Loading branch information
bghira committed Sep 10, 2024
1 parent c37cec1 commit 28cbb34
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
9 changes: 5 additions & 4 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,33 +1299,34 @@ def random_dataloader_iterator(step, backends: dict):


class BatchFetcher:
def __init__(self, max_size=10, datasets={}):
def __init__(self, step, max_size=10, datasets={}):
self.queue = queue.Queue(max_size)
self.datasets = datasets
self.keep_running = True
self.step = step

def start_fetching(self):
thread = threading.Thread(target=self.fetch_responses)
thread.start()
return thread

def fetch_responses(self):
global step
prefetch_log_debug("Launching retrieval thread.")
while self.keep_running:
if self.queue.qsize() < self.queue.maxsize:
prefetch_log_debug(
f"Queue size: {self.queue.qsize()}. Fetching more data."
)
self.queue.put(random_dataloader_iterator(self.datasets))
self.queue.put(random_dataloader_iterator(self.step, self.datasets))
if self.queue.qsize() >= self.queue.maxsize:
prefetch_log_debug("Completed fetching data. Queue is full.")
continue
else:
time.sleep(0.5)
prefetch_log_debug("Exiting retrieval thread.")

def next_response(self):
def next_response(self, step: int):
self.step = step
if self.queue.empty():
prefetch_log_debug("Queue is empty. Waiting for data.")
while self.queue.empty():
Expand Down
4 changes: 3 additions & 1 deletion helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ def init_data_backend(self):
message_level="critical",
)

return False
raise e

self.init_validation_prompts()
# We calculate the number of steps per epoch by dividing the number of images by the effective batch divisor.
# Gradient accumulation steps mean that we only update the model weights every /n/ steps.
Expand Down Expand Up @@ -1585,6 +1586,7 @@ def train(self):
self.bf = BatchFetcher(
datasets=train_backends,
max_size=self.config.dataloader_prefetch_qlen,
step=step,
)
if fetch_thread is not None:
fetch_thread.join()
Expand Down

0 comments on commit 28cbb34

Please sign in to comment.