File tree Expand file tree Collapse file tree 2 files changed +30
-4
lines changed
Expand file tree Collapse file tree 2 files changed +30
-4
lines changed Original file line number Diff line number Diff line change 1111import contextlib
1212import logging
1313import random
14+ import sys
1415import time
1516from typing import Any , Generator
1617
1718import torch
1819
1920logger : logging .Logger = logging .getLogger (__name__ )
2021
21- torch .ops .load_library ("//caffe2/torch/fb/retrieval:faster_hash_cpu" )
22- torch .ops .load_library ("//caffe2/torch/fb/retrieval:faster_hash_cuda" )
22+
23+ def load_required_libraries () -> bool :
24+ try :
25+ torch .ops .load_library ("//torchrec/ops:faster_hash_cpu" )
26+ torch .ops .load_library ("//torchrec/ops:faster_hash_cuda" )
27+ return True
28+ except Exception as e :
29+ logger .error (f"Failed to load faster_hash libraries, skipping test: { e } " )
30+ return False
2331
2432
2533@contextlib .contextmanager
@@ -347,6 +355,9 @@ def _run_benchmark_with_eviction(
347355
348356
349357if __name__ == "__main__" :
358+ if not load_required_libraries ():
359+ print ("Skipping test because libraries were not loaded" )
360+ sys .exit (0 )
350361 logger .setLevel (logging .INFO )
351362 handler = logging .StreamHandler ()
352363 handler .setLevel (logging .INFO )
Original file line number Diff line number Diff line change 1313import torch
1414from hypothesis import settings
1515
16- torch .ops .load_library ("//torchrec/ops:faster_hash_cpu" )
17- torch .ops .load_library ("//torchrec/ops:faster_hash_cuda" )
16+
17+ def load_required_libraries () -> bool :
18+ try :
19+ torch .ops .load_library ("//torchrec/ops:faster_hash_cpu" )
20+ torch .ops .load_library ("//torchrec/ops:faster_hash_cuda" )
21+ return True
22+ except Exception as e :
23+ print (f"Skipping tests because libraries were not loaded: { e } " )
24+ return False
1825
1926
2027class HashZchKernelEvictionPolicy (IntEnum ):
@@ -23,6 +30,14 @@ class HashZchKernelEvictionPolicy(IntEnum):
2330
2431
2532class FasterHashTest (unittest .TestCase ):
33+
34+ @classmethod
35+ def setUpClass (cls ):
36+ if not load_required_libraries ():
37+ raise unittest .SkipTest (
38+ "Libraries not loaded, skipping all tests in MyTestCase"
39+ )
40+
2641 @unittest .skipIf (not torch .cuda .is_available (), "Skip when CUDA is not available" )
2742 @settings (deadline = None )
2843 def test_simple_zch_no_evict (self ) -> None :
You can’t perform that action at this time.
0 commit comments