@@ -409,31 +409,19 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
409
409
410
410
WriteTimestamp (num_pending_dispatches_ * 2 );
411
411
412
- uint32_t entry_index = 0 ;
413
- std::vector<wgpu::BindGroupEntry> bind_group_entries ;
412
+ std::vector<WGPUBuffer> bind_buffers ;
413
+ bind_buffers. reserve (inputs. size () + outputs. size () + (uniform_buffer ? 1 : 0 )) ;
414
414
for (const auto & input : inputs) {
415
- bind_group_entries .push_back ({ nullptr , entry_index++, reinterpret_cast <WGPUBuffer>(const_cast <void *>(input.tensor ->DataRaw ()))} );
415
+ bind_buffers .push_back (reinterpret_cast <WGPUBuffer>(const_cast <void *>(input.tensor ->DataRaw ())));
416
416
}
417
417
for (const auto & output : outputs) {
418
- bind_group_entries .push_back ({ nullptr , entry_index++, reinterpret_cast <WGPUBuffer>(output.tensor ->MutableDataRaw ())} );
418
+ bind_buffers .push_back (reinterpret_cast <WGPUBuffer>(output.tensor ->MutableDataRaw ()));
419
419
}
420
420
if (uniform_buffer) {
421
- bind_group_entries .push_back ({ nullptr , entry_index++, uniform_buffer} );
421
+ bind_buffers .push_back (uniform_buffer);
422
422
}
423
423
424
- wgpu::BindGroupDescriptor bind_group_desc{};
425
- bind_group_desc.layout = program_artifact->compute_pipeline .GetBindGroupLayout (0 );
426
- bind_group_desc.entryCount = bind_group_entries.size ();
427
- bind_group_desc.entries = bind_group_entries.data ();
428
- bind_group_desc.label = program_artifact->name .c_str ();
429
-
430
- auto bind_group = Device ().CreateBindGroup (&bind_group_desc);
431
-
432
- // TODO support graph capture
433
-
434
- compute_pass_encoder.SetPipeline (program_artifact->compute_pipeline );
435
- compute_pass_encoder.SetBindGroup (0 , bind_group);
436
- compute_pass_encoder.DispatchWorkgroups (x, y, z);
424
+ LaunchComputePipeline (compute_pass_encoder, bind_buffers, *program_artifact, x, y, z);
437
425
438
426
if (uniform_buffer) {
439
427
buffer_mgr_->Release (uniform_buffer);
@@ -708,6 +696,35 @@ void WebGpuContext::OnRunEnd() {
708
696
#endif // ENABLE_PIX_FOR_WEBGPU_EP
709
697
}
710
698
699
+ void WebGpuContext::LaunchComputePipeline (const wgpu::ComputePassEncoder& compute_pass_encoder,
700
+ const std::vector<WGPUBuffer>& bind_buffers,
701
+ const ProgramArtifact& program_artifact,
702
+ uint32_t x, uint32_t y, uint32_t z) {
703
+ uint32_t entry_index = 0 ;
704
+ std::vector<WGPUBindGroupEntry> bind_group_entries;
705
+ for (WGPUBuffer buffer : bind_buffers) {
706
+ bind_group_entries.push_back ({nullptr , entry_index++, buffer, 0 , WGPU_WHOLE_SIZE, nullptr , nullptr });
707
+ }
708
+
709
+ WGPUBindGroupLayout bind_group_layout = program_artifact.compute_pipeline .GetBindGroupLayout (0 ).MoveToCHandle ();
710
+ WGPUBindGroupDescriptor bind_group_desc{};
711
+ bind_group_desc.layout = bind_group_layout;
712
+ bind_group_desc.entryCount = bind_group_entries.size ();
713
+ bind_group_desc.entries = bind_group_entries.data ();
714
+ bind_group_desc.label = {program_artifact.name .data (), program_artifact.name .length ()};
715
+
716
+ auto bind_group = wgpuDeviceCreateBindGroup (Device ().Get (), &bind_group_desc);
717
+
718
+ // TODO support graph capture
719
+
720
+ compute_pass_encoder.SetPipeline (program_artifact.compute_pipeline );
721
+ wgpuComputePassEncoderSetBindGroup (compute_pass_encoder.Get (), 0 , bind_group, 0 , nullptr );
722
+ compute_pass_encoder.DispatchWorkgroups (x, y, z);
723
+
724
+ wgpuBindGroupRelease (bind_group);
725
+ wgpuBindGroupLayoutRelease (bind_group_layout);
726
+ }
727
+
711
728
std::unordered_map<int32_t , WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
712
729
std::mutex WebGpuContextFactory::mutex_;
713
730
std::once_flag WebGpuContextFactory::init_default_flag_;
0 commit comments