Skip to content

Commit e68b9ab

Browse files
louisfdlouisfd
andauthored
Refactor/jit cube/mask (#2075)
Co-authored-by: louisfd <[email protected]>
1 parent 47d4139 commit e68b9ab

File tree

16 files changed

+280
-491
lines changed

16 files changed

+280
-491
lines changed

Cargo.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ rstest = "0.19.0"
6767
rusqlite = { version = "0.31.0" }
6868
rust-format = { version = "0.3.4" }
6969
sanitize-filename = "0.5.0"
70-
serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std
70+
serde_bytes = { version = "0.11.15", default-features = false, features = [
71+
"alloc",
72+
] } # alloc for no_std
7173
serde_rusqlite = "0.35.0"
7274
serial_test = "3.1.1"
7375
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
@@ -148,5 +150,5 @@ cubecl-common = { tag = "v0.1.1", git = "https://github.com/tracel-ai/cubecl", d
148150
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
149151

150152
[profile.dev]
151-
debug = 0 # Speed up compilation time and not necessary.
153+
debug = 0 # Speed up compilation time and not necessary.
152154
opt-level = 2

crates/burn-jit/src/fusion/elemwise/kernel.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
3030
outputs: &[&TensorDescription],
3131
stateful: bool,
3232
) -> FusionKernel<R> {
33-
let cube_dim_x = self.cube_dim.x;
34-
let cube_dim_y = self.cube_dim.y;
35-
36-
assert_eq!(cube_dim_x, cube_dim_y, "The grid must be a square");
37-
let cube_dim = cube_dim_x as usize;
33+
assert_eq!(
34+
self.cube_dim.x, self.cube_dim.y,
35+
"The grid must be a square"
36+
);
3837

3938
let vectorize_4 = can_vectorize(handles_inputs, inputs, outputs, 4);
4039
let vectorize_2 = can_vectorize(handles_inputs, inputs, outputs, 2);
@@ -69,7 +68,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
6968

7069
let reference_tensor = inputs[settings.mappings[0].pos_input];
7170
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
72-
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim);
71+
let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
7372
let output_infos =
7473
inplace_output2input
7574
.iter()
@@ -96,7 +95,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
9695
false => {
9796
let reference_tensor = outputs[0];
9897
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
99-
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim);
98+
let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
10099
let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| {
101100
let size = calculate_num_elems_dyn_rank(&tensor.shape)
102101
* self.info.outputs[pos].elem_size::<R>();

crates/burn-jit/src/kernel/cast/base.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
22
use cubecl::linalg::tensor::index_offset_with_layout;
3-
use cubecl::{
4-
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, SUBCUBE_DIM_APPROX,
5-
};
3+
use cubecl::{calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor};
64
use cubecl::{ir::KernelDefinition, KernelSettings};
75
use std::any::TypeId;
86

@@ -46,10 +44,10 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
4644
tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1);
4745

4846
let num_elems: usize = input.shape.num_elements();
49-
let cube_count = calculate_cube_count_elemwise(
50-
num_elems / vectorization_factor as usize,
51-
SUBCUBE_DIM_APPROX,
52-
);
47+
48+
let cube_dim = CubeDim::default();
49+
let cube_count =
50+
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
5351
let client = input.client.clone();
5452
let handle = client.empty(num_elems * core::mem::size_of::<EO>());
5553
let output =
@@ -58,7 +56,7 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
5856
cast_element::launch::<EI::Primitive, EO::Primitive, R>(
5957
&client,
6058
cube_count,
61-
CubeDim::default(),
59+
cube_dim,
6260
TensorArg::vectorized(
6361
vectorization_factor,
6462
&input.handle,
Lines changed: 21 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
2-
use cubecl::{
3-
cpa,
4-
frontend::TensorHandleRef,
5-
ir::{Elem, Item, KernelDefinition, Scope, Variable, Visibility},
6-
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
7-
OutputInfo,
8-
};
9-
use std::marker::PhantomData;
1+
use crate::{tensor::JitTensor, JitElement, JitRuntime};
2+
use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim};
3+
4+
#[cube(launch)]
5+
fn bool_cast_kernel<T: Numeric>(input: &Tensor<UInt>, output: &mut Tensor<T>) {
6+
if input[ABSOLUTE_POS] >= UInt::new(1) {
7+
output[ABSOLUTE_POS] = T::from_int(1);
8+
} else {
9+
output[ABSOLUTE_POS] = T::from_int(0);
10+
}
11+
}
1012

1113
/// Cast a bool tensor to the given element type.
1214
///
@@ -17,7 +19,6 @@ use std::marker::PhantomData;
1719
pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
1820
tensor: JitTensor<R, u32, D>,
1921
) -> JitTensor<R, EO, D> {
20-
let kernel = BoolCastEagerKernel::<R, EO>::new();
2122
let num_elems = tensor.shape.num_elements();
2223
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
2324
let output = JitTensor::new_contiguous(
@@ -27,86 +28,16 @@ pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
2728
buffer,
2829
);
2930

30-
Execution::start(kernel, tensor.client)
31-
.inputs(&[TensorHandleRef::<R>::new(
32-
&tensor.handle,
33-
&tensor.strides,
34-
&tensor.shape.dims,
35-
)])
36-
.outputs(&[TensorHandleRef::new(
37-
&output.handle,
38-
&output.strides,
39-
&output.shape.dims,
40-
)])
41-
.execute(CubeCountSettings::Output { pos: 0 });
42-
43-
output
44-
}
45-
46-
pub(crate) struct BoolCastShader {
47-
tensor: Variable,
48-
output: Variable,
49-
}
50-
51-
#[derive(new)]
52-
pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
53-
_runtime: PhantomData<R>,
54-
_elem_out: PhantomData<EO>,
55-
}
56-
57-
impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
58-
fn define(&self) -> KernelDefinition {
59-
let mut scope = Scope::root();
60-
let item_input = Item::new(Elem::Bool);
61-
let item_output = EO::cube_elem().into();
62-
63-
let tensor = Variable::GlobalInputArray {
64-
id: 0,
65-
item: item_input,
66-
};
67-
let output = Variable::GlobalOutputArray {
68-
id: 0,
69-
item: item_output,
70-
};
71-
72-
BoolCastShader { tensor, output }.expand(&mut scope);
31+
let cube_dim = CubeDim::default();
32+
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
7333

74-
scope.write_global_custom(output);
75-
76-
let tensor = InputInfo::Array {
77-
item: item_input,
78-
visibility: Visibility::Read,
79-
};
80-
81-
let out = OutputInfo::Array { item: item_output };
82-
83-
let info = KernelExpansion {
84-
inputs: vec![tensor],
85-
outputs: vec![out],
86-
scope,
87-
};
88-
89-
let settings = KernelSettings::default();
90-
KernelIntegrator::new(info).integrate(settings)
91-
}
92-
93-
fn id(&self) -> cubecl::KernelId {
94-
cubecl::KernelId::new::<Self>()
95-
}
96-
}
97-
98-
impl BoolCastShader {
99-
pub(crate) fn expand(self, scope: &mut Scope) {
100-
let tensor = self.tensor;
101-
let id = Variable::AbsolutePos;
102-
let output = self.output;
34+
bool_cast_kernel::launch::<EO::Primitive, R>(
35+
&tensor.client,
36+
cube_count,
37+
cube_dim,
38+
TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
39+
TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
40+
);
10341

104-
let represents_true = scope.create_local(Elem::Bool);
105-
cpa!(scope, represents_true = tensor[id]);
106-
cpa!(scope, if(represents_true).then(|scope|{
107-
cpa!(scope, output[id] = 1);
108-
}).else(|scope|{
109-
cpa!(scope, output[id] = 0);
110-
}));
111-
}
42+
output
11243
}

