Skip to content

Commit

Permalink
Add more split tests
Browse files Browse the repository at this point in the history
  • Loading branch information
agelas committed Nov 15, 2024
1 parent d647a9e commit 54ef092
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions crates/burn-tensor/src/tests/ops/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,75 @@ mod tests {
tensor.to_data().assert_eq(&expected[index], false);
}
}

#[test]
fn test_split_along_dim1() {
let device = Default::default();
let tensors = TestTensor::<2>::from_data([[0, 1, 2], [3, 4, 5]], &device);

let split_tensors = tensors.split(2, 1);
assert_eq!(split_tensors.len(), 2);

let expected = vec![
TensorData::from([[0, 1], [3, 4]]),
TensorData::from([[2], [5]]),
];

for (index, tensor) in split_tensors.iter().enumerate() {
tensor.to_data().assert_eq(&expected[index], false);
}
}

#[test]
fn test_split_split_size_larger_than_tensor_size() {
let device = Default::default();
let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device);

let split_tensors = tensors.split(10, 0);
assert_eq!(split_tensors.len(), 1);

let expected = vec![TensorData::from([0, 1, 2, 3, 4])];

for (index, tensor) in split_tensors.iter().enumerate() {
tensor.to_data().assert_eq(&expected[index], false);
}
}

#[test]
fn test_split_with_zero_split_size_zero_tensor_size() {
let device = Default::default();
let empty_array: [i32; 0] = [];
let tensors = TestTensor::<1>::from_data(empty_array, &device);

let split_tensors = tensors.split(0, 0);
assert_eq!(split_tensors.len(), 0);
}

#[test]
fn test_split_zero_sized_tensor() {
let device = Default::default();
let empty_array: [i32; 0] = [];
let tensors = TestTensor::<1>::from_data(empty_array, &device);

let split_tensors = tensors.split(1, 0);
assert_eq!(split_tensors.len(), 0);
}

#[test]
#[should_panic]
fn test_split_with_zero_split_size_non_zero_tensor() {
let device = Default::default();
let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device);

let _split_tensors = tensors.split(0, 0);
}

#[test]
#[should_panic]
fn test_split_invalid_dim() {
let device = Default::default();
let tensors = TestTensor::<1>::from_data([0, 1, 2], &device);

let _split_tensors = tensors.split(1, 2);
}
}

0 comments on commit 54ef092

Please sign in to comment.