|
1 | 1 | # Standard
|
2 | 2 | from dataclasses import dataclass, field
|
3 | 3 | from multiprocessing import Pool
|
4 |
| -from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union |
| 4 | +from typing import Any, Dict, List, TypedDict, TypeVar, Union |
5 | 5 | import gc
|
6 | 6 | import glob
|
7 |
| -import importlib |
8 | 7 | import logging
|
9 | 8 | import math
|
10 | 9 | import os
|
11 | 10 | import re
|
12 |
| -import sys |
13 | 11 |
|
14 | 12 | # Third Party
|
15 | 13 | from datasets import concatenate_datasets, load_dataset
|
|
20 | 18 | import torch
|
21 | 19 |
|
22 | 20 | # Local
|
| 21 | +from .encoders import get_encoder_class |
23 | 22 | from .utils.subset_selection_utils import (
|
24 | 23 | compute_pairwise_dense,
|
25 | 24 | get_default_num_gpus,
|
@@ -171,19 +170,14 @@ class DataProcessor:
|
171 | 170 | Enhanced data processor with support for combined files and multiple selection methods.
|
172 | 171 | """
|
173 | 172 |
|
174 |
| - def __init__(self, config: ProcessingConfig, encoder_cls): |
| 173 | + def __init__(self, config: ProcessingConfig): |
175 | 174 | """
|
176 | 175 | Initializes the DataProcessor with the given configuration and encoder class.
|
177 | 176 |
|
178 | 177 | Args:
|
179 | 178 | config (ProcessingConfig): The processing configuration.
|
180 |
| - encoder_cls: The encoder class to use for generating embeddings. |
181 | 179 | """
|
182 | 180 | self.config = config
|
183 |
| - self.encoder = encoder_cls( |
184 |
| - model_name=config.encoder.encoder_model, |
185 |
| - testing_mode=config.encoder.testing_mode, |
186 |
| - ) |
187 | 181 | self.env = Environment(loader=BaseLoader())
|
188 | 182 | self.templates = {
|
189 | 183 | k: self.env.from_string(v) for k, v in config.template.templates.items()
|
@@ -282,60 +276,6 @@ def get_subset_name(self, size_spec: Union[int, float], actual_size: int) -> str
|
282 | 276 | return f"percent_{size_spec:.1f}"
|
283 | 277 | return f"samples_{actual_size}"
|
284 | 278 |
|
285 |
| - def get_last_processed_batch(self, output_dir: str) -> Tuple[int, Optional[str]]: |
286 |
| - """ |
287 |
| - Retrieves the last processed batch number and its file path from the output directory. |
288 |
| -
|
289 |
| - Args: |
290 |
| - output_dir (str): The directory where batch files are stored. |
291 |
| -
|
292 |
| - Returns: |
293 |
| - Tuple[int, Optional[str]]: The last batch number and the corresponding batch file path. |
294 |
| - """ |
295 |
| - batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5")) |
296 |
| - if not batch_files: |
297 |
| - return -1, None |
298 |
| - |
299 |
| - # Sort batch files by batch number |
300 |
| - batch_files.sort(key=self.extract_batch_number) |
301 |
| - max_batch_file = batch_files[-1] |
302 |
| - max_batch_number = self.extract_batch_number(max_batch_file) |
303 |
| - |
304 |
| - # Return the max batch number and the corresponding batch file path |
305 |
| - return max_batch_number, max_batch_file |
306 |
| - |
307 |
| - @retry_on_exception |
308 |
| - def process_batch(self, batch_texts: List[str], output_file: str) -> Optional[int]: |
309 |
| - """ |
310 |
| - Processes a batch of texts by generating embeddings and saving them to a file. |
311 |
| - Returns the embedding dimension or None if no embeddings were generated. |
312 |
| - """ |
313 |
| - embeddings = ( |
314 |
| - self.encoder.encode( |
315 |
| - inputs=batch_texts, |
316 |
| - instruction=self.config.encoder.instruction, |
317 |
| - ) |
318 |
| - .cpu() |
319 |
| - .numpy() |
320 |
| - ) |
321 |
| - |
322 |
| - if embeddings.size == 0: |
323 |
| - logger.warning( |
324 |
| - f"No embeddings generated for batch, skipping file {output_file}" |
325 |
| - ) |
326 |
| - return None |
327 |
| - |
328 |
| - embedding_dim = int(embeddings.shape[1]) # Cast to int |
329 |
| - logger.info(f"Embedding dimension for batch: {embedding_dim}") |
330 |
| - |
331 |
| - with h5py.File(output_file, "w") as h5f: |
332 |
| - h5f.create_dataset( |
333 |
| - "embeddings", data=embeddings, dtype="float32", chunks=True |
334 |
| - ) |
335 |
| - h5f.flush() |
336 |
| - |
337 |
| - return embedding_dim |
338 |
| - |
339 | 279 | @retry_on_exception
|
340 | 280 | def generate_embeddings(self, dataset, output_dir: str) -> str:
|
341 | 281 | """
|
@@ -405,104 +345,6 @@ def generate_embeddings(self, dataset, output_dir: str) -> str:
|
405 | 345 |
|
406 | 346 | return merged_path
|
407 | 347 |
|
408 |
| - def extract_batch_number(self, filename): |
409 |
| - """ |
410 |
| - Extracts the batch number from the filename. |
411 |
| - Assumes the filename is in the format 'batch_<number>.h5'. |
412 |
| -
|
413 |
| - Args: |
414 |
| - filename (str): The filename from which to extract the batch number. |
415 |
| -
|
416 |
| - Returns: |
417 |
| - int: The batch number extracted from the filename. |
418 |
| - """ |
419 |
| - basename = os.path.basename(filename) |
420 |
| - match = re.search(r"batch_(\d+)\.h5$", basename) |
421 |
| - if match: |
422 |
| - return int(match.group(1)) |
423 |
| - raise ValueError(f"Filename {filename} does not match expected pattern.") |
424 |
| - |
425 |
| - def get_embedding_size_dim_from_file(self, batch_file: str) -> Tuple[int, int]: |
426 |
| - """ |
427 |
| - Reads the batch file to determine the embedding size (number of embeddings) and dimension. |
428 |
| - """ |
429 |
| - with h5py.File(batch_file, "r") as h5f: |
430 |
| - if "embeddings" not in h5f: |
431 |
| - raise ValueError( |
432 |
| - f"The file {batch_file} does not contain 'embeddings' dataset." |
433 |
| - ) |
434 |
| - embeddings = h5f["embeddings"] |
435 |
| - embedding_size = int(embeddings.shape[0]) # Cast to int |
436 |
| - embedding_dim = int(embeddings.shape[1]) # Cast to int |
437 |
| - logger.info(f"Embedding dimension from {batch_file}: {embedding_dim}") |
438 |
| - return embedding_size, embedding_dim |
439 |
| - |
440 |
| - def merge_embeddings(self, output_dir, merged_file, total_samples): |
441 |
| - """ |
442 |
| - Merges all batch embedding files into a single embeddings file. |
443 |
| -
|
444 |
| - Args: |
445 |
| - output_dir (str): The directory where batch embedding files are stored. |
446 |
| - merged_file (str): The path to the merged embeddings file. |
447 |
| - total_samples (int): The total number of samples (embeddings). |
448 |
| -
|
449 |
| - """ |
450 |
| - # Find all batch files |
451 |
| - batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5")) |
452 |
| - if not batch_files: |
453 |
| - logger.warning("No batch files found to merge") |
454 |
| - return |
455 |
| - |
456 |
| - # Sort batch files by batch number |
457 |
| - batch_files.sort(key=self.extract_batch_number) |
458 |
| - |
459 |
| - # Retrieve embedding_dim from the first batch file |
460 |
| - _, embedding_dim = self.get_embedding_size_dim_from_file(batch_files[0]) |
461 |
| - |
462 |
| - if os.path.exists(merged_file): |
463 |
| - logger.info(f"Merged file {merged_file} already exists, skipping merge") |
464 |
| - return |
465 |
| - |
466 |
| - logger.info( |
467 |
| - f"Merging {len(batch_files)} batch files into {merged_file} with {total_samples} samples" |
468 |
| - ) |
469 |
| - |
470 |
| - with h5py.File(merged_file, "w") as h5f_merged: |
471 |
| - # Initialize the dataset in the merged file with the retrieved embedding dimension |
472 |
| - embeddings_ds = h5f_merged.create_dataset( |
473 |
| - "embeddings", shape=(total_samples, embedding_dim), dtype="float32" |
474 |
| - ) |
475 |
| - |
476 |
| - start_idx = 0 |
477 |
| - for batch_file in batch_files: |
478 |
| - with h5py.File(batch_file, "r") as h5f_batch: |
479 |
| - if "embeddings" not in h5f_batch: |
480 |
| - logger.error( |
481 |
| - f"File {batch_file} does not contain 'embeddings' dataset" |
482 |
| - ) |
483 |
| - continue |
484 |
| - |
485 |
| - embeddings = h5f_batch["embeddings"][:] |
486 |
| - batch_size = embeddings.shape[0] |
487 |
| - end_idx = start_idx + batch_size |
488 |
| - |
489 |
| - # Check that each file's embedding dimension matches the retrieved embedding_dim |
490 |
| - if embeddings.shape[1] != embedding_dim: |
491 |
| - logger.error( |
492 |
| - f"Embedding dimension mismatch in {batch_file}. Expected {embedding_dim}, got {embeddings.shape[1]}" |
493 |
| - ) |
494 |
| - continue |
495 |
| - |
496 |
| - # Copy embeddings into the merged dataset |
497 |
| - embeddings_ds[start_idx:end_idx] = embeddings |
498 |
| - start_idx = end_idx |
499 |
| - |
500 |
| - # Remove the batch file after processing |
501 |
| - os.remove(batch_file) |
502 |
| - logger.info(f"Processed and removed {batch_file}") |
503 |
| - |
504 |
| - gc.collect() |
505 |
| - |
506 | 348 | def select_subsets(
|
507 | 349 | self, dataset_name: str, embeddings: torch.Tensor
|
508 | 350 | ) -> Dict[Union[int, float], List[int]]:
|
@@ -750,22 +592,7 @@ def _process_dataset_shard(args):
|
750 | 592 | device = f"cuda:{gpu_id}"
|
751 | 593 | logger.info(f"GPU {gpu_id} started processing {len(dataset_shard)} samples")
|
752 | 594 |
|
753 |
| - # Import the encoder directly using the system path |
754 |
| - # Standard |
755 |
| - |
756 |
| - sys.path.append( |
757 |
| - os.path.dirname( |
758 |
| - os.path.dirname( |
759 |
| - os.path.dirname(os.path.dirname(os.path.dirname(__file__))) |
760 |
| - ) |
761 |
| - ) |
762 |
| - ) |
763 |
| - |
764 |
| - # Import the encoder class using string-based absolute import |
765 |
| - |
766 |
| - module_name = f"sdg.src.instructlab.sdg.encoders.{encoder_type}_encoder" |
767 |
| - module = importlib.import_module(module_name) |
768 |
| - encoder_cls = getattr(module, f"{encoder_type.capitalize()}EmbedEncoder") |
| 595 | + encoder_cls = get_encoder_class(encoder_type) |
769 | 596 |
|
770 | 597 | # Create encoder instance
|
771 | 598 | encoder = encoder_cls(
|
@@ -845,7 +672,7 @@ def _process_dataset_shard(args):
|
845 | 672 | # pylint: disable=broad-exception-caught
|
846 | 673 | except Exception as e:
|
847 | 674 | logger.error(f"Error processing shard on GPU {gpu_id}: {str(e)}")
|
848 |
| - return None |
| 675 | + raise |
849 | 676 |
|
850 | 677 |
|
851 | 678 | def _merge_shard_files(shard_files, merged_file):
|
@@ -1014,25 +841,6 @@ def get_supported_encoders():
|
1014 | 841 | ]
|
1015 | 842 |
|
1016 | 843 |
|
1017 |
| -def get_encoder_class(encoder_type: str): |
1018 |
| - """Get the encoder class based on the encoder type.""" |
1019 |
| - try: |
1020 |
| - # Convert encoder_type to class name (e.g., 'arctic' -> 'ArcticEmbedEncoder') |
1021 |
| - class_name = f"{encoder_type.capitalize()}EmbedEncoder" |
1022 |
| - # Import the module dynamically |
1023 |
| - module = __import__( |
1024 |
| - f"instructlab.sdg.encoders.{encoder_type}_encoder", fromlist=[class_name] |
1025 |
| - ) |
1026 |
| - # Get the class from the module |
1027 |
| - return getattr(module, class_name) |
1028 |
| - except (ImportError, AttributeError) as e: |
1029 |
| - supported_encoders = get_supported_encoders() |
1030 |
| - raise ValueError( |
1031 |
| - f"Unsupported encoder type: '{encoder_type}'. " |
1032 |
| - f"Supported types are: {[f'{t}' for t in supported_encoders]}" |
1033 |
| - ) from e |
1034 |
| - |
1035 |
| - |
1036 | 844 | def subset_datasets(
|
1037 | 845 | input_files: List[str],
|
1038 | 846 | subset_sizes: List[Union[int, float]],
|
@@ -1081,9 +889,7 @@ def subset_datasets(
|
1081 | 889 |
|
1082 | 890 | try:
|
1083 | 891 | logger.info(f"Processing configuration: {config}")
|
1084 |
| - processor = DataProcessor( |
1085 |
| - config, get_encoder_class(config.encoder.encoder_type) |
1086 |
| - ) |
| 892 | + processor = DataProcessor(config) |
1087 | 893 | processor.process_files(input_files, config.basic.output_dir)
|
1088 | 894 |
|
1089 | 895 | except Exception as e:
|
|
0 commit comments