Skip to content

Commit 706a2f0

Browse files
authored
Support Nvidia-Cuda execution provider for wasi-nn onnx backend (#12044)
* Support Nvidia-Cuda execution provider for wasi-nn onnx backend * update about onnx runtime's fallback behavior
1 parent 91bf501 commit 706a2f0

File tree

5 files changed

+163
-9
lines changed

5 files changed

+163
-9
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/wasi-nn/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ openvino = ["dep:openvino"]
7070
onnx = ["dep:ort"]
7171
# Use prebuilt ONNX Runtime binaries from ort.
7272
onnx-download = ["onnx", "ort/download-binaries"]
73+
# CUDA execution provider for NVIDIA GPU support (requires CUDA toolkit)
74+
onnx-cuda = ["onnx", "ort/cuda"]
75+
# Enable tracing for ONNX Runtime
76+
ort-tracing = ["onnx", "ort/tracing"]
7377
# WinML is only available on Windows 10 1809 and later.
7478
winml = ["dep:windows"]
7579
# PyTorch is available on all platforms; requires Libtorch to be installed

crates/wasi-nn/examples/classification-component-onnx/README.md

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,111 @@
33
This example demonstrates how to use the `wasi-nn` crate to run a classification using the
44
[ONNX Runtime](https://onnxruntime.ai/) backend from a WebAssembly component.
55

6+
It supports CPU and GPU (Nvidia CUDA) execution targets.
7+
8+
**Note:**
9+
GPU execution target only supports Nvidia CUDA (onnx-cuda) as execution provider (EP) for now.
10+
611
## Build
12+
713
In this directory, run the following command to build the WebAssembly component:
814
```console
15+
# build component for target wasm32-wasip1
916
cargo component build
17+
18+
# build component for target wasm32-wasip2
19+
cargo component build --target wasm32-wasip2
1020
```
1121

22+
## Running the Example
23+
1224
In the Wasmtime root directory, run the following command to build the Wasmtime CLI and run the WebAssembly component:
25+
26+
### Building Wasmtime
27+
28+
#### For CPU-only execution:
1329
```sh
14-
# build wasmtime with component-model and WASI-NN with ONNX runtime support
1530
cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-download
31+
```
32+
33+
#### For GPU (Nvidia CUDA) support:
34+
```sh
35+
cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-cuda,wasmtime-wasi-nn/onnx-download
36+
```
37+
38+
### Running with Different Execution Targets
39+
40+
The execution target is controlled by passing a single argument to the WASM module.
41+
42+
Arguments:
43+
- No argument or `cpu` - Use CPU execution
44+
- `gpu` or `cuda` - Use GPU/CUDA execution
1645

17-
# run the component with wasmtime
46+
#### CPU Execution (default):
47+
```sh
1848
./target/debug/wasmtime run \
1949
-Snn \
2050
--dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \
21-
./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm
51+
./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip2/debug/classification-component-onnx.wasm
2252
```
2353

24-
You should get the following output:
54+
#### GPU (CUDA) Execution:
55+
```sh
56+
# path to `libonnxruntime_providers_cuda.so` downloaded by `ort-sys`
57+
export LD_LIBRARY_PATH={wasmtime_workspace}/target/debug
58+
59+
./target/debug/wasmtime run \
60+
-Snn \
61+
--dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \
62+
./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip2/debug/classification-component-onnx.wasm \
63+
gpu
64+
65+
```
66+
67+
## Expected Output
68+
69+
You should get output similar to:
2570
```txt
71+
No execution target specified, defaulting to CPU
2672
Read ONNX model, size in bytes: 4956208
27-
Loaded graph into wasi-nn
73+
Loaded graph into wasi-nn with Cpu target
2874
Created wasi-nn execution context.
2975
Read ONNX Labels, # of labels: 1000
30-
Set input tensor
3176
Executed graph inference
32-
Getting inferencing output
3377
Retrieved output data with length: 4000
3478
Index: n02099601 golden retriever - Probability: 0.9948673
3579
Index: n02088094 Afghan hound, Afghan - Probability: 0.002528982
3680
Index: n02102318 cocker spaniel, English cocker spaniel, cocker - Probability: 0.0010986356
3781
```
82+
83+
When using GPU target, the first line will indicate the selected execution target.
84+
You can monitor GPU usage using cmd `watch -n 1 nvidia-smi`.
85+
86+
To see trace logs from `wasmtime_wasi_nn` or `ort`, run Wasmtime with `WASMTIME_LOG` enabled, e.g.,
87+
88+
```sh
89+
WASMTIME_LOG=wasmtime_wasi_nn=warn ./target/debug/wasmtime run ...
90+
WASMTIME_LOG=ort=warn ./target/debug/wasmtime run ...
91+
```
92+
93+
## Prerequisites for GPU(CUDA) Support
94+
- NVIDIA GPU with CUDA support
95+
- CUDA Toolkit 12.x with cuDNN 9.x
96+
- Build wasmtime with `wasmtime-wasi-nn/onnx-cuda` feature
97+
98+
## ONNX Runtime's Fallback Behavior
99+
100+
If the GPU execution provider is requested (by passing `gpu`) but the device does not have a GPU or the necessary CUDA drivers are missing, ONNX Runtime will **silently fall back** to the CPU execution provider. The application will continue to run, but inference will happen on the CPU.
101+
102+
To verify if fallback is happening, you can enable ONNX Runtime logging:
103+
104+
1. Build Wasmtime with the additional `wasmtime-wasi-nn/ort-tracing` feature:
105+
```sh
106+
cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-cuda,wasmtime-wasi-nn/ort-tracing
107+
```
108+
109+
2. Run Wasmtime with `WASMTIME_LOG` enabled to see `ort` warnings:
110+
```sh
111+
WASMTIME_LOG=ort=warn ./target/debug/wasmtime run ...
112+
```
113+
You should see a warning like: `No execution providers from session options registered successfully; may fall back to CPU.`

crates/wasi-nn/examples/classification-component-onnx/src/main.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,46 @@ use self::wasi::nn::{
1717
tensor::{Tensor, TensorData, TensorDimensions, TensorType},
1818
};
1919

20+
/// Determine execution target from command-line argument
21+
/// Usage: wasm_module [cpu|gpu|cuda]
22+
fn get_execution_target() -> ExecutionTarget {
23+
let args: Vec<String> = std::env::args().collect();
24+
25+
// First argument (index 0) is the program name, second (index 1) is the target
26+
// Ignore any arguments after index 1
27+
if args.len() >= 2 {
28+
match args[1].to_lowercase().as_str() {
29+
"gpu" | "cuda" => {
30+
println!("Using GPU (CUDA) execution target from argument");
31+
return ExecutionTarget::Gpu;
32+
}
33+
"cpu" => {
34+
println!("Using CPU execution target from argument");
35+
return ExecutionTarget::Cpu;
36+
}
37+
_ => {
38+
println!("Unknown execution target '{}', defaulting to CPU", args[1]);
39+
}
40+
}
41+
} else {
42+
println!("No execution target specified, defaulting to CPU");
43+
println!("Usage: <program> [cpu|gpu|cuda]");
44+
}
45+
46+
ExecutionTarget::Cpu
47+
}
48+
2049
fn main() {
2150
// Load the ONNX model - SqueezeNet 1.1-7
2251
// Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet
2352
let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap();
2453
println!("Read ONNX model, size in bytes: {}", model.len());
2554

26-
let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap();
27-
println!("Loaded graph into wasi-nn");
55+
// Determine execution target
56+
let execution_target = get_execution_target();
57+
58+
let graph = load(&[model], GraphEncoding::Onnx, execution_target).unwrap();
59+
println!("Loaded graph into wasi-nn with {:?} target", execution_target);
2860

2961
let exec_context = Graph::init_execution_context(&graph).unwrap();
3062
println!("Created wasi-nn execution context.");

crates/wasi-nn/src/backend/onnx.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@ use crate::backend::{Id, read};
77
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
88
use crate::{ExecutionContext, Graph};
99
use ort::{
10+
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
1011
inputs,
1112
session::{Input, Output},
1213
session::{Session, SessionInputValue, builder::GraphOptimizationLevel},
1314
tensor::TensorElementType,
1415
value::{Tensor as OrtTensor, ValueType},
1516
};
17+
18+
#[cfg(feature = "onnx-cuda")]
19+
use ort::execution_providers::CUDAExecutionProvider;
20+
1621
use std::path::Path;
1722
use std::sync::{Arc, Mutex};
1823

@@ -31,7 +36,11 @@ impl BackendInner for OnnxBackend {
3136
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));
3237
}
3338

39+
// Configure execution providers based on target
40+
let execution_providers = configure_execution_providers(target)?;
41+
3442
let session = Session::builder()?
43+
.with_execution_providers(execution_providers)?
3544
.with_optimization_level(GraphOptimizationLevel::Level3)?
3645
.commit_from_memory(builders[0])?;
3746

@@ -45,6 +54,38 @@ impl BackendInner for OnnxBackend {
4554
}
4655
}
4756

