Skip to content

Commit 2268a6f

Browse files
authored
Merge pull request #672 from instructlab/mergify/bp/release-v0.8/pr-652
Filter out blank content when saving mixed dataset (backport #652)
2 parents 7381ca5 + 3cdc70b commit 2268a6f

File tree

4 files changed

+133
-0
lines changed

4 files changed

+133
-0
lines changed

src/instructlab/sdg/datamixing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,16 @@ def save_mixed_dataset(self, output_path, num_proc):
231231
as a jsonl file.
232232
"""
233233
mixed_ds = self._create_mixed_dataset(num_proc)
234+
235+
# filter out any records where the any message content is None
236+
mixed_ds = mixed_ds.filter(
237+
lambda x: all(
238+
message.get("content")
239+
for message in x["messages"]
240+
if message.get("role") != "system"
241+
)
242+
)
243+
234244
mixed_ds.to_json(output_path, orient="records", lines=True)
235245
logger.info(f"Mixed Dataset saved to {output_path}")
236246

src/instructlab/sdg/generate_data.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def _gen_train_data(
107107
if len(synth_example.get("context", "")) > 0:
108108
user += "\n" + synth_example["context"]
109109
assistant = _unescape(_get_response_hack(synth_example))
110+
# filter out any assistant message that is empty
111+
if not assistant:
112+
continue
113+
110114
train_entry = {
111115
"system": system_prompt,
112116
"user": _unescape(user),
@@ -594,6 +598,13 @@ def postprocess_taxonomy(
594598
if leaf_node_type == "knowledge":
595599
is_knowledge = True
596600

601+
if is_knowledge:
602+
# Filter out rows with no document, they cause errors in the datamixing code
603+
for i in range(len(samples) - 1, -1, -1):
604+
if not samples[i].get("document"):
605+
logger.warning("Removing sample without document: %s", samples[i])
606+
samples.pop(i)
607+
597608
samples_ds = Dataset.from_list(samples)
598609
logger.debug("Postprocessing from samples: %s", samples_ds)
599610
all_generated_data.append(samples_ds)

tests/test_datamixing.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,67 @@ def test_mix_instructlab_07x_precomputed_skills_with_unmask(tmp_path):
422422
assert (
423423
sample.get("unmask", None) is not None
424424
), "Mixed sample does not have unmask"
425+
426+
427+
def test_save_mixed_dataset_with_none_content(tmp_path):
428+
"""
429+
Test that we filter out mixed dataset records where any message content is None.
430+
"""
431+
432+
# Create a knowledge dataset
433+
knowledge_dataset = load_auxiliary_dataset()
434+
number_of_records = len(knowledge_dataset)
435+
# append a record with content=None and content="", both should be filtered out
436+
knowledge_dataset = knowledge_dataset.add_item(
437+
{
438+
"id": "test_001",
439+
"messages": [
440+
{"role": "system", "content": "You are a helpful assistant."},
441+
{"role": "user", "content": "What is the capital of Ireland?"},
442+
{"role": "assistant", "content": None},
443+
],
444+
}
445+
)
446+
knowledge_dataset = knowledge_dataset.add_item(
447+
{
448+
"id": "test_002",
449+
"messages": [
450+
{"role": "system", "content": "You are a helpful assistant."},
451+
{"role": "user", "content": "What is the capital of Ireland?"},
452+
{"role": "assistant", "content": "Dublin"},
453+
],
454+
}
455+
)
456+
457+
knowledge_dataset = knowledge_dataset.add_item(
458+
{
459+
"id": "test_003",
460+
"messages": [
461+
{"role": "system", "content": "You are a helpful assistant."},
462+
{"role": "user", "content": "What is the capital of Ireland?"},
463+
{"role": "assistant", "content": ""},
464+
],
465+
}
466+
)
467+
468+
knowledge_path = os.path.join(tmp_path, "knowledge.jsonl")
469+
jldump(knowledge_dataset, knowledge_path)
470+
471+
output_path = os.path.join(tmp_path, "output.jsonl")
472+
recipe = Recipe()
473+
recipe.add_dataset(knowledge_path, 1.0)
474+
recipe.save_mixed_dataset(output_path, TEST_NUM_PROCS)
475+
476+
# Ensure the mixed dataset is saved correctly
477+
mixed_samples = load_dataset("json", data_files=output_path, split="train")
478+
479+
# the row with content=None should have been removed
480+
assert (
481+
len(mixed_samples) == number_of_records + 1
482+
), f"Expected {number_of_records + 1} records in mixed dataset"
483+
484+
# None of the mixed samples should have content=None
485+
for sample in mixed_samples:
486+
assert all(
487+
[message.get("content") is not None for message in sample["messages"]]
488+
), "Mixed sample has content=None"

tests/test_generate_data.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
from instructlab.sdg import LLMBlock, PipelineContext
2424
from instructlab.sdg.generate_data import (
2525
_context_init,
26+
_gen_train_data,
2627
_locate_docling_models,
2728
_sdg_init,
2829
generate_data,
2930
)
31+
from instructlab.sdg.utils.json import jlload
3032

3133
# Local
3234
from .taxonomy import load_test_skills
@@ -586,3 +588,49 @@ def test_locate_docling_models_config_not_found(testdata_path):
586588
os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("nonexistent_dir"))
587589
docling_model_path = _locate_docling_models()
588590
assert docling_model_path is None
591+
592+
593+
class TestGenTrainData(unittest.TestCase):
594+
"""Test the _gen_train_data function with small synthetic examples."""
595+
596+
def setUp(self):
597+
self.test_dir = tempfile.mkdtemp()
598+
self.system_prompt = "Test system prompt"
599+
600+
def tearDown(self):
601+
shutil.rmtree(self.test_dir)
602+
603+
def test_gen_train_data_with_empty_response(self):
604+
"""Test _gen_train_data with synthetic examples with blank responses."""
605+
# Create mock synthetic examples with blank responses
606+
machine_instruction_data = [
607+
[
608+
{"question": "Q1", "response": "", "context": "C1"},
609+
{"question": "Q2", "response": "A2", "context": "C2"},
610+
]
611+
]
612+
613+
output_file_train = os.path.join(self.test_dir, "train_test.jsonl")
614+
output_file_messages = os.path.join(self.test_dir, "messages_test.jsonl")
615+
616+
# Call the function
617+
_gen_train_data(
618+
machine_instruction_data,
619+
output_file_train,
620+
output_file_messages,
621+
self.system_prompt,
622+
)
623+
624+
# Verify train file was created and only has a single sample
625+
self.assertTrue(os.path.exists(output_file_train))
626+
train_data = jlload(output_file_train)
627+
self.assertEqual(len(train_data), 1)
628+
629+
# Check first sample
630+
first_sample = train_data[0]
631+
self.assertEqual(first_sample["system"], self.system_prompt)
632+
self.assertEqual(first_sample["user"], "Q2\nC2")
633+
self.assertEqual(first_sample["assistant"], "A2")
634+
635+
# Verify messages file was created and has correct content
636+
self.assertTrue(os.path.exists(output_file_messages))

0 commit comments

Comments
 (0)