Skip to content

Commit

Permalink
Add is_nan and contains_nan tensor ops (#2088)
Browse files Browse the repository at this point in the history
* Add is_nan and contains_nan tensor ops

* Enable nan test for burn-candle

* Disabling tests due to #2089
  • Loading branch information
antimora authored Aug 6, 2024
1 parent 27d42cd commit cd848b1
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 5 deletions.
13 changes: 8 additions & 5 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
| `tensor.contains_nan()` | N/A |
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
Expand All @@ -199,6 +200,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
| `tensor.is_close(other, atol, rtol)` | `torch.isclose(tensor, other, atol, rtol)` |
| `tensor.is_nan()` | `torch.isnan(tensor)` |
| `tensor.lower(other)` | `tensor.lt(other)` |
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
| `tensor.lower_equal(other)` | `tensor.le(other)` |
Expand Down Expand Up @@ -304,12 +306,13 @@ Those operations are only available for `Bool` tensors.

### Quantization Operations

Those operations are only available for `Float` tensors on backends that implement quantization strategies.
Those operations are only available for `Float` tensors on backends that implement quantization
strategies.

| Burn API | PyTorch Equivalent |
| ------------------------------------ | ------------------------------- |
| `tensor.quantize(scheme, qparams)` | N/A |
| `tensor.dequantize()` | N/A |
| Burn API | PyTorch Equivalent |
| ---------------------------------- | ------------------ |
| `tensor.quantize(scheme, qparams)` | N/A |
| `tensor.dequantize()` | N/A |

## Activation Functions

Expand Down
1 change: 1 addition & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ mod tests {
burn_tensor::testgen_flip!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();
burn_tensor::testgen_nan!();

// TODO: https://github.com/tracel-ai/burn/issues/1237
//
Expand Down
26 changes: 26 additions & 0 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,32 @@ where
// Assign the original tensor data to the appropriate slice of the padded tensor
padded_tensor.slice_assign(ranges, self)
}

/// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
///
/// # Returns
///
/// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
pub fn is_nan(&self) -> Tensor<B, D, Bool> {
// Check if the input tensor is NaN by comparing it to itself
// NaN is the only value that is not equal to itself
K::not_equal(self.primitive.clone(), self.primitive.clone())
}

/// Checks if the tensor contains any NaN values.
///
/// # Returns
///
/// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
// Summing the tensor will result in NaN if the tensor contains any NaN values
// This is faster than checking each element individually
// because it rolls up the NaN values into a single value
let sum = K::sum(self.primitive.clone());

// Check if the sum is NaN by comparing it to itself
K::not_equal(sum.clone(), sum)
}
}

impl<B, K> Tensor<B, 2, K>
Expand Down
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_topk!();
burn_tensor::testgen_remainder!();
burn_tensor::testgen_cartesian_grid!();
burn_tensor::testgen_nan!();

// test stats
burn_tensor::testgen_var!();
Expand Down
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod matmul;
mod maxmin;
mod movedim;
mod mul;
mod nan;
mod narrow;
mod neg;
mod one_hot;
Expand Down
31 changes: 31 additions & 0 deletions crates/burn-tensor/src/tests/ops/nan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#[burn_tensor_testgen::testgen(nan)]
mod tests {
use super::*;
use burn_tensor::{Int, Tensor, TensorData};

#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2089"]
fn is_nan() {
let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let no_nan_expected =
TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);

let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [f32::NAN, 4.0, 5.0]]);
let with_nan_expected =
TestTensorBool::<2>::from([[false, true, false], [true, false, false]]);

assert_eq!(no_nan_expected.into_data(), no_nan.is_nan().into_data());

assert_eq!(with_nan_expected.into_data(), with_nan.is_nan().into_data());
}

#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2089"]
fn contains_nan() {
let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
assert!(!no_nan.contains_nan().into_scalar());

let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [3.0, 4.0, 5.0]]);
assert!(with_nan.contains_nan().into_scalar());
}
}

0 comments on commit cd848b1

Please sign in to comment.