Skip to content

Commit 49e16b6

Browse files
authored
feature(tensor): Add unsqueeze_dim helper (#966)
1 parent 20e9066 commit 49e16b6

File tree

4 files changed

+85
-0
lines changed

4 files changed

+85
-0
lines changed

burn-book/src/building-blocks/tensor.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
4646
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
4747
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
4848
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
49+
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
4950
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
5051
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
5152
| `tensor.device()` | `tensor.device` |

burn-tensor/src/tensor/api/base.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,40 @@ where
257257
self.reshape(shape)
258258
}
259259

260+
/// Creates a new tensor with a dimension of size one inserted at the specified position.
261+
///
262+
/// # Example
263+
///
264+
/// ```rust
265+
/// use burn_tensor::backend::Backend;
266+
/// use burn_tensor::{Tensor, Shape};
267+
///
268+
/// fn example<B: Backend>() {
269+
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
270+
/// let tensor: Tensor<B, 3> = tensor.unsqueeze_dim(1);
271+
/// println!("{:?}", tensor.shape());
272+
/// // Shape { dims: [3, 1, 3] }
273+
/// }
274+
/// ```
275+
pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
276+
check!(TensorCheck::unsqueeze_dim::<{ D }>(dim));
277+
278+
let mut dims = [1; D2];
279+
let shape = self.shape();
280+
281+
dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
282+
283+
if dim < D {
284+
dims[dim] = 1;
285+
dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
286+
} else {
287+
dims[dim] = 1;
288+
}
289+
290+
let shape = Shape::new(dims);
291+
self.reshape(shape)
292+
}
293+
260294
/// Returns a tensor containing the elements selected from the given ranges.
261295
///
262296
/// # Panics

burn-tensor/src/tensor/api/check.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,21 @@ impl TensorCheck {
198198
check
199199
}
200200

201+
pub(crate) fn unsqueeze_dim<const D: usize>(dim: usize) -> Self {
202+
let mut check = Self::Ok;
203+
if dim > D {
204+
check = check.register(
205+
"Unsqueeze",
206+
TensorError::new(format!(
207+
"Can't unsqueeze at dimension {}, exceeds tensor dimensions (D={})",
208+
dim, D
209+
)),
210+
);
211+
}
212+
213+
check
214+
}
215+
201216
pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
202217
let mut check = Self::Ok;
203218

burn-tensor/src/tests/ops/squeeze.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,39 @@ mod tests {
3434
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
3535
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(2);
3636
}
37+
38+
/// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor.
39+
#[test]
40+
fn should_unsqueeze_dim() {
41+
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([2, 4, 1]));
42+
let unsqueezed_tensor: Tensor<TestBackend, 4> = tensor.unsqueeze_dim(1);
43+
let expected_shape = Shape::new([2, 1, 4, 1]);
44+
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
45+
}
46+
47+
/// Test if the function can successfully unsqueeze the first size 1 dimension of a 4D tensor.
48+
#[test]
49+
fn should_unsqueeze_dim_first() {
50+
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
51+
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(0);
52+
let expected_shape = Shape::new([1, 2, 3, 4, 5]);
53+
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
54+
}
55+
56+
/// Test if the function can successfully unsqueeze the last size 1 dimension of a 4D tensor.
57+
#[test]
58+
fn should_unsqueeze_dim_last() {
59+
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([5, 4, 3, 2]));
60+
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(4);
61+
let expected_shape = Shape::new([5, 4, 3, 2, 1]);
62+
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
63+
}
64+
65+
/// Test if the function panics when the unsqueezed dimension is out of bounds.
66+
#[test]
67+
#[should_panic]
68+
fn should_unsqueeze_dim_panic() {
69+
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
70+
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(5);
71+
}
3772
}

0 commit comments

Comments
 (0)