Skip to content

Commit cda0d14

Browse files
authored
[webgpu][dawn API optimization] workgroup dispatch (microsoft#24329)
### Description This PR is one of a series of changes for optimization of Dawn API usage. See microsoft#24281 Optimize the code for workgroup dispatch in the `WebGpuContext` class. The updated code prefers using the C-API instead of the C++ API for WebGPU. This is because the C++ API uses class `wgpu::Buffer`, which causes significant amount of calls to `wgpuBufferAddRef` and `wgpuBufferRelease` to ensure the lifecycle of the buffer is managed correctly. For this specific use case in ONNX Runtime (launch a compute shader program), using the C-API is more efficient.
1 parent 4dc0e35 commit cda0d14

File tree

2 files changed

+40
-18
lines changed

2 files changed

+40
-18
lines changed

onnxruntime/core/providers/webgpu/webgpu_context.cc

+35-18
Original file line numberDiff line numberDiff line change
@@ -409,31 +409,19 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
409409

410410
WriteTimestamp(num_pending_dispatches_ * 2);
411411

412-
uint32_t entry_index = 0;
413-
std::vector<wgpu::BindGroupEntry> bind_group_entries;
412+
std::vector<WGPUBuffer> bind_buffers;
413+
bind_buffers.reserve(inputs.size() + outputs.size() + (uniform_buffer ? 1 : 0));
414414
for (const auto& input : inputs) {
415-
bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast<WGPUBuffer>(const_cast<void*>(input.tensor->DataRaw()))});
415+
bind_buffers.push_back(reinterpret_cast<WGPUBuffer>(const_cast<void*>(input.tensor->DataRaw())));
416416
}
417417
for (const auto& output : outputs) {
418-
bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast<WGPUBuffer>(output.tensor->MutableDataRaw())});
418+
bind_buffers.push_back(reinterpret_cast<WGPUBuffer>(output.tensor->MutableDataRaw()));
419419
}
420420
if (uniform_buffer) {
421-
bind_group_entries.push_back({nullptr, entry_index++, uniform_buffer});
421+
bind_buffers.push_back(uniform_buffer);
422422
}
423423

424-
wgpu::BindGroupDescriptor bind_group_desc{};
425-
bind_group_desc.layout = program_artifact->compute_pipeline.GetBindGroupLayout(0);
426-
bind_group_desc.entryCount = bind_group_entries.size();
427-
bind_group_desc.entries = bind_group_entries.data();
428-
bind_group_desc.label = program_artifact->name.c_str();
429-
430-
auto bind_group = Device().CreateBindGroup(&bind_group_desc);
431-
432-
// TODO support graph capture
433-
434-
compute_pass_encoder.SetPipeline(program_artifact->compute_pipeline);
435-
compute_pass_encoder.SetBindGroup(0, bind_group);
436-
compute_pass_encoder.DispatchWorkgroups(x, y, z);
424+
LaunchComputePipeline(compute_pass_encoder, bind_buffers, *program_artifact, x, y, z);
437425

438426
if (uniform_buffer) {
439427
buffer_mgr_->Release(uniform_buffer);
@@ -708,6 +696,35 @@ void WebGpuContext::OnRunEnd() {
708696
#endif // ENABLE_PIX_FOR_WEBGPU_EP
709697
}
710698

699+
void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder,
700+
const std::vector<WGPUBuffer>& bind_buffers,
701+
const ProgramArtifact& program_artifact,
702+
uint32_t x, uint32_t y, uint32_t z) {
703+
uint32_t entry_index = 0;
704+
std::vector<WGPUBindGroupEntry> bind_group_entries;
705+
for (WGPUBuffer buffer : bind_buffers) {
706+
bind_group_entries.push_back({nullptr, entry_index++, buffer, 0, WGPU_WHOLE_SIZE, nullptr, nullptr});
707+
}
708+
709+
WGPUBindGroupLayout bind_group_layout = program_artifact.compute_pipeline.GetBindGroupLayout(0).MoveToCHandle();
710+
WGPUBindGroupDescriptor bind_group_desc{};
711+
bind_group_desc.layout = bind_group_layout;
712+
bind_group_desc.entryCount = bind_group_entries.size();
713+
bind_group_desc.entries = bind_group_entries.data();
714+
bind_group_desc.label = {program_artifact.name.data(), program_artifact.name.length()};
715+
716+
auto bind_group = wgpuDeviceCreateBindGroup(Device().Get(), &bind_group_desc);
717+
718+
// TODO support graph capture
719+
720+
compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline);
721+
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr);
722+
compute_pass_encoder.DispatchWorkgroups(x, y, z);
723+
724+
wgpuBindGroupRelease(bind_group);
725+
wgpuBindGroupLayoutRelease(bind_group_layout);
726+
}
727+
711728
std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
712729
std::mutex WebGpuContextFactory::mutex_;
713730
std::once_flag WebGpuContextFactory::init_default_flag_;

onnxruntime/core/providers/webgpu/webgpu_context.h

+5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ class WebGpuContext final {
156156
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None}, preserve_device_{preserve_device} {}
157157
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);
158158

159+
void LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder,
160+
const std::vector<WGPUBuffer>& bind_buffers,
161+
const ProgramArtifact& program_artifact,
162+
uint32_t x, uint32_t y, uint32_t z);
163+
159164
std::vector<const char*> GetEnabledAdapterToggles() const;
160165
std::vector<const char*> GetEnabledDeviceToggles() const;
161166
std::vector<const char*> GetDisabledDeviceToggles() const;

0 commit comments

Comments
 (0)