Skip to content

Commit 6fbd852

Browse files
committed
Fix clamp min/max line size > 1 (#3078)
1 parent e6781ab commit 6fbd852

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

crates/burn-cubecl/src/kernel/clamp.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ pub(crate) fn clamp<R: CubeRuntime, E: CubeElement>(
2525
type Options = Options<N>;
2626

2727
fn execute(input: Line<N>, options: &Self::Options) -> Line<N> {
28+
let line_size = input.size();
2829
Line::clamp(
2930
input,
30-
Line::new(options.min_value),
31-
Line::new(options.max_value),
31+
Line::empty(line_size).fill(options.min_value),
32+
Line::empty(line_size).fill(options.max_value),
3233
)
3334
}
3435
}

crates/burn-tensor/src/tests/ops/clamp.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,15 @@ mod tests {
7070
.into_data()
7171
.assert_eq(&TensorData::from([[1, 1, 2], [3, 4, 4]]), false);
7272
}
73+
74+
#[test]
75+
fn clamp_min_max_vec_should_compile() {
76+
let input = TestTensor::<2>::ones([2, 4], &Default::default());
77+
let output = input.clamp(0., 0.5);
78+
79+
output.into_data().assert_eq(
80+
&TensorData::from([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]]),
81+
false,
82+
);
83+
}
7384
}

0 commit comments

Comments
 (0)