crates/burn-jit/src/kernel/comparison.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
33
use burn_tensor::Shape;
44
use cubecl::{
55
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
6-
tensor_vectorization_factor, Runtime, SUBCUBE_DIM_APPROX,
6+
tensor_vectorization_factor, Runtime,
77
};
88

99
#[cube]
@@ -139,17 +139,17 @@ pub(crate) fn launch_cmp<
139139
let shape_out = Shape::new(shape_out);
140140
let client = lhs.client.clone();
141141
let num_elems = shape_out.num_elements();
142-
let cube_count = calculate_cube_count_elemwise(
143-
num_elems / vectorization_factor as usize,
144-
SUBCUBE_DIM_APPROX,
145-
);
142+
143+
let cube_dim = CubeDim::default();
144+
let cube_count =
145+
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
146146

147147
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
148148
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
149149
kernel_cmp::launch::<E::Primitive, O, R>(
150150
&client,
151151
cube_count,
152-
CubeDim::default(),
152+
cube_dim,
153153
TensorArg::vectorized(
154154
vectorization_factor,
155155
&lhs.handle,
@@ -244,17 +244,17 @@ pub(crate) fn launch_scalar_cmp<
244244
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1);
245245
let client = tensor.client.clone();
246246
let num_elems = tensor.shape.num_elements();
247-
let cube_count = calculate_cube_count_elemwise(
248-
num_elems / vectorization_factor as usize,
249-
SUBCUBE_DIM_APPROX,
250-
);
247+
248+
let cube_dim = CubeDim::default();
249+
let cube_count =
250+
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
251251

252252
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
253253
if same_tensor_type && tensor.can_mut() {
254254
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
255255
&client,
256256
cube_count,
257-
CubeDim::default(),
257+
cube_dim,
258258
TensorArg::vectorized(
259259
vectorization_factor,
260260
&tensor.handle,

crates/burn-jit/src/kernel/conv/conv2d.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
1+
use cubecl::{calculate_cube_count_elemwise, prelude::*};
22

33
use burn_tensor::{
44
ops::{conv::calculate_conv_output_size, ConvOptions},
@@ -161,12 +161,13 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
161161
};
162162

163163
let num_elems_output = output.shape.num_elements();
164-
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
164+
let cube_dim = CubeDim::default();
165+
let cube_count = calculate_cube_count_elemwise(num_elems_output, cube_dim);
165166

166167
conv2d_kernel::launch::<E::FloatPrimitive, R>(
167168
&input.client,
169+
cube_count,
168170
cube_dim,
169-
CubeDim::default(),
170171
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
171172
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
172173
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),

crates/burn-jit/src/kernel/conv/conv3d.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
1+
use cubecl::{calculate_cube_count_elemwise, prelude::*};
22

33
use burn_tensor::{
44
ops::{conv::calculate_conv_output_size, ConvOptions},
@@ -188,10 +188,13 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
188188
}
189189
};
190190

191+
let cube_dim = CubeDim::default();
192+
let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
193+
191194
conv3d_kernel::launch::<E::FloatPrimitive, R>(
192195
&input.client,
193-
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX),
194-
CubeDim::default(),
196+
cube_count,
197+
cube_dim,
195198
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
196199
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
197200
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),

0 commit comments

Comments
 (0)