Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit c053b1d

Browse files
authored
Merge pull request #30 from MikeInnes/cudart
Use CUDAdrv instead of CUDArt
2 parents ecc09c9 + 05b68d0 commit c053b1d

File tree

7 files changed

+533
-519
lines changed

7 files changed

+533
-519
lines changed

REQUIRE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
julia 0.5
2-
CUDArt 0.3.0
2+
CUDAdrv 0.5.1

src/CUBLAS.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ module CUBLAS
1111

1212
importall Base.LinAlg.BLAS
1313

14-
using CUDArt
15-
using CUDArt.rt.cudaStream_t
14+
using CUDAdrv: OwnedPtr, CuArray, CuVector, CuMatrix
15+
16+
CuVecOrMat{T} = Union{CuVector{T},CuMatrix{T}}
1617

1718
const BlasChar = Char #import Base.LinAlg.BlasChar
1819
import Base.one
@@ -77,6 +78,9 @@ if isempty(libcublas)
7778
error("CUBLAS library cannot be found. Please make sure that CUDA is installed")
7879
end
7980

81+
# Typedef needed by libcublas
82+
const cudaStream_t = Ptr{Void}
83+
8084
include("libcublas.jl")
8185

8286
# setup cublas handle

src/blas.jl

+196-196
Large diffs are not rendered by default.

src/highlevel.jl

+35-38
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import Base.Operators.(*)
22

3-
import Base: scale!, scale, norm, vecdot
3+
import Base: scale!, norm, vecdot
44

55
import Base: A_mul_B!, At_mul_B, Ac_mul_B, A_mul_Bc, At_mul_Bt, Ac_mul_Bc, At_mul_Bt,
66
At_mul_B!, Ac_mul_B!, A_mul_Bc!, At_mul_Bt!, Ac_mul_Bc!, At_mul_Bt!
77

8-
cublas_size(t::Char, M::CudaVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ? 2:1))
8+
cublas_size(t::Char, M::CuVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ? 2:1))
99

