-
Notifications
You must be signed in to change notification settings - Fork 64
/
module.go
120 lines (105 loc) · 3.48 KB
/
module.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package cu
// #include <cuda.h>
import "C"
import (
"unsafe"
"github.com/pkg/errors"
)
// Module represents a CUDA Module
type Module struct {
mod C.CUmodule
}
func (m Module) c() C.CUmodule { return m.mod }
// Load loads a module into the current context.
// The CUDA driver API does not attempt to lazily allocate the resources needed by a module;
// if the memory for functions and data (constant and global) needed by the module cannot be allocated, `Load()` fails.
//
// The file should be a cubin file as output by nvcc, or a PTX file either as output by nvcc or handwritten, or a fatbin file as output by nvcc from toolchain 4.0 or late
func Load(name string) (Module, error) {
var mod Module
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
err := result(C.cuModuleLoad(&mod.mod, cstr))
return mod, err
}
// LoadData loads a module from a input string.
func LoadData(image string) (Module, error) {
var mod Module
cstr := C.CString(image)
defer C.free(unsafe.Pointer(cstr))
err := result(C.cuModuleLoadData(&mod.mod, unsafe.Pointer(cstr)))
return mod, err
}
// LoadDataEx loads a module from a input string.
func LoadDataEx(image string, options ...JITOption) (Module, error) {
var mod Module
cstr := C.CString(image)
defer C.free(unsafe.Pointer(cstr))
argcount, args, argvals := encodeArguments(options)
err := result(C.cuModuleLoadDataEx(&mod.mod, unsafe.Pointer(cstr), argcount, args, argvals))
return mod, err
}
// LoadFatBinary loads a module from a input string.
func LoadFatBinary(image string) (Module, error) {
var mod Module
cstr := C.CString(image)
defer C.free(unsafe.Pointer(cstr))
err := result(C.cuModuleLoadFatBinary(&mod.mod, unsafe.Pointer(cstr)))
return mod, err
}
// Function returns a pointer to the function in the module by the name. If it's not found, the error NotFound is returned
func (m Module) Function(name string) (Function, error) {
var fn Function
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
err := result(C.cuModuleGetFunction(&fn.fn, m.mod, cstr))
return fn, err
}
// Global returns a global pointer as defined in a module. It returns a pointer to the memory in the device.
func (m Module) Global(name string) (DevicePtr, int64, error) {
var d C.CUdeviceptr
var size C.size_t
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
if err := result(C.cuModuleGetGlobal(&d, &size, m.mod, cstr)); err != nil {
return 0, 0, err
}
return DevicePtr(d), int64(size), nil
}
func (ctx *Ctx) Load(name string) (m Module, err error) {
var mod C.CUmodule
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
f := func() error { return result(C.cuModuleLoad(&mod, cstr)) }
if err = ctx.Do(f); err != nil {
err = errors.Wrap(err, "LoadModule")
return
}
m = Module{mod}
return
}
func (ctx *Ctx) ModuleFunction(m Module, name string) (function Function, err error) {
var fn C.CUfunction
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
f := func() error { return result(C.cuModuleGetFunction(&fn, m.mod, cstr)) }
if err = ctx.Do(f); err != nil {
err = errors.Wrap(err, "ModuleFunction")
return
}
function = Function{fn}
return
}
func (ctx *Ctx) ModuleGlobal(m Module, name string) (dptr DevicePtr, size int64, err error) {
var d C.CUdeviceptr
var s C.size_t
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
f := func() error { return result(C.cuModuleGetGlobal(&d, &s, m.mod, cstr)) }
if err = ctx.Do(f); err != nil {
err = errors.Wrap(err, "ModuleGlobal")
return
}
size = int64(s)
return
}