Skip to content

Commit

Permalink
Update cubecl (#2376)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Oct 16, 2024
1 parent 3d77efc commit 353120e
Show file tree
Hide file tree
Showing 26 changed files with 339 additions and 457 deletions.
603 changes: 263 additions & 340 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ zip = "2.2.0"

# Async handling
async-channel = "2.3"
pollster = "0.3"
futures-lite = { version = "2.3.0", default-features = false }

# Terminal UI
crossterm = "0.27.0"
Expand Down Expand Up @@ -152,11 +152,11 @@ tch = "0.15.0"
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb47090f7a44952ae3e3b2b72f8c5a88d8af56fd" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb47090f7a44952ae3e3b2b72f8c5a88d8af56fd" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ed136e2385b17e36680589f8a6245926f430f59f" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ed136e2385b17e36680589f8a6245926f430f59f" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
# cubecl = { version="0.2.0", default-features = false }
# cubecl-common = { version="0.2.0", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ use burn::{
Distribution, Tensor,
},
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct AutodiffOverheadBenchmark<B: AutodiffBackend> {
config: nn::LstmConfig,
Expand Down Expand Up @@ -50,7 +47,7 @@ impl<B: AutodiffBackend> Benchmark for AutodiffOverheadBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/binary.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct BinaryBenchmark<B: Backend, const D: usize> {
shape: Shape,
Expand Down Expand Up @@ -33,7 +30,7 @@ impl<B: Backend, const D: usize> Benchmark for BinaryBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait);
B::sync(&self.device);
}
}

Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct Conv2dBenchmark<B: Backend> {
input_shape: Shape,
Expand Down Expand Up @@ -51,7 +48,7 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/conv3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv3d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct Conv3dBenchmark<B: Backend> {
input_shape: Shape,
Expand Down Expand Up @@ -51,7 +48,7 @@ impl<B: Backend> Benchmark for Conv3dBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use burn::tensor::{
backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape,
Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct ConvTranspose2dBenchmark<B: Backend> {
input_shape: Shape,
Expand Down Expand Up @@ -52,7 +49,7 @@ impl<B: Backend> Benchmark for ConvTranspose2dBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/conv_transpose3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use burn::tensor::{
backend::Backend, module::conv_transpose3d, ops::ConvTransposeOptions, Distribution, Shape,
Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct ConvTranspose3dBenchmark<B: Backend> {
input_shape: Shape,
Expand Down Expand Up @@ -52,7 +49,7 @@ impl<B: Backend> Benchmark for ConvTranspose3dBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
3 changes: 1 addition & 2 deletions backend-comparison/benches/custom_gelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use backend_comparison::persistence::save;
use burn::backend::Autodiff;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::sync_type::SyncType;
use core::f64::consts::SQRT_2;
use derive_new::new;

Expand Down Expand Up @@ -69,7 +68,7 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}

fn num_samples(&self) -> usize {
Expand Down
9 changes: 3 additions & 6 deletions backend-comparison/benches/data.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor, TensorData};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;

#[derive(new)]
Expand Down Expand Up @@ -32,7 +29,7 @@ impl<B: Backend, const D: usize> Benchmark for ToDataBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down Expand Up @@ -69,7 +66,7 @@ impl<B: Backend, const D: usize> Benchmark for FromDataBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
3 changes: 1 addition & 2 deletions backend-comparison/benches/load_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use burn::tensor::backend::Backend;
use burn::tensor::Device;
use burn::{config::Config, module::Module, nn};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::sync_type::SyncType;
use derive_new::new;

#[derive(Module, Debug)]
Expand Down Expand Up @@ -94,7 +93,7 @@ impl<B: Backend> Benchmark for LoadRecordBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;

#[derive(new)]
Expand Down Expand Up @@ -40,7 +37,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/max_pool2d.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, module::max_pool2d, Distribution, Shape, Tensor};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct MaxPool2dBenchmark<B: Backend> {
shape: Shape,
Expand Down Expand Up @@ -40,7 +37,7 @@ impl<B: Backend> Benchmark for MaxPool2dBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
3 changes: 1 addition & 2 deletions backend-comparison/benches/resnet.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use cubecl::client::SyncType;

// Files retrieved during build to avoid reimplementing ResNet for benchmarks
mod block {
Expand Down Expand Up @@ -42,7 +41,7 @@ impl<B: Backend> Benchmark for ResNetBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
7 changes: 2 additions & 5 deletions backend-comparison/benches/unary.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;

#[derive(new)]
Expand Down Expand Up @@ -33,7 +30,7 @@ impl<B: Backend, const D: usize> Benchmark for UnaryBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

Expand Down
2 changes: 2 additions & 0 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ enum BackendValues {
WgpuFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
#[strum(to_string = "cuda-jit-fusion")]
CudaJitFusion,
}

#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
Expand Down
2 changes: 2 additions & 0 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";
#[cfg(feature = "cuda-jit-fusion")]
let feature_name = "cuda-jit-fusion";

#[cfg(feature = "wgpu")]
{
Expand Down
5 changes: 2 additions & 3 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
tensor::AutodiffTensor,
AutodiffBridge,
};
use burn_common::sync_type::SyncType;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
ops::{BoolTensor, IntTensor, QuantizedTensor},
Expand Down Expand Up @@ -50,8 +49,8 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
B::seed(seed)
}

fn sync(device: &B::Device, sync_type: SyncType) {
B::sync(device, sync_type)
fn sync(device: &B::Device) {
B::sync(device)
}
}

Expand Down
35 changes: 15 additions & 20 deletions crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::marker::PhantomData;

use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps, SyncType},
backend::{Backend, DeviceId, DeviceOps},
quantization::{QTensorPrimitive, QuantizationStrategy},
Device,
};
Expand Down Expand Up @@ -187,25 +187,20 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
panic!("Manual seed not supported by Candle. ")
}

fn sync(device: &Device<Self>, sync_type: SyncType) {
match sync_type {
SyncType::Wait => {
let device: candle_core::Device = (device.clone()).into();

match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device.synchronize().unwrap();
}
candle_core::Device::Metal(device) => {
// For some reason, device.wait_until_completed() does not seem to work,
// and neither does writing and reading a value with into_data
panic!("Device synchronization unavailable with Metal device on Candle backend")
}
}
fn sync(device: &Device<Self>) {
let device: candle_core::Device = (device.clone()).into();

match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device.synchronize().unwrap();
}
SyncType::Flush => (), // Nothhing to flush.
};
candle_core::Device::Metal(device) => {
// For some reason, device.wait_until_completed() does not seem to work,
// and neither does writing and reading a value with into_data
panic!("Device synchronization unavailable with Metal device on Candle backend")
}
}
}
}
Loading

0 comments on commit 353120e

Please sign in to comment.