@@ -167,19 +167,25 @@ def callable(input):
167167 return input .cpu ()
168168
169169 @staticmethod
170- def _all_gather_into_tensor (use_dynamo : bool ):
170+ def _all_gather_into_tensor (use_dynamo : bool , mode : str ):
171171 met .clear_all ()
172+ allowed_modes = ["stack" , "concat" ]
173+ if mode not in allowed_modes :
174+ raise ValueError (f"mode must be one of { allowed_modes } " )
172175
173176 def callable (output , input ):
174- dist .all_gather_into_tensor (output_tensor , input , None )
175- return output_tensor
177+ dist .all_gather_into_tensor (output , input , None )
178+ return output
176179
177180 dist .init_process_group ("xla" , init_method = 'xla://' )
178181 device = torch_xla .device ()
179182 input = torch .tensor ([xr .global_ordinal ()],
180183 dtype = torch .float ,
181184 device = device )
182- output_tensor = torch .empty ((1 , xr .world_size ()), device = device )
185+ if mode == "stack" :
186+ output_tensor = torch .empty ((xr .world_size (), 1 ), device = device )
187+ elif mode == "concat" :
188+ output_tensor = torch .empty ((xr .world_size (),), device = device )
183189 f = torch .compile (callable , backend = 'openxla' ) if use_dynamo else callable
184190 f (output_tensor , input )
185191 torch_xla .sync ()
@@ -278,13 +284,17 @@ def test_all_reduce(self, use_dynamo):
278284 for index , val in results .items ():
279285 torch .testing .assert_close (val , expected )
280286
281- @parameterized .named_parameters (('dynamo' , True ), ('nondynamo' , False ))
282- def test_all_gather_into_tensor (self , use_dynamo ):
287+ @parameterized .product (dynamo = [True , False ], mode = ["stack" , "concat" ])
288+ def test_all_gather_into_tensor (self , dynamo , mode ):
289+ if dynamo and mode == "stack" :
290+ self .skipTest ("https://github.com/pytorch/pytorch/issues/155632" )
283291 results = pjrt .run_multiprocess (
284- self ._all_gather_into_tensor , use_dynamo = use_dynamo )
292+ self ._all_gather_into_tensor , use_dynamo = dynamo , mode = mode )
285293 expected = torch .arange (
286- tpu .num_expected_global_devices (), dtype = torch .float ).unsqueeze (0 )
287- for index , val in results .items ():
294+ tpu .num_expected_global_devices (), dtype = torch .float )
295+ if mode == "stack" :
296+ expected = expected .unsqueeze (1 )
297+ for _ , val in results .items ():
288298 torch .testing .assert_close (val , expected )
289299
290300 @parameterized .named_parameters (('dynamo' , True ), ('nondynamo' , False ))
0 commit comments