Skip to content

Commit 3993541

Browse files
committed
refactor: Refactor DataProcessor to use centralized encoder class retrieval
- Removed the encoder class initialization from the DataProcessor constructor. - Introduced a centralized method for obtaining encoder classes in the encoders module. - Updated relevant tests to reflect the changes in DataProcessor initialization. - Cleaned up unused imports and code related to dynamic encoder imports. Signed-off-by: eshwarprasadS <[email protected]>
1 parent 19beb30 commit 3993541

File tree

3 files changed

+24
-77
lines changed

3 files changed

+24
-77
lines changed
Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1-
# Standard
2-
import importlib
1+
# Import all encoder classes directly
2+
# Local
3+
from .arctic_encoder import ArcticEmbedEncoder
4+
5+
# Create a mapping of encoder types to their classes
6+
ENCODER_REGISTRY = {
7+
"arctic": ArcticEmbedEncoder,
8+
}
39

410

511
def get_encoder_class(encoder_type: str):
612
"""Get the encoder class based on the encoder type."""
713
try:
8-
# Convert encoder_type to class name (e.g., 'arctic' -> 'ArcticEmbedEncoder')
9-
class_name = f"{encoder_type.capitalize()}EmbedEncoder"
10-
11-
# Use absolute import instead of relative
12-
module_name = f"sdg.src.instructlab.sdg.encoders.{encoder_type}_encoder"
13-
14-
module = importlib.import_module(module_name)
15-
16-
# Get the class from the module
17-
return getattr(module, class_name)
18-
except (ImportError, AttributeError) as e:
19-
raise ValueError(f"Unsupported encoder type: '{encoder_type}'") from e
14+
if encoder_type not in ENCODER_REGISTRY:
15+
supported_encoders = list(ENCODER_REGISTRY.keys())
16+
raise ValueError(
17+
f"Unsupported encoder type: '{encoder_type}'. "
18+
f"Supported types are: {supported_encoders}"
19+
)
20+
return ENCODER_REGISTRY[encoder_type]
21+
except Exception as e:
22+
raise ValueError(f"Error getting encoder class: {str(e)}") from e

src/instructlab/sdg/subset_selection.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121

