We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e418350 commit e2d157bCopy full SHA for e2d157b
torch_xla/distributed/spmd/xla_sharding.py
@@ -67,11 +67,11 @@ def __init__(self,
67
assert (len(device_ids) == np.prod(mesh_shape))
68
# device ids are unique
69
assert len(device_ids) == len(np.unique(device_ids))
70
- # device ids are continous
71
- assert all(d < self.size() for d in device_ids - np.min(device_ids))
72
self.device_ids = device_ids
73
self.mesh_shape = mesh_shape
74
self.axis_names = axis_names
+ # device ids are continous
+ assert all(d < self.size() for d in device_ids - np.min(device_ids))
75
76
def size(self):
77
return np.prod(self.mesh_shape)
0 commit comments