Skip to content

Commit 6cc049f

Browse files
committed
Default integers to 32-bit precision
1 parent a0bea9a commit 6cc049f

23 files changed

+1095
-1084
lines changed

exla/lib/exla.ex

+5-5
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ defmodule EXLA do
220220
221221
iex> EXLA.jit(&Nx.add(&1, &1)).(Nx.tensor([1, 2, 3]))
222222
#Nx.Tensor<
223-
s64[3]
223+
s32[3]
224224
[2, 4, 6]
225225
>
226226
@@ -265,7 +265,7 @@ defmodule EXLA do
265265
266266
iex> EXLA.jit_apply(&Nx.add(&1, &1), [Nx.tensor([1, 2, 3])])
267267
#Nx.Tensor<
268-
s64[3]
268+
s32[3]
269269
[2, 4, 6]
270270
>
271271
@@ -278,10 +278,10 @@ defmodule EXLA do
278278
@doc """
279279
A shortcut for `Nx.Defn.compile/3` with the EXLA compiler.
280280
281-
iex> fun = EXLA.compile(&Nx.add(&1, &1), [Nx.template({3}, {:s, 64})])
281+
iex> fun = EXLA.compile(&Nx.add(&1, &1), [Nx.template({3}, {:s, 32})])
282282
iex> fun.(Nx.tensor([1, 2, 3]))
283283
#Nx.Tensor<
284-
s64[3]
284+
s32[3]
285285
[2, 4, 6]
286286
>
287287
@@ -328,7 +328,7 @@ defmodule EXLA do
328328
329329
Now let's invoke it:
330330
331-
stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 64}), 0])
331+
stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 32}), 0])
332332
333333
for i <- 1..5 do
334334
Nx.Stream.send(stream, i)

exla/test/exla/backend_test.exs

+10-10
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ defmodule EXLA.BackendTest do
3232

3333
test "Nx.to_binary/1" do
3434
t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
35-
assert Nx.to_binary(t) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
36-
assert Nx.to_binary(t, limit: 2) == <<1::64-native, 2::64-native>>
37-
assert Nx.to_binary(t, limit: 6) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
35+
assert Nx.to_binary(t) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
36+
assert Nx.to_binary(t, limit: 2) == <<1::32-native, 2::32-native>>
37+
assert Nx.to_binary(t, limit: 6) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
3838
end
3939

4040
test "Nx.backend_transfer/1" do
@@ -44,7 +44,7 @@ defmodule EXLA.BackendTest do
4444
assert %EXLA.Backend{buffer: %EXLA.DeviceBuffer{}} = et.data
4545

4646
nt = Nx.backend_transfer(et)
47-
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
47+
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
4848

4949
assert_raise RuntimeError, ~r"called on deleted or donated buffer", fn ->
5050
Nx.backend_transfer(et)
@@ -63,7 +63,7 @@ defmodule EXLA.BackendTest do
6363
assert old_buffer == new_buffer
6464

6565
nt = Nx.backend_transfer(et)
66-
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
66+
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
6767

6868
assert_raise RuntimeError, ~r"called on deleted or donated buffer", fn ->
6969
Nx.backend_transfer(et)
@@ -83,10 +83,10 @@ defmodule EXLA.BackendTest do
8383
assert old_buffer != new_buffer
8484

8585
nt = Nx.backend_copy(et)
86-
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
86+
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
8787

8888
nt = Nx.backend_copy(et)
89-
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
89+
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
9090
end
9191

9292
test "different clients" do
@@ -102,10 +102,10 @@ defmodule EXLA.BackendTest do
102102
assert new_buffer.device_id == 0
103103

104104
nt = Nx.backend_copy(et)
105-
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
105+
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
106106

107107
nt = Nx.backend_copy(et)
108-
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
108+
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
109109
end
110110
end
111111

@@ -153,7 +153,7 @@ defmodule EXLA.BackendTest do
153153
assert inspect(t) ==
154154
"""
155155
#Nx.Tensor<
156-
s64[4]
156+
s32[4]
157157
[1, 2, 3, 4]
158158
>\
159159
"""

exla/test/exla/defn/expr_test.exs

+18-18
Original file line numberDiff line numberDiff line change
@@ -1807,7 +1807,7 @@ defmodule EXLA.Defn.ExprTest do
18071807
indices = Nx.tensor([[0]])
18081808
updates = Nx.tensor([1])
18091809

1810-
assert_equal(indexed_add(target, indices, updates), Nx.tensor([1], type: {:s, 64}))
1810+
assert_equal(indexed_add(target, indices, updates), Nx.tensor([1], type: {:s, 32}))
18111811

18121812
target = Nx.tensor([0])
18131813
indices = Nx.tensor([[0]])
@@ -1879,7 +1879,7 @@ defmodule EXLA.Defn.ExprTest do
18791879
indices = Nx.tensor([[0]])
18801880
updates = Nx.tensor([1])
18811881

1882-
assert_equal(indexed_put(target, indices, updates), Nx.tensor([1], type: {:s, 64}))
1882+
assert_equal(indexed_put(target, indices, updates), Nx.tensor([1], type: {:s, 32}))
18831883

18841884
target = Nx.tensor([0])
18851885
indices = Nx.tensor([[0]])
@@ -1963,7 +1963,7 @@ defmodule EXLA.Defn.ExprTest do
19631963
test "computes the sum across types" do
19641964
assert_equal(Nx.tensor([1, 2, 3]) |> sum(), Nx.tensor(6))
19651965
assert_equal(Nx.tensor([1, 2, 3], type: {:s, 8}) |> sum(), Nx.tensor(6))
1966-
assert_equal(Nx.tensor([1, 2, 3], type: {:u, 8}) |> sum(), Nx.tensor(6, type: {:u, 64}))
1966+
assert_equal(Nx.tensor([1, 2, 3], type: {:u, 8}) |> sum(), Nx.tensor(6, type: {:u, 32}))
19671967
assert_equal(Nx.tensor([1.0, 2.0, 3.0]) |> sum(), Nx.tensor(6.0))
19681968

19691969
assert_equal(
@@ -1986,9 +1986,9 @@ defmodule EXLA.Defn.ExprTest do
19861986
defn sum_equal(t), do: Nx.sum(Nx.equal(t, 1.0))
19871987

19881988
test "does not overflow" do
1989-
assert_equal(sum_equal(Nx.tensor(1)), Nx.tensor(1, type: {:u, 64}))
1990-
assert_equal(sum_equal(Nx.tensor([1, 1, 1])), Nx.tensor(3, type: {:u, 64}))
1991-
assert_equal(sum_equal(Nx.tensor([1, 2, 3])), Nx.tensor(1, type: {:u, 64}))
1989+
assert_equal(sum_equal(Nx.tensor(1)), Nx.tensor(1, type: {:u, 32}))
1990+
assert_equal(sum_equal(Nx.tensor([1, 1, 1])), Nx.tensor(3, type: {:u, 32}))
1991+
assert_equal(sum_equal(Nx.tensor([1, 2, 3])), Nx.tensor(1, type: {:u, 32}))
19921992
end
19931993

19941994
defn sum_keep(t), do: Nx.sum(t, keep_axes: true)
@@ -2011,7 +2011,7 @@ defmodule EXLA.Defn.ExprTest do
20112011
test "computes the product across types" do
20122012
assert_equal(Nx.tensor([1, 2, 3]) |> product(), Nx.tensor(6))
20132013
assert_equal(Nx.tensor([1, 2, 3], type: {:s, 8}) |> product(), Nx.tensor(6))
2014-
assert_equal(Nx.tensor([1, 2, 3], type: {:u, 8}) |> product(), Nx.tensor(6, type: {:u, 64}))
2014+
assert_equal(Nx.tensor([1, 2, 3], type: {:u, 8}) |> product(), Nx.tensor(6, type: {:u, 32}))
20152015
assert_equal(Nx.tensor([1.0, 2.0, 3.0]) |> product(), Nx.tensor(6.0))
20162016

20172017
assert_equal(
@@ -2034,9 +2034,9 @@ defmodule EXLA.Defn.ExprTest do
20342034
defn product_equal(t), do: Nx.product(Nx.equal(t, 1.0))
20352035

20362036
test "does not overflow" do
2037-
assert_equal(product_equal(Nx.tensor(1)), Nx.tensor(1, type: {:u, 64}))
2038-
assert_equal(product_equal(Nx.tensor([1, 1, 1])), Nx.tensor(1, type: {:u, 64}))
2039-
assert_equal(product_equal(Nx.tensor([1, 2, 3])), Nx.tensor(0, type: {:u, 64}))
2037+
assert_equal(product_equal(Nx.tensor(1)), Nx.tensor(1, type: {:u, 32}))
2038+
assert_equal(product_equal(Nx.tensor([1, 1, 1])), Nx.tensor(1, type: {:u, 32}))
2039+
assert_equal(product_equal(Nx.tensor([1, 2, 3])), Nx.tensor(0, type: {:u, 32}))
20402040
end
20412041

20422042
defn product_keep(t), do: Nx.product(t, keep_axes: true)
@@ -2416,12 +2416,12 @@ defmodule EXLA.Defn.ExprTest do
24162416
window_max2(Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])),
24172417
Nx.tensor([
24182418
[
2419-
[-9_223_372_036_854_775_808, -9_223_372_036_854_775_808],
2420-
[-9_223_372_036_854_775_808, 6]
2419+
[-2_147_483_648, -2_147_483_648],
2420+
[-2_147_483_648, 6]
24212421
],
24222422
[
2423-
[-9_223_372_036_854_775_808, -9_223_372_036_854_775_808],
2424-
[-9_223_372_036_854_775_808, 6]
2423+
[-2_147_483_648, -2_147_483_648],
2424+
[-2_147_483_648, 6]
24252425
]
24262426
])
24272427
)
@@ -2482,12 +2482,12 @@ defmodule EXLA.Defn.ExprTest do
24822482
window_min2(Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])),
24832483
Nx.tensor([
24842484
[
2485-
[9_223_372_036_854_775_807, 9_223_372_036_854_775_807],
2486-
[9_223_372_036_854_775_807, 3]
2485+
[2_147_483_647, 2_147_483_647],
2486+
[2_147_483_647, 3]
24872487
],
24882488
[
2489-
[9_223_372_036_854_775_807, 9_223_372_036_854_775_807],
2490-
[9_223_372_036_854_775_807, 3]
2489+
[2_147_483_647, 2_147_483_647],
2490+
[2_147_483_647, 3]
24912491
]
24922492
])
24932493
)

0 commit comments

Comments
 (0)