From f17963ea611201e5aa9af755f99a3c1c9aeaaead Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 13 Sep 2024 10:36:10 -0400 Subject: [PATCH] GH-120: Add initial Decimal32/Decimal64 implementation (#121) Fix GH-120 ### Rationale for this change Widening the Decimal128/256 type to allow for bitwidths of 32 and 64 allows for more interoperability with other libraries and utilities which already support these types. This provides even more opportunities for zero-copy interactions between things such as libcudf and various databases. ### What changes are included in this PR? This PR contains the basic Go implementations for Decimal32/Decimal64 types, arrays, builders and scalars. It also includes the minimum necessary to get everything compiling and tests passing without also extending the acero kernels and parquet handling (both of which will be handled in follow-up PRs). ### Are these changes tested? Yes, tests were extended where applicable to add decimal32/decimal64 cases. --- .gitignore | 1 + arrow/array/array.go | 2 + arrow/array/array_test.go | 2 + arrow/array/builder.go | 8 + arrow/array/compare.go | 20 +- arrow/array/decimal.go | 432 +++++++++++++++++++++++++++ arrow/array/decimal128.go | 368 ----------------------- arrow/array/decimal256.go | 368 ----------------------- arrow/array/dictionary.go | 66 ++++- arrow/array/numeric.gen.go | 17 ++ arrow/cdata/cdata.go | 13 +- arrow/cdata/cdata_exports.go | 4 + arrow/cdata/cdata_test.go | 2 +- arrow/datatype.go | 24 +- arrow/datatype_fixedwidth.go | 89 +++++- arrow/datatype_fixedwidth_test.go | 88 ++++++ arrow/decimal/decimal.go | 473 ++++++++++++++++++++++++++++++ arrow/decimal/decimal_test.go | 470 +++++++++++++++++++++++++++++ arrow/decimal/traits.go | 78 +++++ arrow/decimal128/decimal128.go | 10 + arrow/decimal256/decimal256.go | 10 + arrow/internal/arrjson/arrjson.go | 84 +++++- arrow/internal/flatbuf/Decimal.go | 24 +- arrow/ipc/file_reader.go | 2 +- arrow/ipc/metadata.go | 20 +- arrow/type_string.go | 6 +- arrow/type_traits.go | 5 +- arrow/type_traits_decimal128.go | 14 +- arrow/type_traits_decimal256.go | 14 +- arrow/type_traits_decimal32.go | 57 ++++ arrow/type_traits_decimal64.go | 57 ++++ arrow/type_traits_test.go | 90 ++++++ 32 files changed, 2116 insertions(+), 802 deletions(-) create mode 100644 arrow/array/decimal.go delete mode 100644 arrow/array/decimal128.go delete mode 100644 arrow/array/decimal256.go create mode 100644 arrow/decimal/decimal.go create mode 100644 arrow/decimal/decimal_test.go create mode 100644 arrow/decimal/traits.go create mode 100644 arrow/type_traits_decimal32.go create mode 100644 arrow/type_traits_decimal64.go diff --git a/.gitignore b/.gitignore index 06f40703..55232367 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +.vscode /apache-arrow-go.tar.gz /dev/release/apache-rat-*.jar /dev/release/filtered_rat.txt diff --git a/arrow/array/array.go b/arrow/array/array.go index 586d8765..6e281a43 100644 --- a/arrow/array/array.go +++ b/arrow/array/array.go @@ -160,6 +160,8 @@ func init() { arrow.TIME64: func(data arrow.ArrayData) arrow.Array { return NewTime64Data(data) }, arrow.INTERVAL_MONTHS: func(data arrow.ArrayData) arrow.Array { return NewMonthIntervalData(data) }, arrow.INTERVAL_DAY_TIME: func(data arrow.ArrayData) arrow.Array { return NewDayTimeIntervalData(data) }, + arrow.DECIMAL32: func(data arrow.ArrayData) arrow.Array { return NewDecimal32Data(data) }, + arrow.DECIMAL64: func(data arrow.ArrayData) arrow.Array { return NewDecimal64Data(data) }, arrow.DECIMAL128: func(data arrow.ArrayData) arrow.Array { return NewDecimal128Data(data) }, arrow.DECIMAL256: func(data arrow.ArrayData) arrow.Array { return NewDecimal256Data(data) }, arrow.LIST: func(data arrow.ArrayData) arrow.Array { return NewListData(data) }, diff --git a/arrow/array/array_test.go b/arrow/array/array_test.go index 203c62ea..9509e314 100644 --- a/arrow/array/array_test.go +++ b/arrow/array/array_test.go @@ -75,6 +75,8 @@ func TestMakeFromData(t *testing.T) { {name: "time64", d: &testDataType{arrow.TIME64}}, {name: "month_interval", d: arrow.FixedWidthTypes.MonthInterval}, {name: "day_time_interval", d: arrow.FixedWidthTypes.DayTimeInterval}, + {name: "decimal32", d: &testDataType{arrow.DECIMAL32}}, + {name: "decimal64", d: &testDataType{arrow.DECIMAL64}}, {name: "decimal128", d: &testDataType{arrow.DECIMAL128}}, {name: "decimal256", d: &testDataType{arrow.DECIMAL256}}, {name: "month_day_nano_interval", d: arrow.FixedWidthTypes.MonthDayNanoInterval}, diff --git a/arrow/array/builder.go b/arrow/array/builder.go index 108b6152..a2a40d48 100644 --- a/arrow/array/builder.go +++ b/arrow/array/builder.go @@ -313,6 +313,14 @@ func NewBuilder(mem memory.Allocator, dtype arrow.DataType) Builder { return NewDayTimeIntervalBuilder(mem) case arrow.INTERVAL_MONTH_DAY_NANO: return NewMonthDayNanoIntervalBuilder(mem) + case arrow.DECIMAL32: + if typ, ok := dtype.(*arrow.Decimal32Type); ok { + return NewDecimal32Builder(mem, typ) + } + case arrow.DECIMAL64: + if typ, ok := dtype.(*arrow.Decimal64Type); ok { + return NewDecimal64Builder(mem, typ) + } case arrow.DECIMAL128: if typ, ok := dtype.(*arrow.Decimal128Type); ok { return NewDecimal128Builder(mem, typ) diff --git a/arrow/array/compare.go b/arrow/array/compare.go index 4117880f..ad3a50b8 100644 --- a/arrow/array/compare.go +++ b/arrow/array/compare.go @@ -271,12 +271,18 @@ func Equal(left, right arrow.Array) bool { case *Float64: r := right.(*Float64) return arrayEqualFloat64(l, r) + case *Decimal32: + r := right.(*Decimal32) + return arrayEqualDecimal(l, r) + case *Decimal64: + r := right.(*Decimal64) + return arrayEqualDecimal(l, r) case *Decimal128: r := right.(*Decimal128) - return arrayEqualDecimal128(l, r) + return arrayEqualDecimal(l, r) case *Decimal256: r := right.(*Decimal256) - return arrayEqualDecimal256(l, r) + return arrayEqualDecimal(l, r) case *Date32: r := right.(*Date32) return arrayEqualDate32(l, r) @@ -527,12 +533,18 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { case *Float64: r := right.(*Float64) return arrayApproxEqualFloat64(l, r, opt) + case *Decimal32: + r := right.(*Decimal32) + return arrayEqualDecimal(l, r) + case *Decimal64: + r := right.(*Decimal64) + return arrayEqualDecimal(l, r) case *Decimal128: r := right.(*Decimal128) - return arrayEqualDecimal128(l, r) + return arrayEqualDecimal(l, r) case *Decimal256: r := right.(*Decimal256) - return arrayEqualDecimal256(l, r) + return arrayEqualDecimal(l, r) case *Date32: r := right.(*Date32) return arrayEqualDate32(l, r) diff --git a/arrow/array/decimal.go b/arrow/array/decimal.go new file mode 100644 index 00000000..1a9d61c1 --- /dev/null +++ b/arrow/array/decimal.go @@ -0,0 +1,432 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package array + +import ( + "bytes" + "fmt" + "reflect" + "strings" + "sync/atomic" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/json" +) + +type baseDecimal[T interface { + decimal.DecimalTypes + decimal.Num[T] +}] struct { + array + + values []T +} + +func newDecimalData[T interface { + decimal.DecimalTypes + decimal.Num[T] +}](data arrow.ArrayData) *baseDecimal[T] { + a := &baseDecimal[T]{} + a.refCount = 1 + a.setData(data.(*Data)) + return a +} + +func (a *baseDecimal[T]) Value(i int) T { return a.values[i] } + +func (a *baseDecimal[T]) ValueStr(i int) string { + if a.IsNull(i) { + return NullValueStr + } + return a.GetOneForMarshal(i).(string) +} + +func (a *baseDecimal[T]) Values() []T { return a.values } + +func (a *baseDecimal[T]) String() string { + o := new(strings.Builder) + o.WriteString("[") + for i := 0; i < a.Len(); i++ { + if i > 0 { + fmt.Fprintf(o, " ") + } + switch { + case a.IsNull(i): + o.WriteString(NullValueStr) + default: + fmt.Fprintf(o, "%v", a.Value(i)) + } + } + o.WriteString("]") + return o.String() +} + +func (a *baseDecimal[T]) setData(data *Data) { + a.array.setData(data) + vals := data.buffers[1] + if vals != nil { + a.values = arrow.GetData[T](vals.Bytes()) + beg := a.array.data.offset + end := beg + a.array.data.length + a.values = a.values[beg:end] + } +} + +func (a *baseDecimal[T]) GetOneForMarshal(i int) any { + if a.IsNull(i) { + return nil + } + + typ := a.DataType().(arrow.DecimalType) + n, scale := a.Value(i), typ.GetScale() + return n.ToBigFloat(scale).Text('g', int(typ.GetPrecision())) +} + +func (a *baseDecimal[T]) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := 0; i < a.Len(); i++ { + vals[i] = a.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +func arrayEqualDecimal[T interface { + decimal.DecimalTypes + decimal.Num[T] +}](left, right *baseDecimal[T]) bool { + for i := 0; i < left.Len(); i++ { + if left.IsNull(i) { + continue + } + + if left.Value(i) != right.Value(i) { + return false + } + } + return true +} + +type Decimal32 = baseDecimal[decimal.Decimal32] + +func NewDecimal32Data(data arrow.ArrayData) *Decimal32 { + return newDecimalData[decimal.Decimal32](data) +} + +type Decimal64 = baseDecimal[decimal.Decimal64] + +func NewDecimal64Data(data arrow.ArrayData) *Decimal64 { + return newDecimalData[decimal.Decimal64](data) +} + +type Decimal128 = baseDecimal[decimal.Decimal128] + +func NewDecimal128Data(data arrow.ArrayData) *Decimal128 { + return newDecimalData[decimal.Decimal128](data) +} + +type Decimal256 = baseDecimal[decimal.Decimal256] + +func NewDecimal256Data(data arrow.ArrayData) *Decimal256 { + return newDecimalData[decimal.Decimal256](data) +} + +type Decimal32Builder = baseDecimalBuilder[decimal.Decimal32] +type Decimal64Builder = baseDecimalBuilder[decimal.Decimal64] +type Decimal128Builder struct { + *baseDecimalBuilder[decimal.Decimal128] +} + +func (b *Decimal128Builder) NewDecimal128Array() *Decimal128 { + return b.NewDecimalArray() +} + +type Decimal256Builder struct { + *baseDecimalBuilder[decimal.Decimal256] +} + +func (b *Decimal256Builder) NewDecimal256Array() *Decimal256 { + return b.NewDecimalArray() +} + +type baseDecimalBuilder[T interface { + decimal.DecimalTypes + decimal.Num[T] +}] struct { + builder + traits decimal.Traits[T] + + dtype arrow.DecimalType + data *memory.Buffer + rawData []T +} + +func newDecimalBuilder[T interface { + decimal.DecimalTypes + decimal.Num[T] +}, DT arrow.DecimalType](mem memory.Allocator, dtype DT) *baseDecimalBuilder[T] { + return &baseDecimalBuilder[T]{ + builder: builder{refCount: 1, mem: mem}, + dtype: dtype, + } +} + +func (b *baseDecimalBuilder[T]) Type() arrow.DataType { return b.dtype } + +func (b *baseDecimalBuilder[T]) Release() { + debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + + if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.nullBitmap != nil { + b.nullBitmap.Release() + b.nullBitmap = nil + } + if b.data != nil { + b.data.Release() + b.data, b.rawData = nil, nil + } + } +} + +func (b *baseDecimalBuilder[T]) Append(v T) { + b.Reserve(1) + b.UnsafeAppend(v) +} + +func (b *baseDecimalBuilder[T]) UnsafeAppend(v T) { + bitutil.SetBit(b.nullBitmap.Bytes(), b.length) + b.rawData[b.length] = v + b.length++ +} + +func (b *baseDecimalBuilder[T]) AppendNull() { + b.Reserve(1) + b.UnsafeAppendBoolToBitmap(false) +} + +func (b *baseDecimalBuilder[T]) AppendNulls(n int) { + for i := 0; i < n; i++ { + b.AppendNull() + } +} + +func (b *baseDecimalBuilder[T]) AppendEmptyValue() { + var empty T + b.Append(empty) +} + +func (b *baseDecimalBuilder[T]) AppendEmptyValues(n int) { + for i := 0; i < n; i++ { + b.AppendEmptyValue() + } +} + +func (b *baseDecimalBuilder[T]) UnsafeAppendBoolToBitmap(isValid bool) { + if isValid { + bitutil.SetBit(b.nullBitmap.Bytes(), b.length) + } else { + b.nulls++ + } + b.length++ +} + +func (b *baseDecimalBuilder[T]) AppendValues(v []T, valid []bool) { + if len(v) != len(valid) && len(valid) != 0 { + panic("len(v) != len(valid) && len(valid) != 0") + } + + if len(v) == 0 { + return + } + + b.Reserve(len(v)) + if len(v) > 0 { + copy(b.rawData[b.length:], v) + } + b.builder.unsafeAppendBoolsToBitmap(valid, len(v)) +} + +func (b *baseDecimalBuilder[T]) init(capacity int) { + b.builder.init(capacity) + + b.data = memory.NewResizableBuffer(b.mem) + bytesN := int(reflect.TypeFor[T]().Size()) * capacity + b.data.Resize(bytesN) + b.rawData = arrow.GetData[T](b.data.Bytes()) +} + +func (b *baseDecimalBuilder[T]) Reserve(n int) { + b.builder.reserve(n, b.Resize) +} + +func (b *baseDecimalBuilder[T]) Resize(n int) { + nBuilder := n + if n < minBuilderCapacity { + n = minBuilderCapacity + } + + if b.capacity == 0 { + b.init(n) + } else { + b.builder.resize(nBuilder, b.init) + b.data.Resize(b.traits.BytesRequired(n)) + b.rawData = arrow.GetData[T](b.data.Bytes()) + } +} + +func (b *baseDecimalBuilder[T]) NewDecimalArray() (a *baseDecimal[T]) { + data := b.newData() + a = newDecimalData[T](data) + data.Release() + return +} + +func (b *baseDecimalBuilder[T]) NewArray() arrow.Array { + return b.NewDecimalArray() +} + +func (b *baseDecimalBuilder[T]) newData() (data *Data) { + bytesRequired := b.traits.BytesRequired(b.length) + if bytesRequired > 0 && bytesRequired < b.data.Len() { + // trim buffers + b.data.Resize(bytesRequired) + } + data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap, b.data}, nil, b.nulls, 0) + b.reset() + + if b.data != nil { + b.data.Release() + b.data, b.rawData = nil, nil + } + + return +} + +func (b *baseDecimalBuilder[T]) AppendValueFromString(s string) error { + if s == NullValueStr { + b.AppendNull() + return nil + } + + val, err := b.traits.FromString(s, b.dtype.GetPrecision(), b.dtype.GetScale()) + if err != nil { + b.AppendNull() + return err + } + b.Append(val) + return nil +} + +func (b *baseDecimalBuilder[T]) UnmarshalOne(dec *json.Decoder) error { + t, err := dec.Token() + if err != nil { + return err + } + + var token T + switch v := t.(type) { + case float64: + token, err = b.traits.FromFloat64(v, b.dtype.GetPrecision(), b.dtype.GetScale()) + if err != nil { + return err + } + b.Append(token) + case string: + token, err = b.traits.FromString(v, b.dtype.GetPrecision(), b.dtype.GetScale()) + if err != nil { + return err + } + b.Append(token) + case json.Number: + token, err = b.traits.FromString(v.String(), b.dtype.GetPrecision(), b.dtype.GetScale()) + if err != nil { + return err + } + b.Append(token) + case nil: + b.AppendNull() + default: + return &json.UnmarshalTypeError{ + Value: fmt.Sprint(t), + Type: reflect.TypeFor[T](), + Offset: dec.InputOffset(), + } + } + + return nil +} + +func (b *baseDecimalBuilder[T]) Unmarshal(dec *json.Decoder) error { + for dec.More() { + if err := b.UnmarshalOne(dec); err != nil { + return err + } + } + return nil +} + +func (b *baseDecimalBuilder[T]) UnmarshalJSON(data []byte) error { + dec := json.NewDecoder(bytes.NewReader(data)) + t, err := dec.Token() + if err != nil { + return err + } + + if delim, ok := t.(json.Delim); !ok || delim != '[' { + return fmt.Errorf("decimal builder must unpack from json array, found %s", delim) + } + + return b.Unmarshal(dec) +} + +func NewDecimal32Builder(mem memory.Allocator, dtype *arrow.Decimal32Type) *Decimal32Builder { + b := newDecimalBuilder[decimal.Decimal32](mem, dtype) + b.traits = decimal.Dec32Traits + return b +} + +func NewDecimal64Builder(mem memory.Allocator, dtype *arrow.Decimal64Type) *Decimal64Builder { + b := newDecimalBuilder[decimal.Decimal64](mem, dtype) + b.traits = decimal.Dec64Traits + return b +} + +func NewDecimal128Builder(mem memory.Allocator, dtype *arrow.Decimal128Type) *Decimal128Builder { + b := newDecimalBuilder[decimal.Decimal128](mem, dtype) + b.traits = decimal.Dec128Traits + return &Decimal128Builder{b} +} + +func NewDecimal256Builder(mem memory.Allocator, dtype *arrow.Decimal256Type) *Decimal256Builder { + b := newDecimalBuilder[decimal.Decimal256](mem, dtype) + b.traits = decimal.Dec256Traits + return &Decimal256Builder{b} +} + +var ( + _ arrow.Array = (*Decimal32)(nil) + _ arrow.Array = (*Decimal64)(nil) + _ arrow.Array = (*Decimal128)(nil) + _ arrow.Array = (*Decimal256)(nil) + _ Builder = (*Decimal32Builder)(nil) + _ Builder = (*Decimal64Builder)(nil) + _ Builder = (*Decimal128Builder)(nil) + _ Builder = (*Decimal256Builder)(nil) +) diff --git a/arrow/array/decimal128.go b/arrow/array/decimal128.go deleted file mode 100644 index c5861dce..00000000 --- a/arrow/array/decimal128.go +++ /dev/null @@ -1,368 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package array - -import ( - "bytes" - "fmt" - "math/big" - "reflect" - "strings" - "sync/atomic" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/bitutil" - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/internal/debug" - "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/arrow-go/v18/internal/json" -) - -// A type which represents an immutable sequence of 128-bit decimal values. -type Decimal128 struct { - array - - values []decimal128.Num -} - -func NewDecimal128Data(data arrow.ArrayData) *Decimal128 { - a := &Decimal128{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -func (a *Decimal128) Value(i int) decimal128.Num { return a.values[i] } - -func (a *Decimal128) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return a.GetOneForMarshal(i).(string) -} - -func (a *Decimal128) Values() []decimal128.Num { return a.values } - -func (a *Decimal128) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i := 0; i < a.Len(); i++ { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", a.Value(i)) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Decimal128) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Decimal128Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} -func (a *Decimal128) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - typ := a.DataType().(*arrow.Decimal128Type) - n := a.Value(i) - scale := typ.Scale - f := (&big.Float{}).SetInt(n.BigInt()) - if scale < 0 { - f.SetPrec(128).Mul(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(-scale)).BigInt())) - } else { - f.SetPrec(128).Quo(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(scale)).BigInt())) - } - return f.Text('g', int(typ.Precision)) -} - -// ["1.23", ] -func (a *Decimal128) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) - } - return json.Marshal(vals) -} - -func arrayEqualDecimal128(left, right *Decimal128) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -type Decimal128Builder struct { - builder - - dtype *arrow.Decimal128Type - data *memory.Buffer - rawData []decimal128.Num -} - -func NewDecimal128Builder(mem memory.Allocator, dtype *arrow.Decimal128Type) *Decimal128Builder { - return &Decimal128Builder{ - builder: builder{refCount: 1, mem: mem}, - dtype: dtype, - } -} - -func (b *Decimal128Builder) Type() arrow.DataType { return b.dtype } - -// Release decreases the reference count by 1. -// When the reference count goes to zero, the memory is freed. -func (b *Decimal128Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") - - if atomic.AddInt64(&b.refCount, -1) == 0 { - if b.nullBitmap != nil { - b.nullBitmap.Release() - b.nullBitmap = nil - } - if b.data != nil { - b.data.Release() - b.data = nil - b.rawData = nil - } - } -} - -func (b *Decimal128Builder) Append(v decimal128.Num) { - b.Reserve(1) - b.UnsafeAppend(v) -} - -func (b *Decimal128Builder) UnsafeAppend(v decimal128.Num) { - bitutil.SetBit(b.nullBitmap.Bytes(), b.length) - b.rawData[b.length] = v - b.length++ -} - -func (b *Decimal128Builder) AppendNull() { - b.Reserve(1) - b.UnsafeAppendBoolToBitmap(false) -} - -func (b *Decimal128Builder) AppendNulls(n int) { - for i := 0; i < n; i++ { - b.AppendNull() - } -} - -func (b *Decimal128Builder) AppendEmptyValue() { - b.Append(decimal128.Num{}) -} - -func (b *Decimal128Builder) AppendEmptyValues(n int) { - for i := 0; i < n; i++ { - b.AppendEmptyValue() - } -} - -func (b *Decimal128Builder) UnsafeAppendBoolToBitmap(isValid bool) { - if isValid { - bitutil.SetBit(b.nullBitmap.Bytes(), b.length) - } else { - b.nulls++ - } - b.length++ -} - -// AppendValues will append the values in the v slice. The valid slice determines which values -// in v are valid (not null). The valid slice must either be empty or be equal in length to v. If empty, -// all values in v are appended and considered valid. -func (b *Decimal128Builder) AppendValues(v []decimal128.Num, valid []bool) { - if len(v) != len(valid) && len(valid) != 0 { - panic("len(v) != len(valid) && len(valid) != 0") - } - - if len(v) == 0 { - return - } - - b.Reserve(len(v)) - if len(v) > 0 { - arrow.Decimal128Traits.Copy(b.rawData[b.length:], v) - } - b.builder.unsafeAppendBoolsToBitmap(valid, len(v)) -} - -func (b *Decimal128Builder) init(capacity int) { - b.builder.init(capacity) - - b.data = memory.NewResizableBuffer(b.mem) - bytesN := arrow.Decimal128Traits.BytesRequired(capacity) - b.data.Resize(bytesN) - b.rawData = arrow.Decimal128Traits.CastFromBytes(b.data.Bytes()) -} - -// Reserve ensures there is enough space for appending n elements -// by checking the capacity and calling Resize if necessary. -func (b *Decimal128Builder) Reserve(n int) { - b.builder.reserve(n, b.Resize) -} - -// Resize adjusts the space allocated by b to n elements. If n is greater than b.Cap(), -// additional memory will be allocated. If n is smaller, the allocated memory may reduced. -func (b *Decimal128Builder) Resize(n int) { - nBuilder := n - if n < minBuilderCapacity { - n = minBuilderCapacity - } - - if b.capacity == 0 { - b.init(n) - } else { - b.builder.resize(nBuilder, b.init) - b.data.Resize(arrow.Decimal128Traits.BytesRequired(n)) - b.rawData = arrow.Decimal128Traits.CastFromBytes(b.data.Bytes()) - } -} - -// NewArray creates a Decimal128 array from the memory buffers used by the builder and resets the Decimal128Builder -// so it can be used to build a new array. -func (b *Decimal128Builder) NewArray() arrow.Array { - return b.NewDecimal128Array() -} - -// NewDecimal128Array creates a Decimal128 array from the memory buffers used by the builder and resets the Decimal128Builder -// so it can be used to build a new array. -func (b *Decimal128Builder) NewDecimal128Array() (a *Decimal128) { - data := b.newData() - a = NewDecimal128Data(data) - data.Release() - return -} - -func (b *Decimal128Builder) newData() (data *Data) { - bytesRequired := arrow.Decimal128Traits.BytesRequired(b.length) - if bytesRequired > 0 && bytesRequired < b.data.Len() { - // trim buffers - b.data.Resize(bytesRequired) - } - data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap, b.data}, nil, b.nulls, 0) - b.reset() - - if b.data != nil { - b.data.Release() - b.data = nil - b.rawData = nil - } - - return -} - -func (b *Decimal128Builder) AppendValueFromString(s string) error { - if s == NullValueStr { - b.AppendNull() - return nil - } - val, err := decimal128.FromString(s, b.dtype.Precision, b.dtype.Scale) - if err != nil { - b.AppendNull() - return err - } - b.Append(val) - return nil -} - -func (b *Decimal128Builder) UnmarshalOne(dec *json.Decoder) error { - t, err := dec.Token() - if err != nil { - return err - } - - switch v := t.(type) { - case float64: - val, err := decimal128.FromFloat64(v, b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(val) - case string: - val, err := decimal128.FromString(v, b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(val) - case json.Number: - val, err := decimal128.FromString(v.String(), b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(val) - case nil: - b.AppendNull() - return nil - default: - return &json.UnmarshalTypeError{ - Value: fmt.Sprint(t), - Type: reflect.TypeOf(decimal128.Num{}), - Offset: dec.InputOffset(), - } - } - - return nil -} - -func (b *Decimal128Builder) Unmarshal(dec *json.Decoder) error { - for dec.More() { - if err := b.UnmarshalOne(dec); err != nil { - return err - } - } - return nil -} - -// UnmarshalJSON will add the unmarshalled values to this builder. -// -// If the values are strings, they will get parsed with big.ParseFloat using -// a rounding mode of big.ToNearestAway currently. -func (b *Decimal128Builder) UnmarshalJSON(data []byte) error { - dec := json.NewDecoder(bytes.NewReader(data)) - t, err := dec.Token() - if err != nil { - return err - } - - if delim, ok := t.(json.Delim); !ok || delim != '[' { - return fmt.Errorf("decimal128 builder must unpack from json array, found %s", delim) - } - - return b.Unmarshal(dec) -} - -var ( - _ arrow.Array = (*Decimal128)(nil) - _ Builder = (*Decimal128Builder)(nil) -) diff --git a/arrow/array/decimal256.go b/arrow/array/decimal256.go deleted file mode 100644 index 7f30f897..00000000 --- a/arrow/array/decimal256.go +++ /dev/null @@ -1,368 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package array - -import ( - "bytes" - "fmt" - "math/big" - "reflect" - "strings" - "sync/atomic" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/bitutil" - "github.com/apache/arrow-go/v18/arrow/decimal256" - "github.com/apache/arrow-go/v18/arrow/internal/debug" - "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/arrow-go/v18/internal/json" -) - -// Decimal256 is a type that represents an immutable sequence of 256-bit decimal values. -type Decimal256 struct { - array - - values []decimal256.Num -} - -func NewDecimal256Data(data arrow.ArrayData) *Decimal256 { - a := &Decimal256{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -func (a *Decimal256) Value(i int) decimal256.Num { return a.values[i] } - -func (a *Decimal256) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return a.GetOneForMarshal(i).(string) -} - -func (a *Decimal256) Values() []decimal256.Num { return a.values } - -func (a *Decimal256) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i := 0; i < a.Len(); i++ { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", a.Value(i)) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Decimal256) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Decimal256Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Decimal256) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - typ := a.DataType().(*arrow.Decimal256Type) - n := a.Value(i) - scale := typ.Scale - f := (&big.Float{}).SetInt(n.BigInt()) - if scale < 0 { - f.SetPrec(256).Mul(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(-scale)).BigInt())) - } else { - f.SetPrec(256).Quo(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(scale)).BigInt())) - } - return f.Text('g', int(typ.Precision)) -} - -func (a *Decimal256) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) - } - return json.Marshal(vals) -} - -func arrayEqualDecimal256(left, right *Decimal256) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -type Decimal256Builder struct { - builder - - dtype *arrow.Decimal256Type - data *memory.Buffer - rawData []decimal256.Num -} - -func NewDecimal256Builder(mem memory.Allocator, dtype *arrow.Decimal256Type) *Decimal256Builder { - return &Decimal256Builder{ - builder: builder{refCount: 1, mem: mem}, - dtype: dtype, - } -} - -// Release decreases the reference count by 1. -// When the reference count goes to zero, the memory is freed. -func (b *Decimal256Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") - - if atomic.AddInt64(&b.refCount, -1) == 0 { - if b.nullBitmap != nil { - b.nullBitmap.Release() - b.nullBitmap = nil - } - if b.data != nil { - b.data.Release() - b.data = nil - b.rawData = nil - } - } -} - -func (b *Decimal256Builder) Append(v decimal256.Num) { - b.Reserve(1) - b.UnsafeAppend(v) -} - -func (b *Decimal256Builder) UnsafeAppend(v decimal256.Num) { - bitutil.SetBit(b.nullBitmap.Bytes(), b.length) - b.rawData[b.length] = v - b.length++ -} - -func (b *Decimal256Builder) AppendNull() { - b.Reserve(1) - b.UnsafeAppendBoolToBitmap(false) -} - -func (b *Decimal256Builder) AppendNulls(n int) { - for i := 0; i < n; i++ { - b.AppendNull() - } -} - -func (b *Decimal256Builder) AppendEmptyValue() { - b.Append(decimal256.Num{}) -} - -func (b *Decimal256Builder) AppendEmptyValues(n int) { - for i := 0; i < n; i++ { - b.AppendEmptyValue() - } -} - -func (b *Decimal256Builder) Type() arrow.DataType { return b.dtype } - -func (b *Decimal256Builder) UnsafeAppendBoolToBitmap(isValid bool) { - if isValid { - bitutil.SetBit(b.nullBitmap.Bytes(), b.length) - } else { - b.nulls++ - } - b.length++ -} - -// AppendValues will append the values in the v slice. The valid slice determines which values -// in v are valid (not null). The valid slice must either be empty or be equal in length to v. If empty, -// all values in v are appended and considered valid. -func (b *Decimal256Builder) AppendValues(v []decimal256.Num, valid []bool) { - if len(v) != len(valid) && len(valid) != 0 { - panic("arrow/array: len(v) != len(valid) && len(valid) != 0") - } - - if len(v) == 0 { - return - } - - b.Reserve(len(v)) - if len(v) > 0 { - arrow.Decimal256Traits.Copy(b.rawData[b.length:], v) - } - b.builder.unsafeAppendBoolsToBitmap(valid, len(v)) -} - -func (b *Decimal256Builder) init(capacity int) { - b.builder.init(capacity) - - b.data = memory.NewResizableBuffer(b.mem) - bytesN := arrow.Decimal256Traits.BytesRequired(capacity) - b.data.Resize(bytesN) - b.rawData = arrow.Decimal256Traits.CastFromBytes(b.data.Bytes()) -} - -// Reserve ensures there is enough space for appending n elements -// by checking the capacity and calling Resize if necessary. -func (b *Decimal256Builder) Reserve(n int) { - b.builder.reserve(n, b.Resize) -} - -// Resize adjusts the space allocated by b to n elements. If n is greater than b.Cap(), -// additional memory will be allocated. If n is smaller, the allocated memory may reduced. -func (b *Decimal256Builder) Resize(n int) { - nBuilder := n - if n < minBuilderCapacity { - n = minBuilderCapacity - } - - if b.capacity == 0 { - b.init(n) - } else { - b.builder.resize(nBuilder, b.init) - b.data.Resize(arrow.Decimal256Traits.BytesRequired(n)) - b.rawData = arrow.Decimal256Traits.CastFromBytes(b.data.Bytes()) - } -} - -// NewArray creates a Decimal256 array from the memory buffers used by the builder and resets the Decimal256Builder -// so it can be used to build a new array. -func (b *Decimal256Builder) NewArray() arrow.Array { - return b.NewDecimal256Array() -} - -// NewDecimal256Array creates a Decimal256 array from the memory buffers used by the builder and resets the Decimal256Builder -// so it can be used to build a new array. -func (b *Decimal256Builder) NewDecimal256Array() (a *Decimal256) { - data := b.newData() - a = NewDecimal256Data(data) - data.Release() - return -} - -func (b *Decimal256Builder) newData() (data *Data) { - bytesRequired := arrow.Decimal256Traits.BytesRequired(b.length) - if bytesRequired > 0 && bytesRequired < b.data.Len() { - // trim buffers - b.data.Resize(bytesRequired) - } - data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap, b.data}, nil, b.nulls, 0) - b.reset() - - if b.data != nil { - b.data.Release() - b.data = nil - b.rawData = nil - } - - return -} - -func (b *Decimal256Builder) AppendValueFromString(s string) error { - if s == NullValueStr { - b.AppendNull() - return nil - } - val, err := decimal256.FromString(s, b.dtype.Precision, b.dtype.Scale) - if err != nil { - b.AppendNull() - return err - } - b.Append(val) - return nil -} - -func (b *Decimal256Builder) UnmarshalOne(dec *json.Decoder) error { - t, err := dec.Token() - if err != nil { - return err - } - - switch v := t.(type) { - case float64: - val, err := decimal256.FromFloat64(v, b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(val) - case string: - out, err := decimal256.FromString(v, b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(out) - case json.Number: - out, err := decimal256.FromString(v.String(), b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(out) - case nil: - b.AppendNull() - return nil - default: - return &json.UnmarshalTypeError{ - Value: fmt.Sprint(t), - Type: reflect.TypeOf(decimal256.Num{}), - Offset: dec.InputOffset(), - } - } - - return nil -} - -func (b *Decimal256Builder) Unmarshal(dec *json.Decoder) error { - for dec.More() { - if err := b.UnmarshalOne(dec); err != nil { - return err - } - } - return nil -} - -// UnmarshalJSON will add the unmarshalled values to this builder. -// -// If the values are strings, they will get parsed with big.ParseFloat using -// a rounding mode of big.ToNearestAway currently. -func (b *Decimal256Builder) UnmarshalJSON(data []byte) error { - dec := json.NewDecoder(bytes.NewReader(data)) - t, err := dec.Token() - if err != nil { - return err - } - - if delim, ok := t.(json.Delim); !ok || delim != '[' { - return fmt.Errorf("arrow/array: decimal256 builder must unpack from json array, found %s", delim) - } - - return b.Unmarshal(dec) -} - -var ( - _ arrow.Array = (*Decimal256)(nil) - _ Builder = (*Decimal256Builder)(nil) -) diff --git a/arrow/array/dictionary.go b/arrow/array/dictionary.go index 34f0f2b4..0c23934a 100644 --- a/arrow/array/dictionary.go +++ b/arrow/array/dictionary.go @@ -27,6 +27,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/float16" @@ -392,7 +393,8 @@ func createMemoTable(mem memory.Allocator, dt arrow.DataType) (ret hashing.MemoT ret = hashing.NewFloat32MemoTable(0) case arrow.FLOAT64: ret = hashing.NewFloat64MemoTable(0) - case arrow.BINARY, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL128, arrow.DECIMAL256, arrow.INTERVAL_DAY_TIME, arrow.INTERVAL_MONTH_DAY_NANO: + case arrow.BINARY, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL32, arrow.DECIMAL64, + arrow.DECIMAL128, arrow.DECIMAL256, arrow.INTERVAL_DAY_TIME, arrow.INTERVAL_MONTH_DAY_NANO: ret = hashing.NewBinaryMemoTable(0, 0, NewBinaryBuilder(mem, arrow.BinaryTypes.Binary)) case arrow.STRING: ret = hashing.NewBinaryMemoTable(0, 0, NewBinaryBuilder(mem, arrow.BinaryTypes.String)) @@ -623,6 +625,22 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType } } return ret + case arrow.DECIMAL32: + ret := &Decimal32DictionaryBuilder{bldr} + if init != nil { + if err = ret.InsertDictValues(init.(*Decimal32)); err != nil { + panic(err) + } + } + return ret + case arrow.DECIMAL64: + ret := &Decimal64DictionaryBuilder{bldr} + if init != nil { + if err = ret.InsertDictValues(init.(*Decimal64)); err != nil { + panic(err) + } + } + return ret case arrow.DECIMAL128: ret := &Decimal128DictionaryBuilder{bldr} if init != nil { @@ -906,6 +924,16 @@ func getvalFn(arr arrow.Array) func(i int) interface{} { return func(i int) interface{} { return typedarr.Value(i) } case *String: return func(i int) interface{} { return typedarr.Value(i) } + case *Decimal32: + return func(i int) interface{} { + val := typedarr.Value(i) + return (*(*[arrow.Decimal32SizeBytes]byte)(unsafe.Pointer(&val)))[:] + } + case *Decimal64: + return func(i int) interface{} { + val := typedarr.Value(i) + return (*(*[arrow.Decimal64SizeBytes]byte)(unsafe.Pointer(&val)))[:] + } case *Decimal128: return func(i int) interface{} { val := typedarr.Value(i) @@ -1394,6 +1422,42 @@ func (b *FixedSizeBinaryDictionaryBuilder) InsertDictValues(arr *FixedSizeBinary return } +type Decimal32DictionaryBuilder struct { + dictionaryBuilder +} + +func (b *Decimal32DictionaryBuilder) Append(v decimal.Decimal32) error { + return b.appendValue((*(*[arrow.Decimal32SizeBytes]byte)(unsafe.Pointer(&v)))[:]) +} +func (b *Decimal32DictionaryBuilder) InsertDictValues(arr *Decimal32) (err error) { + data := arrow.Decimal32Traits.CastToBytes(arr.values) + for len(data) > 0 { + if err = b.insertDictValue(data[:arrow.Decimal32SizeBytes]); err != nil { + break + } + data = data[arrow.Decimal32SizeBytes:] + } + return +} + +type Decimal64DictionaryBuilder struct { + dictionaryBuilder +} + +func (b *Decimal64DictionaryBuilder) Append(v decimal.Decimal64) error { + return b.appendValue((*(*[arrow.Decimal64SizeBytes]byte)(unsafe.Pointer(&v)))[:]) +} +func (b *Decimal64DictionaryBuilder) InsertDictValues(arr *Decimal64) (err error) { + data := arrow.Decimal64Traits.CastToBytes(arr.values) + for len(data) > 0 { + if err = b.insertDictValue(data[:arrow.Decimal64SizeBytes]); err != nil { + break + } + data = data[arrow.Decimal64SizeBytes:] + } + return +} + type Decimal128DictionaryBuilder struct { dictionaryBuilder } diff --git a/arrow/array/numeric.gen.go b/arrow/array/numeric.gen.go index c6c7b0bd..7e94fe5c 100644 --- a/arrow/array/numeric.gen.go +++ b/arrow/array/numeric.gen.go @@ -101,11 +101,13 @@ func (a *Int64) GetOneForMarshal(i int) interface{} { func (a *Int64) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -196,11 +198,13 @@ func (a *Uint64) GetOneForMarshal(i int) interface{} { func (a *Uint64) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -398,11 +402,13 @@ func (a *Int32) GetOneForMarshal(i int) interface{} { func (a *Int32) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -493,11 +499,13 @@ func (a *Uint32) GetOneForMarshal(i int) interface{} { func (a *Uint32) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -602,6 +610,7 @@ func (a *Float32) MarshalJSON() ([]byte, error) { default: vals[i] = f } + } return json.Marshal(vals) @@ -692,11 +701,13 @@ func (a *Int16) GetOneForMarshal(i int) interface{} { func (a *Int16) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -787,11 +798,13 @@ func (a *Uint16) GetOneForMarshal(i int) interface{} { func (a *Uint16) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -882,11 +895,13 @@ func (a *Int8) GetOneForMarshal(i int) interface{} { func (a *Int8) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data } else { vals[i] = nil } + } return json.Marshal(vals) @@ -977,11 +992,13 @@ func (a *Uint8) GetOneForMarshal(i int) interface{} { func (a *Uint8) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data } else { vals[i] = nil } + } return json.Marshal(vals) diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go index 4688a1ef..d5748a35 100644 --- a/arrow/cdata/cdata.go +++ b/arrow/cdata/cdata.go @@ -254,12 +254,17 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) { return ret, xerrors.Errorf("could not parse decimal scale in '%s': %s", f, err.Error()) } - if bitwidth == 128 { + switch bitwidth { + case 32: + dt = &arrow.Decimal32Type{Precision: int32(precision), Scale: int32(scale)} + case 64: + dt = &arrow.Decimal64Type{Precision: int32(precision), Scale: int32(scale)} + case 128: dt = &arrow.Decimal128Type{Precision: int32(precision), Scale: int32(scale)} - } else if bitwidth == 256 { + case 256: dt = &arrow.Decimal256Type{Precision: int32(precision), Scale: int32(scale)} - } else { - return ret, xerrors.Errorf("only decimal128 and decimal256 are supported, got '%s'", f) + default: + return ret, xerrors.Errorf("unsupported decimal bitwidth, got '%s'", f) } } diff --git a/arrow/cdata/cdata_exports.go b/arrow/cdata/cdata_exports.go index 4ed9d0e5..d3673481 100644 --- a/arrow/cdata/cdata_exports.go +++ b/arrow/cdata/cdata_exports.go @@ -154,6 +154,10 @@ func (exp *schemaExporter) exportFormat(dt arrow.DataType) string { return "g" case *arrow.FixedSizeBinaryType: return fmt.Sprintf("w:%d", dt.ByteWidth) + case *arrow.Decimal32Type: + return fmt.Sprintf("d:%d,%d,32", dt.Precision, dt.Scale) + case *arrow.Decimal64Type: + return fmt.Sprintf("d:%d,%d,64", dt.Precision, dt.Scale) case *arrow.Decimal128Type: return fmt.Sprintf("d:%d,%d", dt.Precision, dt.Scale) case *arrow.Decimal256Type: diff --git a/arrow/cdata/cdata_test.go b/arrow/cdata/cdata_test.go index 697a73b3..2a86ea62 100644 --- a/arrow/cdata/cdata_test.go +++ b/arrow/cdata/cdata_test.go @@ -153,7 +153,7 @@ func TestDecimalSchemaErrors(t *testing.T) { {"d:a,2,3", "could not parse decimal precision in 'd:a,2,3':"}, {"d:1,a,3", "could not parse decimal scale in 'd:1,a,3':"}, {"d:1,2,a", "could not parse decimal bitwidth in 'd:1,2,a':"}, - {"d:1,2,384", "only decimal128 and decimal256 are supported, got 'd:1,2,384'"}, + {"d:1,2,384", "unsupported decimal bitwidth, got 'd:1,2,384'"}, } for _, tt := range tests { diff --git a/arrow/datatype.go b/arrow/datatype.go index 2fba6550..95565859 100644 --- a/arrow/datatype.go +++ b/arrow/datatype.go @@ -107,7 +107,7 @@ const ( // parameters. DECIMAL128 - // DECIMAL256 is a precision and scale based decimal type, with 256 bit max. not yet implemented + // DECIMAL256 is a precision and scale based decimal type, with 256 bit max. DECIMAL256 // LIST is a list of some logical data type @@ -116,10 +116,10 @@ const ( // STRUCT of logical types STRUCT - // SPARSE_UNION of logical types. not yet implemented + // SPARSE_UNION of logical types SPARSE_UNION - // DENSE_UNION of logical types. not yet implemented + // DENSE_UNION of logical types DENSE_UNION // DICTIONARY aka Category type @@ -138,13 +138,13 @@ const ( // or nanoseconds. DURATION - // like STRING, but 64-bit offsets. not yet implemented + // like STRING, but 64-bit offsets LARGE_STRING - // like BINARY but with 64-bit offsets, not yet implemented + // like BINARY but with 64-bit offsets LARGE_BINARY - // like LIST but with 64-bit offsets. not yet implemented + // like LIST but with 64-bit offsets LARGE_LIST // calendar interval with three fields @@ -165,6 +165,12 @@ const ( // like LIST but with 64-bit offsets LARGE_LIST_VIEW + // Decimal value with 32-bit representation + DECIMAL32 + + // Decimal value with 64-bit representation + DECIMAL64 + // Alias to ensure we do not break any consumers DECIMAL = DECIMAL128 ) @@ -365,10 +371,10 @@ func IsLargeBinaryLike(t Type) bool { return false } -// IsFixedSizeBinary returns true for Decimal128/256 and FixedSizeBinary +// IsFixedSizeBinary returns true for Decimal32/64/128/256 and FixedSizeBinary func IsFixedSizeBinary(t Type) bool { switch t { - case DECIMAL128, DECIMAL256, FIXED_SIZE_BINARY: + case DECIMAL32, DECIMAL64, DECIMAL128, DECIMAL256, FIXED_SIZE_BINARY: return true } return false @@ -377,7 +383,7 @@ func IsFixedSizeBinary(t Type) bool { // IsDecimal returns true for Decimal128 and Decimal256 func IsDecimal(t Type) bool { switch t { - case DECIMAL128, DECIMAL256: + case DECIMAL32, DECIMAL64, DECIMAL128, DECIMAL256: return true } return false diff --git a/arrow/datatype_fixedwidth.go b/arrow/datatype_fixedwidth.go index 41c7b6f3..5928be3a 100644 --- a/arrow/datatype_fixedwidth.go +++ b/arrow/datatype_fixedwidth.go @@ -22,8 +22,9 @@ import ( "sync" "time" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/internal/json" - "golang.org/x/xerrors" ) @@ -532,19 +533,103 @@ type DecimalType interface { DataType GetPrecision() int32 GetScale() int32 + BitWidth() int +} + +// NarrowestDecimalType constructs the smallest decimal type that can represent +// the requested precision. An error is returned if the requested precision +// cannot be represented (prec <= 0 || prec > 76). +// +// For reference: +// +// prec in [ 1, 9] => Decimal32Type +// prec in [10, 18] => Decimal64Type +// prec in [19, 38] => Decimal128Type +// prec in [39, 76] => Decimal256Type +func NarrowestDecimalType(prec, scale int32) (DecimalType, error) { + switch { + case prec <= 0: + return nil, fmt.Errorf("%w: precision must be > 0 for decimal types, got %d", + ErrInvalid, prec) + case prec <= int32(decimal.MaxPrecision[decimal.Decimal32]()): + return &Decimal32Type{Precision: prec, Scale: scale}, nil + case prec <= int32(decimal.MaxPrecision[decimal.Decimal64]()): + return &Decimal64Type{Precision: prec, Scale: scale}, nil + case prec <= int32(decimal.MaxPrecision[decimal.Decimal128]()): + return &Decimal128Type{Precision: prec, Scale: scale}, nil + case prec <= int32(decimal.MaxPrecision[decimal.Decimal256]()): + return &Decimal256Type{Precision: prec, Scale: scale}, nil + default: + return nil, fmt.Errorf("%w: invalid precision for decimal types, %d", + ErrInvalid, prec) + } } func NewDecimalType(id Type, prec, scale int32) (DecimalType, error) { switch id { + case DECIMAL32: + debug.Assert(prec <= int32(decimal.MaxPrecision[decimal.Decimal32]()), "invalid precision for decimal32") + return &Decimal32Type{Precision: prec, Scale: scale}, nil + case DECIMAL64: + debug.Assert(prec <= int32(decimal.MaxPrecision[decimal.Decimal64]()), "invalid precision for decimal64") + return &Decimal64Type{Precision: prec, Scale: scale}, nil case DECIMAL128: + debug.Assert(prec <= int32(decimal.MaxPrecision[decimal.Decimal128]()), "invalid precision for decimal128") return &Decimal128Type{Precision: prec, Scale: scale}, nil case DECIMAL256: + debug.Assert(prec <= int32(decimal.MaxPrecision[decimal.Decimal256]()), "invalid precision for decimal256") return &Decimal256Type{Precision: prec, Scale: scale}, nil default: - return nil, fmt.Errorf("%w: must use DECIMAL128 or DECIMAL256 to create a DecimalType", ErrInvalid) + return nil, fmt.Errorf("%w: must use one of the DECIMAL IDs to create a DecimalType", ErrInvalid) } } +// Decimal32Type represents a fixed-size 32-bit decimal type. +type Decimal32Type struct { + Precision int32 + Scale int32 +} + +func (*Decimal32Type) ID() Type { return DECIMAL32 } +func (*Decimal32Type) Name() string { return "decimal32" } +func (*Decimal32Type) BitWidth() int { return 32 } +func (*Decimal32Type) Bytes() int { return Decimal32SizeBytes } +func (t *Decimal32Type) String() string { + return fmt.Sprintf("%s(%d, %d)", t.Name(), t.Precision, t.Scale) +} +func (t *Decimal32Type) Fingerprint() string { + return fmt.Sprintf("%s[%d,%d,%d]", typeFingerprint(t), t.BitWidth(), t.Precision, t.Scale) +} +func (t *Decimal32Type) GetPrecision() int32 { return t.Precision } +func (t *Decimal32Type) GetScale() int32 { return t.Scale } + +func (Decimal32Type) Layout() DataTypeLayout { + return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(Decimal32SizeBytes)}} +} + +// Decimal64Type represents a fixed-size 32-bit decimal type. +type Decimal64Type struct { + Precision int32 + Scale int32 +} + +func (*Decimal64Type) ID() Type { return DECIMAL64 } +func (*Decimal64Type) Name() string { return "decimal64" } +func (*Decimal64Type) BitWidth() int { return 64 } +func (*Decimal64Type) Bytes() int { return Decimal64SizeBytes } +func (t *Decimal64Type) String() string { + return fmt.Sprintf("%s(%d, %d)", t.Name(), t.Precision, t.Scale) +} +func (t *Decimal64Type) Fingerprint() string { + return fmt.Sprintf("%s[%d,%d,%d]", typeFingerprint(t), t.BitWidth(), t.Precision, t.Scale) +} +func (t *Decimal64Type) GetPrecision() int32 { return t.Precision } +func (t *Decimal64Type) GetScale() int32 { return t.Scale } + +func (Decimal64Type) Layout() DataTypeLayout { + return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(Decimal64SizeBytes)}} +} + // Decimal128Type represents a fixed-size 128-bit decimal type. type Decimal128Type struct { Precision int32 diff --git a/arrow/datatype_fixedwidth_test.go b/arrow/datatype_fixedwidth_test.go index d60c6b17..bc899f34 100644 --- a/arrow/datatype_fixedwidth_test.go +++ b/arrow/datatype_fixedwidth_test.go @@ -23,6 +23,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestTimeUnit_String verifies each time unit matches its string representation. @@ -43,6 +44,60 @@ func TestTimeUnit_String(t *testing.T) { } } +func TestDecimal32Type(t *testing.T) { + for _, tc := range []struct { + precision int32 + scale int32 + want string + }{ + {1, 9, "decimal32(1, 9)"}, + {9, 9, "decimal32(9, 9)"}, + {9, 1, "decimal32(9, 1)"}, + } { + t.Run(tc.want, func(t *testing.T) { + dt := arrow.Decimal32Type{Precision: tc.precision, Scale: tc.scale} + if got, want := dt.BitWidth(), 32; got != want { + t.Fatalf("invalid bitwidth: got=%d, want=%d", got, want) + } + + if got, want := dt.ID(), arrow.DECIMAL32; got != want { + t.Fatalf("invalid type ID: got=%v, want=%v", got, want) + } + + if got, want := dt.String(), tc.want; got != want { + t.Fatalf("invalid stringer: got=%q, want=%q", got, want) + } + }) + } +} + +func TestDecimal64Type(t *testing.T) { + for _, tc := range []struct { + precision int32 + scale int32 + want string + }{ + {1, 10, "decimal64(1, 10)"}, + {10, 10, "decimal64(10, 10)"}, + {10, 1, "decimal64(10, 1)"}, + } { + t.Run(tc.want, func(t *testing.T) { + dt := arrow.Decimal64Type{Precision: tc.precision, Scale: tc.scale} + if got, want := dt.BitWidth(), 64; got != want { + t.Fatalf("invalid bitwidth: got=%d, want=%d", got, want) + } + + if got, want := dt.ID(), arrow.DECIMAL64; got != want { + t.Fatalf("invalid type ID: got=%v, want=%v", got, want) + } + + if got, want := dt.String(), tc.want; got != want { + t.Fatalf("invalid stringer: got=%q, want=%q", got, want) + } + }) + } +} + func TestDecimal128Type(t *testing.T) { for _, tc := range []struct { precision int32 @@ -438,3 +493,36 @@ func TestDateFromTime(t *testing.T) { assert.EqualValues(t, wantD64, arrow.Date64FromTime(tm)) assert.EqualValues(t, wantD32, arrow.Date32FromTime(tm)) } + +func TestNarrowestDecimalType(t *testing.T) { + tests := []struct { + min, max int32 + expected arrow.Type + }{ + {1, 9, arrow.DECIMAL32}, + {10, 18, arrow.DECIMAL64}, + {19, 38, arrow.DECIMAL128}, + {39, 76, arrow.DECIMAL256}, + } + + for _, tt := range tests { + t.Run(tt.expected.String(), func(t *testing.T) { + for i := tt.min; i <= tt.max; i++ { + typ, err := arrow.NarrowestDecimalType(i, 5) + require.NoError(t, err) + + assert.Equal(t, i, typ.GetPrecision()) + assert.Equal(t, int32(5), typ.GetScale()) + assert.Equal(t, tt.expected, typ.ID()) + } + }) + } + + _, err := arrow.NarrowestDecimalType(-1, 5) + assert.Error(t, err) + assert.ErrorIs(t, err, arrow.ErrInvalid) + + _, err = arrow.NarrowestDecimalType(78, 5) + assert.Error(t, err) + assert.ErrorIs(t, err, arrow.ErrInvalid) +} diff --git a/arrow/decimal/decimal.go b/arrow/decimal/decimal.go new file mode 100644 index 00000000..098a4e0f --- /dev/null +++ b/arrow/decimal/decimal.go @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package decimal + +import ( + "errors" + "fmt" + "math" + "math/big" + "math/bits" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/internal/debug" +) + +// DecimalTypes is a generic constraint representing the implemented decimal types +// in this package, and a single point of update for future additions. Everything +// else is constrained by this. +type DecimalTypes interface { + Decimal32 | Decimal64 | Decimal128 | Decimal256 +} + +// Num is an interface that is useful for building generic types for all decimal +// type implementations. It presents all the methods and interfaces necessary to +// operate on the decimal objects without having to care about the bit width. +type Num[T DecimalTypes] interface { + Negate() T + Add(T) T + Sub(T) T + Mul(T) T + Div(T) (res, rem T) + Pow(T) T + + FitsInPrecision(int32) bool + Abs() T + Sign() int + Rescale(int32, int32) (T, error) + Cmp(T) int + + IncreaseScaleBy(int32) T + ReduceScaleBy(int32, bool) T + + ToFloat32(int32) float32 + ToFloat64(int32) float64 + ToBigFloat(int32) *big.Float + + ToString(int32) string +} + +type ( + Decimal32 int32 + Decimal64 int64 + Decimal128 = decimal128.Num + Decimal256 = decimal256.Num +) + +func MaxPrecision[T DecimalTypes]() int { + // max precision is computed by Floor(log10(2^(nbytes * 8 - 1) - 1)) + var z T + return int(math.Floor(math.Log10(math.Pow(2, float64(unsafe.Sizeof(z))*8-1) - 1))) +} + +func (d Decimal32) Negate() Decimal32 { return -d } +func (d Decimal64) Negate() Decimal64 { return -d } + +func (d Decimal32) Add(rhs Decimal32) Decimal32 { return d + rhs } +func (d Decimal64) Add(rhs Decimal64) Decimal64 { return d + rhs } + +func (d Decimal32) Sub(rhs Decimal32) Decimal32 { return d - rhs } +func (d Decimal64) Sub(rhs Decimal64) Decimal64 { return d - rhs } + +func (d Decimal32) Mul(rhs Decimal32) Decimal32 { return d * rhs } +func (d Decimal64) Mul(rhs Decimal64) Decimal64 { return d * rhs } + +func (d Decimal32) Div(rhs Decimal32) (res, rem Decimal32) { + return d / rhs, d % rhs +} + +func (d Decimal64) Div(rhs Decimal64) (res, rem Decimal64) { + return d / rhs, d % rhs +} + +// about 4-5x faster than using math.Pow which requires converting to float64 +// and back to integers +func intPow[T int32 | int64](base, exp T) T { + result := T(1) + for { + if exp&1 == 1 { + result *= base + } + exp >>= 1 + if exp == 0 { + break + } + base *= base + } + + return result +} + +func (d Decimal32) Pow(rhs Decimal32) Decimal32 { + return Decimal32(intPow(int32(d), int32(rhs))) +} + +func (d Decimal64) Pow(rhs Decimal64) Decimal64 { + return Decimal64(intPow(int64(d), int64(rhs))) +} + +func (d Decimal32) Sign() int { + if d == 0 { + return 0 + } + return int(1 | (d >> 31)) +} + +func (d Decimal64) Sign() int { + if d == 0 { + return 0 + } + return int(1 | (d >> 63)) +} + +func (n Decimal32) Abs() Decimal32 { + if n < 0 { + return -n + } + return n +} + +func (n Decimal64) Abs() Decimal64 { + if n < 0 { + return -n + } + return n +} + +func (n Decimal32) FitsInPrecision(prec int32) bool { + debug.Assert(prec > 0, "precision must be > 0") + debug.Assert(prec <= 9, "precision must be <= 9") + return n.Abs() < Decimal32(math.Pow10(int(prec))) +} + +func (n Decimal64) FitsInPrecision(prec int32) bool { + debug.Assert(prec > 0, "precision must be > 0") + debug.Assert(prec <= 18, "precision must be <= 18") + return n.Abs() < Decimal64(math.Pow10(int(prec))) +} + +func (n Decimal32) ToString(scale int32) string { + return n.ToBigFloat(scale).Text('f', int(scale)) +} + +func (n Decimal64) ToString(scale int32) string { + return n.ToBigFloat(scale).Text('f', int(scale)) +} + +var pt5 = big.NewFloat(0.5) + +func decimalFromString[T interface { + Decimal32 | Decimal64 + FitsInPrecision(int32) bool +}](v string, prec, scale int32) (n T, err error) { + var nbits = uint(unsafe.Sizeof(T(0))) * 8 + + var out *big.Float + out, _, err = big.ParseFloat(v, 10, nbits, big.ToNearestEven) + + if scale < 0 { + var tmp big.Int + val, _ := out.Int(&tmp) + if val.BitLen() > int(nbits) { + return n, fmt.Errorf("bitlen too large for decimal%d", nbits) + } + + n = T(val.Int64() / int64(math.Pow10(int(-scale)))) + } else { + var precInBits = uint(math.Round(float64(prec+scale+1)/math.Log10(2))) + 1 + + p := (&big.Float{}).SetInt(big.NewInt(int64(math.Pow10(int(scale))))) + out.SetPrec(precInBits).Mul(out, p) + if out.Signbit() { + out.Sub(out, pt5) + } else { + out.Add(out, pt5) + } + + var tmp big.Int + val, _ := out.Int(&tmp) + if val.BitLen() > int(nbits) { + return n, fmt.Errorf("bitlen too large for decimal%d", nbits) + } + n = T(val.Int64()) + } + + if !n.FitsInPrecision(prec) { + err = fmt.Errorf("val %v doesn't fit in precision %d", n, prec) + } + return +} + +func Decimal32FromString(v string, prec, scale int32) (n Decimal32, err error) { + return decimalFromString[Decimal32](v, prec, scale) +} + +func Decimal64FromString(v string, prec, scale int32) (n Decimal64, err error) { + return decimalFromString[Decimal64](v, prec, scale) +} + +func Decimal128FromString(v string, prec, scale int32) (n Decimal128, err error) { + return decimal128.FromString(v, prec, scale) +} + +func Decimal256FromString(v string, prec, scale int32) (n Decimal256, err error) { + return decimal256.FromString(v, prec, scale) +} + +func scalePositiveFloat64(v float64, prec, scale int32) (float64, error) { + v *= math.Pow10(int(scale)) + v = math.RoundToEven(v) + + maxabs := math.Pow10(int(prec)) + if v >= maxabs { + return 0, fmt.Errorf("cannot convert %f to decimal(precision=%d, scale=%d)", v, prec, scale) + } + return v, nil +} + +func fromPositiveFloat[T Decimal32 | Decimal64, F float32 | float64](v F, prec, scale int32) (T, error) { + if prec > int32(MaxPrecision[T]()) { + return T(0), fmt.Errorf("invalid precision %d for converting float to Decimal", prec) + } + + val, err := scalePositiveFloat64(float64(v), prec, scale) + if err != nil { + return T(0), err + } + + return T(F(val)), nil +} + +func Decimal32FromFloat[F float32 | float64](v F, prec, scale int32) (Decimal32, error) { + if v < 0 { + dec, err := fromPositiveFloat[Decimal32](-v, prec, scale) + if err != nil { + return dec, err + } + + return -dec, nil + } + + return fromPositiveFloat[Decimal32](v, prec, scale) +} + +func Decimal64FromFloat[F float32 | float64](v F, prec, scale int32) (Decimal64, error) { + if v < 0 { + dec, err := fromPositiveFloat[Decimal64](-v, prec, scale) + if err != nil { + return dec, err + } + + return -dec, nil + } + + return fromPositiveFloat[Decimal64](v, prec, scale) +} + +func Decimal128FromFloat(v float64, prec, scale int32) (Decimal128, error) { + return decimal128.FromFloat64(v, prec, scale) +} + +func Decimal256FromFloat(v float64, prec, scale int32) (Decimal256, error) { + return decimal256.FromFloat64(v, prec, scale) +} + +func (n Decimal32) ToFloat32(scale int32) float32 { + return float32(n.ToFloat64(scale)) +} + +func (n Decimal64) ToFloat32(scale int32) float32 { + return float32(n.ToFloat64(scale)) +} + +func (n Decimal32) ToFloat64(scale int32) float64 { + return float64(n) * math.Pow10(-int(scale)) +} + +func (n Decimal64) ToFloat64(scale int32) float64 { + return float64(n) * math.Pow10(-int(scale)) +} + +func (n Decimal32) ToBigFloat(scale int32) *big.Float { + f := (&big.Float{}).SetInt64(int64(n)) + if scale < 0 { + f.SetPrec(32).Mul(f, (&big.Float{}).SetInt64(intPow(10, -int64(scale)))) + } else { + f.SetPrec(32).Quo(f, (&big.Float{}).SetInt64(intPow(10, int64(scale)))) + } + return f +} + +func (n Decimal64) ToBigFloat(scale int32) *big.Float { + f := (&big.Float{}).SetInt64(int64(n)) + if scale < 0 { + f.SetPrec(64).Mul(f, (&big.Float{}).SetInt64(intPow(10, -int64(scale)))) + } else { + f.SetPrec(64).Quo(f, (&big.Float{}).SetInt64(intPow(10, int64(scale)))) + } + return f +} + +func cmpDec[T Decimal32 | Decimal64](lhs, rhs T) int { + switch { + case lhs > rhs: + return 1 + case lhs < rhs: + return -1 + } + return 0 +} + +func (n Decimal32) Cmp(other Decimal32) int { + return cmpDec(n, other) +} + +func (n Decimal64) Cmp(other Decimal64) int { + return cmpDec(n, other) +} + +func (n Decimal32) IncreaseScaleBy(increase int32) Decimal32 { + debug.Assert(increase >= 0, "invalid increase scale for decimal32") + debug.Assert(increase <= 9, "invalid increase scale for decimal32") + + return n * Decimal32(intPow(10, increase)) +} + +func (n Decimal64) IncreaseScaleBy(increase int32) Decimal64 { + debug.Assert(increase >= 0, "invalid increase scale for decimal64") + debug.Assert(increase <= 18, "invalid increase scale for decimal64") + + return n * Decimal64(intPow(10, int64(increase))) +} + +func reduceScale[T interface { + Decimal32 | Decimal64 + Abs() T +}](n T, reduce int32, round bool) T { + if reduce == 0 { + return n + } + + divisor := T(intPow(10, reduce)) + if !round { + return n / divisor + } + + quo, remainder := n/divisor, n%divisor + divisorHalf := divisor / 2 + if remainder.Abs() >= divisorHalf { + if n > 0 { + quo++ + } else { + quo-- + } + } + + return quo +} + +func (n Decimal32) ReduceScaleBy(reduce int32, round bool) Decimal32 { + debug.Assert(reduce >= 0, "invalid reduce scale for decimal32") + debug.Assert(reduce <= 9, "invalid reduce scale for decimal32") + + return reduceScale(n, reduce, round) +} + +func (n Decimal64) ReduceScaleBy(reduce int32, round bool) Decimal64 { + debug.Assert(reduce >= 0, "invalid reduce scale for decimal32") + debug.Assert(reduce <= 18, "invalid reduce scale for decimal32") + + return reduceScale(n, reduce, round) +} + +//lint:ignore U1000 function is being used, staticcheck seems to not follow generics +func (n Decimal32) rescaleWouldCauseDataLoss(deltaScale int32, multiplier Decimal32) (out Decimal32, loss bool) { + if deltaScale < 0 { + debug.Assert(multiplier != 0, "multiplier must not be zero") + quo, remainder := bits.Div32(0, uint32(n), uint32(multiplier)) + return Decimal32(quo), remainder != 0 + } + + overflow, result := bits.Mul32(uint32(n), uint32(multiplier)) + if overflow != 0 { + return Decimal32(result), true + } + + out = Decimal32(result) + return out, out < n +} + +//lint:ignore U1000 function is being used, staticcheck seems to not follow generics +func (n Decimal64) rescaleWouldCauseDataLoss(deltaScale int32, multiplier Decimal64) (out Decimal64, loss bool) { + if deltaScale < 0 { + debug.Assert(multiplier != 0, "multiplier must not be zero") + quo, remainder := bits.Div32(0, uint32(n), uint32(multiplier)) + return Decimal64(quo), remainder != 0 + } + + overflow, result := bits.Mul32(uint32(n), uint32(multiplier)) + if overflow != 0 { + return Decimal64(result), true + } + + out = Decimal64(result) + return out, out < n +} + +func rescale[T interface { + Decimal32 | Decimal64 + rescaleWouldCauseDataLoss(int32, T) (T, bool) + Sign() int +}](n T, originalScale, newScale int32) (out T, err error) { + if originalScale == newScale { + return n, nil + } + + deltaScale := newScale - originalScale + absDeltaScale := int32(math.Abs(float64(deltaScale))) + + sign := n.Sign() + if n < 0 { + n = -n + } + + multiplier := T(intPow(10, absDeltaScale)) + var wouldHaveLoss bool + out, wouldHaveLoss = n.rescaleWouldCauseDataLoss(deltaScale, multiplier) + if wouldHaveLoss { + err = errors.New("rescale data loss") + } + out *= T(sign) + return +} + +func (n Decimal32) Rescale(originalScale, newScale int32) (out Decimal32, err error) { + return rescale(n, originalScale, newScale) +} + +func (n Decimal64) Rescale(originalScale, newScale int32) (out Decimal64, err error) { + return rescale(n, originalScale, newScale) +} + +var ( + _ Num[Decimal32] = Decimal32(0) + _ Num[Decimal64] = Decimal64(0) + _ Num[Decimal128] = Decimal128{} + _ Num[Decimal256] = Decimal256{} +) diff --git a/arrow/decimal/decimal_test.go b/arrow/decimal/decimal_test.go new file mode 100644 index 00000000..9b4f04f5 --- /dev/null +++ b/arrow/decimal/decimal_test.go @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package decimal_test + +import ( + "fmt" + "math" + "math/big" + "strconv" + "testing" + + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func ulps64(actual, expected float64) int64 { + ulp := math.Nextafter(actual, math.Inf(1)) - actual + return int64(math.Abs((expected - actual) / ulp)) +} + +func ulps32(actual, expected float32) int64 { + ulp := math.Nextafter32(actual, float32(math.Inf(1))) - actual + return int64(math.Abs(float64((expected - actual) / ulp))) +} + +func assertFloat32Approx(t *testing.T, x, y float32) bool { + const maxulps int64 = 4 + ulps := ulps32(x, y) + return assert.LessOrEqualf(t, ulps, maxulps, "%f not equal to %f (%d ulps)", x, y, ulps) +} + +func assertFloat64Approx(t *testing.T, x, y float64) bool { + const maxulps int64 = 4 + ulps := ulps64(x, y) + return assert.LessOrEqualf(t, ulps, maxulps, "%f not equal to %f (%d ulps)", x, y, ulps) +} + +func TestDecimalToReal(t *testing.T) { + tests := []struct { + decimalVal string + scale int32 + exp float64 + }{ + {"0", 0, 0}, + {"0", 10, 0.0}, + {"0", -10, 0.0}, + {"1", 0, 1.0}, + {"12345", 0, 12345.0}, + {"12345", 1, 1234.5}, + {"536870912", 0, math.Pow(2, 29)}, + } + + t.Run("float32", func(t *testing.T) { + checkDecimalToFloat := func(t *testing.T, str string, v float32, scale int32) { + n, err := decimal.Decimal32FromString(str, 9, 0) + require.NoError(t, err) + assert.Equalf(t, v, n.ToFloat32(scale), "Decimal Val: %s, Scale: %d", str, scale) + + n64, err := decimal.Decimal64FromString(str, 18, 0) + require.NoError(t, err) + assert.Equalf(t, v, n64.ToFloat32(scale), "Decimal Val: %s, Scale: %d", str, scale) + } + for _, tt := range tests { + t.Run(tt.decimalVal, func(t *testing.T) { + checkDecimalToFloat(t, tt.decimalVal, float32(tt.exp), tt.scale) + if tt.decimalVal != "0" { + checkDecimalToFloat(t, "-"+tt.decimalVal, float32(-tt.exp), tt.scale) + } + }) + } + + t.Run("large values", func(t *testing.T) { + checkApproxDecimaltoFloat := func(str string, v float32, scale int32) { + n, err := decimal.Decimal32FromString(str, 9, 0) + require.NoError(t, err) + assertFloat32Approx(t, v, n.ToFloat32(scale)) + } + + checkApproxDecimal64toFloat := func(str string, v float32, scale int32) { + n, err := decimal.Decimal64FromString(str, 9, 0) + require.NoError(t, err) + assertFloat32Approx(t, v, n.ToFloat32(scale)) + } + + // exact comparisons would succeed on most platforms, but not all power-of-ten + // factors are exactly representable in binary floating point, so we'll use + // approx and ensure that the values are within 4 ULP (unit of least precision) + for scale := int32(-9); scale <= 9; scale++ { + checkApproxDecimaltoFloat("1", float32(math.Pow10(-int(scale))), scale) + checkApproxDecimaltoFloat("123", float32(123)*float32(math.Pow10(-int(scale))), scale) + } + + for scale := int32(-18); scale <= 18; scale++ { + checkApproxDecimal64toFloat("1", float32(math.Pow10(-int(scale))), scale) + checkApproxDecimal64toFloat("123", float32(123)*float32(math.Pow10(-int(scale))), scale) + } + }) + }) + + t.Run("float64", func(t *testing.T) { + checkDecimalToFloat := func(t *testing.T, str string, v float64, scale int32) { + n, err := decimal.Decimal32FromString(str, 9, 0) + require.NoError(t, err) + assert.Equalf(t, v, n.ToFloat64(scale), "Decimal Val: %s, Scale: %d", str, scale) + + assert.Equalf(t, big.NewFloat(v).SetPrec(32), n.ToBigFloat(scale), + "Decimal Val: %s, Scale: %d", str, scale) + + n64, err := decimal.Decimal64FromString(str, 18, 0) + require.NoError(t, err) + assert.Equalf(t, v, n64.ToFloat64(scale), "Decimal Val: %s, Scale: %d", str, scale) + assert.Equalf(t, big.NewFloat(v).SetPrec(64), n64.ToBigFloat(scale), + "Decimal Val: %s, Scale: %d", str, scale) + } + for _, tt := range tests { + t.Run(tt.decimalVal, func(t *testing.T) { + checkDecimalToFloat(t, tt.decimalVal, tt.exp, tt.scale) + if tt.decimalVal != "0" { + checkDecimalToFloat(t, "-"+tt.decimalVal, -tt.exp, tt.scale) + } + }) + } + + t.Run("large values", func(t *testing.T) { + checkApproxDecimaltoFloat := func(str string, v float64, scale int32) { + n, err := decimal.Decimal32FromString(str, 9, 0) + require.NoError(t, err) + assertFloat64Approx(t, v, n.ToFloat64(scale)) + + assert.Equalf(t, big.NewFloat(v).SetPrec(32), n.ToBigFloat(scale), + "Decimal Val: %s, Scale: %d", str, scale) + } + + checkApproxDecimal64toFloat := func(str string, v float64, scale int32) { + n, err := decimal.Decimal64FromString(str, 9, 0) + require.NoError(t, err) + assertFloat64Approx(t, v, n.ToFloat64(scale)) + + bf, _ := n.ToBigFloat(scale).Float64() + assertFloat64Approx(t, v, bf) + } + + // exact comparisons would succeed on most platforms, but not all power-of-ten + // factors are exactly representable in binary floating point, so we'll use + // approx and ensure that the values are within 4 ULP (unit of least precision) + for scale := int32(-9); scale <= 9; scale++ { + checkApproxDecimaltoFloat("1", math.Pow10(-int(scale)), scale) + checkApproxDecimaltoFloat("123", float64(123)*math.Pow10(-int(scale)), scale) + } + + for scale := int32(-18); scale <= 18; scale++ { + checkApproxDecimal64toFloat("1", math.Pow10(-int(scale)), scale) + checkApproxDecimal64toFloat("123", float64(123)*math.Pow10(-int(scale)), scale) + } + }) + }) +} + +func TestDecimalFromFloat(t *testing.T) { + tests := []struct { + val float64 + precision, scale int32 + expected string + }{ + {0, 1, 0, "0"}, + {-0, 1, 0, "0"}, + {0, 9, 4, "0.0000"}, + {math.Copysign(0.0, -1), 9, 4, "0.0000"}, + {123, 7, 4, "123.0000"}, + {-123, 7, 4, "-123.0000"}, + {456.78, 7, 4, "456.7800"}, + {-456.78, 7, 4, "-456.7800"}, + {456.784, 5, 2, "456.78"}, + {-456.784, 5, 2, "-456.78"}, + {456.786, 5, 2, "456.79"}, + {-456.786, 5, 2, "-456.79"}, + {999.99, 5, 2, "999.99"}, + {-999.99, 5, 2, "-999.99"}, + {123, 9, 0, "123"}, + {-123, 9, 0, "-123"}, + {123.4, 9, 0, "123"}, + {-123.4, 9, 0, "-123"}, + {123.6, 9, 0, "124"}, + {-123.6, 9, 0, "-124"}, + } + + t.Run("float64", func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + n, err := decimal.Decimal32FromFloat(tt.val, tt.precision, tt.scale) + require.NoError(t, err) + assert.Equal(t, tt.expected, fmt.Sprintf("%."+strconv.Itoa(int(tt.scale))+"f", n.ToFloat64(tt.scale))) + }) + } + + t.Run("large values", func(t *testing.T) { + // test entire float64 range + for scale := int32(-308); scale <= 308; scale++ { + val := math.Pow10(int(scale)) + n, err := decimal.Decimal64FromFloat(val, 1, -scale) + require.NoError(t, err) + assert.EqualValues(t, 1, n) + } + + for scale := int32(-307); scale <= 306; scale++ { + val := 123 * math.Pow10(int(scale)) + n, err := decimal.Decimal64FromFloat(val, 2, -scale-1) + require.NoError(t, err) + assert.EqualValues(t, 12, n) + n, err = decimal.Decimal64FromFloat(val, 3, -scale) + require.NoError(t, err) + assert.EqualValues(t, 123, n) + n, err = decimal.Decimal64FromFloat(val, 4, -scale+1) + require.NoError(t, err) + assert.EqualValues(t, 1230, n) + } + }) + }) + + t.Run("float32", func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + n, err := decimal.Decimal32FromFloat(float32(tt.val), tt.precision, tt.scale) + require.NoError(t, err) + assert.Equal(t, tt.expected, fmt.Sprintf("%."+strconv.Itoa(int(tt.scale))+"f", n.ToFloat32(tt.scale))) + }) + } + + t.Run("large values", func(t *testing.T) { + // test entire float32 range + for scale := int32(-38); scale <= 38; scale++ { + val := float32(math.Pow10(int(scale))) + n, err := decimal.Decimal64FromFloat(val, 1, -scale) + require.NoError(t, err) + assert.EqualValues(t, 1, n) + } + + for scale := int32(-37); scale <= 36; scale++ { + val := 123 * float32(math.Pow10(int(scale))) + n, err := decimal.Decimal64FromFloat(val, 2, -scale-1) + require.NoError(t, err) + assert.EqualValues(t, 12, n) + n, err = decimal.Decimal64FromFloat(val, 3, -scale) + require.NoError(t, err) + assert.EqualValues(t, 123, n) + n, err = decimal.Decimal64FromFloat(val, 4, -scale+1) + require.NoError(t, err) + assert.EqualValues(t, 1230, n) + } + }) + }) +} + +func TestFromString(t *testing.T) { + tests := []struct { + s string + expected int64 + expectedScale int32 + }{ + {"12.3", 123, 1}, + {"0.00123", 123, 5}, + {"1.23e-8", 123, 10}, + {"-1.23E-8", -123, 10}, + {"1.23e+3", 1230, 0}, + {"-1.23E+3", -1230, 0}, + {"1.23e+5", 123000, 0}, + {"1.2345E+7", 12345000, 0}, + {"1.23e-8", 123, 10}, + {"-1.23E-8", -123, 10}, + {"0000000", 0, 0}, + {"000.0000", 0, 4}, + {".0000", 0, 5}, + {"1e1", 10, 0}, + {"+234.567", 234567, 3}, + {"1e-8", 1, 8}, + {"2112.33", 211233, 2}, + {"-2112.33", -211233, 2}, + {"12E2", 12, -2}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_%d", tt.s, tt.expectedScale), func(t *testing.T) { + n, err := decimal.Decimal32FromString(tt.s, 8, tt.expectedScale) + require.NoError(t, err) + + ex := decimal.Decimal32(tt.expected) + assert.Equal(t, ex, n) + + n64, err := decimal.Decimal64FromString(tt.s, 8, tt.expectedScale) + require.NoError(t, err) + + ex64 := decimal.Decimal64(tt.expected) + assert.Equal(t, ex64, n64) + }) + } +} + +func TestCmp(t *testing.T) { + for _, tc := range []struct { + n decimal.Decimal32 + rhs decimal.Decimal32 + want int + }{ + {decimal.Decimal32(2), decimal.Decimal32(1), 1}, + {decimal.Decimal32(-1), decimal.Decimal32(-2), 1}, + {decimal.Decimal32(2), decimal.Decimal32(3), -1}, + {decimal.Decimal32(-3), decimal.Decimal32(-2), -1}, + {decimal.Decimal32(2), decimal.Decimal32(2), 0}, + {decimal.Decimal32(-2), decimal.Decimal32(-2), 0}, + } { + t.Run("cmp", func(t *testing.T) { + n := tc.n.Cmp(tc.rhs) + if got, want := n, tc.want; got != want { + t.Fatalf("invalid value. got=%v, want=%v", got, want) + } + }) + } + + for _, tc := range []struct { + n decimal.Decimal64 + rhs decimal.Decimal64 + want int + }{ + {decimal.Decimal64(2), decimal.Decimal64(1), 1}, + {decimal.Decimal64(-1), decimal.Decimal64(-2), 1}, + {decimal.Decimal64(2), decimal.Decimal64(3), -1}, + {decimal.Decimal64(-3), decimal.Decimal64(-2), -1}, + {decimal.Decimal64(2), decimal.Decimal64(2), 0}, + {decimal.Decimal64(-2), decimal.Decimal64(-2), 0}, + } { + t.Run("cmp", func(t *testing.T) { + n := tc.n.Cmp(tc.rhs) + if got, want := n, tc.want; got != want { + t.Fatalf("invalid value. got=%v, want=%v", got, want) + } + }) + } +} + +func TestDecimalRescale(t *testing.T) { + tests := []struct { + orig, exp int32 + oldScale, newScale int32 + }{ + {111, 11100, 0, 2}, + {11100, 111, 2, 0}, + {500000, 5, 6, 1}, + {5, 500000, 1, 6}, + {-111, -11100, 0, 2}, + {-11100, -111, 2, 0}, + {555, 555, 2, 2}, + } + + for _, tt := range tests { + t.Run("decimal32", func(t *testing.T) { + out, err := decimal.Decimal32(tt.orig).Rescale(tt.oldScale, tt.newScale) + require.NoError(t, err) + assert.Equal(t, decimal.Decimal32(tt.exp), out) + }) + t.Run("decimal64", func(t *testing.T) { + out, err := decimal.Decimal64(tt.orig).Rescale(tt.oldScale, tt.newScale) + require.NoError(t, err) + assert.Equal(t, decimal.Decimal64(tt.exp), out) + }) + } + + _, err := decimal.Decimal32(555555).Rescale(6, 1) + assert.Error(t, err) + _, err = decimal.Decimal64(555555).Rescale(6, 1) + assert.Error(t, err) + + _, err = decimal.Decimal32(555555).Rescale(0, 5) + assert.ErrorContains(t, err, "rescale data loss") + _, err = decimal.Decimal64(555555).Rescale(0, 5) + assert.ErrorContains(t, err, "rescale data loss") +} + +func TestDecimalIncreaseScale(t *testing.T) { + assert.Equal(t, decimal.Decimal32(1234), decimal.Decimal32(1234).IncreaseScaleBy(0)) + assert.Equal(t, decimal.Decimal32(1234000), decimal.Decimal32(1234).IncreaseScaleBy(3)) + assert.Equal(t, decimal.Decimal32(-1234000), decimal.Decimal32(-1234).IncreaseScaleBy(3)) + + assert.Equal(t, decimal.Decimal64(1234), decimal.Decimal64(1234).IncreaseScaleBy(0)) + assert.Equal(t, decimal.Decimal64(1234000), decimal.Decimal64(1234).IncreaseScaleBy(3)) + assert.Equal(t, decimal.Decimal64(-1234000), decimal.Decimal64(-1234).IncreaseScaleBy(3)) +} + +func TestDecimalReduceScale(t *testing.T) { + tests := []struct { + value int32 + scale int32 + round bool + expected int32 + }{ + {123456, 0, false, 123456}, + {123456, 1, false, 12345}, + {123456, 1, true, 12346}, + {123451, 1, true, 12345}, + {123789, 2, true, 1238}, + {123749, 2, true, 1237}, + {123750, 2, true, 1238}, + {5, 1, true, 1}, + {0, 1, true, 0}, + } + + for _, tt := range tests { + assert.Equal(t, decimal.Decimal32(tt.expected), + decimal.Decimal32(tt.value).ReduceScaleBy(tt.scale, tt.round), "decimal32") + assert.Equal(t, decimal.Decimal32(tt.expected).Negate(), + decimal.Decimal32(tt.value).Negate().ReduceScaleBy(tt.scale, tt.round), "decimal32") + assert.Equal(t, decimal.Decimal64(tt.expected), + decimal.Decimal64(tt.value).ReduceScaleBy(tt.scale, tt.round), "decimal64") + assert.Equal(t, decimal.Decimal64(tt.expected).Negate(), + decimal.Decimal64(tt.value).Negate().ReduceScaleBy(tt.scale, tt.round), "decimal64") + } +} + +func TestDecimalBasics(t *testing.T) { + tests := []struct { + lhs, rhs int32 + }{ + {100, 3}, + {200, 3}, + {20100, 301}, + {-20100, 301}, + {20100, -301}, + {-20100, -301}, + } + + for _, tt := range tests { + assert.EqualValues(t, tt.lhs+tt.rhs, + decimal.Decimal32(tt.lhs).Add(decimal.Decimal32(tt.rhs))) + assert.EqualValues(t, tt.lhs+tt.rhs, + decimal.Decimal64(tt.lhs).Add(decimal.Decimal64(tt.rhs))) + + assert.EqualValues(t, tt.lhs-tt.rhs, + decimal.Decimal32(tt.lhs).Sub(decimal.Decimal32(tt.rhs))) + assert.EqualValues(t, tt.lhs-tt.rhs, + decimal.Decimal64(tt.lhs).Sub(decimal.Decimal64(tt.rhs))) + + assert.EqualValues(t, tt.lhs*tt.rhs, + decimal.Decimal32(tt.lhs).Mul(decimal.Decimal32(tt.rhs))) + assert.EqualValues(t, tt.lhs*tt.rhs, + decimal.Decimal64(tt.lhs).Mul(decimal.Decimal64(tt.rhs))) + + expdiv, expmod := tt.lhs/tt.rhs, tt.lhs%tt.rhs + div, mod := decimal.Decimal32(tt.lhs).Div(decimal.Decimal32(tt.rhs)) + assert.EqualValues(t, expdiv, div) + assert.EqualValues(t, expmod, mod) + + div64, mod64 := decimal.Decimal64(tt.lhs).Div(decimal.Decimal64(tt.rhs)) + assert.EqualValues(t, expdiv, div64) + assert.EqualValues(t, expmod, mod64) + } +} diff --git a/arrow/decimal/traits.go b/arrow/decimal/traits.go new file mode 100644 index 00000000..0ec0c315 --- /dev/null +++ b/arrow/decimal/traits.go @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package decimal + +// Traits is a convenience for building generic objects for operating on +// Decimal values to get around the limitations of Go generics. By providing this +// interface a generic object can handle producing the proper types to generate +// new decimal values. +type Traits[T DecimalTypes] interface { + BytesRequired(int) int + FromString(string, int32, int32) (T, error) + FromFloat64(float64, int32, int32) (T, error) +} + +var ( + Dec32Traits dec32Traits + Dec64Traits dec64Traits + Dec128Traits dec128Traits + Dec256Traits dec256Traits +) + +type ( + dec32Traits struct{} + dec64Traits struct{} + dec128Traits struct{} + dec256Traits struct{} +) + +func (dec32Traits) BytesRequired(n int) int { return 4 * n } +func (dec64Traits) BytesRequired(n int) int { return 8 * n } +func (dec128Traits) BytesRequired(n int) int { return 16 * n } +func (dec256Traits) BytesRequired(n int) int { return 32 * n } + +func (dec32Traits) FromString(v string, prec, scale int32) (Decimal32, error) { + return Decimal32FromString(v, prec, scale) +} + +func (dec64Traits) FromString(v string, prec, scale int32) (Decimal64, error) { + return Decimal64FromString(v, prec, scale) +} + +func (dec128Traits) FromString(v string, prec, scale int32) (Decimal128, error) { + return Decimal128FromString(v, prec, scale) +} + +func (dec256Traits) FromString(v string, prec, scale int32) (Decimal256, error) { + return Decimal256FromString(v, prec, scale) +} + +func (dec32Traits) FromFloat64(v float64, prec, scale int32) (Decimal32, error) { + return Decimal32FromFloat(v, prec, scale) +} + +func (dec64Traits) FromFloat64(v float64, prec, scale int32) (Decimal64, error) { + return Decimal64FromFloat(v, prec, scale) +} + +func (dec128Traits) FromFloat64(v float64, prec, scale int32) (Decimal128, error) { + return Decimal128FromFloat(v, prec, scale) +} + +func (dec256Traits) FromFloat64(v float64, prec, scale int32) (Decimal256, error) { + return Decimal256FromFloat(v, prec, scale) +} diff --git a/arrow/decimal128/decimal128.go b/arrow/decimal128/decimal128.go index 2e451c1c..660c4131 100644 --- a/arrow/decimal128/decimal128.go +++ b/arrow/decimal128/decimal128.go @@ -327,6 +327,16 @@ func (n Num) ToFloat64(scale int32) float64 { return n.tofloat64Positive(scale) } +func (n Num) ToBigFloat(scale int32) *big.Float { + f := (&big.Float{}).SetInt(n.BigInt()) + if scale < 0 { + f.SetPrec(128).Mul(f, (&big.Float{}).SetInt(scaleMultipliers[-scale].BigInt())) + } else { + f.SetPrec(128).Quo(f, (&big.Float{}).SetInt(scaleMultipliers[scale].BigInt())) + } + return f +} + // LowBits returns the low bits of the two's complement representation of the number. func (n Num) LowBits() uint64 { return n.lo } diff --git a/arrow/decimal256/decimal256.go b/arrow/decimal256/decimal256.go index 76b61853..82c52a65 100644 --- a/arrow/decimal256/decimal256.go +++ b/arrow/decimal256/decimal256.go @@ -339,6 +339,16 @@ func (n Num) ToFloat64(scale int32) float64 { return n.tofloat64Positive(scale) } +func (n Num) ToBigFloat(scale int32) *big.Float { + f := (&big.Float{}).SetInt(n.BigInt()) + if scale < 0 { + f.SetPrec(256).Mul(f, (&big.Float{}).SetInt(scaleMultipliers[-scale].BigInt())) + } else { + f.SetPrec(256).Quo(f, (&big.Float{}).SetInt(scaleMultipliers[scale].BigInt())) + } + return f +} + func (n Num) Sign() int { if n == (Num{}) { return 0 diff --git a/arrow/internal/arrjson/arrjson.go b/arrow/internal/arrjson/arrjson.go index 452809e2..2181ebdf 100644 --- a/arrow/internal/arrjson/arrjson.go +++ b/arrow/internal/arrjson/arrjson.go @@ -29,6 +29,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/float16" @@ -224,6 +225,10 @@ func typeToJSON(arrowType arrow.DataType) (json.RawMessage, error) { typ = listSizeJSON{"fixedsizelist", dt.Len()} case *arrow.FixedSizeBinaryType: typ = byteWidthJSON{"fixedsizebinary", dt.ByteWidth} + case *arrow.Decimal32Type: + typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision), 32} + case *arrow.Decimal64Type: + typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision), 64} case *arrow.Decimal128Type: typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision), 128} case *arrow.Decimal256Type: @@ -491,6 +496,10 @@ func typeFromJSON(typ json.RawMessage, children []FieldWrapper) (arrowType arrow arrowType = &arrow.Decimal256Type{Precision: int32(t.Precision), Scale: int32(t.Scale)} case 128, 0: // default to 128 bits when missing arrowType = &arrow.Decimal128Type{Precision: int32(t.Precision), Scale: int32(t.Scale)} + case 64: + arrowType = &arrow.Decimal64Type{Precision: int32(t.Precision), Scale: int32(t.Scale)} + case 32: + arrowType = &arrow.Decimal32Type{Precision: int32(t.Precision), Scale: int32(t.Scale)} } case "union": t := unionJSON{} @@ -1295,6 +1304,22 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr bldr.AppendValues(data, valids) return returnNewArrayData(bldr) + case *arrow.Decimal32Type: + bldr := array.NewDecimal32Builder(mem, dt) + defer bldr.Release() + data := decimal32FromJSON(arr.Data) + valids := validsFromJSON(arr.Valids) + bldr.AppendValues(data, valids) + return returnNewArrayData(bldr) + + case *arrow.Decimal64Type: + bldr := array.NewDecimal64Builder(mem, dt) + defer bldr.Release() + data := decimal64FromJSON(arr.Data) + valids := validsFromJSON(arr.Valids) + bldr.AppendValues(data, valids) + return returnNewArrayData(bldr) + case *arrow.Decimal128Type: bldr := array.NewDecimal128Builder(mem, dt) defer bldr.Release() @@ -1713,6 +1738,22 @@ func arrayToJSON(field arrow.Field, arr arrow.Array) Array { Valids: validsToJSON(arr), } + case *array.Decimal32: + return Array{ + Name: field.Name, + Count: arr.Len(), + Data: decimal32ToJSON(arr), + Valids: validsToJSON(arr), + } + + case *array.Decimal64: + return Array{ + Name: field.Name, + Count: arr.Len(), + Data: decimal64ToJSON(arr), + Valids: validsToJSON(arr), + } + case *array.Decimal128: return Array{ Name: field.Name, @@ -2038,6 +2079,47 @@ func f64ToJSON(arr *array.Float64) []interface{} { return o } +func decimal32ToJSON(arr *array.Decimal32) []interface{} { + o := make([]interface{}, arr.Len()) + for i := range o { + o[i] = arr.ValueStr(i) + } + return o +} + +func decimal32FromJSON(vs []interface{}) []decimal.Decimal32 { + var tmp big.Int + o := make([]decimal.Decimal32, len(vs)) + for i, v := range vs { + if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil { + panic(fmt.Errorf("could not convert %v (%T) to decimal32: %w", v, v, err)) + } + + o[i] = decimal.Decimal32(tmp.Int64()) + } + return o +} + +func decimal64ToJSON(arr *array.Decimal64) []interface{} { + o := make([]interface{}, arr.Len()) + for i := range o { + o[i] = arr.ValueStr(i) + } + return o +} + +func decimal64FromJSON(vs []interface{}) []decimal.Decimal64 { + var tmp big.Int + o := make([]decimal.Decimal64, len(vs)) + for i, v := range vs { + if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil { + panic(fmt.Errorf("could not convert %v (%T) to decimal64: %w", v, v, err)) + } + + o[i] = decimal.Decimal64(tmp.Int64()) + } + return o +} func decimal128ToJSON(arr *array.Decimal128) []interface{} { o := make([]interface{}, arr.Len()) for i := range o { @@ -2072,7 +2154,7 @@ func decimal256FromJSON(vs []interface{}) []decimal256.Num { o := make([]decimal256.Num, len(vs)) for i, v := range vs { if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil { - panic(fmt.Errorf("could not convert %v (%T) to decimal128: %w", v, v, err)) + panic(fmt.Errorf("could not convert %v (%T) to decimal256: %w", v, v, err)) } o[i] = decimal256.FromBigInt(&tmp) diff --git a/arrow/internal/flatbuf/Decimal.go b/arrow/internal/flatbuf/Decimal.go index 2fc9d5ad..234c3964 100644 --- a/arrow/internal/flatbuf/Decimal.go +++ b/arrow/internal/flatbuf/Decimal.go @@ -22,10 +22,10 @@ import ( flatbuffers "github.com/google/flatbuffers/go" ) -// / Exact decimal value represented as an integer value in two's -// / complement. Currently only 128-bit (16-byte) and 256-bit (32-byte) integers -// / are used. The representation uses the endianness indicated -// / in the Schema. +/// Exact decimal value represented as an integer value in two's +/// complement. Currently 32-bit (4-byte), 64-bit (8-byte), +/// 128-bit (16-byte) and 256-bit (32-byte) integers are used. +/// The representation uses the endianness indicated in the Schema. type Decimal struct { _tab flatbuffers.Table } @@ -46,7 +46,7 @@ func (rcv *Decimal) Table() flatbuffers.Table { return rcv._tab } -// / Total number of decimal digits +/// Total number of decimal digits func (rcv *Decimal) Precision() int32 { o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) if o != 0 { @@ -55,12 +55,12 @@ func (rcv *Decimal) Precision() int32 { return 0 } -// / Total number of decimal digits +/// Total number of decimal digits func (rcv *Decimal) MutatePrecision(n int32) bool { return rcv._tab.MutateInt32Slot(4, n) } -// / Number of digits after the decimal point "." +/// Number of digits after the decimal point "." func (rcv *Decimal) Scale() int32 { o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) if o != 0 { @@ -69,13 +69,13 @@ func (rcv *Decimal) Scale() int32 { return 0 } -// / Number of digits after the decimal point "." +/// Number of digits after the decimal point "." func (rcv *Decimal) MutateScale(n int32) bool { return rcv._tab.MutateInt32Slot(6, n) } -// / Number of bits per value. The only accepted widths are 128 and 256. -// / We use bitWidth for consistency with Int::bitWidth. +/// Number of bits per value. The accepted widths are 32, 64, 128 and 256. +/// We use bitWidth for consistency with Int::bitWidth. func (rcv *Decimal) BitWidth() int32 { o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) if o != 0 { @@ -84,8 +84,8 @@ func (rcv *Decimal) BitWidth() int32 { return 128 } -// / Number of bits per value. The only accepted widths are 128 and 256. -// / We use bitWidth for consistency with Int::bitWidth. +/// Number of bits per value. The accepted widths are 32, 64, 128 and 256. +/// We use bitWidth for consistency with Int::bitWidth. func (rcv *Decimal) MutateBitWidth(n int32) bool { return rcv._tab.MutateInt32Slot(8, n) } diff --git a/arrow/ipc/file_reader.go b/arrow/ipc/file_reader.go index d027db5e..2715831e 100644 --- a/arrow/ipc/file_reader.go +++ b/arrow/ipc/file_reader.go @@ -476,7 +476,7 @@ func (ctx *arrayLoaderContext) loadArray(dt arrow.DataType) arrow.ArrayData { *arrow.Int8Type, *arrow.Int16Type, *arrow.Int32Type, *arrow.Int64Type, *arrow.Uint8Type, *arrow.Uint16Type, *arrow.Uint32Type, *arrow.Uint64Type, *arrow.Float16Type, *arrow.Float32Type, *arrow.Float64Type, - *arrow.Decimal128Type, *arrow.Decimal256Type, + arrow.DecimalType, *arrow.Time32Type, *arrow.Time64Type, *arrow.TimestampType, *arrow.Date32Type, *arrow.Date64Type, diff --git a/arrow/ipc/metadata.go b/arrow/ipc/metadata.go index 228f271b..a5bf1877 100644 --- a/arrow/ipc/metadata.go +++ b/arrow/ipc/metadata.go @@ -281,20 +281,12 @@ func (fv *fieldVisitor) visit(field arrow.Field) { fv.dtype = flatbuf.TypeFloatingPoint fv.offset = floatToFB(fv.b, int32(dt.BitWidth())) - case *arrow.Decimal128Type: + case arrow.DecimalType: fv.dtype = flatbuf.TypeDecimal flatbuf.DecimalStart(fv.b) - flatbuf.DecimalAddPrecision(fv.b, dt.Precision) - flatbuf.DecimalAddScale(fv.b, dt.Scale) - flatbuf.DecimalAddBitWidth(fv.b, 128) - fv.offset = flatbuf.DecimalEnd(fv.b) - - case *arrow.Decimal256Type: - fv.dtype = flatbuf.TypeDecimal - flatbuf.DecimalStart(fv.b) - flatbuf.DecimalAddPrecision(fv.b, dt.Precision) - flatbuf.DecimalAddScale(fv.b, dt.Scale) - flatbuf.DecimalAddBitWidth(fv.b, 256) + flatbuf.DecimalAddPrecision(fv.b, dt.GetPrecision()) + flatbuf.DecimalAddScale(fv.b, dt.GetScale()) + flatbuf.DecimalAddBitWidth(fv.b, int32(dt.BitWidth())) fv.offset = flatbuf.DecimalEnd(fv.b) case *arrow.FixedSizeBinaryType: @@ -947,6 +939,10 @@ func floatToFB(b *flatbuffers.Builder, bw int32) flatbuffers.UOffsetT { func decimalFromFB(data flatbuf.Decimal) (arrow.DataType, error) { switch data.BitWidth() { + case 32: + return &arrow.Decimal32Type{Precision: data.Precision(), Scale: data.Scale()}, nil + case 64: + return &arrow.Decimal64Type{Precision: data.Precision(), Scale: data.Scale()}, nil case 128: return &arrow.Decimal128Type{Precision: data.Precision(), Scale: data.Scale()}, nil case 256: diff --git a/arrow/type_string.go b/arrow/type_string.go index ee3ccb7e..6e5a943d 100644 --- a/arrow/type_string.go +++ b/arrow/type_string.go @@ -51,11 +51,13 @@ func _() { _ = x[BINARY_VIEW-40] _ = x[LIST_VIEW-41] _ = x[LARGE_LIST_VIEW-42] + _ = x[DECIMAL32-43] + _ = x[DECIMAL64-44] } -const _Type_name = "NULLBOOLUINT8INT8UINT16INT16UINT32INT32UINT64INT64FLOAT16FLOAT32FLOAT64STRINGBINARYFIXED_SIZE_BINARYDATE32DATE64TIMESTAMPTIME32TIME64INTERVAL_MONTHSINTERVAL_DAY_TIMEDECIMAL128DECIMAL256LISTSTRUCTSPARSE_UNIONDENSE_UNIONDICTIONARYMAPEXTENSIONFIXED_SIZE_LISTDURATIONLARGE_STRINGLARGE_BINARYLARGE_LISTINTERVAL_MONTH_DAY_NANORUN_END_ENCODEDSTRING_VIEWBINARY_VIEWLIST_VIEWLARGE_LIST_VIEW" +const _Type_name = "NULLBOOLUINT8INT8UINT16INT16UINT32INT32UINT64INT64FLOAT16FLOAT32FLOAT64STRINGBINARYFIXED_SIZE_BINARYDATE32DATE64TIMESTAMPTIME32TIME64INTERVAL_MONTHSINTERVAL_DAY_TIMEDECIMAL128DECIMAL256LISTSTRUCTSPARSE_UNIONDENSE_UNIONDICTIONARYMAPEXTENSIONFIXED_SIZE_LISTDURATIONLARGE_STRINGLARGE_BINARYLARGE_LISTINTERVAL_MONTH_DAY_NANORUN_END_ENCODEDSTRING_VIEWBINARY_VIEWLIST_VIEWLARGE_LIST_VIEWDECIMAL32DECIMAL64" -var _Type_index = [...]uint16{0, 4, 8, 13, 17, 23, 28, 34, 39, 45, 50, 57, 64, 71, 77, 83, 100, 106, 112, 121, 127, 133, 148, 165, 175, 185, 189, 195, 207, 218, 228, 231, 240, 255, 263, 275, 287, 297, 320, 335, 346, 357, 366, 381} +var _Type_index = [...]uint16{0, 4, 8, 13, 17, 23, 28, 34, 39, 45, 50, 57, 64, 71, 77, 83, 100, 106, 112, 121, 127, 133, 148, 165, 175, 185, 189, 195, 207, 218, 228, 231, 240, 255, 263, 275, 287, 297, 320, 335, 346, 357, 366, 381, 390, 399} func (i Type) String() string { if i < 0 || i >= Type(len(_Type_index)-1) { diff --git a/arrow/type_traits.go b/arrow/type_traits.go index 87e2f065..7185ef25 100644 --- a/arrow/type_traits.go +++ b/arrow/type_traits.go @@ -20,8 +20,7 @@ import ( "reflect" "unsafe" - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/float16" "golang.org/x/exp/constraints" ) @@ -68,7 +67,7 @@ type NumericType interface { // as a bitmap and thus the buffer can't be just reinterpreted as a []bool type FixedWidthType interface { IntType | UintType | - FloatType | decimal128.Num | decimal256.Num | + FloatType | decimal.DecimalTypes | DayTimeInterval | MonthDayNanoInterval } diff --git a/arrow/type_traits_decimal128.go b/arrow/type_traits_decimal128.go index 860a7f12..6e416cd6 100644 --- a/arrow/type_traits_decimal128.go +++ b/arrow/type_traits_decimal128.go @@ -19,7 +19,7 @@ package arrow import ( "unsafe" - "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/endian" ) @@ -28,7 +28,7 @@ var Decimal128Traits decimal128Traits const ( // Decimal128SizeBytes specifies the number of bytes required to store a single decimal128 in memory - Decimal128SizeBytes = int(unsafe.Sizeof(decimal128.Num{})) + Decimal128SizeBytes = int(unsafe.Sizeof(decimal.Decimal128{})) ) type decimal128Traits struct{} @@ -37,7 +37,7 @@ type decimal128Traits struct{} func (decimal128Traits) BytesRequired(n int) int { return Decimal128SizeBytes * n } // PutValue -func (decimal128Traits) PutValue(b []byte, v decimal128.Num) { +func (decimal128Traits) PutValue(b []byte, v decimal.Decimal128) { endian.Native.PutUint64(b[:8], uint64(v.LowBits())) endian.Native.PutUint64(b[8:], uint64(v.HighBits())) } @@ -45,14 +45,14 @@ func (decimal128Traits) PutValue(b []byte, v decimal128.Num) { // CastFromBytes reinterprets the slice b to a slice of type uint16. // // NOTE: len(b) must be a multiple of Uint16SizeBytes. -func (decimal128Traits) CastFromBytes(b []byte) []decimal128.Num { - return GetData[decimal128.Num](b) +func (decimal128Traits) CastFromBytes(b []byte) []decimal.Decimal128 { + return GetData[decimal.Decimal128](b) } // CastToBytes reinterprets the slice b to a slice of bytes. -func (decimal128Traits) CastToBytes(b []decimal128.Num) []byte { +func (decimal128Traits) CastToBytes(b []decimal.Decimal128) []byte { return GetBytes(b) } // Copy copies src to dst. -func (decimal128Traits) Copy(dst, src []decimal128.Num) { copy(dst, src) } +func (decimal128Traits) Copy(dst, src []decimal.Decimal128) { copy(dst, src) } diff --git a/arrow/type_traits_decimal256.go b/arrow/type_traits_decimal256.go index f86bd2a6..b196c2e7 100644 --- a/arrow/type_traits_decimal256.go +++ b/arrow/type_traits_decimal256.go @@ -19,7 +19,7 @@ package arrow import ( "unsafe" - "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/endian" ) @@ -27,14 +27,14 @@ import ( var Decimal256Traits decimal256Traits const ( - Decimal256SizeBytes = int(unsafe.Sizeof(decimal256.Num{})) + Decimal256SizeBytes = int(unsafe.Sizeof(decimal.Decimal256{})) ) type decimal256Traits struct{} func (decimal256Traits) BytesRequired(n int) int { return Decimal256SizeBytes * n } -func (decimal256Traits) PutValue(b []byte, v decimal256.Num) { +func (decimal256Traits) PutValue(b []byte, v decimal.Decimal256) { for i, a := range v.Array() { start := i * 8 endian.Native.PutUint64(b[start:], a) @@ -42,12 +42,12 @@ func (decimal256Traits) PutValue(b []byte, v decimal256.Num) { } // CastFromBytes reinterprets the slice b to a slice of decimal256 -func (decimal256Traits) CastFromBytes(b []byte) []decimal256.Num { - return GetData[decimal256.Num](b) +func (decimal256Traits) CastFromBytes(b []byte) []decimal.Decimal256 { + return GetData[decimal.Decimal256](b) } -func (decimal256Traits) CastToBytes(b []decimal256.Num) []byte { +func (decimal256Traits) CastToBytes(b []decimal.Decimal256) []byte { return GetBytes(b) } -func (decimal256Traits) Copy(dst, src []decimal256.Num) { copy(dst, src) } +func (decimal256Traits) Copy(dst, src []decimal.Decimal256) { copy(dst, src) } diff --git a/arrow/type_traits_decimal32.go b/arrow/type_traits_decimal32.go new file mode 100644 index 00000000..ebca65f6 --- /dev/null +++ b/arrow/type_traits_decimal32.go @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arrow + +import ( + "unsafe" + + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/endian" +) + +// Decimal32 traits +var Decimal32Traits decimal32Traits + +const ( + // Decimal32SizeBytes specifies the number of bytes required to store a single decimal32 in memory + Decimal32SizeBytes = int(unsafe.Sizeof(decimal.Decimal32(0))) +) + +type decimal32Traits struct{} + +// BytesRequired returns the number of bytes required to store n elements in memory. +func (decimal32Traits) BytesRequired(n int) int { return Decimal32SizeBytes * n } + +// PutValue +func (decimal32Traits) PutValue(b []byte, v decimal.Decimal32) { + endian.Native.PutUint32(b[:4], uint32(v)) +} + +// CastFromBytes reinterprets the slice b to a slice of type uint16. +// +// NOTE: len(b) must be a multiple of Uint16SizeBytes. +func (decimal32Traits) CastFromBytes(b []byte) []decimal.Decimal32 { + return GetData[decimal.Decimal32](b) +} + +// CastToBytes reinterprets the slice b to a slice of bytes. +func (decimal32Traits) CastToBytes(b []decimal.Decimal32) []byte { + return GetBytes(b) +} + +// Copy copies src to dst. +func (decimal32Traits) Copy(dst, src []decimal.Decimal32) { copy(dst, src) } diff --git a/arrow/type_traits_decimal64.go b/arrow/type_traits_decimal64.go new file mode 100644 index 00000000..bd07883a --- /dev/null +++ b/arrow/type_traits_decimal64.go @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arrow + +import ( + "unsafe" + + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/endian" +) + +// Decimal64 traits +var Decimal64Traits decimal64Traits + +const ( + // Decimal64SizeBytes specifies the number of bytes required to store a single decimal64 in memory + Decimal64SizeBytes = int(unsafe.Sizeof(decimal.Decimal64(0))) +) + +type decimal64Traits struct{} + +// BytesRequired returns the number of bytes required to store n elements in memory. +func (decimal64Traits) BytesRequired(n int) int { return Decimal64SizeBytes * n } + +// PutValue +func (decimal64Traits) PutValue(b []byte, v decimal.Decimal64) { + endian.Native.PutUint64(b[:8], uint64(v)) +} + +// CastFromBytes reinterprets the slice b to a slice of type uint16. +// +// NOTE: len(b) must be a multiple of Uint16SizeBytes. +func (decimal64Traits) CastFromBytes(b []byte) []decimal.Decimal64 { + return GetData[decimal.Decimal64](b) +} + +// CastToBytes reinterprets the slice b to a slice of bytes. +func (decimal64Traits) CastToBytes(b []decimal.Decimal64) []byte { + return GetBytes(b) +} + +// Copy copies src to dst. +func (decimal64Traits) Copy(dst, src []decimal.Decimal64) { copy(dst, src) } diff --git a/arrow/type_traits_test.go b/arrow/type_traits_test.go index d86a67b1..93d98b95 100644 --- a/arrow/type_traits_test.go +++ b/arrow/type_traits_test.go @@ -23,8 +23,10 @@ import ( "testing" "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/float16" ) @@ -90,6 +92,94 @@ func TestFloat16Traits(t *testing.T) { } } +func TestDecimal32Traits(t *testing.T) { + const N = 10 + nbytes := arrow.Decimal32Traits.BytesRequired(N) + b1 := arrow.Decimal32Traits.CastToBytes([]decimal.Decimal32{ + decimal.Decimal32(0), + decimal.Decimal32(1), + decimal.Decimal32(2), + decimal.Decimal32(3), + decimal.Decimal32(4), + decimal.Decimal32(5), + decimal.Decimal32(6), + decimal.Decimal32(7), + decimal.Decimal32(8), + decimal.Decimal32(9), + }) + + b2 := make([]byte, nbytes) + for i := 0; i < N; i++ { + beg := i * arrow.Decimal32SizeBytes + end := (i + 1) * arrow.Decimal32SizeBytes + arrow.Decimal32Traits.PutValue(b2[beg:end], decimal.Decimal32(i)) + } + + if !reflect.DeepEqual(b1, b2) { + v1 := arrow.Decimal32Traits.CastFromBytes(b1) + v2 := arrow.Decimal32Traits.CastFromBytes(b2) + t.Fatalf("invalid values:\nb1=%v\nb2=%v\nv1=%v\nv2=%v\n", b1, b2, v1, v2) + } + + v1 := arrow.Decimal32Traits.CastFromBytes(b1) + for i, v := range v1 { + if got, want := v, decimal.Decimal32(i); got != want { + t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) + } + } + + v2 := make([]decimal.Decimal32, N) + arrow.Decimal32Traits.Copy(v2, v1) + + if !reflect.DeepEqual(v1, v2) { + t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) + } +} + +func TestDecimal64Traits(t *testing.T) { + const N = 10 + nbytes := arrow.Decimal64Traits.BytesRequired(N) + b1 := arrow.Decimal64Traits.CastToBytes([]decimal.Decimal64{ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + }) + + b2 := make([]byte, nbytes) + for i := 0; i < N; i++ { + beg := i * arrow.Decimal64SizeBytes + end := (i + 1) * arrow.Decimal64SizeBytes + arrow.Decimal64Traits.PutValue(b2[beg:end], decimal.Decimal64(i)) + } + + if !reflect.DeepEqual(b1, b2) { + v1 := arrow.Decimal64Traits.CastFromBytes(b1) + v2 := arrow.Decimal64Traits.CastFromBytes(b2) + t.Fatalf("invalid values:\nb1=%v\nb2=%v\nv1=%v\nv2=%v\n", b1, b2, v1, v2) + } + + v1 := arrow.Decimal64Traits.CastFromBytes(b1) + for i, v := range v1 { + if got, want := v, decimal.Decimal64(i); got != want { + t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) + } + } + + v2 := make([]decimal.Decimal64, N) + arrow.Decimal64Traits.Copy(v2, v1) + + if !reflect.DeepEqual(v1, v2) { + t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) + } +} + func TestDecimal128Traits(t *testing.T) { const N = 10 nbytes := arrow.Decimal128Traits.BytesRequired(N)