diff --git a/tests/test_util.py b/tests/test_util.py index 71d194e54..81e0d253f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,10 +1,12 @@ from __future__ import annotations import numpy as np +import pytest import sklearn import torch from sentence_transformers import SentenceTransformer, util +from sentence_transformers.util import community_detection def test_normalize_embeddings() -> None: @@ -145,3 +147,109 @@ def test_dot_score_cos_sim() -> None: assert np.allclose(cosine_calculated, dot_and_cosine_expected) assert np.allclose(dot_calculated, dot_and_cosine_expected) + +def test_two_clear_communities(): + """Test case with two clear communities.""" + embeddings = torch.tensor([ + [1.0, 0.0, 0.0], # Point 0 + [0.9, 0.1, 0.0], # Point 1 + [0.8, 0.2, 0.0], # Point 2 + [0.1, 0.9, 0.0], # Point 3 + [0.0, 1.0, 0.0], # Point 4 + [0.2, 0.8, 0.0], # Point 5 + ]) + expected = [ + [0, 1, 2], # Community 1 + [3, 4, 5], # Community 2 + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + +def test_no_communities_high_threshold(): + """Test case where no communities are found due to a high threshold.""" + embeddings = torch.tensor([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ]) + expected = [] + result = community_detection(embeddings, threshold=0.99, min_community_size=2) + assert result == expected + +def test_all_points_in_one_community(): + """Test case where all points form a single community due to a low threshold.""" + embeddings = torch.tensor([ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.8, 0.2, 0.0], + ]) + expected = [ + [0, 1, 2], # Single community + ] + result = community_detection(embeddings, threshold=0.5, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + +def test_min_community_size_filtering(): + """Test case where communities are filtered based on minimum size.""" + embeddings = torch.tensor([ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.8, 0.2, 0.0], + [0.1, 0.9, 0.0], + ]) + expected = [ + [0, 1, 2], # Only one community meets the min size requirement + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=3) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + +def test_overlapping_communities(): + """Test case with overlapping communities (resolved by the function).""" + embeddings = torch.tensor([ + [1.0, 0.0, 0.0], # Point 0 + [0.9, 0.1, 0.0], # Point 1 + [0.8, 0.2, 0.0], # Point 2 + [0.7, 0.3, 0.0], # Point 3 (overlaps with both communities) + [0.1, 0.9, 0.0], # Point 4 + [0.0, 1.0, 0.0], # Point 5 + ]) + expected = [ + [0, 1, 2, 3], # Community 1 (includes overlapping point 3) + [4, 5], # Community 2 + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + +def test_numpy_input(): + """Test case where input is a numpy array instead of a torch tensor.""" + embeddings = np.array([ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.8, 0.2, 0.0], + ]) + expected = [ + [0, 1, 2], # Single community + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + +def test_large_batch_size(): + """Test case with a large dataset and batching.""" + embeddings = torch.rand(1000, 128) # Random embeddings + result = community_detection(embeddings, threshold=0.8, min_community_size=10, batch_size=256) + # Check that all communities meet the minimum size requirement + assert all(len(community) >= 10 for community in result) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") +def test_gpu_support(): + """Test case for GPU support (if available).""" + embeddings = torch.tensor([ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.8, 0.2, 0.0], + ]).cuda() + expected = [ + [0, 1, 2], # Single community + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])