|
1 | 1 | # Standard |
2 | 2 | from dataclasses import dataclass, field |
3 | 3 | from multiprocessing import Pool |
4 | | -from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union |
| 4 | +from typing import Any, Dict, List, TypedDict, TypeVar, Union |
5 | 5 | import gc |
6 | 6 | import glob |
7 | | -import importlib |
8 | 7 | import logging |
9 | 8 | import math |
10 | 9 | import os |
11 | 10 | import re |
12 | | -import sys |
13 | 11 |
|
14 | 12 | # Third Party |
15 | 13 | from datasets import concatenate_datasets, load_dataset |
|
20 | 18 | import torch |
21 | 19 |
|
22 | 20 | # Local |
| 21 | +from .encoders import get_encoder_class |
23 | 22 | from .utils.subset_selection_utils import ( |
24 | 23 | compute_pairwise_dense, |
25 | 24 | get_default_num_gpus, |
@@ -171,19 +170,14 @@ class DataProcessor: |
171 | 170 | Enhanced data processor with support for combined files and multiple selection methods. |
172 | 171 | """ |
173 | 172 |
|
174 | | - def __init__(self, config: ProcessingConfig, encoder_cls): |
| 173 | + def __init__(self, config: ProcessingConfig): |
175 | 174 | """ |
176 | 175 | Initializes the DataProcessor with the given configuration and encoder class. |
177 | 176 |
|
178 | 177 | Args: |
179 | 178 | config (ProcessingConfig): The processing configuration. |
180 | | - encoder_cls: The encoder class to use for generating embeddings. |
181 | 179 | """ |
182 | 180 | self.config = config |
183 | | - self.encoder = encoder_cls( |
184 | | - model_name=config.encoder.encoder_model, |
185 | | - testing_mode=config.encoder.testing_mode, |
186 | | - ) |
187 | 181 | self.env = Environment(loader=BaseLoader()) |
188 | 182 | self.templates = { |
189 | 183 | k: self.env.from_string(v) for k, v in config.template.templates.items() |
@@ -282,60 +276,6 @@ def get_subset_name(self, size_spec: Union[int, float], actual_size: int) -> str |
282 | 276 | return f"percent_{size_spec:.1f}" |
283 | 277 | return f"samples_{actual_size}" |
284 | 278 |
|
285 | | - def get_last_processed_batch(self, output_dir: str) -> Tuple[int, Optional[str]]: |
286 | | - """ |
287 | | - Retrieves the last processed batch number and its file path from the output directory. |
288 | | -
|
289 | | - Args: |
290 | | - output_dir (str): The directory where batch files are stored. |
291 | | -
|
292 | | - Returns: |
293 | | - Tuple[int, Optional[str]]: The last batch number and the corresponding batch file path. |
294 | | - """ |
295 | | - batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5")) |
296 | | - if not batch_files: |
297 | | - return -1, None |
298 | | - |
299 | | - # Sort batch files by batch number |
300 | | - batch_files.sort(key=self.extract_batch_number) |
301 | | - max_batch_file = batch_files[-1] |
302 | | - max_batch_number = self.extract_batch_number(max_batch_file) |
303 | | - |
304 | | - # Return the max batch number and the corresponding batch file path |
305 | | - return max_batch_number, max_batch_file |
306 | | - |
307 | | - @retry_on_exception |
308 | | - def process_batch(self, batch_texts: List[str], output_file: str) -> Optional[int]: |
309 | | - """ |
310 | | - Processes a batch of texts by generating embeddings and saving them to a file. |
311 | | - Returns the embedding dimension or None if no embeddings were generated. |
312 | | - """ |
313 | | - embeddings = ( |
314 | | - self.encoder.encode( |
315 | | - inputs=batch_texts, |
316 | | - instruction=self.config.encoder.instruction, |
317 | | - ) |
318 | | - .cpu() |
319 | | - .numpy() |
320 | | - ) |
321 | | - |
322 | | - if embeddings.size == 0: |
323 | | - logger.warning( |
324 | | - f"No embeddings generated for batch, skipping file {output_file}" |
325 | | - ) |
326 | | - return None |
327 | | - |
328 | | - embedding_dim = int(embeddings.shape[1]) # Cast to int |
329 | | - logger.info(f"Embedding dimension for batch: {embedding_dim}") |
330 | | - |
331 | | - with h5py.File(output_file, "w") as h5f: |
332 | | - h5f.create_dataset( |
333 | | - "embeddings", data=embeddings, dtype="float32", chunks=True |
334 | | - ) |
335 | | - h5f.flush() |
336 | | - |
337 | | - return embedding_dim |
338 | | - |
339 | 279 | @retry_on_exception |
340 | 280 | def generate_embeddings(self, dataset, output_dir: str) -> str: |
341 | 281 | """ |
@@ -405,104 +345,6 @@ def generate_embeddings(self, dataset, output_dir: str) -> str: |
405 | 345 |
|
406 | 346 | return merged_path |
407 | 347 |
|
408 | | - def extract_batch_number(self, filename): |
409 | | - """ |
410 | | - Extracts the batch number from the filename. |
411 | | - Assumes the filename is in the format 'batch_<number>.h5'. |
412 | | -
|
413 | | - Args: |
414 | | - filename (str): The filename from which to extract the batch number. |
415 | | -
|
416 | | - Returns: |
417 | | - int: The batch number extracted from the filename. |
418 | | - """ |
419 | | - basename = os.path.basename(filename) |
420 | | - match = re.search(r"batch_(\d+)\.h5$", basename) |
421 | | - if match: |
422 | | - return int(match.group(1)) |
423 | | - raise ValueError(f"Filename {filename} does not match expected pattern.") |
424 | | - |
425 | | - def get_embedding_size_dim_from_file(self, batch_file: str) -> Tuple[int, int]: |
426 | | - """ |
427 | | - Reads the batch file to determine the embedding size (number of embeddings) and dimension. |
428 | | - """ |
429 | | - with h5py.File(batch_file, "r") as h5f: |
430 | | - if "embeddings" not in h5f: |
431 | | - raise ValueError( |
432 | | - f"The file {batch_file} does not contain 'embeddings' dataset." |
433 | | - ) |
434 | | - embeddings = h5f["embeddings"] |
435 | | - embedding_size = int(embeddings.shape[0]) # Cast to int |
436 | | - embedding_dim = int(embeddings.shape[1]) # Cast to int |
437 | | - logger.info(f"Embedding dimension from {batch_file}: {embedding_dim}") |
438 | | - return embedding_size, embedding_dim |
439 | | - |
440 | | - def merge_embeddings(self, output_dir, merged_file, total_samples): |
441 | | - """ |
442 | | - Merges all batch embedding files into a single embeddings file. |
443 | | -
|
444 | | - Args: |
445 | | - output_dir (str): The directory where batch embedding files are stored. |
446 | | - merged_file (str): The path to the merged embeddings file. |
447 | | - total_samples (int): The total number of samples (embeddings). |
448 | | -
|
449 | | - """ |
450 | | - # Find all batch files |
451 | | - batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5")) |
452 | | - if not batch_files: |
453 | | - logger.warning("No batch files found to merge") |
454 | | - return |
455 | | - |
456 | | - # Sort batch files by batch number |
457 | | - batch_files.sort(key=self.extract_batch_number) |
458 | | - |
459 | | - # Retrieve embedding_dim from the first batch file |
460 | | - _, embedding_dim = self.get_embedding_size_dim_from_file(batch_files[0]) |
461 | | - |
462 | | - if os.path.exists(merged_file): |
463 | | - logger.info(f"Merged file {merged_file} already exists, skipping merge") |
464 | | - return |
465 | | - |
466 | | - logger.info( |
467 | | - f"Merging {len(batch_files)} batch files into {merged_file} with {total_samples} samples" |
468 | | - ) |
469 | | - |
470 | | - with h5py.File(merged_file, "w") as h5f_merged: |
471 | | - # Initialize the dataset in the merged file with the retrieved embedding dimension |
472 | | - embeddings_ds = h5f_merged.create_dataset( |
473 | | - "embeddings", shape=(total_samples, embedding_dim), dtype="float32" |
474 | | - ) |
475 | | - |
476 | | - start_idx = 0 |
477 | | - for batch_file in batch_files: |
478 | | - with h5py.File(batch_file, "r") as h5f_batch: |
479 | | - if "embeddings" not in h5f_batch: |
480 | | - logger.error( |
481 | | - f"File {batch_file} does not contain 'embeddings' dataset" |
482 | | - ) |
483 | | - continue |
484 | | - |
485 | | - embeddings = h5f_batch["embeddings"][:] |
486 | | - batch_size = embeddings.shape[0] |
487 | | - end_idx = start_idx + batch_size |
488 | | - |
489 | | - # Check that each file's embedding dimension matches the retrieved embedding_dim |
490 | | - if embeddings.shape[1] != embedding_dim: |
491 | | - logger.error( |
492 | | - f"Embedding dimension mismatch in {batch_file}. Expected {embedding_dim}, got {embeddings.shape[1]}" |
493 | | - ) |
494 | | - continue |
495 | | - |
496 | | - # Copy embeddings into the merged dataset |
497 | | - embeddings_ds[start_idx:end_idx] = embeddings |
498 | | - start_idx = end_idx |
499 | | - |
500 | | - # Remove the batch file after processing |
501 | | - os.remove(batch_file) |
502 | | - logger.info(f"Processed and removed {batch_file}") |
503 | | - |
504 | | - gc.collect() |
505 | | - |
506 | 348 | def select_subsets( |
507 | 349 | self, dataset_name: str, embeddings: torch.Tensor |
508 | 350 | ) -> Dict[Union[int, float], List[int]]: |
@@ -750,22 +592,7 @@ def _process_dataset_shard(args): |
750 | 592 | device = f"cuda:{gpu_id}" |
751 | 593 | logger.info(f"GPU {gpu_id} started processing {len(dataset_shard)} samples") |
752 | 594 |
|
753 | | - # Import the encoder directly using the system path |
754 | | - # Standard |
755 | | - |
756 | | - sys.path.append( |
757 | | - os.path.dirname( |
758 | | - os.path.dirname( |
759 | | - os.path.dirname(os.path.dirname(os.path.dirname(__file__))) |
760 | | - ) |
761 | | - ) |
762 | | - ) |
763 | | - |
764 | | - # Import the encoder class using string-based absolute import |
765 | | - |
766 | | - module_name = f"sdg.src.instructlab.sdg.encoders.{encoder_type}_encoder" |
767 | | - module = importlib.import_module(module_name) |
768 | | - encoder_cls = getattr(module, f"{encoder_type.capitalize()}EmbedEncoder") |
| 595 | + encoder_cls = get_encoder_class(encoder_type) |
769 | 596 |
|
770 | 597 | # Create encoder instance |
771 | 598 | encoder = encoder_cls( |
@@ -845,7 +672,7 @@ def _process_dataset_shard(args): |
845 | 672 | # pylint: disable=broad-exception-caught |
846 | 673 | except Exception as e: |
847 | 674 | logger.error(f"Error processing shard on GPU {gpu_id}: {str(e)}") |
848 | | - return None |
| 675 | + raise |
849 | 676 |
|
850 | 677 |
|
851 | 678 | def _merge_shard_files(shard_files, merged_file): |
@@ -1014,25 +841,6 @@ def get_supported_encoders(): |
1014 | 841 | ] |
1015 | 842 |
|
1016 | 843 |
|
1017 | | -def get_encoder_class(encoder_type: str): |
1018 | | - """Get the encoder class based on the encoder type.""" |
1019 | | - try: |
1020 | | - # Convert encoder_type to class name (e.g., 'arctic' -> 'ArcticEmbedEncoder') |
1021 | | - class_name = f"{encoder_type.capitalize()}EmbedEncoder" |
1022 | | - # Import the module dynamically |
1023 | | - module = __import__( |
1024 | | - f"instructlab.sdg.encoders.{encoder_type}_encoder", fromlist=[class_name] |
1025 | | - ) |
1026 | | - # Get the class from the module |
1027 | | - return getattr(module, class_name) |
1028 | | - except (ImportError, AttributeError) as e: |
1029 | | - supported_encoders = get_supported_encoders() |
1030 | | - raise ValueError( |
1031 | | - f"Unsupported encoder type: '{encoder_type}'. " |
1032 | | - f"Supported types are: {[f'{t}' for t in supported_encoders]}" |
1033 | | - ) from e |
1034 | | - |
1035 | | - |
1036 | 844 | def subset_datasets( |
1037 | 845 | input_files: List[str], |
1038 | 846 | subset_sizes: List[Union[int, float]], |
@@ -1081,9 +889,7 @@ def subset_datasets( |
1081 | 889 |
|
1082 | 890 | try: |
1083 | 891 | logger.info(f"Processing configuration: {config}") |
1084 | | - processor = DataProcessor( |
1085 | | - config, get_encoder_class(config.encoder.encoder_type) |
1086 | | - ) |
| 892 | + processor = DataProcessor(config) |
1087 | 893 | processor.process_files(input_files, config.basic.output_dir) |
1088 | 894 |
|
1089 | 895 | except Exception as e: |
|
0 commit comments