Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/cube/vectorization #1781

Merged
merged 24 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/burn-cube-macros/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ impl CodeAnalysisBuilder {
}
syn::Expr::Break(_) => {}
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Array(expr) => {
for element in expr.elems.iter() {
match element {
syn::Expr::Lit(_) => {}
_ => todo!("Analysis: only array of literals is supported"),
}
}
}
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
_ => todo!("Analysis: unsupported expr {expr:?}"),
}
}
Expand Down
16 changes: 15 additions & 1 deletion crates/burn-cube-macros/src/codegen/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use super::{
branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop},
function::{codegen_call, codegen_closure, codegen_expr_method_call},
operation::codegen_binary,
variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs},
variable::{
codegen_array_lit, codegen_assign, codegen_index, codegen_lit, codegen_local,
codegen_path_rhs,
},
};

/// Codegen for a statement (generally one line)
Expand Down Expand Up @@ -59,6 +62,15 @@ pub(crate) fn codegen_expr_block(
codegen_block(&block.block, loop_level, variable_analyses)
}

pub(crate) fn codegen_ref(
reference: &syn::ExprReference,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let inner = codegen_expr(&reference.expr, loop_level, variable_analyses);
quote::quote! { & #inner }
}

/// Codegen for expressions
/// There are many variants of expression, treated differently
pub(crate) fn codegen_expr(
Expand All @@ -84,6 +96,8 @@ pub(crate) fn codegen_expr(
syn::Expr::MethodCall(call) => codegen_expr_method_call(call),
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses),
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses),
syn::Expr::Array(array) => codegen_array_lit(array),
syn::Expr::Reference(reference) => codegen_ref(reference, loop_level, variable_analyses),
_ => panic!("Codegen: Unsupported {:?}", expr),
}
}
7 changes: 1 addition & 6 deletions crates/burn-cube-macros/src/codegen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@ pub(crate) fn codegen_closure(
}

/// Codegen for a function call
/// Supports:
/// func()
/// func::<T>()
/// T::func()
///
/// Should map:
/// Maps
/// [A[::<...>]?::]^* func[::<...>] (args)
/// to
/// [A[::<...>]?::]^* func_expand[::<...>] (context, args)
Expand Down
13 changes: 13 additions & 0 deletions crates/burn-cube-macros/src/codegen/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
}
}

