Skip to content

Commit ff685d3

Browse files
Shorten datatype conversions in enums.go with generics. (#321)
1 parent 02ede09 commit ff685d3

File tree

3 files changed

+75
-192
lines changed

3 files changed

+75
-192
lines changed

common.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package tiledb
22

33
import (
4-
"reflect"
54
"unsafe"
65
)
76

@@ -17,6 +16,5 @@ type scalarType interface {
1716

1817
// slicePtr gives you an unsafe pointer to the start of a slice.
1918
func slicePtr[T any](slc []T) unsafe.Pointer {
20-
hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slc))
21-
return unsafe.Pointer(hdr.Data)
19+
return unsafe.Pointer(unsafe.SliceData(slc))
2220
}

enums.go

Lines changed: 65 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"fmt"
1414
"reflect"
1515
"strconv"
16+
"time"
1617
"unsafe"
1718
)
1819

@@ -265,240 +266,115 @@ func (d Datatype) Size() uint64 {
265266
func (d Datatype) MakeSlice(numElements uint64) (interface{}, unsafe.Pointer, error) {
266267
switch d {
267268
case TILEDB_INT8:
268-
slice := make([]int8, numElements)
269-
return slice, unsafe.Pointer(&slice[0]), nil
270-
269+
return makeSlice[int8](numElements)
271270
case TILEDB_INT16:
272-
slice := make([]int16, numElements)
273-
return slice, unsafe.Pointer(&slice[0]), nil
274-
271+
return makeSlice[int16](numElements)
275272
case TILEDB_INT32:
276-
slice := make([]int32, numElements)
277-
return slice, unsafe.Pointer(&slice[0]), nil
278-
273+
return makeSlice[int32](numElements)
279274
case TILEDB_INT64, TILEDB_DATETIME_YEAR, TILEDB_DATETIME_MONTH, TILEDB_DATETIME_WEEK, TILEDB_DATETIME_DAY, TILEDB_DATETIME_HR, TILEDB_DATETIME_MIN, TILEDB_DATETIME_SEC, TILEDB_DATETIME_MS, TILEDB_DATETIME_US, TILEDB_DATETIME_NS, TILEDB_DATETIME_PS, TILEDB_DATETIME_FS, TILEDB_DATETIME_AS, TILEDB_TIME_HR, TILEDB_TIME_MIN, TILEDB_TIME_SEC, TILEDB_TIME_MS, TILEDB_TIME_US, TILEDB_TIME_NS, TILEDB_TIME_PS, TILEDB_TIME_FS, TILEDB_TIME_AS:
280-
slice := make([]int64, numElements)
281-
return slice, unsafe.Pointer(&slice[0]), nil
282-
275+
return makeSlice[int64](numElements)
283276
case TILEDB_UINT8, TILEDB_CHAR, TILEDB_STRING_ASCII, TILEDB_STRING_UTF8, TILEDB_BLOB, TILEDB_GEOM_WKB, TILEDB_GEOM_WKT:
284-
slice := make([]uint8, numElements)
285-
return slice, unsafe.Pointer(&slice[0]), nil
286-
277+
return makeSlice[uint8](numElements)
287278
case TILEDB_UINT16, TILEDB_STRING_UTF16, TILEDB_STRING_UCS2:
288-
slice := make([]uint16, numElements)
289-
return slice, unsafe.Pointer(&slice[0]), nil
290-
279+
return makeSlice[uint16](numElements)
291280
case TILEDB_UINT32, TILEDB_STRING_UTF32, TILEDB_STRING_UCS4:
292-
slice := make([]uint32, numElements)
293-
return slice, unsafe.Pointer(&slice[0]), nil
294-
281+
return makeSlice[uint32](numElements)
295282
case TILEDB_UINT64:
296-
slice := make([]uint64, numElements)
297-
return slice, unsafe.Pointer(&slice[0]), nil
298-
283+
return makeSlice[uint64](numElements)
299284
case TILEDB_FLOAT32:
300-
slice := make([]float32, numElements)
301-
return slice, unsafe.Pointer(&slice[0]), nil
302-
285+
return makeSlice[float32](numElements)
303286
case TILEDB_FLOAT64:
304-
slice := make([]float64, numElements)
305-
return slice, unsafe.Pointer(&slice[0]), nil
306-
287+
return makeSlice[float64](numElements)
307288
case TILEDB_BOOL:
308-
slice := make([]bool, numElements)
309-
return slice, unsafe.Pointer(&slice[0]), nil
310-
289+
return makeSlice[bool](numElements)
311290
default:
312291
return nil, nil, fmt.Errorf("error making datatype slice; unrecognized datatype: %d", d)
313292
}
314293
}
315294

295+
// makeSlice makes a slice and returns it as well as a pointer to its start.
296+
// Its return type matches d.MakeSlice for convenience.
297+
func makeSlice[T any](numElements uint64) (any, unsafe.Pointer, error) {
298+
slice := make([]T, numElements)
299+
return slice, slicePtr(slice), nil
300+
}
301+
316302
// GetValue gets value stored in a void pointer for this data type.
317303
func (d Datatype) GetValue(valueNum uint, cvalue unsafe.Pointer) (interface{}, error) {
318304
switch d {
319305
case TILEDB_INT8:
320-
if cvalue == nil {
321-
return int8(0), nil
322-
}
323-
if valueNum > 1 {
324-
tmpValue := make([]int8, valueNum)
325-
tmpslice := (*[1 << 46]C.int8_t)(cvalue)[:valueNum:valueNum]
326-
for i, s := range tmpslice {
327-
tmpValue[i] = int8(s)
328-
}
329-
return tmpValue, nil
330-
}
331-
return *(*int8)(cvalue), nil
306+
return getValueInternal[int8](valueNum, cvalue)
332307
case TILEDB_INT16:
333-
if cvalue == nil {
334-
return int16(0), nil
335-
}
336-
if valueNum > 1 {
337-
tmpValue := make([]int16, valueNum)
338-
tmpslice := (*[1 << 46]C.int16_t)(cvalue)[:valueNum:valueNum]
339-
for i, s := range tmpslice {
340-
tmpValue[i] = int16(s)
341-
}
342-
return tmpValue, nil
343-
}
344-
return *(*int16)(cvalue), nil
308+
return getValueInternal[int16](valueNum, cvalue)
345309
case TILEDB_INT32:
346-
if cvalue == nil {
347-
return int32(0), nil
348-
}
349-
if valueNum > 1 {
350-
tmpValue := make([]int32, valueNum)
351-
tmpslice := (*[1 << 46]C.int32_t)(cvalue)[:valueNum:valueNum]
352-
for i, s := range tmpslice {
353-
tmpValue[i] = int32(s)
354-
}
355-
return tmpValue, nil
356-
}
357-
return *(*int32)(cvalue), nil
310+
return getValueInternal[int32](valueNum, cvalue)
358311
case TILEDB_INT64:
359-
if cvalue == nil {
360-
return int64(0), nil
361-
}
362-
if valueNum > 1 {
363-
tmpValue := make([]int64, valueNum)
364-
tmpslice := (*[1 << 46]C.int64_t)(cvalue)[:valueNum:valueNum]
365-
for i, s := range tmpslice {
366-
tmpValue[i] = int64(s)
367-
}
368-
return tmpValue, nil
369-
}
370-
return *(*int64)(cvalue), nil
312+
return getValueInternal[int64](valueNum, cvalue)
371313
case TILEDB_UINT8, TILEDB_BLOB, TILEDB_GEOM_WKB, TILEDB_GEOM_WKT:
372-
if cvalue == nil {
373-
return uint8(0), nil
374-
}
375-
if valueNum > 1 {
376-
tmpValue := make([]uint8, valueNum)
377-
tmpslice := (*[1 << 46]C.uint8_t)(cvalue)[:valueNum:valueNum]
378-
for i, s := range tmpslice {
379-
tmpValue[i] = uint8(s)
380-
}
381-
return tmpValue, nil
382-
}
383-
return *(*uint8)(cvalue), nil
314+
return getValueInternal[uint8](valueNum, cvalue)
384315
case TILEDB_UINT16:
385-
if cvalue == nil {
386-
return uint16(0), nil
387-
}
388-
if valueNum > 1 {
389-
tmpValue := make([]uint16, valueNum)
390-
tmpslice := (*[1 << 46]C.uint16_t)(cvalue)[:valueNum:valueNum]
391-
for i, s := range tmpslice {
392-
tmpValue[i] = uint16(s)
393-
}
394-
return tmpValue, nil
395-
}
396-
return *(*uint16)(cvalue), nil
316+
return getValueInternal[uint16](valueNum, cvalue)
397317
case TILEDB_UINT32:
398-
if cvalue == nil {
399-
return uint32(0), nil
400-
}
401-
if valueNum > 1 {
402-
tmpValue := make([]uint32, valueNum)
403-
tmpslice := (*[1 << 46]C.uint32_t)(cvalue)[:valueNum:valueNum]
404-
for i, s := range tmpslice {
405-
tmpValue[i] = uint32(s)
406-
}
407-
return tmpValue, nil
408-
}
409-
return *(*uint32)(cvalue), nil
318+
return getValueInternal[uint32](valueNum, cvalue)
410319
case TILEDB_UINT64:
411-
if cvalue == nil {
412-
return uint64(0), nil
413-
}
414-
if valueNum > 1 {
415-
tmpValue := make([]uint64, valueNum)
416-
tmpslice := (*[1 << 46]C.uint64_t)(cvalue)[:valueNum:valueNum]
417-
for i, s := range tmpslice {
418-
tmpValue[i] = uint64(s)
419-
}
420-
return tmpValue, nil
421-
}
422-
return *(*uint64)(cvalue), nil
320+
return getValueInternal[uint64](valueNum, cvalue)
423321
case TILEDB_FLOAT32:
424-
if cvalue == nil {
425-
return float32(0), nil
426-
}
427-
if valueNum > 1 {
428-
tmpValue := make([]float32, valueNum)
429-
tmpslice := (*[1 << 46]C.float)(cvalue)[:valueNum:valueNum]
430-
for i, s := range tmpslice {
431-
tmpValue[i] = float32(s)
432-
}
433-
return tmpValue, nil
434-
}
435-
return *(*float32)(cvalue), nil
322+
return getValueInternal[float32](valueNum, cvalue)
436323
case TILEDB_FLOAT64:
437-
if cvalue == nil {
438-
return float64(0), nil
439-
}
440-
if valueNum > 1 {
441-
tmpValue := make([]float64, valueNum)
442-
tmpslice := (*[1 << 46]C.double)(cvalue)[:valueNum:valueNum]
443-
for i, s := range tmpslice {
444-
tmpValue[i] = float64(s)
445-
}
446-
return tmpValue, nil
447-
}
448-
return *(*float64)(cvalue), nil
449-
case TILEDB_CHAR:
450-
if cvalue == nil || valueNum == 0 {
451-
return "", nil
452-
}
453-
tmpslice := (*[1 << 46]C.char)(cvalue)[:valueNum:valueNum]
454-
// TODO: Handle overflow from unsigned conversion
455-
return C.GoStringN(&tmpslice[0], C.int(valueNum))[0:valueNum], nil
456-
case TILEDB_STRING_ASCII:
457-
if cvalue == nil || valueNum == 0 {
458-
return "", nil
459-
}
460-
tmpslice := (*[1 << 46]C.char)(cvalue)[:valueNum:valueNum]
461-
// TODO: Handle overflow from unsigned conversion
462-
return C.GoStringN(&tmpslice[0], C.int(valueNum))[0:valueNum], nil
463-
case TILEDB_STRING_UTF8:
464-
if cvalue == nil || valueNum == 0 {
465-
return "", nil
466-
}
467-
tmpslice := (*[1 << 46]C.char)(cvalue)[:valueNum:valueNum]
468-
// TODO: Handle overflow from unsigned conversion
469-
return C.GoStringN(&tmpslice[0], C.int(valueNum))[0:valueNum], nil
324+
return getValueInternal[float64](valueNum, cvalue)
325+
case TILEDB_CHAR, TILEDB_STRING_ASCII, TILEDB_STRING_UTF8:
326+
return C.GoStringN((*C.char)(cvalue), C.int(valueNum)), nil
470327
case TILEDB_DATETIME_YEAR, TILEDB_DATETIME_MONTH, TILEDB_DATETIME_WEEK,
471328
TILEDB_DATETIME_DAY, TILEDB_DATETIME_HR, TILEDB_DATETIME_MIN,
472329
TILEDB_DATETIME_SEC, TILEDB_DATETIME_MS, TILEDB_DATETIME_US,
473330
TILEDB_DATETIME_NS, TILEDB_DATETIME_PS, TILEDB_DATETIME_FS,
474331
TILEDB_DATETIME_AS, TILEDB_TIME_HR, TILEDB_TIME_MIN, TILEDB_TIME_SEC, TILEDB_TIME_MS, TILEDB_TIME_US, TILEDB_TIME_NS, TILEDB_TIME_PS, TILEDB_TIME_FS, TILEDB_TIME_AS:
475332
if valueNum > 1 {
476-
return nil, fmt.Errorf("Unrecognized value type: %d", d)
477-
} else {
478-
if cvalue == nil {
479-
return int64(0), nil
480-
}
481-
var timestamp interface{} = *(*int16)(cvalue)
482-
return GetTimeFromTimestamp(d, timestamp.(int64)), nil
333+
return nil, fmt.Errorf("only 1 timestamp may be returned, not %d", d)
483334
}
335+
if cvalue == nil {
336+
return time.Time{}, nil
337+
}
338+
timestamp := *(*int64)(cvalue)
339+
return GetTimeFromTimestamp(d, timestamp), nil
484340
case TILEDB_BOOL:
341+
// We handle this differently to ensure that our bools are always in the
342+
// canonical form (true/1 or false/0).
485343
if cvalue == nil {
486344
return false, nil
487345
}
488-
if valueNum > 1 {
489-
tmpValue := make([]bool, valueNum)
490-
tmpslice := (*[1 << 46]C.int8_t)(cvalue)[:valueNum:valueNum]
491-
for i, s := range tmpslice {
492-
tmpValue[i] = s != 0
493-
}
494-
return tmpValue, nil
346+
bytes := unsafeSlice[byte](cvalue, valueNum)
347+
if valueNum == 1 {
348+
return bytes[0] != 0, nil
349+
}
350+
bools := make([]bool, valueNum)
351+
for i, b := range bytes {
352+
bools[i] = b != 0
495353
}
496-
return *(*int8)(cvalue), nil
354+
return bools, nil
497355
default:
498356
return nil, fmt.Errorf("Unrecognized value type: %d", d)
499357
}
500358
}
501359

360+
// getValueInternal handles the internals of Datatype.GetValue. It returns
361+
// `valueNum` Ts located at `ptr`. As a special case, if valueNum == 1,
362+
// it returns a T itself rather than a []T.
363+
func getValueInternal[T any](valueNum uint, ptr unsafe.Pointer) (any, error) {
364+
var singleValue T
365+
if ptr == nil {
366+
return singleValue, nil
367+
}
368+
if valueNum == 1 {
369+
singleValue = *(*T)(ptr)
370+
return singleValue, nil
371+
}
372+
out := make([]T, valueNum)
373+
inSlice := unsafeSlice[T](ptr, valueNum)
374+
copy(out, inSlice)
375+
return out, nil
376+
}
377+
502378
var tileDBInt, tileDBUint = intUintTypes() // The Datatypes of Go `int` and `uint`.
503379

504380
func intUintTypes() (Datatype, Datatype) {

memory.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,12 @@ func (bb byteBuffer) subSlice(sliceStart unsafe.Pointer, sliceBytes uintptr) []b
5555
startIdx := uintptr(sliceStart) - uintptr(bb.start())
5656
return bb[startIdx:sliceBytes]
5757
}
58+
59+
// unsafeSlice creates a slice pointing at the given memory.
60+
func unsafeSlice[T any](ptr unsafe.Pointer, length uint) []T {
61+
if ptr == nil {
62+
return nil
63+
}
64+
typedPtr := (*T)(ptr)
65+
return unsafe.Slice(typedPtr, length)
66+
}

0 commit comments

Comments
 (0)