Skip to content

testcases for community detection #3163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import numpy as np
import pytest
import sklearn
import torch

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.util import community_detection


def test_normalize_embeddings() -> None:
Expand Down Expand Up @@ -145,3 +147,131 @@ def test_dot_score_cos_sim() -> None:

assert np.allclose(cosine_calculated, dot_and_cosine_expected)
assert np.allclose(dot_calculated, dot_and_cosine_expected)


def test_community_detection_two_clear_communities():
"""Test case with two clear communities."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0], # Point 0
[0.9, 0.1, 0.0], # Point 1
[0.8, 0.2, 0.0], # Point 2
[0.1, 0.9, 0.0], # Point 3
[0.0, 1.0, 0.0], # Point 4
[0.2, 0.8, 0.0], # Point 5
]
)
expected = [
[0, 1, 2], # Community 1
[3, 4, 5], # Community 2
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_no_communities_high_threshold():
"""Test case where no communities are found due to a high threshold."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
expected = []
result = community_detection(embeddings, threshold=0.99, min_community_size=2)
assert result == expected


def test_community_detection_all_points_in_one_community():
"""Test case where all points form a single community due to a low threshold."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
]
)
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.5, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_min_community_size_filtering():
"""Test case where communities are filtered based on minimum size."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
[0.1, 0.9, 0.0],
]
)
expected = [
[0, 1, 2], # Only one community meets the min size requirement
]
result = community_detection(embeddings, threshold=0.8, min_community_size=3)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_overlapping_communities():
"""Test case with overlapping communities (resolved by the function)."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0], # Point 0
[0.9, 0.1, 0.0], # Point 1
[0.8, 0.2, 0.0], # Point 2
[0.7, 0.3, 0.0], # Point 3 (overlaps with both communities)
[0.1, 0.9, 0.0], # Point 4
[0.0, 1.0, 0.0], # Point 5
]
)
expected = [
[0, 1, 2, 3], # Community 1 (includes overlapping point 3)
[4, 5], # Community 2
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_numpy_input():
"""Test case where input is a numpy array instead of a torch tensor."""
embeddings = np.array(
[
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
]
)
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_large_batch_size():
"""Test case with a large dataset and batching."""
embeddings = torch.rand(1000, 128) # Random embeddings
result = community_detection(embeddings, threshold=0.8, min_community_size=10, batch_size=256)
# Check that all communities meet the minimum size requirement
assert all(len(community) >= 10 for community in result)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
def test_community_detection_gpu_support():
"""Test case for GPU support (if available)."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
]
).cuda()
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])
Loading