Skip to content

Commit b5e2ae6

Browse files
authored
Merge pull request #542 from eshwarprasadS/subset-selection-integration
Subset Selection Integration
2 parents cf434b3 + 4b77d1f commit b5e2ae6

File tree

8 files changed

+2169
-0
lines changed

8 files changed

+2169
-0
lines changed

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ docling[tesserocr]>=2.28.4; sys_platform != 'darwin'
66
docling>=2.28.4; sys_platform == 'darwin'
77
GitPython>=3.1.42,<4.0.0
88
gguf>=0.6.0
9+
h5py>=3.12.1
910
httpx>=0.25.0,<1.0.0
1011
instructlab-schema>=0.4.0
1112
jinja2>=3.0.0
1213
langchain-text-splitters
1314
openai>=1.13.3,<2.0.0
15+
numba
1416
sentencepiece>=0.2.0
17+
# Note: this dependency has to be built from source
18+
submodlib-py==0.0.1; sys_platform == 'linux'
1519
tabulate>=0.9.0
1620

1721
# Note: this dependency goes along with langchain-text-splitters and may be
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Standard
2+
import importlib
3+
4+
5+
def get_encoder_class(encoder_type: str):
6+
"""Get the encoder class based on the encoder type."""
7+
try:
8+
# Convert encoder_type to class name (e.g., 'arctic' -> 'ArcticEmbedEncoder')
9+
class_name = f"{encoder_type.capitalize()}EmbedEncoder"
10+
11+
# Use absolute import instead of relative
12+
module_name = f"sdg.src.instructlab.sdg.encoders.{encoder_type}_encoder"
13+
14+
module = importlib.import_module(module_name)
15+
16+
# Get the class from the module
17+
return getattr(module, class_name)
18+
except (ImportError, AttributeError) as e:
19+
raise ValueError(f"Unsupported encoder type: '{encoder_type}'") from e
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Standard
4+
from dataclasses import dataclass
5+
from typing import Dict, List, Optional, TypedDict, Union
6+
import logging
7+
import os
8+
9+
# Third Party
10+
from tqdm import tqdm
11+
from transformers import AutoModel, AutoTokenizer
12+
import numpy as np
13+
import torch
14+
import torch.distributed as dist
15+
import torch.nn.functional as F
16+
17+
logger = logging.getLogger(__name__)
18+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
19+
20+
21+
def safe_print(rank, msg):
22+
"""Only print from rank 0."""
23+
if rank == 0:
24+
logger.info(msg)
25+
26+
27+
# Define model configuration
28+
class ModelConfig(TypedDict):
29+
pooling_method: str
30+
normalize_embeddings: bool
31+
max_length: int
32+
default_instruction: str
33+
batch_size: int
34+
35+
36+
MODEL_CONFIGS: Dict[str, ModelConfig] = {
37+
"Snowflake/snowflake-arctic-embed-l-v2.0": {
38+
"pooling_method": "cls",
39+
"normalize_embeddings": True,
40+
"max_length": 4096,
41+
"default_instruction": "Retrieve relevant passages:",
42+
"batch_size": 24,
43+
}
44+
}
45+
46+
47+
# pylint: disable=too-many-instance-attributes
48+
@dataclass
49+
class EncoderConfig:
50+
model_name: str
51+
model_config: ModelConfig
52+
device: torch.device
53+
num_gpus: int
54+
batch_size: int
55+
use_default_instruction: bool
56+
use_fp16: bool
57+
testing_mode: bool = False
58+
59+
60+
class ArcticEmbedEncoder:
61+
def __init__(
62+
self,
63+
model_name: str = "Snowflake/snowflake-arctic-embed-l-v2.0",
64+
device: Optional[torch.device] = None,
65+
use_fp16: bool = False,
66+
use_default_instruction: bool = True,
67+
testing_mode: bool = False,
68+
) -> None:
69+
"""Initialize the Arctic encoder."""
70+
if model_name not in MODEL_CONFIGS:
71+
raise ValueError(
72+
f"Model {model_name} not supported. Supported models: {list(MODEL_CONFIGS.keys())}"
73+
)
74+
75+
# Use the provided device or default to CUDA
76+
self.device = device or torch.device(
77+
"cuda" if torch.cuda.is_available() else "cpu"
78+
)
79+
80+
# Get device ID for logging
81+
self.device_id = self.device.index if hasattr(self.device, "index") else 0
82+
83+
# We don't need multi-GPU inside this encoder instance since each instance
84+
# will run on a dedicated GPU
85+
self.cfg = EncoderConfig(
86+
model_name=model_name,
87+
model_config=MODEL_CONFIGS[model_name],
88+
device=self.device,
89+
num_gpus=1, # Only use 1 GPU per encoder instance
90+
batch_size=MODEL_CONFIGS[model_name]["batch_size"],
91+
use_default_instruction=use_default_instruction,
92+
use_fp16=use_fp16,
93+
testing_mode=testing_mode,
94+
)
95+
96+
self._initialize_model()
97+
98+
def _initialize_model(self) -> None:
99+
"""Initialize model on the specific GPU."""
100+
home_dir = os.path.expanduser("~")
101+
model_path = os.path.join(
102+
home_dir, ".cache", "instructlab", "models", self.cfg.model_name
103+
)
104+
105+
# In testing mode, allow direct download from HuggingFace
106+
if hasattr(self.cfg, "testing_mode") and self.cfg.testing_mode:
107+
logger.warning(
108+
f"Model not found locally at {model_path}. "
109+
"Testing mode enabled - downloading from HuggingFace..."
110+
)
111+
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)
112+
self.model = AutoModel.from_pretrained(
113+
self.cfg.model_name,
114+
add_pooling_layer=False,
115+
trust_remote_code=True,
116+
)
117+
else:
118+
if not os.path.exists(model_path):
119+
raise ValueError(
120+
f"Model not found in available models: {self.cfg.model_name}\n"
121+
"Please run `ilab model download` and download the necessary model"
122+
)
123+
124+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
125+
self.model = AutoModel.from_pretrained(
126+
model_path,
127+
add_pooling_layer=False,
128+
trust_remote_code=True,
129+
local_files_only=True,
130+
)
131+
132+
if self.cfg.use_fp16:
133+
self.model = self.model.half()
134+
135+
self.model = self.model.to(self.cfg.device)
136+
logger.info(f"Model loaded on device: {self.cfg.device}")
137+
138+
# No need for DataParallel since we're running one encoder per GPU
139+
self.model.eval()
140+
141+
def _prepare_inputs(
142+
self, texts: Union[str, List[str]], instruction: str = ""
143+
) -> List[str]:
144+
"""Prepare inputs with model-specific formatting."""
145+
if isinstance(texts, str):
146+
texts = [texts]
147+
148+
# Ensure we always have an instruction
149+
if not instruction and not self.cfg.use_default_instruction:
150+
raise ValueError(
151+
"An instruction must be provided when use_default_instruction is False. "
152+
"Either provide an instruction or set use_default_instruction to True."
153+
)
154+
155+
if (
156+
not instruction
157+
and self.cfg.use_default_instruction
158+
and self.cfg.model_config["default_instruction"]
159+
):
160+
instruction = str(self.cfg.model_config["default_instruction"])
161+
162+
if not instruction: # catch if default_instruction is empty
163+
raise ValueError(
164+
"No instruction available. Either provide an instruction or ensure "
165+
"the model config has a valid default_instruction."
166+
)
167+
168+
texts = [f"{instruction}: {text}" for text in texts]
169+
return texts
170+
171+
@torch.no_grad()
172+
def encode(
173+
self,
174+
inputs: Union[str, List[str]],
175+
instruction: str = "",
176+
return_tensors: bool = True,
177+
show_progress: bool = True,
178+
) -> Union[torch.Tensor, np.ndarray]:
179+
"""Encode texts into embeddings."""
180+
input_was_string = isinstance(inputs, str)
181+
inputs = self._prepare_inputs(inputs, instruction)
182+
183+
encodings = self.tokenizer(
184+
inputs,
185+
max_length=self.cfg.model_config["max_length"],
186+
padding=True,
187+
truncation=True,
188+
return_tensors="pt",
189+
).to(self.cfg.device)
190+
191+
embeddings_list = []
192+
for i in tqdm(
193+
range(0, len(inputs), self.cfg.batch_size),
194+
disable=not show_progress or len(inputs) < 256,
195+
):
196+
batch = {k: v[i : i + self.cfg.batch_size] for k, v in encodings.items()}
197+
outputs = self.model(**batch)
198+
# Take the first token embedding (CLS) and normalize it
199+
embeddings = F.normalize(outputs.last_hidden_state[:, 0], p=2, dim=1)
200+
embeddings_list.append(embeddings.cpu())
201+
202+
embeddings = torch.cat(embeddings_list, dim=0)
203+
if input_was_string:
204+
embeddings = embeddings[0]
205+
206+
return embeddings if return_tensors else embeddings.numpy()
207+
208+
209+
def cleanup():
210+
if dist.is_initialized():
211+
dist.destroy_process_group()
212+
213+
214+
# FIXME: Use / Adapt below for unit / functional test for the encoder later
215+
# def run_demo():
216+
# try:
217+
# encoder = ArcticEmbedEncoder(batch_size=2, max_length=512)
218+
# # Create some sample conversation texts. Multiply to have enough samples.
219+
# conversations = [
220+
# "User: I've been feeling really down lately...",
221+
# "User: I have a big presentation tomorrow...",
222+
# "User: I just read about the rapid decline in bee populations...",
223+
# "User: I'm planning a trip to Japan next year...",
224+
# ] * 10 # Adjust the number as needed
225+
226+
# if encoder.cfg.rank == 0:
227+
# print("Last four conversations:")
228+
# print(conversations)
229+
230+
# # Encode the texts using the encoder.encode method.
231+
# embeddings = encoder.encode(
232+
# conversations, instruction="Retrieve relevant passages."
233+
# )
234+
# if encoder.cfg.rank == 0:
235+
# print("\nEncode results:")
236+
# for i, (text, emb) in enumerate(zip(conversations, embeddings)):
237+
# print(f"{i+1}. {text[:50]}... -> Embedding shape: {emb.shape}")
238+
239+
# # Demonstrate using embed_dataset directly.
240+
# dataset = Dataset.from_dict(
241+
# {"text": conversations, "idx": list(range(len(conversations)))}
242+
# )
243+
# embedded_ds = encoder.embed_dataset(
244+
# dataset, instruction="Retrieve relevant passages.", add_to_dataset=True
245+
# )
246+
# if encoder.cfg.rank == 0:
247+
# print("\nDataset results:")
248+
# print(embedded_ds)
249+
250+
# # Also show an example of returning numpy arrays.
251+
# embeddings_np = encoder.encode(
252+
# conversations,
253+
# instruction="Retrieve relevant passages.",
254+
# return_tensors=False,
255+
# )
256+
# if encoder.cfg.rank == 0:
257+
# print("\nNumpy array results:")
258+
# print(embeddings_np, embeddings_np.shape)
259+
# except Exception as e:
260+
# safe_print(dist.get_rank(), f"Demo failed: {str(e)}")
261+
# finally:
262+
# cleanup()
263+
264+
265+
# if __name__ == "__main__":
266+
# run_demo()

0 commit comments

Comments
 (0)