57+
/// Configure execution providers based on the target
58+
fn configure_execution_providers(
59+
target: ExecutionTarget,
60+
) -> Result<Vec<ExecutionProviderDispatch>, BackendError> {
61+
match target {
62+
ExecutionTarget::Cpu => {
63+
// Use CPU execution provider with default configuration
64+
tracing::debug!("Using CPU execution provider");
65+
Ok(vec![CPUExecutionProvider::default().build()])
66+
}
67+
ExecutionTarget::Gpu => {
68+
#[cfg(feature = "onnx-cuda")]
69+
{
70+
// Use CUDA execution provider for GPU acceleration
71+
tracing::debug!("Using Nvidia GPU CUDA execution provider");
72+
Ok(vec![CUDAExecutionProvider::default().build()])
73+
}
74+
#[cfg(not(feature = "onnx-cuda"))]
75+
{
76+
tracing::warn!("GPU CUDA execution provider is not enabled, falling back to CPU");
77+
Ok(vec![CPUExecutionProvider::default().build()])
78+
}
79+
}
80+
ExecutionTarget::Tpu => {
81+
tracing::warn!(
82+
"TPU execution target is not supported for ONNX backend yet, falling back to CPU"
83+
);
84+
Ok(vec![CPUExecutionProvider::default().build()])
85+
}
86+
}
87+
}
88+
4889
impl BackendFromDir for OnnxBackend {
4990
fn load_from_dir(
5091
&mut self,

0 commit comments

Comments
 (0)