Skip to content

Commit c5522bc

Browse files
Merge pull request #610 from instructlab/mergify/bp/release-v0.8/pr-608
Fix subset selection dynamic imports (backport #608)
2 parents 401ca6c + cba23a9 commit c5522bc

File tree

3 files changed

+25
-232
lines changed

3 files changed

+25
-232
lines changed
+17-14
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1-
# Standard
2-
import importlib
1+
# Import all encoder classes directly
2+
# Local
3+
from .arctic_encoder import ArcticEmbedEncoder
4+
5+
# Create a mapping of encoder types to their classes
6+
ENCODER_REGISTRY = {
7+
"arctic": ArcticEmbedEncoder,
8+
}
39

410

511
def get_encoder_class(encoder_type: str):
612
"""Get the encoder class based on the encoder type."""
713
try:
8-
# Convert encoder_type to class name (e.g., 'arctic' -> 'ArcticEmbedEncoder')
9-
class_name = f"{encoder_type.capitalize()}EmbedEncoder"
10-
11-
# Use absolute import instead of relative
12-
module_name = f"sdg.src.instructlab.sdg.encoders.{encoder_type}_encoder"
13-
14-
module = importlib.import_module(module_name)
15-
16-
# Get the class from the module
17-
return getattr(module, class_name)
18-
except (ImportError, AttributeError) as e:
19-
raise ValueError(f"Unsupported encoder type: '{encoder_type}'") from e
14+
if encoder_type not in ENCODER_REGISTRY:
15+
supported_encoders = list(ENCODER_REGISTRY.keys())
16+
raise ValueError(
17+
f"Unsupported encoder type: '{encoder_type}'. "
18+
f"Supported types are: {supported_encoders}"
19+
)
20+
return ENCODER_REGISTRY[encoder_type]
21+
except Exception as e:
22+
raise ValueError(f"Error getting encoder class: {str(e)}") from e

src/instructlab/sdg/subset_selection.py

+6-200
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# Standard
22
from dataclasses import dataclass, field
33
from multiprocessing import Pool
4-
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
4+
from typing import Any, Dict, List, TypedDict, TypeVar, Union
55
import gc
66
import glob
7-
import importlib
87
import logging
98
import math
109
import os
1110
import re
12-
import sys
1311

1412
# Third Party
1513
from datasets import concatenate_datasets, load_dataset
@@ -20,6 +18,7 @@
2018
import torch
2119

2220
# Local
21+
from .encoders import get_encoder_class
2322
from .utils.subset_selection_utils import (
2423
compute_pairwise_dense,
2524
get_default_num_gpus,
@@ -171,19 +170,14 @@ class DataProcessor:
171170
Enhanced data processor with support for combined files and multiple selection methods.
172171
"""
173172

174-
def __init__(self, config: ProcessingConfig, encoder_cls):
173+
def __init__(self, config: ProcessingConfig):
175174
"""
176175
Initializes the DataProcessor with the given configuration and encoder class.
177176
178177
Args:
179178
config (ProcessingConfig): The processing configuration.
180-
encoder_cls: The encoder class to use for generating embeddings.
181179
"""
182180
self.config = config
183-
self.encoder = encoder_cls(
184-
model_name=config.encoder.encoder_model,
185-
testing_mode=config.encoder.testing_mode,
186-
)
187181
self.env = Environment(loader=BaseLoader())
188182
self.templates = {
189183
k: self.env.from_string(v) for k, v in config.template.templates.items()
@@ -282,60 +276,6 @@ def get_subset_name(self, size_spec: Union[int, float], actual_size: int) -> str
282276
return f"percent_{size_spec:.1f}"
283277
return f"samples_{actual_size}"
284278

285-
def get_last_processed_batch(self, output_dir: str) -> Tuple[int, Optional[str]]:
286-
"""
287-
Retrieves the last processed batch number and its file path from the output directory.
288-
289-
Args:
290-
output_dir (str): The directory where batch files are stored.
291-
292-
Returns:
293-
Tuple[int, Optional[str]]: The last batch number and the corresponding batch file path.
294-
"""
295-
batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5"))
296-
if not batch_files:
297-
return -1, None
298-
299-
# Sort batch files by batch number
300-
batch_files.sort(key=self.extract_batch_number)
301-
max_batch_file = batch_files[-1]
302-
max_batch_number = self.extract_batch_number(max_batch_file)
303-
304-
# Return the max batch number and the corresponding batch file path
305-
return max_batch_number, max_batch_file
306-
307-
@retry_on_exception
308-
def process_batch(self, batch_texts: List[str], output_file: str) -> Optional[int]:
309-
"""
310-
Processes a batch of texts by generating embeddings and saving them to a file.
311-
Returns the embedding dimension or None if no embeddings were generated.
312-
"""
313-
embeddings = (
314-
self.encoder.encode(
315-
inputs=batch_texts,
316-
instruction=self.config.encoder.instruction,
317-
)
318-
.cpu()
319-
.numpy()
320-
)
321-
322-
if embeddings.size == 0:
323-
logger.warning(
324-
f"No embeddings generated for batch, skipping file {output_file}"
325-
)
326-
return None
327-
328-
embedding_dim = int(embeddings.shape[1]) # Cast to int
329-
logger.info(f"Embedding dimension for batch: {embedding_dim}")
330-
331-
with h5py.File(output_file, "w") as h5f:
332-
h5f.create_dataset(
333-
"embeddings", data=embeddings, dtype="float32", chunks=True
334-
)
335-
h5f.flush()
336-
337-
return embedding_dim
338-
339279
@retry_on_exception
340280
def generate_embeddings(self, dataset, output_dir: str) -> str:
341281
"""
@@ -405,104 +345,6 @@ def generate_embeddings(self, dataset, output_dir: str) -> str:
405345

406346
return merged_path
407347

408-
def extract_batch_number(self, filename):
409-
"""
410-
Extracts the batch number from the filename.
411-
Assumes the filename is in the format 'batch_<number>.h5'.
412-
413-
Args:
414-
filename (str): The filename from which to extract the batch number.
415-
416-
Returns:
417-
int: The batch number extracted from the filename.
418-
"""
419-
basename = os.path.basename(filename)
420-
match = re.search(r"batch_(\d+)\.h5$", basename)
421-
if match:
422-
return int(match.group(1))
423-
raise ValueError(f"Filename {filename} does not match expected pattern.")
424-
425-
def get_embedding_size_dim_from_file(self, batch_file: str) -> Tuple[int, int]:
426-
"""
427-
Reads the batch file to determine the embedding size (number of embeddings) and dimension.
428-
"""
429-
with h5py.File(batch_file, "r") as h5f:
430-
if "embeddings" not in h5f:
431-
raise ValueError(
432-
f"The file {batch_file} does not contain 'embeddings' dataset."
433-
)
434-
embeddings = h5f["embeddings"]
435-
embedding_size = int(embeddings.shape[0]) # Cast to int
436-
embedding_dim = int(embeddings.shape[1]) # Cast to int
437-
logger.info(f"Embedding dimension from {batch_file}: {embedding_dim}")
438-
return embedding_size, embedding_dim
439-
440-
def merge_embeddings(self, output_dir, merged_file, total_samples):
441-
"""
442-
Merges all batch embedding files into a single embeddings file.
443-
444-
Args:
445-
output_dir (str): The directory where batch embedding files are stored.
446-
merged_file (str): The path to the merged embeddings file.
447-
total_samples (int): The total number of samples (embeddings).
448-
449-
"""
450-
# Find all batch files
451-
batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5"))
452-
if not batch_files:
453-
logger.warning("No batch files found to merge")
454-
return
455-
456-
# Sort batch files by batch number
457-
batch_files.sort(key=self.extract_batch_number)
458-
459-
# Retrieve embedding_dim from the first batch file
460-
_, embedding_dim = self.get_embedding_size_dim_from_file(batch_files[0])
461-
462-
if os.path.exists(merged_file):
463-
logger.info(f"Merged file {merged_file} already exists, skipping merge")
464-
return
465-
466-
logger.info(
467-
f"Merging {len(batch_files)} batch files into {merged_file} with {total_samples} samples"
468-
)
469-
470-
with h5py.File(merged_file, "w") as h5f_merged:
471-
# Initialize the dataset in the merged file with the retrieved embedding dimension
472-
embeddings_ds = h5f_merged.create_dataset(
473-
"embeddings", shape=(total_samples, embedding_dim), dtype="float32"
474-
)
475-
476-
start_idx = 0
477-
for batch_file in batch_files:
478-
with h5py.File(batch_file, "r") as h5f_batch:
479-
if "embeddings" not in h5f_batch:
480-
logger.error(
481-
f"File {batch_file} does not contain 'embeddings' dataset"
482-
)
483-
continue
484-
485-
embeddings = h5f_batch["embeddings"][:]
486-
batch_size = embeddings.shape[0]
487-
end_idx = start_idx + batch_size
488-
489-
# Check that each file's embedding dimension matches the retrieved embedding_dim
490-
if embeddings.shape[1] != embedding_dim:
491-
logger.error(
492-
f"Embedding dimension mismatch in {batch_file}. Expected {embedding_dim}, got {embeddings.shape[1]}"
493-
)
494-
continue
495-
496-
# Copy embeddings into the merged dataset
497-
embeddings_ds[start_idx:end_idx] = embeddings
498-
start_idx = end_idx
499-
500-
# Remove the batch file after processing
501-
os.remove(batch_file)
502-
logger.info(f"Processed and removed {batch_file}")
503-
504-
gc.collect()
505-
506348
def select_subsets(
507349
self, dataset_name: str, embeddings: torch.Tensor
508350
) -> Dict[Union[int, float], List[int]]:
@@ -750,22 +592,7 @@ def _process_dataset_shard(args):
750592
device = f"cuda:{gpu_id}"
751593
logger.info(f"GPU {gpu_id} started processing {len(dataset_shard)} samples")
752594

753-
# Import the encoder directly using the system path
754-
# Standard
755-
756-
sys.path.append(
757-
os.path.dirname(
758-
os.path.dirname(
759-
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
760-
)
761-
)
762-
)
763-
764-
# Import the encoder class using string-based absolute import
765-
766-
module_name = f"sdg.src.instructlab.sdg.encoders.{encoder_type}_encoder"
767-
module = importlib.import_module(module_name)
768-
encoder_cls = getattr(module, f"{encoder_type.capitalize()}EmbedEncoder")
595+
encoder_cls = get_encoder_class(encoder_type)
769596

770597
# Create encoder instance
771598
encoder = encoder_cls(
@@ -845,7 +672,7 @@ def _process_dataset_shard(args):
845672
# pylint: disable=broad-exception-caught
846673
except Exception as e:
847674
logger.error(f"Error processing shard on GPU {gpu_id}: {str(e)}")
848-
return None
675+
raise
849676

850677

851678
def _merge_shard_files(shard_files, merged_file):
@@ -1014,25 +841,6 @@ def get_supported_encoders():
1014841
]
1015842

1016843

1017-
def get_encoder_class(encoder_type: str):
1018-
"""Get the encoder class based on the encoder type."""
1019-
try:
1020-
# Convert encoder_type to class name (e.g., 'arctic' -> 'ArcticEmbedEncoder')
1021-
class_name = f"{encoder_type.capitalize()}EmbedEncoder"
1022-
# Import the module dynamically
1023-
module = __import__(
1024-
f"instructlab.sdg.encoders.{encoder_type}_encoder", fromlist=[class_name]
1025-
)
1026-
# Get the class from the module
1027-
return getattr(module, class_name)
1028-
except (ImportError, AttributeError) as e:
1029-
supported_encoders = get_supported_encoders()
1030-
raise ValueError(
1031-
f"Unsupported encoder type: '{encoder_type}'. "
1032-
f"Supported types are: {[f'{t}' for t in supported_encoders]}"
1033-
) from e
1034-
1035-
1036844
def subset_datasets(
1037845
input_files: List[str],
1038846
subset_sizes: List[Union[int, float]],
@@ -1081,9 +889,7 @@ def subset_datasets(
1081889

1082890
try:
1083891
logger.info(f"Processing configuration: {config}")
1084-
processor = DataProcessor(
1085-
config, get_encoder_class(config.encoder.encoder_type)
1086-
)
892+
processor = DataProcessor(config)
1087893
processor.process_files(input_files, config.basic.output_dir)
1088894

1089895
except Exception as e:

tests/test_subset_selection.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def data_processor(mock_encoder, mock_gpu_environment):
5151
input_files=["test.jsonl"],
5252
subset_sizes=[10, 20.5],
5353
)
54-
return DataProcessor(config, mock_encoder)
54+
return DataProcessor(config)
5555

5656

5757
def test_format_text(data_processor):
@@ -124,22 +124,6 @@ def test_invalid_subset_sizes(mock_gpu_environment):
124124
)
125125

126126

127-
def test_process_batch(mock_gpu_environment, data_processor, tmp_path):
128-
"""Test batch processing of texts"""
129-
130-
batch_texts = ["text1", "text2", "text3"]
131-
output_file = str(tmp_path / "test_batch.h5")
132-
133-
embedding_dim = data_processor.process_batch(batch_texts, output_file)
134-
135-
assert embedding_dim is not None
136-
assert os.path.exists(output_file)
137-
138-
with h5py.File(output_file, "r") as f:
139-
embeddings = f["embeddings"][:]
140-
assert embeddings.shape == (3, embedding_dim)
141-
142-
143127
def test_generate_embeddings_parallel(mock_gpu_environment, tmp_path, mock_encoder):
144128
"""Test the parallelized embedding generation feature."""
145129
# Create a sample dataset
@@ -165,7 +149,7 @@ def test_generate_embeddings_parallel(mock_gpu_environment, tmp_path, mock_encod
165149
config.system.num_gpus = 2
166150

167151
# Create processor
168-
processor = DataProcessor(config, mock_encoder)
152+
processor = DataProcessor(config)
169153

170154
# Test case 1: File exists, should return early
171155
result_path = processor.generate_embeddings(dataset, output_dir)

0 commit comments

Comments
 (0)