Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 25 additions & 48 deletions xla/service/gpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1583,28 +1583,34 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCollectiveThunk(
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
int64_t replica_count = hlo_config.replica_count();
int64_t partition_count = hlo_config.num_partitions();
int64_t operand_count = inst->operand_count();
VLOG(2) << CollectiveThunkType::GetHloOpName()
<< "; replica count: " << replica_count
<< "; partition count: " << partition_count
<< "; operand count: " << inst->operand_count();
<< "; operand count: " << operand_count;

// Stash relevant information in CollectiveThunk::Buffer even if
// we may not generate an CollectiveThunk.
std::vector<CollectiveThunk::Buffer> buffers;

int64_t operand_count = inst->operand_count();
buffers.reserve(operand_count);

// Adds a source and destination buffers pair to `buffers`.
auto add_buffer = [&](int64_t element_count, BufferAllocation::Slice src,
int64_t src_memory_space, BufferAllocation::Slice dst,
int64_t dst_memory_space) {
buffers.push_back(
CollectiveThunk::Buffer{/*element_count=*/element_count,
/*source_buffer=*/src,
/*destination_buffer=*/dst,
/*source_memory_space=*/src_memory_space,
/*destination_memory_space=*/dst_memory_space});
auto add_buffer = [&](const HloInstruction* src, const HloInstruction* dst,
const ShapeIndex& dst_shape_index) -> absl::Status {
const Shape& src_shape = src->shape();
const Shape& dst_shape =
ShapeUtil::GetSubshape(dst->shape(), dst_shape_index);
TF_ASSIGN_OR_RETURN(auto src_slice, GetAllocationSliceForHlo(src));
TF_ASSIGN_OR_RETURN(auto dst_slice,
GetAllocationSliceForHlo(dst, dst_shape_index));

buffers.push_back(CollectiveThunk::Buffer{
/*element_count=*/ShapeUtil::ElementsIn(src_shape),
/*source_buffer=*/src_slice,
/*destination_buffer=*/dst_slice,
/*source_memory_space=*/src_shape.layout().memory_space(),
/*destination_memory_space=*/dst_shape.layout().memory_space()});
return absl::OkStatus();
};

if (kind == Thunk::Kind::kAllGatherStart) {
Expand All @@ -1613,56 +1619,27 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCollectiveThunk(
// multiple operands).
for (int64_t i = 0; i < operand_count; i++) {
ShapeIndex idx = operand_count > 1 ? ShapeIndex({1, i}) : ShapeIndex({1});
const Shape& src_shape = inst->operand(i)->shape();
const Shape& dst_shape = ShapeUtil::GetSubshape(inst->shape(), idx);
TF_ASSIGN_OR_RETURN(auto src, GetAllocationSliceForHlo(inst->operand(i)));
TF_ASSIGN_OR_RETURN(auto dst, GetAllocationSliceForHlo(inst, idx));
add_buffer(ShapeUtil::ElementsIn(src_shape), src,
src_shape.layout().memory_space(), dst,
dst_shape.layout().memory_space());
}
} else if (kind == Thunk::Kind::kRaggedAllToAll) {
// RaggedAllToAll operation has 6 operands: input, output,
// input_offset, send_size, output_offset, recv_size. `output`
// operand is aliased with the instruction result. All other
// operands are not aliased.
const Shape& input_shape = inst->operand(0)->shape();
TF_ASSIGN_OR_RETURN(auto input_buffer,
GetAllocationSliceForHlo(inst->operand(0)));
add_buffer(ShapeUtil::ElementsIn(input_shape), input_buffer,
input_shape.layout().memory_space(), input_buffer,
input_shape.layout().memory_space());

const Shape& output_shape = inst->operand(1)->shape();
const Shape& result_shape = inst->shape();
TF_ASSIGN_OR_RETURN(auto output_buffer,
GetAllocationSliceForHlo(inst->operand(1)));
TF_ASSIGN_OR_RETURN(auto result_buffer, GetAllocationSliceForHlo(inst));

add_buffer(ShapeUtil::ElementsIn(result_shape), output_buffer,
output_shape.layout().memory_space(), result_buffer,
result_shape.layout().memory_space());
TF_RETURN_IF_ERROR(
add_buffer(inst->operand(0), inst->operand(0), ShapeIndex({})));
TF_RETURN_IF_ERROR(add_buffer(inst->operand(1), inst, ShapeIndex({})));

for (int64_t i = 2; i < operand_count; i++) {
const Shape& shape = inst->operand(i)->shape();
TF_ASSIGN_OR_RETURN(auto slice,
GetAllocationSliceForHlo(inst->operand(i)));
add_buffer(ShapeUtil::ElementsIn(shape), slice,
shape.layout().memory_space(), slice,
shape.layout().memory_space());
TF_RETURN_IF_ERROR(
add_buffer(inst->operand(i), inst->operand(i), ShapeIndex({})));
}
} else {
// For other operations simply zip operands with results.
for (int64_t i = 0; i < operand_count; i++) {
ShapeIndex idx =
inst->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({});
const Shape& src_shape = inst->operand(i)->shape();
const Shape& dst_shape = ShapeUtil::GetSubshape(inst->shape(), idx);
TF_ASSIGN_OR_RETURN(auto src, GetAllocationSliceForHlo(inst->operand(i)));
TF_ASSIGN_OR_RETURN(auto dst, GetAllocationSliceForHlo(inst, idx));
add_buffer(ShapeUtil::ElementsIn(src_shape), src,
src_shape.layout().memory_space(), dst,
dst_shape.layout().memory_space());

TF_RETURN_IF_ERROR(add_buffer(inst->operand(i), inst, idx));
}
}

Expand Down
Loading