1010
###########
1111
#
@@ -16,13 +16,12 @@ cublas_size(t::Char, M::CudaVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ?
1616
#######
1717
# SCAL
1818
#######
19-
scale!{T<:CublasFloat}(x::CudaArray{T}, k::Number) = CUBLAS.scal!(length(x), k, x, 1)
20-
scale{T<:CublasFloat}(x::CudaArray{T}, k::Number) = CUBLAS.scal!(length(x), k, copy(x), 1)
19+
scale!{T<:CublasFloat}(x::CuArray{T}, k::Number) = CUBLAS.scal!(length(x), k, x, 1)
2120

2221
#######
2322
# DOT
2423
#######
25-
function dot{T <: CublasFloat, TI<:Integer}(x::CudaVector{T}, rx::Union{UnitRange{TI},Range{TI}}, y::CudaVector{T}, ry::Union{UnitRange{TI},Range{TI}})
24+
function dot{T <: CublasFloat, TI<:Integer}(x::CuVector{T}, rx::Union{UnitRange{TI},Range{TI}}, y::CuVector{T}, ry::Union{UnitRange{TI},Range{TI}})
2625
if length(rx) != length(ry)
2726
throw(DimensionMismatch("length of rx, $(length(rx)), does not equal length of ry, $(length(ry))"))
2827
end
@@ -35,17 +34,17 @@ function dot{T <: CublasFloat, TI<:Integer}(x::CudaVector{T}, rx::Union{UnitRang
3534
dot(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
3635
end
3736

38-
At_mul_B{T<:CublasReal}(x::CudaVector{T}, y::CudaVector{T}) = [CUBLAS.dot(x, y)]
39-
At_mul_B{T<:CublasComplex}(x::CudaVector{T}, y::CudaVector{T}) = [CUBLAS.dotu(x, y)]
40-
Ac_mul_B{T<:CublasComplex}(x::CudaVector{T}, y::CudaVector{T}) = [CUBLAS.dotc(x, y)]
37+
At_mul_B{T<:CublasReal}(x::CuVector{T}, y::CuVector{T}) = [CUBLAS.dot(x, y)]
38+
At_mul_B{T<:CublasComplex}(x::CuVector{T}, y::CuVector{T}) = [CUBLAS.dotu(x, y)]
39+
Ac_mul_B{T<:CublasComplex}(x::CuVector{T}, y::CuVector{T}) = [CUBLAS.dotc(x, y)]
4140

42-
vecdot{T<:CublasReal}(x::CudaVector{T}, y::CudaVector{T}) = dot(x, y)
43-
vecdot{T<:CublasComplex}(x::CudaVector{T}, y::CudaVector{T}) = dotc(x, y)
41+
vecdot{T<:CublasReal}(x::CuVector{T}, y::CuVector{T}) = dot(x, y)
42+
vecdot{T<:CublasComplex}(x::CuVector{T}, y::CuVector{T}) = dotc(x, y)
4443

4544
#######
4645
# NRM2
4746
#######
48-
norm(x::CudaArray) = nrm2(x)
47+
norm(x::CuArray) = nrm2(x)
4948

5049

5150
############
@@ -58,7 +57,7 @@ norm(x::CudaArray) = nrm2(x)
5857
#########
5958
# GEMV
6059
##########
61-
function gemv_wrapper!{T<:CublasFloat}(y::CudaVector{T}, tA::Char, A::CudaMatrix{T}, x::CudaVector{T},
60+
function gemv_wrapper!{T<:CublasFloat}(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T},
6261
alpha = one(T), beta = zero(T))
6362
mA, nA = cublas_size(tA, A)
6463
if nA != length(x)
@@ -76,20 +75,20 @@ function gemv_wrapper!{T<:CublasFloat}(y::CudaVector{T}, tA::Char, A::CudaMatrix
7675
gemv!(tA, alpha, A, x, beta, y)
7776
end
7877

79-
A_mul_B!{T<:CublasFloat}(y::CudaVector{T}, A::CudaMatrix{T}, x::CudaVector{T}) = gemv_wrapper!(y, 'N', A, x)
80-
At_mul_B!{T<:CublasFloat}(y::CudaVector{T}, A::CudaMatrix{T}, x::CudaVector{T}) = gemv_wrapper!(y, 'T', A, x)
81-
Ac_mul_B!{T<:CublasFloat}(y::CudaVector{T}, A::CudaMatrix{T}, x::CudaVector{T}) = gemv_wrapper!(y, 'T', A, x)
82-
Ac_mul_B!{T<:CublasComplex}(y::CudaVector{T}, A::CudaMatrix{T}, x::CudaVector{T}) = gemv_wrapper!(y, 'C', A, x)
78+
A_mul_B!{T<:CublasFloat}(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) = gemv_wrapper!(y, 'N', A, x)
79+
At_mul_B!{T<:CublasFloat}(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) = gemv_wrapper!(y, 'T', A, x)
80+
Ac_mul_B!{T<:CublasFloat}(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) = gemv_wrapper!(y, 'T', A, x)
81+
Ac_mul_B!{T<:CublasComplex}(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) = gemv_wrapper!(y, 'C', A, x)
8382

84-
function (*){T<:CublasFloat}(A::CudaMatrix{T}, x::CudaVector{T})
83+
function (*){T<:CublasFloat}(A::CuMatrix{T}, x::CuVector{T})
8584
A_mul_B!(similar(x, T, size(A,1)), A, x)
8685
end
8786

88-
function At_mul_B{T<:CublasFloat}(A::CudaMatrix{T}, x::CudaVector{T})
87+
function At_mul_B{T<:CublasFloat}(A::CuMatrix{T}, x::CuVector{T})
8988
At_mul_B!(similar(x, T, size(A,2)), A, x)
9089
end
9190

92-
function Ac_mul_B{T<:CublasFloat}(A::CudaMatrix{T}, x::CudaVector{T})
91+
function Ac_mul_B{T<:CublasFloat}(A::CuMatrix{T}, x::CuVector{T})
9392
Ac_mul_B!(similar(x, T, size(A,2)), A, x)
9493
end
9594

@@ -103,9 +102,9 @@ end
103102
########
104103
# GEMM
105104
########
106-
function gemm_wrapper!{T <: CublasFloat}(C::CudaVecOrMat{T}, tA::Char, tB::Char,
107-
A::CudaVecOrMat{T},
108-
B::CudaVecOrMat{T},
105+
function gemm_wrapper!{T <: CublasFloat}(C::CuVecOrMat{T}, tA::Char, tB::Char,
106+
A::CuVecOrMat{T},
107+
B::CuVecOrMat{T},
109108
alpha = one(T),
110109
beta = zero(T))
111110
mA, nA = cublas_size(tA, A)
@@ -130,51 +129,49 @@ function gemm_wrapper!{T <: CublasFloat}(C::CudaVecOrMat{T}, tA::Char, tB::Char,
130129
end
131130

132131
# Mutating
133-
A_mul_B!{T <: CublasFloat}(C::CudaMatrix{T}, A::CudaMatrix{T}, B::CudaMatrix{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
134-
At_mul_B!(C::CudaMatrix, A::CudaMatrix, B::CudaMatrix) = gemm_wrapper!(C, 'T', 'N', A, B)
135-
At_mul_Bt!(C::CudaMatrix, A::CudaMatrix, B::CudaMatrix) = gemm_wrapper!(C, 'T', 'T', A, B)
136-
Ac_mul_B!{T<:CublasReal}(C::CudaMatrix{T}, A::CudaMatrix{T}, B::CudaMatrix{T}) = At_mul_B!(C, A, B)
137-
Ac_mul_B!(C::CudaMatrix, A::CudaMatrix, B::CudaMatrix) = gemm_wrapper!(C, 'C', 'N', A, B)
132+
A_mul_B!{T <: CublasFloat}(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
133+
At_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'N', A, B)
134+
At_mul_Bt!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'T', A, B)
135+
Ac_mul_B!{T<:CublasReal}(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) = At_mul_B!(C, A, B)
136+
Ac_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'C', 'N', A, B)
138137

139-
function A_mul_B!{T}(C::CudaMatrix{T}, A::CudaVecOrMat{T}, B::CudaVecOrMat{T})
138+
function A_mul_B!{T}(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T})
140139
gemm_wrapper!(C, 'N', 'N', A, B)
141140
end
142141

143142
# Non mutating
144143

145144
# A_mul_Bx
146-
function (*){T <: CublasFloat}(A::CudaMatrix{T}, B::CudaMatrix{T})
145+
function (*){T <: CublasFloat}(A::CuMatrix{T}, B::CuMatrix{T})
147146
A_mul_B!(similar(B, T,(size(A,1), size(B,2))), A, B)
148147
end
149148

150-
function A_mul_Bt{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
149+
function A_mul_Bt{T}(A::CuMatrix{T}, B::CuMatrix{T})
151150
A_mul_Bt!(similar(B, T, (size(A,1), size(B,1))), A, B)
152151
end
153152

154-
function A_mul_Bc{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
153+
function A_mul_Bc{T}(A::CuMatrix{T}, B::CuMatrix{T})
155154
A_mul_Bc!(similar(B, T,(size(A,1),size(B,1))),A, B)
156155
end
157156

158157
# At_mul_Bx
159-
function At_mul_B{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
158+
function At_mul_B{T}(A::CuMatrix{T}, B::CuMatrix{T})
160159
At_mul_B!(similar(B, T, (size(A,2), size(B,2))), A, B)
161160
end
162161

163-
function At_mul_Bt{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
162+
function At_mul_Bt{T}(A::CuMatrix{T}, B::CuMatrix{T})
164163
At_mul_Bt!(similar(B, T, (size(A,2), size(B,1))), A, B)
165164
end
166165

167166
# Ac_mul_Bx
168-
function Ac_mul_B{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
167+
function Ac_mul_B{T}(A::CuMatrix{T}, B::CuMatrix{T})
169168
Ac_mul_B!(similar(B, T, (size(A,2), size(B,2))), A, B)
170169
end
171170

172-
function Ac_mul_Bt{T,S}(A::CudaMatrix{T}, B::CudaMatrix{S})
171+
function Ac_mul_Bt{T,S}(A::CuMatrix{T}, B::CuMatrix{S})
173172
Ac_mul_Bt(similar(B, T, (size(A,2), size(B,1))), A, B)
174173
end
175174

176-
function Ac_mul_Bc{T,S}(A::CudaMatrix{T}, B::CudaMatrix{S})
175+
function Ac_mul_Bc{T,S}(A::CuMatrix{T}, B::CuMatrix{S})
177176
Ac_mul_Bc!(similar(B, T, (size(A,2), size(B,1))), A, B)
178177
end
179-
180-

src/libcublas.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -773,4 +773,4 @@ try
773773
catch exception
774774
Base.show_backtrace(STDOUT, backtrace());
775775
println();
776-
end
776+
end

src/libcublas_types.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,4 @@ try
100100
catch exception
101101
Base.show_backtrace(STDOUT, backtrace());
102102
println();
103-
end
103+
end

0 commit comments

Comments
 (0)