Skip to content

Commit

Permalink
Refactor/jit cube/mask (#2075)
Browse files Browse the repository at this point in the history
Co-authored-by: louisfd <[email protected]>
  • Loading branch information
louisfd and louisfd authored Jul 30, 2024
1 parent 47d4139 commit e68b9ab
Show file tree
Hide file tree
Showing 16 changed files with 280 additions and 491 deletions.
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ rstest = "0.19.0"
rusqlite = { version = "0.31.0" }
rust-format = { version = "0.3.4" }
sanitize-filename = "0.5.0"
serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std
serde_bytes = { version = "0.11.15", default-features = false, features = [
"alloc",
] } # alloc for no_std
serde_rusqlite = "0.35.0"
serial_test = "3.1.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
Expand Down Expand Up @@ -148,5 +150,5 @@ cubecl-common = { tag = "v0.1.1", git = "https://github.com/tracel-ai/cubecl", d
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }

[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
debug = 0 # Speed up compilation time and not necessary.
opt-level = 2
13 changes: 6 additions & 7 deletions crates/burn-jit/src/fusion/elemwise/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
outputs: &[&TensorDescription],
stateful: bool,
) -> FusionKernel<R> {
let cube_dim_x = self.cube_dim.x;
let cube_dim_y = self.cube_dim.y;

assert_eq!(cube_dim_x, cube_dim_y, "The grid must be a square");
let cube_dim = cube_dim_x as usize;
assert_eq!(
self.cube_dim.x, self.cube_dim.y,
"The grid must be a square"
);

let vectorize_4 = can_vectorize(handles_inputs, inputs, outputs, 4);
let vectorize_2 = can_vectorize(handles_inputs, inputs, outputs, 2);
Expand Down Expand Up @@ -69,7 +68,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {

let reference_tensor = inputs[settings.mappings[0].pos_input];
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
let output_infos =
inplace_output2input
.iter()
Expand All @@ -96,7 +95,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
false => {
let reference_tensor = outputs[0];
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| {
let size = calculate_num_elems_dyn_rank(&tensor.shape)
* self.info.outputs[pos].elem_size::<R>();
Expand Down
14 changes: 6 additions & 8 deletions crates/burn-jit/src/kernel/cast/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
use cubecl::linalg::tensor::index_offset_with_layout;
use cubecl::{
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, SUBCUBE_DIM_APPROX,
};
use cubecl::{calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor};
use cubecl::{ir::KernelDefinition, KernelSettings};
use std::any::TypeId;

Expand Down Expand Up @@ -46,10 +44,10 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1);

let num_elems: usize = input.shape.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);

let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let client = input.client.clone();
let handle = client.empty(num_elems * core::mem::size_of::<EO>());
let output =
Expand All @@ -58,7 +56,7 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
cast_element::launch::<EI::Primitive, EO::Primitive, R>(
&client,
cube_count,
CubeDim::default(),
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&input.handle,
Expand Down
111 changes: 21 additions & 90 deletions crates/burn-jit/src/kernel/cast/bool_cast.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
use cubecl::{
cpa,
frontend::TensorHandleRef,
ir::{Elem, Item, KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
use std::marker::PhantomData;
use crate::{tensor::JitTensor, JitElement, JitRuntime};
use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim};

#[cube(launch)]
fn bool_cast_kernel<T: Numeric>(input: &Tensor<UInt>, output: &mut Tensor<T>) {
if input[ABSOLUTE_POS] >= UInt::new(1) {
output[ABSOLUTE_POS] = T::from_int(1);
} else {
output[ABSOLUTE_POS] = T::from_int(0);
}
}

/// Cast a bool tensor to the given element type.
///
Expand All @@ -17,7 +19,6 @@ use std::marker::PhantomData;
pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
tensor: JitTensor<R, u32, D>,
) -> JitTensor<R, EO, D> {
let kernel = BoolCastEagerKernel::<R, EO>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new_contiguous(
Expand All @@ -27,86 +28,16 @@ pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
buffer,
);

Execution::start(kernel, tensor.client)
.inputs(&[TensorHandleRef::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)])
.outputs(&[TensorHandleRef::new(
&output.handle,
&output.strides,
&output.shape.dims,
)])
.execute(CubeCountSettings::Output { pos: 0 });

output
}

pub(crate) struct BoolCastShader {
tensor: Variable,
output: Variable,
}

#[derive(new)]
pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
_runtime: PhantomData<R>,
_elem_out: PhantomData<EO>,
}

impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_input = Item::new(Elem::Bool);
let item_output = EO::cube_elem().into();

let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};

BoolCastShader { tensor, output }.expand(&mut scope);
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);

scope.write_global_custom(output);

let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};

let out = OutputInfo::Array { item: item_output };

let info = KernelExpansion {
inputs: vec![tensor],
outputs: vec![out],
scope,
};

let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}

fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>()
}
}

impl BoolCastShader {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let id = Variable::AbsolutePos;
let output = self.output;
bool_cast_kernel::launch::<EO::Primitive, R>(
&tensor.client,
cube_count,
cube_dim,
TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
);

let represents_true = scope.create_local(Elem::Bool);
cpa!(scope, represents_true = tensor[id]);
cpa!(scope, if(represents_true).then(|scope|{
cpa!(scope, output[id] = 1);
}).else(|scope|{
cpa!(scope, output[id] = 0);
}));
}
output
}
22 changes: 11 additions & 11 deletions crates/burn-jit/src/kernel/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_tensor::Shape;
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor, Runtime, SUBCUBE_DIM_APPROX,
tensor_vectorization_factor, Runtime,
};

#[cube]
Expand Down Expand Up @@ -139,17 +139,17 @@ pub(crate) fn launch_cmp<
let shape_out = Shape::new(shape_out);
let client = lhs.client.clone();
let num_elems = shape_out.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);

let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);

let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
kernel_cmp::launch::<E::Primitive, O, R>(
&client,
cube_count,
CubeDim::default(),
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&lhs.handle,
Expand Down Expand Up @@ -244,17 +244,17 @@ pub(crate) fn launch_scalar_cmp<
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1);
let client = tensor.client.clone();
let num_elems = tensor.shape.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);

let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);

let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && tensor.can_mut() {
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
&client,
cube_count,
CubeDim::default(),
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
Expand Down
7 changes: 4 additions & 3 deletions crates/burn-jit/src/kernel/conv/conv2d.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
use cubecl::{calculate_cube_count_elemwise, prelude::*};

use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
Expand Down Expand Up @@ -161,12 +161,13 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
};

let num_elems_output = output.shape.num_elements();
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems_output, cube_dim);

conv2d_kernel::launch::<E::FloatPrimitive, R>(
&input.client,
cube_count,
cube_dim,
CubeDim::default(),
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),
Expand Down
9 changes: 6 additions & 3 deletions crates/burn-jit/src/kernel/conv/conv3d.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
use cubecl::{calculate_cube_count_elemwise, prelude::*};

use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
Expand Down Expand Up @@ -188,10 +188,13 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
}
};

let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);

conv3d_kernel::launch::<E::FloatPrimitive, R>(
&input.client,
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX),
CubeDim::default(),
cube_count,
cube_dim,
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),
Expand Down
Loading

0 comments on commit e68b9ab

Please sign in to comment.