2222
# Local
23+
from .encoders import get_encoder_class
2324
from .utils.subset_selection_utils import (
2425
compute_pairwise_dense,
2526
get_default_num_gpus,
@@ -171,19 +172,14 @@ class DataProcessor:
171172
Enhanced data processor with support for combined files and multiple selection methods.
172173
"""
173174

174-
def __init__(self, config: ProcessingConfig, encoder_cls):
175+
def __init__(self, config: ProcessingConfig):
175176
"""
176177
Initializes the DataProcessor with the given configuration and encoder class.
177178
178179
Args:
179180
config (ProcessingConfig): The processing configuration.
180-
encoder_cls: The encoder class to use for generating embeddings.
181181
"""
182182
self.config = config
183-
self.encoder = encoder_cls(
184-
model_name=config.encoder.encoder_model,
185-
testing_mode=config.encoder.testing_mode,
186-
)
187183
self.env = Environment(loader=BaseLoader())
188184
self.templates = {
189185
k: self.env.from_string(v) for k, v in config.template.templates.items()
@@ -750,22 +746,7 @@ def _process_dataset_shard(args):
750746
device = f"cuda:{gpu_id}"
751747
logger.info(f"GPU {gpu_id} started processing {len(dataset_shard)} samples")
752748

753-
# Import the encoder directly using the system path
754-
# Standard
755-
756-
sys.path.append(
757-
os.path.dirname(
758-
os.path.dirname(
759-
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
760-
)
761-
)
762-
)
763-
764-
# Import the encoder class using string-based absolute import
765-
766-
module_name = f"sdg.src.instructlab.sdg.encoders.{encoder_type}_encoder"
767-
module = importlib.import_module(module_name)
768-
encoder_cls = getattr(module, f"{encoder_type.capitalize()}EmbedEncoder")
749+
encoder_cls = get_encoder_class(encoder_type)
769750

770751
# Create encoder instance
771752
encoder = encoder_cls(
@@ -845,7 +826,7 @@ def _process_dataset_shard(args):
845826
# pylint: disable=broad-exception-caught
846827
except Exception as e:
847828
logger.error(f"Error processing shard on GPU {gpu_id}: {str(e)}")
848-
return None
829+
raise
849830

850831

851832
def _merge_shard_files(shard_files, merged_file):
@@ -1014,24 +995,6 @@ def get_supported_encoders():
1014995
]
1015996

1016997

1017-
def get_encoder_class(encoder_type: str):
1018-
"""Get the encoder class based on the encoder type."""
1019-
try:
1020-
# Convert encoder_type to class name (e.g., 'arctic' -> 'ArcticEmbedEncoder')
1021-
class_name = f"{encoder_type.capitalize()}EmbedEncoder"
1022-
# Import the module dynamically
1023-
module = __import__(
1024-
f"instructlab.sdg.encoders.{encoder_type}_encoder", fromlist=[class_name]
1025-
)
1026-
# Get the class from the module
1027-
return getattr(module, class_name)
1028-
except (ImportError, AttributeError) as e:
1029-
supported_encoders = get_supported_encoders()
1030-
raise ValueError(
1031-
f"Unsupported encoder type: '{encoder_type}'. "
1032-
f"Supported types are: {[f'{t}' for t in supported_encoders]}"
1033-
) from e
1034-
1035998

1036999
def subset_datasets(
10371000
input_files: List[str],
@@ -1081,9 +1044,7 @@ def subset_datasets(
10811044

10821045
try:
10831046
logger.info(f"Processing configuration: {config}")
1084-
processor = DataProcessor(
1085-
config, get_encoder_class(config.encoder.encoder_type)
1086-
)
1047+
processor = DataProcessor(config)
10871048
processor.process_files(input_files, config.basic.output_dir)
10881049

10891050
except Exception as e:

tests/test_subset_selection.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def data_processor(mock_encoder, mock_gpu_environment):
5151
input_files=["test.jsonl"],
5252
subset_sizes=[10, 20.5],
5353
)
54-
return DataProcessor(config, mock_encoder)
54+
return DataProcessor(config)
5555

5656

5757
def test_format_text(data_processor):
@@ -123,23 +123,6 @@ def test_invalid_subset_sizes(mock_gpu_environment):
123123
subset_sizes=[-10],
124124
)
125125

126-
127-
def test_process_batch(mock_gpu_environment, data_processor, tmp_path):
128-
"""Test batch processing of texts"""
129-
130-
batch_texts = ["text1", "text2", "text3"]
131-
output_file = str(tmp_path / "test_batch.h5")
132-
133-
embedding_dim = data_processor.process_batch(batch_texts, output_file)
134-
135-
assert embedding_dim is not None
136-
assert os.path.exists(output_file)
137-
138-
with h5py.File(output_file, "r") as f:
139-
embeddings = f["embeddings"][:]
140-
assert embeddings.shape == (3, embedding_dim)
141-
142-
143126
def test_generate_embeddings_parallel(mock_gpu_environment, tmp_path, mock_encoder):
144127
"""Test the parallelized embedding generation feature."""
145128
# Create a sample dataset
@@ -165,7 +148,7 @@ def test_generate_embeddings_parallel(mock_gpu_environment, tmp_path, mock_encod
165148
config.system.num_gpus = 2
166149

167150
# Create processor
168-
processor = DataProcessor(config, mock_encoder)
151+
processor = DataProcessor(config)
169152

170153
# Test case 1: File exists, should return early
171154
result_path = processor.generate_embeddings(dataset, output_dir)

0 commit comments

Comments
 (0)