@@ -1807,7 +1807,7 @@ defmodule EXLA.Defn.ExprTest do
1807
1807
indices = Nx . tensor ( [ [ 0 ] ] )
1808
1808
updates = Nx . tensor ( [ 1 ] )
1809
1809
1810
- assert_equal ( indexed_add ( target , indices , updates ) , Nx . tensor ( [ 1 ] , type: { :s , 64 } ) )
1810
+ assert_equal ( indexed_add ( target , indices , updates ) , Nx . tensor ( [ 1 ] , type: { :s , 32 } ) )
1811
1811
1812
1812
target = Nx . tensor ( [ 0 ] )
1813
1813
indices = Nx . tensor ( [ [ 0 ] ] )
@@ -1879,7 +1879,7 @@ defmodule EXLA.Defn.ExprTest do
1879
1879
indices = Nx . tensor ( [ [ 0 ] ] )
1880
1880
updates = Nx . tensor ( [ 1 ] )
1881
1881
1882
- assert_equal ( indexed_put ( target , indices , updates ) , Nx . tensor ( [ 1 ] , type: { :s , 64 } ) )
1882
+ assert_equal ( indexed_put ( target , indices , updates ) , Nx . tensor ( [ 1 ] , type: { :s , 32 } ) )
1883
1883
1884
1884
target = Nx . tensor ( [ 0 ] )
1885
1885
indices = Nx . tensor ( [ [ 0 ] ] )
@@ -1963,7 +1963,7 @@ defmodule EXLA.Defn.ExprTest do
1963
1963
test "computes the sum across types" do
1964
1964
assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) |> sum ( ) , Nx . tensor ( 6 ) )
1965
1965
assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :s , 8 } ) |> sum ( ) , Nx . tensor ( 6 ) )
1966
- assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :u , 8 } ) |> sum ( ) , Nx . tensor ( 6 , type: { :u , 64 } ) )
1966
+ assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :u , 8 } ) |> sum ( ) , Nx . tensor ( 6 , type: { :u , 32 } ) )
1967
1967
assert_equal ( Nx . tensor ( [ 1.0 , 2.0 , 3.0 ] ) |> sum ( ) , Nx . tensor ( 6.0 ) )
1968
1968
1969
1969
assert_equal (
@@ -1986,9 +1986,9 @@ defmodule EXLA.Defn.ExprTest do
1986
1986
defn sum_equal ( t ) , do: Nx . sum ( Nx . equal ( t , 1.0 ) )
1987
1987
1988
1988
test "does not overflow" do
1989
- assert_equal ( sum_equal ( Nx . tensor ( 1 ) ) , Nx . tensor ( 1 , type: { :u , 64 } ) )
1990
- assert_equal ( sum_equal ( Nx . tensor ( [ 1 , 1 , 1 ] ) ) , Nx . tensor ( 3 , type: { :u , 64 } ) )
1991
- assert_equal ( sum_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) ) , Nx . tensor ( 1 , type: { :u , 64 } ) )
1989
+ assert_equal ( sum_equal ( Nx . tensor ( 1 ) ) , Nx . tensor ( 1 , type: { :u , 32 } ) )
1990
+ assert_equal ( sum_equal ( Nx . tensor ( [ 1 , 1 , 1 ] ) ) , Nx . tensor ( 3 , type: { :u , 32 } ) )
1991
+ assert_equal ( sum_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) ) , Nx . tensor ( 1 , type: { :u , 32 } ) )
1992
1992
end
1993
1993
1994
1994
defn sum_keep ( t ) , do: Nx . sum ( t , keep_axes: true )
@@ -2011,7 +2011,7 @@ defmodule EXLA.Defn.ExprTest do
2011
2011
test "computes the product across types" do
2012
2012
assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) |> product ( ) , Nx . tensor ( 6 ) )
2013
2013
assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :s , 8 } ) |> product ( ) , Nx . tensor ( 6 ) )
2014
- assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :u , 8 } ) |> product ( ) , Nx . tensor ( 6 , type: { :u , 64 } ) )
2014
+ assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :u , 8 } ) |> product ( ) , Nx . tensor ( 6 , type: { :u , 32 } ) )
2015
2015
assert_equal ( Nx . tensor ( [ 1.0 , 2.0 , 3.0 ] ) |> product ( ) , Nx . tensor ( 6.0 ) )
2016
2016
2017
2017
assert_equal (
@@ -2034,9 +2034,9 @@ defmodule EXLA.Defn.ExprTest do
2034
2034
defn product_equal ( t ) , do: Nx . product ( Nx . equal ( t , 1.0 ) )
2035
2035
2036
2036
test "does not overflow" do
2037
- assert_equal ( product_equal ( Nx . tensor ( 1 ) ) , Nx . tensor ( 1 , type: { :u , 64 } ) )
2038
- assert_equal ( product_equal ( Nx . tensor ( [ 1 , 1 , 1 ] ) ) , Nx . tensor ( 1 , type: { :u , 64 } ) )
2039
- assert_equal ( product_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) ) , Nx . tensor ( 0 , type: { :u , 64 } ) )
2037
+ assert_equal ( product_equal ( Nx . tensor ( 1 ) ) , Nx . tensor ( 1 , type: { :u , 32 } ) )
2038
+ assert_equal ( product_equal ( Nx . tensor ( [ 1 , 1 , 1 ] ) ) , Nx . tensor ( 1 , type: { :u , 32 } ) )
2039
+ assert_equal ( product_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) ) , Nx . tensor ( 0 , type: { :u , 32 } ) )
2040
2040
end
2041
2041
2042
2042
defn product_keep ( t ) , do: Nx . product ( t , keep_axes: true )
@@ -2416,12 +2416,12 @@ defmodule EXLA.Defn.ExprTest do
2416
2416
window_max2 ( Nx . tensor ( [ [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] , [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ] ) ) ,
2417
2417
Nx . tensor ( [
2418
2418
[
2419
- [ - 9_223_372_036_854_775_808 , - 9_223_372_036_854_775_808 ] ,
2420
- [ - 9_223_372_036_854_775_808 , 6 ]
2419
+ [ - 2_147_483_648 , - 2_147_483_648 ] ,
2420
+ [ - 2_147_483_648 , 6 ]
2421
2421
] ,
2422
2422
[
2423
- [ - 9_223_372_036_854_775_808 , - 9_223_372_036_854_775_808 ] ,
2424
- [ - 9_223_372_036_854_775_808 , 6 ]
2423
+ [ - 2_147_483_648 , - 2_147_483_648 ] ,
2424
+ [ - 2_147_483_648 , 6 ]
2425
2425
]
2426
2426
] )
2427
2427
)
@@ -2482,12 +2482,12 @@ defmodule EXLA.Defn.ExprTest do
2482
2482
window_min2 ( Nx . tensor ( [ [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] , [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ] ) ) ,
2483
2483
Nx . tensor ( [
2484
2484
[
2485
- [ 9_223_372_036_854_775_807 , 9_223_372_036_854_775_807 ] ,
2486
- [ 9_223_372_036_854_775_807 , 3 ]
2485
+ [ 2_147_483_647 , 2_147_483_647 ] ,
2486
+ [ 2_147_483_647 , 3 ]
2487
2487
] ,
2488
2488
[
2489
- [ 9_223_372_036_854_775_807 , 9_223_372_036_854_775_807 ] ,
2490
- [ 9_223_372_036_854_775_807 , 3 ]
2489
+ [ 2_147_483_647 , 2_147_483_647 ] ,
2490
+ [ 2_147_483_647 , 3 ]
2491
2491
]
2492
2492
] )
2493
2493
)
0 commit comments