@@ -309,37 +309,26 @@ def __iter__(self):
309309 for i in np .random .choice (range (len (self )), len (self ), replace = False ):
310310 yield self .batches_of_indices [i ]
311311
312- def check_omp_num_threads (self ):
313- if get_world_size () == 1 :
314- return
315- val = os .environ .get ("OMP_NUM_THREADS" , "unset" )
316- if not (val != "unset" and int (val ) > 1 ):
317- logging .warning (
318- f"You are using multi-processing but OMP_NUM_THREADS is { val } . "
319- f"Index building in uniform batching mode may be slower than expected. "
320- f"Set OMP_NUM_THREADS to a value greater than 1 to improve performance."
321- )
322-
323312 def build_or_sync_indices (self ):
324- self .check_omp_num_threads ()
313+ orig_device = self .positions .device
314+ self .positions = self .positions .to (torch .get_default_device ())
315+
325316 if get_rank () == 0 :
326317 self .build_indices ()
327318
328319 if get_world_size () > 1 :
329320 # Temporarily move indices to GPU.
330321 if get_rank () == 0 :
331322 batch_lengths = torch .tensor (
332- [len (batch ) for batch in self .batches_of_indices ], device = torch . get_default_device () , dtype = torch .long
323+ [len (batch ) for batch in self .batches_of_indices ], device = self . positions . device , dtype = torch .long
333324 )
334- flat_indices = torch .cat (self .batches_of_indices ).to (torch . get_default_device () )
325+ flat_indices = torch .cat (self .batches_of_indices ).to (self . positions . device )
335326 else :
336- batch_lengths = torch .empty (len (self ), dtype = torch .long , device = torch . get_default_device () )
337- flat_indices = torch .empty (len (self .positions ), dtype = torch .long , device = torch . get_default_device () )
327+ batch_lengths = torch .empty (len (self ), dtype = torch .long , device = self . positions . device )
328+ flat_indices = torch .empty (len (self .positions ), dtype = torch .long , device = self . positions . device )
338329
339330 torch .distributed .broadcast (batch_lengths , src = 0 )
340331 torch .distributed .broadcast (flat_indices , src = 0 )
341- batch_lengths = batch_lengths .to (self .positions .device )
342- flat_indices = flat_indices .to (self .positions .device )
343332
344333 # Re-assemble batch index list.
345334 if get_rank () != 0 :
@@ -350,6 +339,9 @@ def build_or_sync_indices(self):
350339 batches .append (flat_indices [start :end ].clone ())
351340 start = end
352341 self .batches_of_indices = tuple (batches )
342+
343+ # Move back to original device.
344+ self .batches_of_indices = [x .to (orig_device ) for x in self .batches_of_indices ]
353345
354346 def build_indices (self ):
355347 dist_mat = torch .cdist (self .positions , self .positions , p = 2 )
0 commit comments