Skip to content

Commit e2d157b

Browse files
committed
fix assertion
1 parent e418350 commit e2d157b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch_xla/distributed/spmd/xla_sharding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def __init__(self,
6767
assert (len(device_ids) == np.prod(mesh_shape))
6868
# device ids are unique
6969
assert len(device_ids) == len(np.unique(device_ids))
70-
# device ids are continous
71-
assert all(d < self.size() for d in device_ids - np.min(device_ids))
7270
self.device_ids = device_ids
7371
self.mesh_shape = mesh_shape
7472
self.axis_names = axis_names
73+
# device ids are continous
74+
assert all(d < self.size() for d in device_ids - np.min(device_ids))
7575

7676
def size(self):
7777
return np.prod(self.mesh_shape)

0 commit comments

Comments
 (0)