Skip to content

Conversation

@lingzhq
Copy link
Collaborator

@lingzhq lingzhq commented Jun 12, 2025

Introduces a novel data selection op based on semantic diversity across domains, designed to automatically select the most diverse subset of data samples, which is inspired by the DaaR paper.

  • Converts input samples into embeddings
  • Use embeddings to cluster pseudo-domains
  • Selects samples based on various distances to maximize diversity

[WIP] Ongoing development of additional operators derived from the DaaR.

@yxdyc
Copy link
Collaborator

yxdyc commented Nov 17, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new domain_diversity_selector operator, which is a valuable addition for curating datasets based on semantic diversity. The implementation follows the DaaR paper's principles, including embedding generation, clustering, and strategic sample selection. The code is well-structured, and the inclusion of configuration, documentation, and tests is appreciated.

However, there are several critical areas for improvement, primarily concerning performance and memory efficiency. The current implementation processes samples one by one, which will be very slow on large datasets. I've provided suggestions to use batching and vectorized operations to significantly speed up embedding and similarity calculations. Additionally, the memory usage can be optimized by avoiding the storage of redundant data.

The tests are a good start, but they rely on external APIs, making them difficult to run in automated environments. I've suggested mocking these dependencies. Finally, a few minor refactoring opportunities and configuration improvements are noted to enhance code clarity and usability.

Comment on lines +81 to +93
if self.is_hf_model:
# Embeddings extract via local models
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
with torch.no_grad():
embedding = model.encode(text)
embeddings.append(embedding)
else:
# Embeddings extract via API
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(embedding)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for generating embeddings processes samples one by one, which is highly inefficient for large datasets. Both HuggingFace encode methods and many embedding APIs support batch processing. Refactoring this to process samples in batches will significantly improve performance.

Suggested change
if self.is_hf_model:
# Embeddings extract via local models
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
with torch.no_grad():
embedding = model.encode(text)
embeddings.append(embedding)
else:
# Embeddings extract via API
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(embedding)
if self.is_hf_model:
# Embeddings extract via local models in batches for efficiency
texts = [sample['text'] for sample in dataset]
with torch.no_grad():
embeddings = model.encode(texts, batch_size=self.batch_size, show_progress_bar=True)
else:
# Embeddings extract via API. Consider batching if the API supports it.
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(embedding)

Comment on lines +117 to +150
# Sample-level cos-similarity to other centroids
for i, entry in tqdm(enumerate(dataset), total=len(dataset), desc="Calculating similarity:"):
current_embedding = embeddings_array[i]
current_label = int(labels[i])

similarities = []
for j, centroid in enumerate(centroid_embeddings):
if j != current_label:
similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0), torch.tensor(centroid).unsqueeze(0)
).item()
similarities.append(similarity)

own_centroid_similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0),
torch.tensor(centroid_embeddings[current_label]).unsqueeze(0),
).item()

global_centroid_similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0), torch.tensor(global_centroid).unsqueeze(0)
).item()
total_similarity = sum(similarities)

data_status.append(
{
"text": entry["text"],
"label": current_label,
"similarity_with_other_centroids": similarities,
"total_similarity": total_similarity,
"similarity_with_own_centroid": own_centroid_similarity,
"global_centroid_similarity": global_centroid_similarity,
"original_index": i,
}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This section has critical performance and memory issues that should be addressed:

  1. Performance: Cosine similarities are calculated in a loop, creating and destroying PyTorch tensors for each sample. This is very slow. These calculations should be vectorized using sklearn.metrics.pairwise.cosine_similarity or torch.nn.functional.cosine_similarity on the entire matrices.
  2. Memory: The data_status list stores a dictionary for each sample, including the original text. For large datasets, this will consume a very large amount of memory. It's more memory-efficient to work with NumPy arrays for embeddings, labels, and calculated similarities, and only use indices to refer back to the original dataset.

A full refactor is recommended to process these calculations in a batched/vectorized manner and to avoid creating large intermediate data structures.

from data_juicer.ops.selector.domain_diversity_selector import DomainDiversitySelector
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase

@unittest.skipIf(FROM_FORK, "Skipping the test because running from a fork repo")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test is skipped when running from a fork, and its implementation relies on an external API (text-embedding-v3), which requires credentials. This makes the test difficult to run in CI/CD environments and for external contributors. It's highly recommended to either mock the API calls or use a small, publicly available HuggingFace model (by setting is_hf_model=True and providing a model name) for testing to ensure the operator's logic can be verified automatically and reliably.

api_endpoint: '/embeddings' # embedding URL endpoint for the API
response_path: 'data.0.embedding' # path to extract content from the API response
model_params: {} # parameters for initializing the API model
select_ratio: # the ratio to be sampled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The select_ratio is left empty in this configuration example. While the code correctly handles a None value by skipping the operation, it would be more user-friendly for an example configuration to provide a default value (e.g., 0.5) and add a comment explaining its purpose.

      select_ratio: 0.5                                           # the ratio to be sampled

Comment on lines +162 to +169
if self.strategy == "inter":
label_data_status.sort(key=lambda x: x["total_similarity"])
elif self.strategy == "intra":
label_data_status.sort(key=lambda x: x["similarity_with_own_centroid"], reverse=True)
elif self.strategy == "global":
label_data_status.sort(key=lambda x: x["global_centroid_similarity"])
else:
raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/elif/else chain for selecting the sorting key based on the strategy can be made more concise and extensible by using a dictionary to map strategy names to their corresponding sorting logic (key and reverse flag).

Suggested change
if self.strategy == "inter":
label_data_status.sort(key=lambda x: x["total_similarity"])
elif self.strategy == "intra":
label_data_status.sort(key=lambda x: x["similarity_with_own_centroid"], reverse=True)
elif self.strategy == "global":
label_data_status.sort(key=lambda x: x["global_centroid_similarity"])
else:
raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.")
strategy_configs = {
'inter': {'key': 'total_similarity', 'reverse': False},
'intra': {'key': 'similarity_with_own_centroid', 'reverse': True},
'global': {'key': 'global_centroid_similarity', 'reverse': False},
}
if self.strategy not in strategy_configs:
raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.")
config = strategy_configs[self.strategy]
label_data_status.sort(key=lambda x: x[config['key']], reverse=config['reverse'])

def _run_domain_diversity_selector(self, dataset: Dataset, target_num, op):
dataset = op.process(dataset)
res_list = dataset.to_list()
self.assertEqual(len(res_list), target_num)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion only checks the number of selected samples, which is not sufficient to verify that the diversity selection logic is working correctly. The test should be strengthened to check which samples are selected. You could add a unique ID to each sample in ds_list, and then after processing, verify that the IDs of the selected samples match a list of expected IDs for a given seed and input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants