Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ cc_library(
":collective_broadcast_thunk",
":collective_permute_thunk",
":collective_thunk",
":convolution_thunk",
":copy_thunk",
":custom_call_thunk",
":dynamic_slice_thunk",
Expand Down Expand Up @@ -3095,6 +3096,7 @@ xla_test(
":command_buffer_conversion_pass",
":command_buffer_thunk",
":conditional_thunk",
":convolution_thunk",
":copy_thunk",
":cudnn_thunk",
":custom_call_thunk",
Expand Down
49 changes: 49 additions & 0 deletions xla/backends/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,55 @@ CommandBufferCmd::BufferUseVector CublasLtCmd::buffers() const {
return buffer_usage;
}

//===----------------------------------------------------------------------===//
// ConvolutionCmd
//===----------------------------------------------------------------------===//

ConvolutionCmd::ConvolutionCmd(
const ConvolutionThunk& thunk)
: TracedCommandBufferCmd(CommandBufferCmdType::kConvolutionCmd),
operand_buffers_(thunk.operand_buffers_),
result_buffers_(thunk.result_buffers_),
scratch_buffer_(thunk.scratch_buffer_),
config_(thunk.config_) {}

absl::Status ConvolutionCmd::Initialize(const Thunk::InitializeParams& params,
StateManager& state) {
// populate cache of ConvRunner
cache_.GetOrCreate(config_, params.stream);
return absl::OkStatus();
}

absl::StatusOr<const se::CommandBuffer::Command*> ConvolutionCmd::Record(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) {

VLOG(5) << "ConvolutionCmd";

return RecordTracedCommand(
execute_params, record_params, std::move(record_action), command_buffer,
[&](se::Stream* stream) {
return RunConvolutionOnStream(execute_params, operand_buffers_,
result_buffers_, scratch_buffer_, config_, cache_, stream);
});
}

CommandBufferCmd::BufferUseVector ConvolutionCmd::buffers() const {

BufferUseVector buffer_usage;
buffer_usage.reserve(operand_buffers_.size() + result_buffers_.size() + 1);

for (BufferAllocation::Slice buffer : operand_buffers_) {
buffer_usage.push_back({buffer, MemoryAccess::kRead});
}
for (BufferAllocation::Slice buffer : result_buffers_) {
buffer_usage.push_back({buffer, MemoryAccess::kWrite});
}
buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite});
return buffer_usage;
}

//===----------------------------------------------------------------------===//
// CuDnnCmd
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 30 additions & 0 deletions xla/backends/gpu/runtime/command_buffer_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/runtime/collective_permute_thunk.h"
#include "xla/backends/gpu/runtime/convolution_thunk.h"
#include "xla/backends/gpu/runtime/collective_thunk.h"
#include "xla/backends/gpu/runtime/copy_thunk.h"
#include "xla/backends/gpu/runtime/custom_call_thunk.h"
Expand Down Expand Up @@ -84,6 +85,7 @@ namespace xla::gpu {
V(kLaunchCmd, "LaunchCmd") \
V(kCustomKernelLaunchCmd, "CustomKernelLaunchCmd") \
V(kCublasLtCmd, "CublasLtCmd") \
V(kConvolutionCmd, "ConvolutionCmd") \
V(kCuDnnCmd, "CuDnnCmd") \
V(kGemmCmd, "GemmCmd") \
V(kMemcpyDeviceToDeviceCmd, "MemcpyDeviceToDeviceCmd") \
Expand Down Expand Up @@ -958,6 +960,34 @@ class CublasLtCmd : public TracedCommandBufferCmd, public CublasLtMatmulThunk {
bool IsNestedCommandBuffer() const final { return true; }
};

//===----------------------------------------------------------------------===//
// ConvolutionCmd
//===----------------------------------------------------------------------===//

class ConvolutionCmd : public TracedCommandBufferCmd {
public:
ConvolutionCmd(const ConvolutionThunk& conv_thunk);

absl::Status Initialize(const Thunk::InitializeParams& params,
StateManager& state) override;

absl::StatusOr<const se::CommandBuffer::Command*> Record(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) override;

BufferUseVector buffers() const override;

bool IsNestedCommandBuffer() const final { return true; }

private:
std::vector<BufferAllocation::Slice> operand_buffers_;
std::vector<BufferAllocation::Slice> result_buffers_;
BufferAllocation::Slice scratch_buffer_;
GpuConvConfig config_;
ConvRunnerCache cache_;
};

//===----------------------------------------------------------------------===//
// CuDnnCmd
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ static absl::StatusOr<Command> Convert(const CuDnnThunk& thunk) {
return std::make_unique<CuDnnCmd>(thunk.arguments(), thunk.graph());
}

static absl::StatusOr<Command> Convert(const ConvolutionThunk& thunk) {
return std::make_unique<ConvolutionCmd>(thunk);
}

//===----------------------------------------------------------------------===//
static absl::StatusOr<Command> CopyMetadata(absl::StatusOr<Command> cmd,
const Thunk& thunk) {
Expand Down Expand Up @@ -315,6 +319,8 @@ static absl::Status AppendCommands(CommandBufferCmdSequence& cmd_sequence,
return append(Convert<WhileThunk>(thunk, options));
case Thunk::Kind::kCuDnn:
return append(Convert<CuDnnThunk>(thunk));
case Thunk::Kind::kConvolution:
return append(Convert<ConvolutionThunk>(thunk));
case Thunk::Kind::kDynamicSlice:
return append(Convert<DynamicSliceThunk>(thunk, options));

Expand Down
5 changes: 4 additions & 1 deletion xla/backends/gpu/runtime/command_buffer_conversion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ CommandBufferConfig GetCommandBufferConfig(
DebugOptions::WHILE};
static constexpr auto kRequireTracing = {
DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN,
DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES};
DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES,
DebugOptions::CONVOLUTION};

auto erase = [&](absl::Span<const DebugOptions::CommandBufferCmdType> cmds) {
for (auto cmd : cmds) {
Expand Down Expand Up @@ -150,6 +151,8 @@ std::optional<DebugOptions::CommandBufferCmdType> GetCommandBufferCmdType(
return DebugOptions::COLLECTIVES;
case Thunk::kCuDnn:
return DebugOptions::CUDNN;
case Thunk::kConvolution:
return DebugOptions::CONVOLUTION;
case Thunk::kCustomCall:
return DebugOptions::CUSTOM_CALL;
case Thunk::kCublasLtMatmul:
Expand Down
96 changes: 96 additions & 0 deletions xla/backends/gpu/runtime/command_buffer_conversion_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/collective_thunk.h"
#include "xla/backends/gpu/runtime/command_buffer_thunk.h"
#include "xla/backends/gpu/runtime/conditional_thunk.h"
#include "xla/backends/gpu/runtime/convolution_thunk.h"
#include "xla/backends/gpu/runtime/copy_thunk.h"
#include "xla/backends/gpu/runtime/cudnn_thunk.h"
#include "xla/backends/gpu/runtime/custom_call_thunk.h"
Expand Down Expand Up @@ -152,6 +153,59 @@ std::unique_ptr<GemmThunk> CreateGemmThunk(const BufferAllocation& alloc1) {
slice1, slice1, slice1, true);
}

std::unique_ptr<ConvolutionThunk> CreateConvolutionThunk(
const BufferAllocation& alloc) {
se::StreamExecutor* executor = GpuExecutor();
std::vector< BufferAllocation::Slice > operand_slices, result_slices;
for (int i = 0, num = 3; i < num; i++) {
operand_slices.emplace_back(&alloc, i*16, 16);
result_slices.emplace_back(&alloc, (i + num)*16, 16);
}

ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.add_input_spatial_dimensions(2);
dnums.add_input_spatial_dimensions(3);
dnums.set_kernel_input_feature_dimension(0);
dnums.set_kernel_output_feature_dimension(1);
dnums.add_kernel_spatial_dimensions(2);
dnums.add_kernel_spatial_dimensions(3);
dnums.set_output_batch_dimension(0);
dnums.set_output_feature_dimension(1);
dnums.add_output_spatial_dimensions(2);
dnums.add_output_spatial_dimensions(3);

Window window;
const auto dim0 = window.add_dimensions();
const auto dim1 = window.add_dimensions();
dim0->set_size(4);
dim1->set_size(4);
dim0->set_base_dilation(1);
dim1->set_base_dilation(1);
dim0->set_stride(1);
dim1->set_stride(1);
dim0->set_window_dilation(3);
dim1->set_window_dilation(2);

GpuConvDescriptor desc {
.kind = CudnnConvKind::kForward,
.backend_config = CudnnConvBackendConfig{},
.operand0_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13}),
.operand1_shape = ShapeUtil::MakeShapeWithDenseLayout(
F32, {38, 10, 4, 4}, {3,2,0,1}),
.result_shape = ShapeUtil::MakeShapeWithType< float >({64,64,64,13}),
.scratch_size = 128*1024,
.window = window,
.dnums = dnums,
.feature_group_count = 1
};
auto thunk = ConvolutionThunk::Create(Thunk::ThunkInfo(), desc,
operand_slices, result_slices, result_slices.back());
TF_CHECK_OK(thunk.status());
return std::move(thunk).value();
}

std::unique_ptr<CollectiveDoneThunk> CreateAllGatherDoneThunk(
Thunk* start_thunk) {
auto async_events =
Expand Down Expand Up @@ -315,6 +369,48 @@ TEST(CommandBufferConversionPassTest, PartiallyConvertsToCommandBufferThunk) {
EXPECT_THAT(thunks_in_command_buffer1, ThunkKindsAre(Thunk::kCopy));
}

TEST(CommandBufferConversionPassTest, ConvertConvolutionAndGemmThunks) {
CommandBufferConversionPass pass{"test"};

std::vector<std::unique_ptr<Thunk>> thunks;

// Create a {CopyThunk, GemmThunk, ConvolutionThunk}
BufferAllocation alloc0(0, 1024, 0);
BufferAllocation alloc1(1, 2048, 0);
BufferAllocation alloc2(2, 2048, 0);
thunks.push_back(CreateCopyThunk(alloc0));
thunks.push_back(CreateGemmThunk(alloc1));
thunks.push_back(CreateConvolutionThunk(alloc0));

auto root_thunk =
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
DebugOptions debug_options;

// Enable only FUSION, which means GemmThunk should not be converted.
debug_options.clear_xla_gpu_enable_command_buffer();
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONVOLUTION);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS);

se::DeviceDescription device_info = TestGpuDeviceInfo::CudaOrRocmDeviceInfo();
FakeErrorAllocator allocator;

ASSERT_EQ(root_thunk->thunks().size(), 3);

ASSERT_THAT(pass.Run(root_thunk.get(), debug_options, /*hlo_module=*/nullptr,
device_info, allocator),
IsOkAndHolds(true));

ASSERT_EQ(root_thunk->thunks().size(), 1);

const auto* command_buffer_thunk =
static_cast<const CommandBufferThunk*>(root_thunk->thunks()[0].get());
const auto& thunks_in_command_buffer =
command_buffer_thunk->thunks()->thunks();
EXPECT_THAT(thunks_in_command_buffer,
ThunkKindsAre(Thunk::kCopy, Thunk::kGemm, Thunk::kConvolution));
}

TEST(CommandBufferConversionPassTest, ConvertsAsyncPairToCommandBuffer) {
std::vector<std::unique_ptr<Thunk>> thunks;
// Create a start thunk
Expand Down
70 changes: 38 additions & 32 deletions xla/backends/gpu/runtime/convolution_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,76 +69,82 @@ ConvolutionThunk::ConvolutionThunk(
descriptor_(std::move(descriptor)),
config_(std::move(config)) {}

GenericConvRunner& ConvolutionThunk::GetOrCreateRunner(
const stream_executor::Stream* stream, bool* runner_created) {
std::pair< RunConvOptions, bool > ConvRunnerCache::GetOrCreate(
const GpuConvConfig& config, const se::Stream* stream) {
absl::MutexLock lock(mu_);
auto it = runner_cache_.find(stream);
*runner_created = (it == runner_cache_.end());
if (*runner_created) {
it = runner_cache_
.insert({stream, std::make_unique<GenericConvRunner>(config_)})
.first;
auto [it, inserted] = cache_.emplace(stream->parent(),
std::unique_ptr<GenericConvRunner>{});
if (inserted) {
it->second = std::make_unique<GenericConvRunner>(config);
}
return *it->second;
return std::pair{ RunConvOptions{ nullptr, it->second.get() }, inserted };
}

absl::Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
absl::Status RunConvolutionOnStream(const Thunk::ExecuteParams& params,
const std::vector<BufferAllocation::Slice>& operand_buffers,
const std::vector<BufferAllocation::Slice>& result_buffers,
const BufferAllocation::Slice& scratch_buffer,
const GpuConvConfig& config, ConvRunnerCache& cache, se::Stream* stream) {

const auto& buffer_allocations = *params.buffer_allocations;

std::vector<se::DeviceMemoryBase> operand_se_buffers, result_se_buffers;
operand_se_buffers.reserve(operand_buffers_.size());
for (BufferAllocation::Slice buffer : operand_buffers_) {
operand_se_buffers.reserve(operand_buffers.size());

for (BufferAllocation::Slice buffer : operand_buffers) {
operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
VLOG(5) << "operand buffer: " << buffer.ToString()
<< " addr: " << operand_se_buffers.back().opaque();
}

result_se_buffers.reserve(result_buffers_.size());
for (BufferAllocation::Slice buffer : result_buffers_) {
result_se_buffers.reserve(result_buffers.size());
for (BufferAllocation::Slice buffer : result_buffers) {
result_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
VLOG(5) << "result buffer: " << buffer.ToString()
<< " addr: " << result_se_buffers.back().opaque();
}

se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
buffer_allocations.GetDeviceAddress(scratch_buffer);
VLOG(5) << "scratch buffer: " << scratch_buffer
<< " addr: " << scratch.opaque();

bool runner_created = false;
RunConvOptions opts;
opts.runner_cache = &GetOrCreateRunner(params.stream, &runner_created);

if (runner_created && params.stream->parent()
auto [opts, runner_created] = cache.GetOrCreate(config, stream);
if (runner_created && stream->parent()
->GetDeviceDescription()
.gpu_compute_capability()
.IsRocm()) {
TF_ASSIGN_OR_RETURN(
GpuConvParams conv_params,
GetGpuConvParams(config_, operand_se_buffers, result_se_buffers));
GetGpuConvParams(config, operand_se_buffers, result_se_buffers));

TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type,
GetDNNDataTypeFromPrimitiveType(config_.input_type));
GetDNNDataTypeFromPrimitiveType(config.input_type));

TF_ASSIGN_OR_RETURN(se::dnn::DataType output_type,
GetDNNDataTypeFromPrimitiveType(config_.output_type));
GetDNNDataTypeFromPrimitiveType(config.output_type));

TF_ASSIGN_OR_RETURN(auto dnn,
se::dnn::internal::GetDnnFromStream(params.stream));
se::dnn::internal::GetDnnFromStream(stream));
se::OwningScratchAllocator<> scratch_allocator(
buffer_allocations.device_ordinal(),
buffer_allocations.memory_allocator());

std::vector<se::dnn::ProfileResult> profile_results;
dnn->GetMIOpenConvolveAlgorithms(
CudnnConvKindToProto(config_.kind), input_type, output_type,
params.stream, config_.input_descriptor, conv_params.input_buf,
config_.filter_descriptor, conv_params.filter_buf,
config_.output_descriptor, conv_params.output_buf, config_.conv_desc,
CudnnConvKindToProto(config.kind), input_type, output_type,
stream, config.input_descriptor, conv_params.input_buf,
config.filter_descriptor, conv_params.filter_buf,
config.output_descriptor, conv_params.output_buf, config.conv_desc,
&scratch_allocator, &profile_results);
}

TF_RETURN_IF_ERROR(RunGpuConv(config_, absl::MakeSpan(operand_se_buffers),
TF_RETURN_IF_ERROR(RunGpuConv(config, absl::MakeSpan(operand_se_buffers),
absl::MakeSpan(result_se_buffers), scratch,
params.stream, opts));
stream, opts));

// Note: Convolution has a tuple buffer as an output, but we don't need to
// populate it as no one should be reading from the tuple directly.
if (!params.stream->ok()) {
if (!stream->ok()) {
return Internal("ConvolutionThunk::ExecuteOnStream failed.");
}
return absl::OkStatus();
Expand Down
Loading