Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7c1d969

Browse files
authoredJan 13, 2025··
Set-up JuliaFormatter (#67)
* Set-up JuliaFormatter * Implemented formatting
1 parent 4066bf2 commit 7c1d969

20 files changed

+869
-673
lines changed
 

‎.JuliaFormatter.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
style = "blue"
2+
3+
ignore = ["src/Wrapper.jl"]
4+
pipe_to_function_call = false
5+
whitespace_in_kwargs = true
6+
whitespace_typedefs = true

‎.github/workflows/Format.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
name: Format suggestions
2+
on:
3+
pull_request:
4+
jobs:
5+
code-style:
6+
runs-on: ubuntu-latest
7+
steps:
8+
- uses: julia-actions/julia-format@v3

‎README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Torch.jl
22

3+
[![Build Status](https://github.com/FluxML/Torch.jl/actions/workflows/CI.yaml/badge.svg?branch=master)](https://github.com/FluxML/Torch.jl/actions/workflows/CI.yaml?query=branch%3Amaster)
4+
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/JuliaDiff/BlueStyle)
5+
36
Sensible extensions for exposing torch in Julia.
47

58
This package is aimed at providing the `Tensor` type, which offloads all computations over to [ATen](https://pytorch.org/cppdocs/), the foundational tensor library for PyTorch, written in C++.

‎deps/julia_wrapper_generator/generator.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ function rewrite!(e::Expr)
2222
end
2323

2424
function rewrite!(e::Expr, ::Val{:function})
25-
rewrite!(e.args[2], Val(e.args[2].head))
25+
return rewrite!(e.args[2], Val(e.args[2].head))
2626
end
2727

2828
function rewrite!(e::Expr, ::Val{:block})
29-
e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
29+
return e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
3030
end
3131

3232
function rewrite!(dag::ExprDAG)

‎deps/julia_wrapper_generator/generator.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[general]
22
library_name = "libtorch_c_api"
3-
output_file_path = "../../src/wrapper.jl"
3+
output_file_path = "../../src/Wrapper.jl"
44
prologue_file_path = "./prologue.jl"
55
module_name = "Wrapper"
66
jll_pkg_name = "TorchCAPI_jll"
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
function get_error()
2-
err = cglobal((:myerr, libtorch_c_api), Cstring) |> unsafe_load
3-
unsafe_string(err)
2+
err = cglobal((:myerr, libtorch_c_api), Cstring) |> unsafe_load
3+
return unsafe_string(err)
44
end
55

66
macro runtime_error_check(ex)
7-
quote
8-
x = $ex
9-
if x == 1
10-
cs = get_error()
11-
flush_error()
12-
throw(cs)
13-
end
14-
end |> esc
7+
return quote
8+
x = $ex
9+
if x == 1
10+
cs = get_error()
11+
flush_error()
12+
throw(cs)
13+
end
14+
end |> esc
1515
end

‎src/Torch.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using FillArrays
1313

1414
TURN_ON_LOGGING = false
1515

16-
include("wrapper.jl")
16+
include("Wrapper.jl")
1717

1818
using .Wrapper
1919

@@ -32,23 +32,25 @@ include("statistics.jl")
3232
include("grads.jl")
3333
include("utils.jl")
3434

35-
@init @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
36-
using .Flux
35+
@init @require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
36+
using .Flux
3737

38-
function (tbn::Flux.BatchNorm)(x::Tensor)
39-
tbn.λ.(Torch.batchnorm(x, tbn.γ, tbn.β, tbn.μ, tbn.σ², 0, tbn.momentum, tbn.ϵ, 1))
40-
end
38+
function (tbn::Flux.BatchNorm)(x::Tensor)
39+
return tbn.λ.(
40+
Torch.batchnorm(x, tbn.γ, tbn.β, tbn.μ, tbn.σ², 0, tbn.momentum, tbn.ϵ, 1)
41+
)
42+
end
4143

42-
function Flux.Zygote.accum(t1::Tensor, t2::Tensor{T,N}) where {T,N}
43-
ptr = Ref(Ptr{Cvoid}())
44+
function Flux.Zygote.accum(t1::Tensor, t2::Tensor{T, N}) where {T, N}
45+
ptr = Ref(Ptr{Cvoid}())
4446

45-
Torch.Wrapper.atg_add_(ptr, t1.ptr, t2.ptr)
46-
Tensor{T,N}(ptr[], Torch.on(t1))
47-
end
47+
Torch.Wrapper.atg_add_(ptr, t1.ptr, t2.ptr)
48+
return Tensor{T, N}(ptr[], Torch.on(t1))
49+
end
4850

49-
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_copy_data))
50-
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_dim))
51-
torch(x) = Flux.fmap(to_tensor, x)
51+
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_copy_data))
52+
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_dim))
53+
torch(x) = Flux.fmap(to_tensor, x)
5254
end
5355

