Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testcases for community detection #3163

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import numpy as np
import pytest
import sklearn
import torch

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.util import community_detection


def test_normalize_embeddings() -> None:
Expand Down Expand Up @@ -145,3 +147,109 @@ def test_dot_score_cos_sim() -> None:

assert np.allclose(cosine_calculated, dot_and_cosine_expected)
assert np.allclose(dot_calculated, dot_and_cosine_expected)

def test_two_clear_communities():
"""Test case with two clear communities."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0], # Point 0
[0.9, 0.1, 0.0], # Point 1
[0.8, 0.2, 0.0], # Point 2
[0.1, 0.9, 0.0], # Point 3
[0.0, 1.0, 0.0], # Point 4
[0.2, 0.8, 0.0], # Point 5
])
expected = [
[0, 1, 2], # Community 1
[3, 4, 5], # Community 2
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_no_communities_high_threshold():
"""Test case where no communities are found due to a high threshold."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
expected = []
result = community_detection(embeddings, threshold=0.99, min_community_size=2)
assert result == expected

def test_all_points_in_one_community():
"""Test case where all points form a single community due to a low threshold."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
])
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.5, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_min_community_size_filtering():
"""Test case where communities are filtered based on minimum size."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
[0.1, 0.9, 0.0],
])
expected = [
[0, 1, 2], # Only one community meets the min size requirement
]
result = community_detection(embeddings, threshold=0.8, min_community_size=3)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_overlapping_communities():
"""Test case with overlapping communities (resolved by the function)."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0], # Point 0
[0.9, 0.1, 0.0], # Point 1
[0.8, 0.2, 0.0], # Point 2
[0.7, 0.3, 0.0], # Point 3 (overlaps with both communities)
[0.1, 0.9, 0.0], # Point 4
[0.0, 1.0, 0.0], # Point 5
])
expected = [
[0, 1, 2, 3], # Community 1 (includes overlapping point 3)
[4, 5], # Community 2
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_numpy_input():
"""Test case where input is a numpy array instead of a torch tensor."""
embeddings = np.array([
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
])
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_large_batch_size():
"""Test case with a large dataset and batching."""
embeddings = torch.rand(1000, 128) # Random embeddings
result = community_detection(embeddings, threshold=0.8, min_community_size=10, batch_size=256)
# Check that all communities meet the minimum size requirement
assert all(len(community) >= 10 for community in result)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
def test_gpu_support():
"""Test case for GPU support (if available)."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
]).cuda()
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])
Loading