-
Notifications
You must be signed in to change notification settings - Fork 291
[NewOp] Add domain_diversity_selector based on DaaR principles #699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new domain_diversity_selector operator, which is a valuable addition for curating datasets based on semantic diversity. The implementation follows the DaaR paper's principles, including embedding generation, clustering, and strategic sample selection. The code is well-structured, and the inclusion of configuration, documentation, and tests is appreciated.
However, there are several critical areas for improvement, primarily concerning performance and memory efficiency. The current implementation processes samples one by one, which will be very slow on large datasets. I've provided suggestions to use batching and vectorized operations to significantly speed up embedding and similarity calculations. Additionally, the memory usage can be optimized by avoiding the storage of redundant data.
The tests are a good start, but they rely on external APIs, making them difficult to run in automated environments. I've suggested mocking these dependencies. Finally, a few minor refactoring opportunities and configuration improvements are noted to enhance code clarity and usability.
| if self.is_hf_model: | ||
| # Embeddings extract via local models | ||
| for sample in tqdm(dataset, desc="Embedding", unit="sample"): | ||
| text = sample["text"] | ||
| with torch.no_grad(): | ||
| embedding = model.encode(text) | ||
| embeddings.append(embedding) | ||
| else: | ||
| # Embeddings extract via API | ||
| for sample in tqdm(dataset, desc="Embedding", unit="sample"): | ||
| text = sample["text"] | ||
| embedding = model(text, dimensions=self.ebd_dim, encoding_format="float") | ||
| embeddings.append(embedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for generating embeddings processes samples one by one, which is highly inefficient for large datasets. Both HuggingFace encode methods and many embedding APIs support batch processing. Refactoring this to process samples in batches will significantly improve performance.
| if self.is_hf_model: | |
| # Embeddings extract via local models | |
| for sample in tqdm(dataset, desc="Embedding", unit="sample"): | |
| text = sample["text"] | |
| with torch.no_grad(): | |
| embedding = model.encode(text) | |
| embeddings.append(embedding) | |
| else: | |
| # Embeddings extract via API | |
| for sample in tqdm(dataset, desc="Embedding", unit="sample"): | |
| text = sample["text"] | |
| embedding = model(text, dimensions=self.ebd_dim, encoding_format="float") | |
| embeddings.append(embedding) | |
| if self.is_hf_model: | |
| # Embeddings extract via local models in batches for efficiency | |
| texts = [sample['text'] for sample in dataset] | |
| with torch.no_grad(): | |
| embeddings = model.encode(texts, batch_size=self.batch_size, show_progress_bar=True) | |
| else: | |
| # Embeddings extract via API. Consider batching if the API supports it. | |
| for sample in tqdm(dataset, desc="Embedding", unit="sample"): | |
| text = sample["text"] | |
| embedding = model(text, dimensions=self.ebd_dim, encoding_format="float") | |
| embeddings.append(embedding) |
| # Sample-level cos-similarity to other centroids | ||
| for i, entry in tqdm(enumerate(dataset), total=len(dataset), desc="Calculating similarity:"): | ||
| current_embedding = embeddings_array[i] | ||
| current_label = int(labels[i]) | ||
|
|
||
| similarities = [] | ||
| for j, centroid in enumerate(centroid_embeddings): | ||
| if j != current_label: | ||
| similarity = torch.nn.functional.cosine_similarity( | ||
| torch.tensor(current_embedding).unsqueeze(0), torch.tensor(centroid).unsqueeze(0) | ||
| ).item() | ||
| similarities.append(similarity) | ||
|
|
||
| own_centroid_similarity = torch.nn.functional.cosine_similarity( | ||
| torch.tensor(current_embedding).unsqueeze(0), | ||
| torch.tensor(centroid_embeddings[current_label]).unsqueeze(0), | ||
| ).item() | ||
|
|
||
| global_centroid_similarity = torch.nn.functional.cosine_similarity( | ||
| torch.tensor(current_embedding).unsqueeze(0), torch.tensor(global_centroid).unsqueeze(0) | ||
| ).item() | ||
| total_similarity = sum(similarities) | ||
|
|
||
| data_status.append( | ||
| { | ||
| "text": entry["text"], | ||
| "label": current_label, | ||
| "similarity_with_other_centroids": similarities, | ||
| "total_similarity": total_similarity, | ||
| "similarity_with_own_centroid": own_centroid_similarity, | ||
| "global_centroid_similarity": global_centroid_similarity, | ||
| "original_index": i, | ||
| } | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section has critical performance and memory issues that should be addressed:
- Performance: Cosine similarities are calculated in a loop, creating and destroying PyTorch tensors for each sample. This is very slow. These calculations should be vectorized using
sklearn.metrics.pairwise.cosine_similarityortorch.nn.functional.cosine_similarityon the entire matrices. - Memory: The
data_statuslist stores a dictionary for each sample, including the original text. For large datasets, this will consume a very large amount of memory. It's more memory-efficient to work with NumPy arrays for embeddings, labels, and calculated similarities, and only use indices to refer back to the original dataset.
A full refactor is recommended to process these calculations in a batched/vectorized manner and to avoid creating large intermediate data structures.
| from data_juicer.ops.selector.domain_diversity_selector import DomainDiversitySelector | ||
| from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase | ||
|
|
||
| @unittest.skipIf(FROM_FORK, "Skipping the test because running from a fork repo") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is skipped when running from a fork, and its implementation relies on an external API (text-embedding-v3), which requires credentials. This makes the test difficult to run in CI/CD environments and for external contributors. It's highly recommended to either mock the API calls or use a small, publicly available HuggingFace model (by setting is_hf_model=True and providing a model name) for testing to ensure the operator's logic can be verified automatically and reliably.
| api_endpoint: '/embeddings' # embedding URL endpoint for the API | ||
| response_path: 'data.0.embedding' # path to extract content from the API response | ||
| model_params: {} # parameters for initializing the API model | ||
| select_ratio: # the ratio to be sampled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The select_ratio is left empty in this configuration example. While the code correctly handles a None value by skipping the operation, it would be more user-friendly for an example configuration to provide a default value (e.g., 0.5) and add a comment explaining its purpose.
select_ratio: 0.5 # the ratio to be sampled| if self.strategy == "inter": | ||
| label_data_status.sort(key=lambda x: x["total_similarity"]) | ||
| elif self.strategy == "intra": | ||
| label_data_status.sort(key=lambda x: x["similarity_with_own_centroid"], reverse=True) | ||
| elif self.strategy == "global": | ||
| label_data_status.sort(key=lambda x: x["global_centroid_similarity"]) | ||
| else: | ||
| raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if/elif/else chain for selecting the sorting key based on the strategy can be made more concise and extensible by using a dictionary to map strategy names to their corresponding sorting logic (key and reverse flag).
| if self.strategy == "inter": | |
| label_data_status.sort(key=lambda x: x["total_similarity"]) | |
| elif self.strategy == "intra": | |
| label_data_status.sort(key=lambda x: x["similarity_with_own_centroid"], reverse=True) | |
| elif self.strategy == "global": | |
| label_data_status.sort(key=lambda x: x["global_centroid_similarity"]) | |
| else: | |
| raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.") | |
| strategy_configs = { | |
| 'inter': {'key': 'total_similarity', 'reverse': False}, | |
| 'intra': {'key': 'similarity_with_own_centroid', 'reverse': True}, | |
| 'global': {'key': 'global_centroid_similarity', 'reverse': False}, | |
| } | |
| if self.strategy not in strategy_configs: | |
| raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.") | |
| config = strategy_configs[self.strategy] | |
| label_data_status.sort(key=lambda x: x[config['key']], reverse=config['reverse']) |
| def _run_domain_diversity_selector(self, dataset: Dataset, target_num, op): | ||
| dataset = op.process(dataset) | ||
| res_list = dataset.to_list() | ||
| self.assertEqual(len(res_list), target_num) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion only checks the number of selected samples, which is not sufficient to verify that the diversity selection logic is working correctly. The test should be strengthened to check which samples are selected. You could add a unique ID to each sample in ds_list, and then after processing, verify that the IDs of the selected samples match a list of expected IDs for a given seed and input.
Introduces a novel data selection op based on semantic diversity across domains, designed to automatically select the most diverse subset of data samples, which is inspired by the DaaR paper.
[WIP] Ongoing development of additional operators derived from the DaaR.