|
20 | 20 | import torch
|
21 | 21 |
|
22 | 22 | # Local
|
| 23 | +from .encoders import get_encoder_class |
23 | 24 | from .utils.subset_selection_utils import (
|
24 | 25 | compute_pairwise_dense,
|
25 | 26 | get_default_num_gpus,
|
@@ -171,19 +172,14 @@ class DataProcessor:
|
171 | 172 | Enhanced data processor with support for combined files and multiple selection methods.
|
172 | 173 | """
|
173 | 174 |
|
174 |
| - def __init__(self, config: ProcessingConfig, encoder_cls): |
| 175 | + def __init__(self, config: ProcessingConfig): |
175 | 176 | """
|
176 | 177 | Initializes the DataProcessor with the given configuration and encoder class.
|
177 | 178 |
|
178 | 179 | Args:
|
179 | 180 | config (ProcessingConfig): The processing configuration.
|
180 |
| - encoder_cls: The encoder class to use for generating embeddings. |
181 | 181 | """
|
182 | 182 | self.config = config
|
183 |
| - self.encoder = encoder_cls( |
184 |
| - model_name=config.encoder.encoder_model, |
185 |
| - testing_mode=config.encoder.testing_mode, |
186 |
| - ) |
187 | 183 | self.env = Environment(loader=BaseLoader())
|
188 | 184 | self.templates = {
|
189 | 185 | k: self.env.from_string(v) for k, v in config.template.templates.items()
|
@@ -750,22 +746,7 @@ def _process_dataset_shard(args):
|
750 | 746 | device = f"cuda:{gpu_id}"
|
751 | 747 | logger.info(f"GPU {gpu_id} started processing {len(dataset_shard)} samples")
|
752 | 748 |
|
753 |
| - # Import the encoder directly using the system path |
754 |
| - # Standard |
755 |
| - |
756 |
| - sys.path.append( |
757 |
| - os.path.dirname( |
758 |
| - os.path.dirname( |
759 |
| - os.path.dirname(os.path.dirname(os.path.dirname(__file__))) |
760 |
| - ) |
761 |
| - ) |
762 |
| - ) |
763 |
| - |
764 |
| - # Import the encoder class using string-based absolute import |
765 |
| - |
766 |
| - module_name = f"sdg.src.instructlab.sdg.encoders.{encoder_type}_encoder" |
767 |
| - module = importlib.import_module(module_name) |
768 |
| - encoder_cls = getattr(module, f"{encoder_type.capitalize()}EmbedEncoder") |
| 749 | + encoder_cls = get_encoder_class(encoder_type) |
769 | 750 |
|
770 | 751 | # Create encoder instance
|
771 | 752 | encoder = encoder_cls(
|
@@ -845,7 +826,7 @@ def _process_dataset_shard(args):
|
845 | 826 | # pylint: disable=broad-exception-caught
|
846 | 827 | except Exception as e:
|
847 | 828 | logger.error(f"Error processing shard on GPU {gpu_id}: {str(e)}")
|
848 |
| - return None |
| 829 | + raise |
849 | 830 |
|
850 | 831 |
|
851 | 832 | def _merge_shard_files(shard_files, merged_file):
|
@@ -1014,24 +995,6 @@ def get_supported_encoders():
|
1014 | 995 | ]
|
1015 | 996 |
|
1016 | 997 |
|
1017 |
| -def get_encoder_class(encoder_type: str): |
1018 |
| - """Get the encoder class based on the encoder type.""" |
1019 |
| - try: |
1020 |
| - # Convert encoder_type to class name (e.g., 'arctic' -> 'ArcticEmbedEncoder') |
1021 |
| - class_name = f"{encoder_type.capitalize()}EmbedEncoder" |
1022 |
| - # Import the module dynamically |
1023 |
| - module = __import__( |
1024 |
| - f"instructlab.sdg.encoders.{encoder_type}_encoder", fromlist=[class_name] |
1025 |
| - ) |
1026 |
| - # Get the class from the module |
1027 |
| - return getattr(module, class_name) |
1028 |
| - except (ImportError, AttributeError) as e: |
1029 |
| - supported_encoders = get_supported_encoders() |
1030 |
| - raise ValueError( |
1031 |
| - f"Unsupported encoder type: '{encoder_type}'. " |
1032 |
| - f"Supported types are: {[f'{t}' for t in supported_encoders]}" |
1033 |
| - ) from e |
1034 |
| - |
1035 | 998 |
|
1036 | 999 | def subset_datasets(
|
1037 | 1000 | input_files: List[str],
|
@@ -1081,9 +1044,7 @@ def subset_datasets(
|
1081 | 1044 |
|
1082 | 1045 | try:
|
1083 | 1046 | logger.info(f"Processing configuration: {config}")
|
1084 |
| - processor = DataProcessor( |
1085 |
| - config, get_encoder_class(config.encoder.encoder_type) |
1086 |
| - ) |
| 1047 | + processor = DataProcessor(config) |
1087 | 1048 | processor.process_files(input_files, config.basic.output_dir)
|
1088 | 1049 |
|
1089 | 1050 | except Exception as e:
|
|
0 commit comments