Skip to content

Commit 4bd07f5

Browse files
Merge pull request #272 from relyt0925/issue-240
Handle empty dataset from output of sdg leaf node without raising error
2 parents 2b73f2b + 59ba8a5 commit 4bd07f5

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

src/instructlab/sdg/generate_data.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from instructlab.sdg.pipeline import (
2626
FULL_PIPELINES_PACKAGE,
2727
SIMPLE_PIPELINES_PACKAGE,
28-
EmptyDatasetError,
2928
Pipeline,
3029
PipelineContext,
3130
)
@@ -359,6 +358,7 @@ def generate_data(
359358
)
360359

361360
generated_data = None
361+
empty_sdg_leaf_nodes = []
362362
for leaf_node in leaf_nodes.values():
363363
is_knowledge = False
364364
leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_")
@@ -382,9 +382,9 @@ def generate_data(
382382
logger.debug("Dataset: %s", ds)
383383
new_generated_data = pipe.generate(ds, leaf_node_path)
384384
if len(new_generated_data) == 0:
385-
raise EmptyDatasetError(
386-
"Pipeline stopped: Empty dataset after running pipe"
387-
)
385+
empty_sdg_leaf_nodes.append(leaf_node_path)
386+
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
387+
continue
388388
generated_data = (
389389
[new_generated_data]
390390
if generated_data is None
@@ -418,3 +418,9 @@ def generate_data(
418418

419419
generate_duration = time.time() - generate_start
420420
logger.info(f"Generation took {generate_duration:.2f}s")
421+
if len(empty_sdg_leaf_nodes) > 0:
422+
logger.warning(
423+
"Leaf nodes with empty sdg output: {}".format(
424+
" ".join(empty_sdg_leaf_nodes)
425+
)
426+
)

tests/test_generate_data.py

+61
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import glob
1010
import json
1111
import os
12+
import re
1213
import shutil
1314
import tempfile
1415
import unittest
@@ -267,6 +268,11 @@ def strip_q(q):
267268
return output
268269

269270

271+
def _empty_llmblock_generate(self, samples):
272+
"""Return an empty set of generated samples."""
273+
return []
274+
275+
270276
@patch.object(LLMBlock, "_generate", _noop_llmblock_generate)
271277
class TestGenerateCompositionalData(unittest.TestCase):
272278
@pytest.fixture(autouse=True)
@@ -452,6 +458,61 @@ def __exit__(self, *args):
452458
self.teardown()
453459

454460

461+
@patch.object(LLMBlock, "_generate", _empty_llmblock_generate)
462+
class TestGenerateEmptyDataset(unittest.TestCase):
463+
@pytest.fixture(autouse=True)
464+
def _init_taxonomy(self, taxonomy_dir):
465+
self.test_taxonomy = taxonomy_dir
466+
467+
def setUp(self):
468+
self.tmp_path = tempfile.TemporaryDirectory().name
469+
test_valid_knowledge_skill_file = os.path.join(
470+
TEST_DATA_DIR, "test_valid_knowledge_skill.yaml"
471+
)
472+
tracked_knowledge_file = os.path.join("knowledge ", "tracked", "qna.yaml")
473+
untracked_knowledge_file = os.path.join("knowledge", "new", "qna.yaml")
474+
test_valid_knowledge_skill = load_test_skills(test_valid_knowledge_skill_file)
475+
self.test_taxonomy.add_tracked(
476+
tracked_knowledge_file, test_valid_knowledge_skill
477+
)
478+
self.test_taxonomy.create_untracked(
479+
untracked_knowledge_file, test_valid_knowledge_skill
480+
)
481+
482+
def test_generate(self):
483+
with patch("logging.Logger.info") as mocked_logger:
484+
generate_data(
485+
client=MagicMock(),
486+
logger=mocked_logger,
487+
model_family="merlinite",
488+
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
489+
num_instructions_to_generate=10,
490+
taxonomy=self.test_taxonomy.root,
491+
taxonomy_base=TEST_TAXONOMY_BASE,
492+
output_dir=self.tmp_path,
493+
chunk_word_count=1000,
494+
server_ctx_size=4096,
495+
pipeline="simple",
496+
)
497+
mocked_logger.warning.assert_called()
498+
assert re.search(
499+
"empty sdg output: knowledge_new", mocked_logger.warning.call_args.args[0]
500+
)
501+
502+
def teardown(self) -> None:
503+
"""Recursively remove the temporary repository and all of its
504+
subdirectories and files.
505+
"""
506+
shutil.rmtree(self.tmp_path)
507+
return
508+
509+
def __enter__(self):
510+
return self
511+
512+
def __exit__(self, *args):
513+
self.teardown()
514+
515+
455516
def test_context_init_batch_size_optional():
456517
"""Test that the _context_init function can handle a missing batch size by
457518
delegating to the default in PipelineContext.

0 commit comments

Comments
 (0)