Skip to content

Commit

Permalink
Cuda 12 (#69)
Browse files Browse the repository at this point in the history
* Update jit.go - comment out old JIT Target compute 20, 21

* Update params.go - move to the _func trick

Support the new cuda 12 structure (v2)

* Update params.go - define the function prototype

hides a warning

* Update api.go - hide deprecated api warnings
  • Loading branch information
neurlang authored May 20, 2024
1 parent 60c34ed commit 498bd6b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
6 changes: 5 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package cu

// #include <cuda.h>
/*
#cgo CFLAGS: -Wno-deprecated-declarations
#include <cuda.h>
*/
import "C"
import "unsafe"

Expand Down
4 changes: 2 additions & 2 deletions jit.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ const (
// JITTarget11 JITTargetOption = C.CU_TARGET_COMPUTE_11
// JITTarget12 JITTargetOption = C.CU_TARGET_COMPUTE_12
// JITTarget13 JITTargetOption = C.CU_TARGET_COMPUTE_13
JITTarget20 JITTargetOption = C.CU_TARGET_COMPUTE_20
JITTarget21 JITTargetOption = C.CU_TARGET_COMPUTE_21
// JITTarget20 JITTargetOption = C.CU_TARGET_COMPUTE_20
// JITTarget21 JITTargetOption = C.CU_TARGET_COMPUTE_21
JITTarget30 JITTargetOption = C.CU_TARGET_COMPUTE_30
JITTarget32 JITTargetOption = C.CU_TARGET_COMPUTE_32
JITTarget35 JITTargetOption = C.CU_TARGET_COMPUTE_35
Expand Down
19 changes: 9 additions & 10 deletions params.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cu
/*
#include <cuda.h>
void handleCUDACB(void* fn);
void CallHostFunc(void* fn){
handleCUDACB(fn);
};
Expand All @@ -28,16 +29,14 @@ func (p *KernelNodeParams) c() *C.CUDA_KERNEL_NODE_PARAMS {
// here anonymous initialization of struct fields is used because `func` is a keyword.
// see also: https://github.com/golang/go/issues/41968
retVal := &C.CUDA_KERNEL_NODE_PARAMS{
p.Func.fn,
C.uint(p.GridDimX),
C.uint(p.GridDimY),
C.uint(p.GridDimZ),
C.uint(p.BlockDimX),
C.uint(p.BlockDimY),
C.uint(p.BlockDimZ),
C.uint(p.SharedMemBytes),
nil,
nil,
_func: p.Func.fn,
gridDimX: C.uint(p.GridDimX),
gridDimY: C.uint(p.GridDimY),
gridDimZ: C.uint(p.GridDimZ),
blockDimX: C.uint(p.BlockDimX),
blockDimY: C.uint(p.BlockDimY),
blockDimZ: C.uint(p.BlockDimZ),
sharedMemBytes: C.uint(p.SharedMemBytes),
}
return retVal
}
Expand Down

0 comments on commit 498bd6b

Please sign in to comment.