Skip to content

Commit 0fe0770

Browse files
committed
Refactor pipeline threading and simplify sdg batch processing.
Added logging to track remaining threads during pipeline execution for better debugging. Removed redundant batching logic in block processing to fix concurrency bug slowing down sdg.
1 parent 6454380 commit 0fe0770

File tree

1 file changed

+11
-19
lines changed

1 file changed

+11
-19
lines changed

src/instructlab/sdg/pipeline.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,20 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:
170170
executor.submit(self._generate_single, input_split)
171171
for input_split in input_splits
172172
]
173-
173+
threads_remaining_to_execute = len(futures)
174+
logger.info(
175+
"Total of %d pipeline threads to execute",
176+
len(futures),
177+
)
174178
# Collect the results of each batch as they finish. This needs to
175179
# wait for them all, so the order of waiting doesn't matter
176180
for future in futures:
177181
ds = future.result()
182+
threads_remaining_to_execute-=1
183+
logger.info(
184+
"Total of %d pipeline threads to check for completion",
185+
threads_remaining_to_execute,
186+
)
178187
output_splits.append(ds)
179188
checkpointer.checkpoint(ds)
180189
checkpointer.done()
@@ -197,24 +206,7 @@ def _generate_single(self, dataset) -> Dataset:
197206
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
198207
block = block_type(self.ctx, self, block_name, **block_config)
199208
logger.info("Running block: %s", block_name)
200-
201-
# Check if batching is enabled
202-
if not self.ctx.batching_enabled:
203-
logger.info(
204-
"Batching disabled; processing block '%s' single-threaded.",
205-
block_name,
206-
)
207-
dataset = block.generate(dataset)
208-
else:
209-
# Split the dataset into batches
210-
input_splits = self._split_dataset(dataset)
211-
# Process each batch in sequence
212-
output_splits = [
213-
block.generate(input_split) for input_split in input_splits
214-
]
215-
# Combine the processed splits back into a single dataset
216-
dataset = concatenate_datasets(output_splits)
217-
209+
dataset = block.generate(dataset)
218210
# If the dataset is empty after processing, terminate early
219211
if len(dataset) == 0:
220212
return dataset

0 commit comments

Comments
 (0)