|
9 | 9 | import glob
|
10 | 10 | import json
|
11 | 11 | import os
|
| 12 | +import re |
12 | 13 | import shutil
|
13 | 14 | import tempfile
|
14 | 15 | import unittest
|
@@ -267,6 +268,11 @@ def strip_q(q):
|
267 | 268 | return output
|
268 | 269 |
|
269 | 270 |
|
| 271 | +def _empty_llmblock_generate(self, samples): |
| 272 | + """Return an empty set of generated samples.""" |
| 273 | + return [] |
| 274 | + |
| 275 | + |
270 | 276 | @patch.object(LLMBlock, "_generate", _noop_llmblock_generate)
|
271 | 277 | class TestGenerateCompositionalData(unittest.TestCase):
|
272 | 278 | @pytest.fixture(autouse=True)
|
@@ -452,6 +458,61 @@ def __exit__(self, *args):
|
452 | 458 | self.teardown()
|
453 | 459 |
|
454 | 460 |
|
| 461 | +@patch.object(LLMBlock, "_generate", _empty_llmblock_generate) |
| 462 | +class TestGenerateEmptyDataset(unittest.TestCase): |
| 463 | + @pytest.fixture(autouse=True) |
| 464 | + def _init_taxonomy(self, taxonomy_dir): |
| 465 | + self.test_taxonomy = taxonomy_dir |
| 466 | + |
| 467 | + def setUp(self): |
| 468 | + self.tmp_path = tempfile.TemporaryDirectory().name |
| 469 | + test_valid_knowledge_skill_file = os.path.join( |
| 470 | + TEST_DATA_DIR, "test_valid_knowledge_skill.yaml" |
| 471 | + ) |
| 472 | + tracked_knowledge_file = os.path.join("knowledge ", "tracked", "qna.yaml") |
| 473 | + untracked_knowledge_file = os.path.join("knowledge", "new", "qna.yaml") |
| 474 | + test_valid_knowledge_skill = load_test_skills(test_valid_knowledge_skill_file) |
| 475 | + self.test_taxonomy.add_tracked( |
| 476 | + tracked_knowledge_file, test_valid_knowledge_skill |
| 477 | + ) |
| 478 | + self.test_taxonomy.create_untracked( |
| 479 | + untracked_knowledge_file, test_valid_knowledge_skill |
| 480 | + ) |
| 481 | + |
| 482 | + def test_generate(self): |
| 483 | + with patch("logging.Logger.info") as mocked_logger: |
| 484 | + generate_data( |
| 485 | + client=MagicMock(), |
| 486 | + logger=mocked_logger, |
| 487 | + model_family="merlinite", |
| 488 | + model_name="models/merlinite-7b-lab-Q4_K_M.gguf", |
| 489 | + num_instructions_to_generate=10, |
| 490 | + taxonomy=self.test_taxonomy.root, |
| 491 | + taxonomy_base=TEST_TAXONOMY_BASE, |
| 492 | + output_dir=self.tmp_path, |
| 493 | + chunk_word_count=1000, |
| 494 | + server_ctx_size=4096, |
| 495 | + pipeline="simple", |
| 496 | + ) |
| 497 | + mocked_logger.warning.assert_called() |
| 498 | + assert re.search( |
| 499 | + "empty sdg output: knowledge_new", mocked_logger.warning.call_args.args[0] |
| 500 | + ) |
| 501 | + |
| 502 | + def teardown(self) -> None: |
| 503 | + """Recursively remove the temporary repository and all of its |
| 504 | + subdirectories and files. |
| 505 | + """ |
| 506 | + shutil.rmtree(self.tmp_path) |
| 507 | + return |
| 508 | + |
| 509 | + def __enter__(self): |
| 510 | + return self |
| 511 | + |
| 512 | + def __exit__(self, *args): |
| 513 | + self.teardown() |
| 514 | + |
| 515 | + |
455 | 516 | def test_context_init_batch_size_optional():
|
456 | 517 | """Test that the _context_init function can handle a missing batch size by
|
457 | 518 | delegating to the default in PipelineContext.
|
|
0 commit comments