|
1 | 1 | module MooncakeNNlibExt
|
2 | 2 |
|
3 |
| -using NNlib, Random, Mooncake |
| 3 | +using GPUArraysCore, NNlib, Random, Mooncake |
4 | 4 | using Base: IEEEFloat
|
5 | 5 | using NNlib: dropout
|
6 | 6 |
|
7 | 7 | using NNlib: conv, depthwiseconv
|
8 | 8 | import Mooncake: @from_rrule, DefaultCtx, MinimalCtx
|
9 | 9 |
|
| 10 | +# Array types which we test rules against, so are confident work. |
| 11 | +const SupportedArray{P,N} = Union{Array{P,N},AbstractGPUArray{P,N}} |
| 12 | + |
10 | 13 | @from_rrule(
|
11 |
| - MinimalCtx, Tuple{typeof(batched_mul),Array{P,3},Array{P,3}} where {P<:IEEEFloat}, |
| 14 | + MinimalCtx, |
| 15 | + Tuple{typeof(batched_mul),SupportedArray{P,3},SupportedArray{P,3}} where {P<:IEEEFloat}, |
12 | 16 | )
|
13 | 17 | @from_rrule(
|
14 |
| - MinimalCtx, Tuple{typeof(dropout),AbstractRNG,Array{P},P} where {P<:IEEEFloat}, true, |
| 18 | + MinimalCtx, |
| 19 | + Tuple{typeof(dropout),AbstractRNG,SupportedArray{P},P} where {P<:IEEEFloat}, |
| 20 | + true, |
15 | 21 | )
|
16 |
| -@from_rrule(MinimalCtx, Tuple{typeof(softmax),Array{<:IEEEFloat}}, true) |
17 |
| -@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax),Array{<:IEEEFloat}}, true) |
18 |
| -@from_rrule(MinimalCtx, Tuple{typeof(logsumexp),Array{<:IEEEFloat}}, true) |
| 22 | +@from_rrule(MinimalCtx, Tuple{typeof(softmax),SupportedArray{<:IEEEFloat}}, true) |
| 23 | +@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax),SupportedArray{<:IEEEFloat}}, true) |
| 24 | +@from_rrule(MinimalCtx, Tuple{typeof(logsumexp),SupportedArray{<:IEEEFloat}}, true) |
19 | 25 | @from_rrule(
|
20 |
| - MinimalCtx, Tuple{typeof(upsample_nearest),Array{<:IEEEFloat},NTuple{N,Int} where {N}}, |
| 26 | + MinimalCtx, |
| 27 | + Tuple{typeof(upsample_nearest),SupportedArray{<:IEEEFloat},NTuple{N,Int} where {N}}, |
21 | 28 | )
|
22 | 29 | @from_rrule(
|
23 | 30 | MinimalCtx,
|
24 |
| - Tuple{typeof(NNlib.fold),Array{<:IEEEFloat},NTuple{N,Int} where {N},DenseConvDims}, |
| 31 | + Tuple{ |
| 32 | + typeof(NNlib.fold),SupportedArray{<:IEEEFloat},NTuple{N,Int} where {N},DenseConvDims |
| 33 | + }, |
25 | 34 | )
|
26 |
| -@from_rrule(MinimalCtx, Tuple{typeof(NNlib.unfold),Array{<:IEEEFloat},DenseConvDims}) |
27 | 35 | @from_rrule(
|
28 |
| - MinimalCtx, Tuple{typeof(NNlib.scatter),Any,Array,Array{<:Union{Integer,Tuple}}}, true, |
| 36 | + MinimalCtx, Tuple{typeof(NNlib.unfold),SupportedArray{<:IEEEFloat},DenseConvDims} |
| 37 | +) |
| 38 | +@from_rrule( |
| 39 | + MinimalCtx, |
| 40 | + Tuple{typeof(NNlib.scatter),Any,SupportedArray,SupportedArray{<:Union{Integer,Tuple}}}, |
| 41 | + true, |
29 | 42 | )
|
30 | 43 | for conv in [:conv, :depthwiseconv]
|
31 | 44 | local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])
|
32 | 45 |
|
33 | 46 | @eval @from_rrule(
|
34 | 47 | MinimalCtx,
|
35 |
| - Tuple{typeof($conv),Array{P},Array{P},ConvDims} where {P<:IEEEFloat}, |
| 48 | + Tuple{ |
| 49 | + typeof($conv),SupportedArray{P},SupportedArray{P},ConvDims |
| 50 | + } where {P<:IEEEFloat}, |
36 | 51 | true,
|
37 | 52 | )
|
38 | 53 | @eval @from_rrule(
|
39 | 54 | MinimalCtx,
|
40 |
| - Tuple{typeof($∇conv_data),Array{P},Array{P},ConvDims} where {P<:IEEEFloat}, |
| 55 | + Tuple{ |
| 56 | + typeof($∇conv_data),SupportedArray{P},SupportedArray{P},ConvDims |
| 57 | + } where {P<:IEEEFloat}, |
41 | 58 | true,
|
42 | 59 | )
|
43 | 60 | end
|
44 | 61 | @from_rrule(
|
45 | 62 | MinimalCtx,
|
46 |
| - Tuple{typeof(∇conv_filter),Array{P},Array{P},ConvDims} where {P<:IEEEFloat}, |
| 63 | + Tuple{ |
| 64 | + typeof(∇conv_filter),SupportedArray{P},SupportedArray{P},ConvDims |
| 65 | + } where {P<:IEEEFloat}, |
47 | 66 | true,
|
48 | 67 | )
|
49 | 68 | for pool in [:maxpool, :meanpool]
|
50 |
| - @eval @from_rrule(MinimalCtx, Tuple{typeof($pool),Array{<:IEEEFloat},PoolDims}, true) |
| 69 | + @eval @from_rrule( |
| 70 | + MinimalCtx, Tuple{typeof($pool),SupportedArray{<:IEEEFloat},PoolDims}, true |
| 71 | + ) |
51 | 72 | end
|
52 |
| -@from_rrule(MinimalCtx, Tuple{typeof(pad_constant),Array,Any,Any}, true) |
| 73 | +@from_rrule(MinimalCtx, Tuple{typeof(pad_constant),SupportedArray,Any,Any}, true) |
53 | 74 |
|
54 | 75 | end
|
0 commit comments