5456
end # module
File renamed without changes.

‎src/broadcast.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,23 @@ using Base.Broadcast: broadcast_shape
99
# Base.BroadcastStyle(::Type{Tensor}) = TensorStyle()
1010

1111
for op in (:+, :-, :/)
12-
@eval function broadcasted(::typeof($op), t1::Tensor, t2::Tensor)
13-
$op(t1, t2)
14-
end
12+
@eval function broadcasted(::typeof($op), t1::Tensor, t2::Tensor)
13+
return $op(t1, t2)
14+
end
1515
end
1616

1717
for op in (:+, :-)
18-
@eval function broadcasted(::typeof($op), t1::Tensor, t2::TensorVector)
19-
t_ = reshape(t2, -1, 1)
20-
$op(t1, t_)
21-
end
18+
@eval function broadcasted(::typeof($op), t1::Tensor, t2::TensorVector)
19+
t_ = reshape(t2, -1, 1)
20+
return $op(t1, t_)
21+
end
2222
end
2323

24-
function broadcasted(::typeof(*), t1::Tensor{T,N}, t2::Tensor{T,M}) where {T,N,M}
25-
ptr = Ref(Ptr{Cvoid}())
24+
function broadcasted(::typeof(*), t1::Tensor{T, N}, t2::Tensor{T, M}) where {T, N, M}
25+
ptr = Ref(Ptr{Cvoid}())
2626

27-
atg_mul(ptr, t1.ptr, t2.ptr)
28-
Tensor{T,max(N,M)}(ptr[], on(t1))
27+
atg_mul(ptr, t1.ptr, t2.ptr)
28+
return Tensor{T, max(N, M)}(ptr[], on(t1))
2929
end
3030

3131
broadcasted(::typeof(NNlib.relu), t::Tensor) = NNlib.relu(t)
@@ -34,22 +34,21 @@ broadcasted(::typeof(identity), t::Tensor) = identity(t)
3434
broadcasted(::typeof(NNlib.sigmoid), t::Tensor) = NNlib.sigmoid(t)
3535

3636
for op in (:+, :-, :*, :/)
37-
@eval function broadcasted(::typeof($op), t::Tensor, args...)
38-
$op(t, args...)
39-
end
37+
@eval function broadcasted(::typeof($op), t::Tensor, args...)
38+
return $op(t, args...)
39+
end
4040
end
4141

4242
broadcasted(::typeof(sqrt), t::Tensor) = sqrt(t)
4343

44-
function broadcasted(::typeof(copy), t::Tensor{T,N}) where {T,N}
45-
t
44+
function broadcasted(::typeof(copy), t::Tensor{T, N}) where {T, N}
45+
return t
4646
end
4747

4848
@adjoint function broadcast(::typeof(NNlib.sigmoid), t::Tensor)
49-
50-
NNlib.sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
49+
return NNlib.sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
5150
end
5251

53-
@adjoint function broadcasted(::typeof(NNlib.relu), t::Tensor{T}) where T
54-
relu(t), Δ -> (nothing, ∇leaky_relu(Δ, t, zero(T)),)
52+
@adjoint function broadcasted(::typeof(NNlib.relu), t::Tensor{T}) where {T}
53+
return relu(t), Δ -> (nothing, ∇leaky_relu(Δ, t, zero(T)))
5554
end

0 commit comments

Comments
 (0)
Please sign in to comment.