Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions xla/hlo/ir/hlo_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ void HloSharding::Print(Printer* printer, bool include_metadata) const {
return;
}

if (maximal_) {
if (single_device_) {
AppendCat(printer, "{maximal device=",
static_cast<int64_t>(*tile_assignment_.array().begin()));
print_shard_group();
Expand Down Expand Up @@ -584,7 +584,8 @@ bool HloSharding::UsesDevice(int64_t device) const {
}

std::vector<int64_t> HloSharding::TileIndexForDevice(int64_t device) const {
CHECK(!maximal_);
CHECK(!replicated_);
CHECK(!single_device_);
CHECK(!IsManual());
CHECK(!IsUnknown());
CHECK(!IsTuple());
Expand All @@ -605,7 +606,7 @@ std::vector<int64_t> HloSharding::TileOffsetForDevice(const Shape& shape,
CHECK(!IsManual());
CHECK(!IsUnknown());

if (maximal_) {
if (replicated_ || single_device_) {
return std::vector<int64_t>(shape.dimensions().size(), 0);
}
CHECK_EQ(shape.dimensions().size(), TiledDataRank());
Expand All @@ -624,7 +625,7 @@ std::vector<int64_t> HloSharding::TileLimitForDevice(const Shape& shape,
CHECK(!IsManual());
CHECK(!IsUnknown());

if (maximal_) {
if (replicated_ || single_device_) {
return std::vector<int64_t>(shape.dimensions().begin(),
shape.dimensions().end());
}
Expand Down Expand Up @@ -656,7 +657,8 @@ absl::Status HloSharding::EachTile(
CHECK(!IsTuple());
CHECK(!IsManual());
CHECK(!IsUnknown());
CHECK(!maximal_);
CHECK(!replicated_);
CHECK(!single_device_);

// At the high-level, sharding_dims[i] describes the number of ways the shape
// is partitioned along i-th dimension. Note that sharding_dims[i] with i >=
Expand Down
44 changes: 24 additions & 20 deletions xla/hlo/ir/hlo_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class HloSharding {
if (UseNamedShardingLeaf()) {
return named_sharding_->IsTileMaximal();
}
return maximal_;
return replicated_ || single_device_;
}

// Returns whether the sharding represents manual partitioning.
Expand Down Expand Up @@ -439,7 +439,8 @@ class HloSharding {
// REQUIRES: !IsTuple()
// REQUIRES: !IsManual()
// REQUIRES: !IsUnknown()
// REQUIRES: !maximal_
// REQUIRES: !replicated_
// REQUIRES: !single_device_
//
// For NamedSharding we convert it to tile based HloShardingV2 and then invoke
// callback on the tile based sharding.
Expand Down Expand Up @@ -505,7 +506,8 @@ class HloSharding {

bool operator==(const HloSharding& other) const {
if (named_sharding_.has_value() == other.named_sharding_.has_value()) {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
return replicated_ == other.replicated_ &&
single_device_ == other.single_device_ &&
manual_ == other.manual_ && unknown_ == other.unknown_ &&
unreduced_ == other.unreduced_ &&
tile_assignment_ == other.tile_assignment_ &&
Expand Down Expand Up @@ -738,7 +740,7 @@ class HloSharding {

explicit HloSharding(NamedSharding named_sharding)
: replicated_(false),
maximal_(false),
single_device_(false),
tuple_(false),
manual_(false),
unknown_(false),
Expand All @@ -751,7 +753,7 @@ class HloSharding {
bool unreduced, absl::Span<const OpMetadata> metadata)
: metadata_(metadata.begin(), metadata.end()),
replicated_(replicated),
maximal_(replicated),
single_device_(false),
tuple_(false),
manual_(manual),
unknown_(unknown),
Expand All @@ -768,7 +770,7 @@ class HloSharding {
: tile_assignment_(device_id),
metadata_(metadata.begin(), metadata.end()),
replicated_(false),
maximal_(true),
single_device_(true),
tuple_(false),
manual_(false),
unknown_(false),
Expand All @@ -781,7 +783,7 @@ class HloSharding {
: tile_assignment_(std::move(tile_assignment)),
metadata_(metadata.begin(), metadata.end()),
replicated_(false),
maximal_(false),
single_device_(false),
tuple_(false),
manual_(false),
unknown_(false),
Expand All @@ -795,7 +797,7 @@ class HloSharding {
metadata_(metadata.begin(), metadata.end()),
subgroup_types_(subgroup_types.begin(), subgroup_types.end()),
replicated_(false),
maximal_(false),
single_device_(false),
tuple_(false),
manual_(false),
unknown_(false),
Expand All @@ -805,7 +807,7 @@ class HloSharding {
explicit HloSharding(std::vector<HloSharding> tuple_shardings)
: tuple_elements_(std::move(tuple_shardings)),
replicated_(false),
maximal_(false),
single_device_(false),
tuple_(true),
manual_(false),
unknown_(false),
Expand All @@ -821,7 +823,7 @@ class HloSharding {
metadata_(other.metadata_),
subgroup_types_(other.subgroup_types_),
replicated_(other.replicated_),
maximal_(other.maximal_),
single_device_(other.single_device_),
tuple_(other.tuple_),
manual_(other.manual_),
unknown_(other.unknown_),
Expand All @@ -848,14 +850,15 @@ class HloSharding {

const TileAssignment& TileAgnosticDeviceAssignment() const;

// This field is only used if replicated_ is false. If maximal_ is true, then
// the field contains a rank 1 array with a single element, which is the
// device the HLO is assigned to. If maximal_ is false, the field contains an
// array with the same rank as the corresponding HLO. The dimension sizes of
// the array describe the number of ways the HLO is partitioned along each
// dimension. The values of the array specify which device each tile of
// the HLO is assigned to. The index of each value determines which tile it
// takes.
// This field is only used if replicated_ is false. If single_device_ is true,
// then the field contains a rank 1 array with a single element, which is the
// device the HLO is assigned to. If single_device_ is false, the field
// contains an array with the same rank as the corresponding HLO. The
// dimension sizes of the array describe the number of ways the HLO is
// partitioned along each dimension. The values of the array specify which
// device each tile of the HLO is assigned to. The index of each value
// determines which tile it takes.
//
// For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is
// "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and
// dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the
Expand All @@ -878,9 +881,10 @@ class HloSharding {
// When creating HloSharding, subgroup dims of the same type will be merged,
// so that the elements in subgroup_types_ are unique.
std::vector<OpSharding::Type> subgroup_types_;

bool replicated_ : 1; // When non-tuple, true if the sharding is trivial.
bool maximal_ : 1; // When non-tuple, true if the tile size is the same as
// the input size.
bool single_device_ : 1; // When non-tuple, true if the tensor is on a single
// device.
bool tuple_ : 1; // True if this is a tuple.
bool manual_ : 1; // When non-tuple, true if the sharding represents manual
// partitioning.
Expand Down
Loading