/// Codegen for arrays of literals
pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
let mut tokens = quote::quote! {};
for element in array.elems.iter() {
let token = match element {
syn::Expr::Lit(lit) => codegen_lit(lit),
_ => todo!("Codegen: Only arrays of literals are supported"),
};
tokens.extend(quote::quote! { #token, });
}
quote::quote! { [ #tokens ] }
}

/// Codegen for a local declaration (let ...)
/// Supports:
/// let x = ...
Expand Down
21 changes: 7 additions & 14 deletions crates/burn-cube/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,7 @@ impl core::fmt::Display for CompilationSettings {
}

match self.vectorization {
Some(vectorization) => match vectorization {
Vectorization::Vec4 => f.write_str("v4"),
Vectorization::Vec3 => f.write_str("v3"),
Vectorization::Vec2 => f.write_str("v2"),
Vectorization::Scalar => f.write_str("v1"),
}?,
Some(vectorization) => f.write_fmt(format_args!("v{}", vectorization))?,
None => f.write_str("vn")?,
};

Expand Down Expand Up @@ -154,7 +149,7 @@ impl InputInfo {
item,
visibility: _,
} => *item,
InputInfo::Scalar { elem, size: _ } => Item::Scalar(*elem),
InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
}
}
}
Expand Down Expand Up @@ -252,7 +247,7 @@ impl Compilation {
named.push((
"info".to_string(),
Binding {
item: Item::Scalar(Elem::UInt),
item: Item::new(Elem::UInt),
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
Expand Down Expand Up @@ -300,7 +295,7 @@ impl Compilation {
self.named_bindings.push((
format!("scalars_{}", elem),
Binding {
item: Item::Scalar(elem),
item: Item::new(elem),
visibility: Visibility::Read,
location: Location::Storage,
size: Some(size),
Expand Down Expand Up @@ -440,11 +435,9 @@ impl Compilation {
}

fn bool_item(ty: Item) -> Item {
match ty {
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
Item {
elem: bool_elem(ty.elem),
vectorization: ty.vectorization,
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cube/src/codegen/dialect/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl RangeLoop {
func: F,
) {
let mut scope = parent_scope.child();
let index_ty = Item::Scalar(Elem::UInt);
let index_ty = Item::new(Elem::UInt);
let i = scope.create_local_undeclared(index_ty);

func(i, &mut scope);
Expand Down
47 changes: 21 additions & 26 deletions crates/burn-cube/src/codegen/dialect/procedure/assign.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::codegen::dialect::{macros::cpa, Item, Scope, Variable, Vectorization};
use crate::{
branch::range,
codegen::dialect::{macros::cpa, Scope, Variable, Vectorization},
};
use serde::{Deserialize, Serialize};

/// Assign value to a variable based on a given condition.
Expand All @@ -19,14 +22,15 @@ impl ConditionalAssign {
let rhs = self.rhs;
let out = self.out;

let index_var = |scope: &mut Scope, var: Variable, index: usize| match var.item() {
Item::Scalar(_) => var,
_ => {
let out = scope.create_local(var.item().elem());
cpa!(scope, out = var[index]);
out
}
};
let index_var =
|scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 {
true => var,
false => {
let out = scope.create_local(var.item().elem());
cpa!(scope, out = var[index]);
out
}
};

let mut assign_index = |index: usize| {
let cond = index_var(scope, cond, index);
Expand All @@ -42,29 +46,20 @@ impl ConditionalAssign {
}));
};

match out.item() {
Item::Vec4(_) => {
assign_index(0);
assign_index(1);
assign_index(2);
assign_index(3);
}
Item::Vec3(_) => {
assign_index(0);
assign_index(1);
assign_index(2);
}
Item::Vec2(_) => {
assign_index(0);
assign_index(1);
}
Item::Scalar(_) => {
let vectorization = out.item().vectorization;
match vectorization == 1 {
true => {
cpa!(scope, if (cond).then(|scope| {
cpa!(scope, out = lhs);
}).else(|scope| {
cpa!(scope, out = rhs);
}));
}
false => {
for i in range(0u32, vectorization as u32, true) {
assign_index(i);
}
}
};
}

Expand Down
8 changes: 4 additions & 4 deletions crates/burn-cube/src/codegen/dialect/procedure/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ impl CheckedIndex {
let lhs = self.lhs;
let rhs = self.rhs;
let out = self.out;
let array_len = scope.create_local(Item::Scalar(crate::dialect::Elem::UInt));
let inside_bound = scope.create_local(Item::Scalar(crate::dialect::Elem::Bool));
let array_len = scope.create_local(Item::new(crate::dialect::Elem::UInt));
let inside_bound = scope.create_local(Item::new(crate::dialect::Elem::Bool));

cpa!(scope, array_len = len(lhs));
cpa!(scope, inside_bound = rhs < array_len);
Expand Down Expand Up @@ -56,8 +56,8 @@ impl CheckedIndexAssign {
let lhs = self.lhs;
let rhs = self.rhs;
let out = self.out;
let array_len = scope.create_local(Item::Scalar(Elem::UInt));
let inside_bound = scope.create_local(Item::Scalar(Elem::Bool));
let array_len = scope.create_local(Item::new(Elem::UInt));
let inside_bound = scope.create_local(Item::new(Elem::Bool));

cpa!(scope, array_len = len(out));
cpa!(scope, inside_bound = lhs < array_len);
Expand Down
12 changes: 3 additions & 9 deletions crates/burn-cube/src/codegen/dialect/procedure/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,11 @@ impl IndexOffsetGlobalWithLayout {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let layout = self.layout;
let index_item_ty = Item::Scalar(Elem::UInt);
let index_item_ty = Item::new(Elem::UInt);
let offset_ref = self.position;
let zero: Variable = 0u32.into();
let vectorization_factor: Variable = match self.tensors[0].item() {
Item::Vec4(_) => 4u32,
Item::Vec3(_) => 3u32,
Item::Vec2(_) => 2u32,
Item::Scalar(_) => 1u32,
}
.into();

let vectorization_factor: u8 = self.tensors[0].item().vectorization;
let vectorization_factor: Variable = (vectorization_factor as u32).into();
for index in self.indexes.iter() {
cpa!(scope, index = zero);
}
Expand Down
8 changes: 3 additions & 5 deletions crates/burn-cube/src/codegen/dialect/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,9 @@ impl Scope {
position: Variable,
) -> Variable {
let item_global = match item.elem() {
Elem::Bool => match item {
Item::Vec4(_) => Item::Vec4(Elem::UInt),
Item::Vec3(_) => Item::Vec3(Elem::UInt),
Item::Vec2(_) => Item::Vec2(Elem::UInt),
Item::Scalar(_) => Item::Scalar(Elem::UInt),
Elem::Bool => Item {
elem: Elem::UInt,
vectorization: item.vectorization,
},
_ => item,
};
Expand Down
34 changes: 21 additions & 13 deletions crates/burn-cube/src/codegen/dialect/shader.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Scope;
use super::{Scope, Vectorization};
use crate::WORKGROUP_DEFAULT;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
Expand Down Expand Up @@ -44,7 +44,7 @@ pub enum Elem {

impl From<Elem> for Item {
fn from(val: Elem) -> Self {
Item::Scalar(val)
Item::new(val)
}
}

Expand Down Expand Up @@ -81,22 +81,30 @@ impl Display for Elem {
}

#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)]
#[allow(missing_docs)]
pub enum Item {
Vec4(Elem),
Vec3(Elem),
Vec2(Elem),
Scalar(Elem),
pub struct Item {
pub elem: Elem,
pub vectorization: Vectorization,
}

impl Item {
/// Fetch the elem of the item.
pub fn elem(&self) -> Elem {
match self {
Self::Vec4(elem) => *elem,
Self::Vec3(elem) => *elem,
Self::Vec2(elem) => *elem,
Self::Scalar(elem) => *elem,
self.elem
}

/// Create a new item without vectorization
pub fn new(elem: Elem) -> Self {
Self {
elem,
vectorization: 1,
}
}

/// Create a new item with vectorization
pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self {
Self {
elem,
vectorization,
}
}
}
Expand Down
42 changes: 21 additions & 21 deletions crates/burn-cube/src/codegen/dialect/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,30 @@ impl Variable {
match self {
Variable::GlobalInputArray(_, item) => *item,
Variable::GlobalOutputArray(_, item) => *item,
Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
Variable::GlobalScalar(_, elem) => Item::new(*elem),
Variable::Local(_, item, _) => *item,
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
Variable::LocalScalar(_, elem, _) => Item::new(*elem),
Variable::ConstantScalar(_, elem) => Item::new(*elem),
Variable::SharedMemory(_, item, _) => *item,
Variable::LocalArray(_, item, _, _) => *item,
Variable::Id => Item::Scalar(Elem::UInt),
Variable::Rank => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdX => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdY => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdZ => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdX => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdY => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdZ => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdX => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdY => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdZ => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeX => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeY => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeZ => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsX => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsY => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsZ => Item::Scalar(Elem::UInt),
Variable::Id => Item::new(Elem::UInt),
Variable::Rank => Item::new(Elem::UInt),
Variable::LocalInvocationIndex => Item::new(Elem::UInt),
Variable::LocalInvocationIdX => Item::new(Elem::UInt),
Variable::LocalInvocationIdY => Item::new(Elem::UInt),
Variable::LocalInvocationIdZ => Item::new(Elem::UInt),
Variable::WorkgroupIdX => Item::new(Elem::UInt),
Variable::WorkgroupIdY => Item::new(Elem::UInt),
Variable::WorkgroupIdZ => Item::new(Elem::UInt),
Variable::GlobalInvocationIdX => Item::new(Elem::UInt),
Variable::GlobalInvocationIdY => Item::new(Elem::UInt),
Variable::GlobalInvocationIdZ => Item::new(Elem::UInt),
Variable::WorkgroupSizeX => Item::new(Elem::UInt),
Variable::WorkgroupSizeY => Item::new(Elem::UInt),
Variable::WorkgroupSizeZ => Item::new(Elem::UInt),
Variable::NumWorkgroupsX => Item::new(Elem::UInt),
Variable::NumWorkgroupsY => Item::new(Elem::UInt),
Variable::NumWorkgroupsZ => Item::new(Elem::UInt),
}
}
}
Expand Down
Loading
Loading