2424from torchmetrics .utilities .exceptions import TorchMetricsUserError
2525from torchmetrics .utilities .imports import _TORCH_GREATER_EQUAL_2_1
2626
27- from unittests import NUM_PROCESSES
27+ from unittests import NUM_PROCESSES , USE_PYTEST_POOL
2828from unittests ._helpers import seed_all
2929from unittests ._helpers .testers import DummyListMetric , DummyMetric , DummyMetricSum
3030
@@ -88,6 +88,7 @@ def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) ->
8888
8989@pytest .mark .DDP ()
9090@pytest .mark .skipif (sys .platform == "win32" , reason = "DDP not available on windows" )
91+ @pytest .mark .skipif (not USE_PYTEST_POOL , reason = "DDP pool is not available." )
9192@pytest .mark .parametrize (
9293 "process" ,
9394 [
@@ -125,6 +126,7 @@ def compute(self):
125126
126127@pytest .mark .DDP ()
127128@pytest .mark .skipif (sys .platform == "win32" , reason = "DDP not available on windows" )
129+ @pytest .mark .skipif (not USE_PYTEST_POOL , reason = "DDP pool is not available." )
128130def test_non_contiguous_tensors ():
129131 """Test that gather_all operation works for non-contiguous tensors."""
130132 pytest .pool .map (_test_non_contiguous_tensors , range (NUM_PROCESSES ))
@@ -232,6 +234,7 @@ def reload_state_dict(state_dict, expected_x, expected_c):
232234
233235@pytest .mark .DDP ()
234236@pytest .mark .skipif (sys .platform == "win32" , reason = "DDP not available on windows" )
237+ @pytest .mark .skipif (not USE_PYTEST_POOL , reason = "DDP pool is not available." )
235238def test_state_dict_is_synced (tmpdir ):
236239 """Tests that metrics are synced while creating the state dict but restored after to continue accumulation."""
237240 pytest .pool .map (partial (_test_state_dict_is_synced , tmpdir = tmpdir ), range (NUM_PROCESSES ))
@@ -260,6 +263,7 @@ def _test_sync_on_compute_list_state(rank, sync_on_compute):
260263
261264@pytest .mark .DDP ()
262265@pytest .mark .skipif (sys .platform == "win32" , reason = "DDP not available on windows" )
266+ @pytest .mark .skipif (not USE_PYTEST_POOL , reason = "DDP pool is not available." )
263267@pytest .mark .parametrize ("sync_on_compute" , [True , False ])
264268@pytest .mark .parametrize ("test_func" , [_test_sync_on_compute_list_state , _test_sync_on_compute_tensor_state ])
265269def test_sync_on_compute (sync_on_compute , test_func ):
@@ -276,6 +280,7 @@ def _test_sync_with_empty_lists(rank):
276280@pytest .mark .DDP ()
277281@pytest .mark .skipif (not _TORCH_GREATER_EQUAL_2_1 , reason = "test only works on newer torch versions" )
278282@pytest .mark .skipif (sys .platform == "win32" , reason = "DDP not available on windows" )
283+ @pytest .mark .skipif (not USE_PYTEST_POOL , reason = "DDP pool is not available." )
279284def test_sync_with_empty_lists ():
280285 """Test that synchronization of states can be enabled and disabled for compute."""
281286 pytest .pool .map (_test_sync_with_empty_lists , range (NUM_PROCESSES ))
0 commit comments