diff --git a/.gitignore b/.gitignore index 06f40703..55232367 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +.vscode /apache-arrow-go.tar.gz /dev/release/apache-rat-*.jar /dev/release/filtered_rat.txt diff --git a/arrow/array/array.go b/arrow/array/array.go index 586d8765..6e281a43 100644 --- a/arrow/array/array.go +++ b/arrow/array/array.go @@ -160,6 +160,8 @@ func init() { arrow.TIME64: func(data arrow.ArrayData) arrow.Array { return NewTime64Data(data) }, arrow.INTERVAL_MONTHS: func(data arrow.ArrayData) arrow.Array { return NewMonthIntervalData(data) }, arrow.INTERVAL_DAY_TIME: func(data arrow.ArrayData) arrow.Array { return NewDayTimeIntervalData(data) }, + arrow.DECIMAL32: func(data arrow.ArrayData) arrow.Array { return NewDecimal32Data(data) }, + arrow.DECIMAL64: func(data arrow.ArrayData) arrow.Array { return NewDecimal64Data(data) }, arrow.DECIMAL128: func(data arrow.ArrayData) arrow.Array { return NewDecimal128Data(data) }, arrow.DECIMAL256: func(data arrow.ArrayData) arrow.Array { return NewDecimal256Data(data) }, arrow.LIST: func(data arrow.ArrayData) arrow.Array { return NewListData(data) }, diff --git a/arrow/array/array_test.go b/arrow/array/array_test.go index 203c62ea..9509e314 100644 --- a/arrow/array/array_test.go +++ b/arrow/array/array_test.go @@ -75,6 +75,8 @@ func TestMakeFromData(t *testing.T) { {name: "time64", d: &testDataType{arrow.TIME64}}, {name: "month_interval", d: arrow.FixedWidthTypes.MonthInterval}, {name: "day_time_interval", d: arrow.FixedWidthTypes.DayTimeInterval}, + {name: "decimal32", d: &testDataType{arrow.DECIMAL32}}, + {name: "decimal64", d: &testDataType{arrow.DECIMAL64}}, {name: "decimal128", d: &testDataType{arrow.DECIMAL128}}, {name: "decimal256", d: &testDataType{arrow.DECIMAL256}}, {name: "month_day_nano_interval", d: arrow.FixedWidthTypes.MonthDayNanoInterval}, diff --git a/arrow/array/builder.go b/arrow/array/builder.go index 108b6152..a2a40d48 100644 --- a/arrow/array/builder.go +++ b/arrow/array/builder.go @@ -313,6 +313,14 @@ func NewBuilder(mem memory.Allocator, dtype arrow.DataType) Builder { return NewDayTimeIntervalBuilder(mem) case arrow.INTERVAL_MONTH_DAY_NANO: return NewMonthDayNanoIntervalBuilder(mem) + case arrow.DECIMAL32: + if typ, ok := dtype.(*arrow.Decimal32Type); ok { + return NewDecimal32Builder(mem, typ) + } + case arrow.DECIMAL64: + if typ, ok := dtype.(*arrow.Decimal64Type); ok { + return NewDecimal64Builder(mem, typ) + } case arrow.DECIMAL128: if typ, ok := dtype.(*arrow.Decimal128Type); ok { return NewDecimal128Builder(mem, typ) diff --git a/arrow/array/compare.go b/arrow/array/compare.go index 4117880f..ad3a50b8 100644 --- a/arrow/array/compare.go +++ b/arrow/array/compare.go @@ -271,12 +271,18 @@ func Equal(left, right arrow.Array) bool { case *Float64: r := right.(*Float64) return arrayEqualFloat64(l, r) + case *Decimal32: + r := right.(*Decimal32) + return arrayEqualDecimal(l, r) + case *Decimal64: + r := right.(*Decimal64) + return arrayEqualDecimal(l, r) case *Decimal128: r := right.(*Decimal128) - return arrayEqualDecimal128(l, r) + return arrayEqualDecimal(l, r) case *Decimal256: r := right.(*Decimal256) - return arrayEqualDecimal256(l, r) + return arrayEqualDecimal(l, r) case *Date32: r := right.(*Date32) return arrayEqualDate32(l, r) @@ -527,12 +533,18 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { case *Float64: r := right.(*Float64) return arrayApproxEqualFloat64(l, r, opt) + case *Decimal32: + r := right.(*Decimal32) + return arrayEqualDecimal(l, r) + case *Decimal64: + r := right.(*Decimal64) + return arrayEqualDecimal(l, r) case *Decimal128: r := right.(*Decimal128) - return arrayEqualDecimal128(l, r) + return arrayEqualDecimal(l, r) case *Decimal256: r := right.(*Decimal256) - return arrayEqualDecimal256(l, r) + return arrayEqualDecimal(l, r) case *Date32: r := right.(*Date32) return arrayEqualDate32(l, r) diff --git a/arrow/array/decimal.go b/arrow/array/decimal.go new file mode 100644 index 00000000..1a9d61c1 --- /dev/null +++ b/arrow/array/decimal.go @@ -0,0 +1,432 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package array + +import ( + "bytes" + "fmt" + "reflect" + "strings" + "sync/atomic" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/json" +) + +type baseDecimal[T interface { + decimal.DecimalTypes + decimal.Num[T] +}] struct { + array + + values []T +} + +func newDecimalData[T interface { + decimal.DecimalTypes + decimal.Num[T] +}](data arrow.ArrayData) *baseDecimal[T] { + a := &baseDecimal[T]{} + a.refCount = 1 + a.setData(data.(*Data)) + return a +} + +func (a *baseDecimal[T]) Value(i int) T { return a.values[i] } + +func (a *baseDecimal[T]) ValueStr(i int) string { + if a.IsNull(i) { + return NullValueStr + } + return a.GetOneForMarshal(i).(string) +} + +func (a *baseDecimal[T]) Values() []T { return a.values } + +func (a *baseDecimal[T]) String() string { + o := new(strings.Builder) + o.WriteString("[") + for i := 0; i < a.Len(); i++ { + if i > 0 { + fmt.Fprintf(o, " ") + } + switch { + case a.IsNull(i): + o.WriteString(NullValueStr) + default: + fmt.Fprintf(o, "%v", a.Value(i)) + } + } + o.WriteString("]") + return o.String() +} + +func (a *baseDecimal[T]) setData(data *Data) { + a.array.setData(data) + vals := data.buffers[1] + if vals != nil { + a.values = arrow.GetData[T](vals.Bytes()) + beg := a.array.data.offset + end := beg + a.array.data.length + a.values = a.values[beg:end] + } +} + +func (a *baseDecimal[T]) GetOneForMarshal(i int) any { + if a.IsNull(i) { + return nil + } + + typ := a.DataType().(arrow.DecimalType) + n, scale := a.Value(i), typ.GetScale() + return n.ToBigFloat(scale).Text('g', int(typ.GetPrecision())) +} + +func (a *baseDecimal[T]) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := 0; i < a.Len(); i++ { + vals[i] = a.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +func arrayEqualDecimal[T interface { + decimal.DecimalTypes + decimal.Num[T] +}](left, right *baseDecimal[T]) bool { + for i := 0; i < left.Len(); i++ { + if left.IsNull(i) { + continue + } + + if left.Value(i) != right.Value(i) { + return false + } + } + return true +} + +type Decimal32 = baseDecimal[decimal.Decimal32] + +func NewDecimal32Data(data arrow.ArrayData) *Decimal32 { + return newDecimalData[decimal.Decimal32](data) +} + +type Decimal64 = baseDecimal[decimal.Decimal64] + +func NewDecimal64Data(data arrow.ArrayData) *Decimal64 { + return newDecimalData[decimal.Decimal64](data) +} + +type Decimal128 = baseDecimal[decimal.Decimal128] + +func NewDecimal128Data(data arrow.ArrayData) *Decimal128 { + return newDecimalData[decimal.Decimal128](data) +} + +type Decimal256 = baseDecimal[decimal.Decimal256] + +func NewDecimal256Data(data arrow.ArrayData) *Decimal256 { + return newDecimalData[decimal.Decimal256](data) +} + +type Decimal32Builder = baseDecimalBuilder[decimal.Decimal32] +type Decimal64Builder = baseDecimalBuilder[decimal.Decimal64] +type Decimal128Builder struct { + *baseDecimalBuilder[decimal.Decimal128] +} + +func (b *Decimal128Builder) NewDecimal128Array() *Decimal128 { + return b.NewDecimalArray() +} + +type Decimal256Builder struct { + *baseDecimalBuilder[decimal.Decimal256] +} + +func (b *Decimal256Builder) NewDecimal256Array() *Decimal256 { + return b.NewDecimalArray() +} + +type baseDecimalBuilder[T interface { + decimal.DecimalTypes + decimal.Num[T] +}] struct { + builder + traits decimal.Traits[T] + + dtype arrow.DecimalType + data *memory.Buffer + rawData []T +} + +func newDecimalBuilder[T interface { + decimal.DecimalTypes + decimal.Num[T] +}, DT arrow.DecimalType](mem memory.Allocator, dtype DT) *baseDecimalBuilder[T] { + return &baseDecimalBuilder[T]{ + builder: builder{refCount: 1, mem: mem}, + dtype: dtype, + } +} + +func (b *baseDecimalBuilder[T]) Type() arrow.DataType { return b.dtype } + +func (b *baseDecimalBuilder[T]) Release() { + debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + + if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.nullBitmap != nil { + b.nullBitmap.Release() + b.nullBitmap = nil + } + if b.data != nil { + b.data.Release() + b.data, b.rawData = nil, nil + } + } +} + +func (b *baseDecimalBuilder[T]) Append(v T) { + b.Reserve(1) + b.UnsafeAppend(v) +} + +func (b *baseDecimalBuilder[T]) UnsafeAppend(v T) { + bitutil.SetBit(b.nullBitmap.Bytes(), b.length) + b.rawData[b.length] = v + b.length++ +} + +func (b *baseDecimalBuilder[T]) AppendNull() { + b.Reserve(1) + b.UnsafeAppendBoolToBitmap(false) +} + +func (b *baseDecimalBuilder[T]) AppendNulls(n int) { + for i := 0; i < n; i++ { + b.AppendNull() + } +} + +func (b *baseDecimalBuilder[T]) AppendEmptyValue() { + var empty T + b.Append(empty) +} + +func (b *baseDecimalBuilder[T]) AppendEmptyValues(n int) { + for i := 0; i < n; i++ { + b.AppendEmptyValue() + } +} + +func (b *baseDecimalBuilder[T]) UnsafeAppendBoolToBitmap(isValid bool) { + if isValid { + bitutil.SetBit(b.nullBitmap.Bytes(), b.length) + } else { + b.nulls++ + } + b.length++ +} + +func (b *baseDecimalBuilder[T]) AppendValues(v []T, valid []bool) { + if len(v) != len(valid) && len(valid) != 0 { + panic("len(v) != len(valid) && len(valid) != 0") + } + + if len(v) == 0 { + return + } + + b.Reserve(len(v)) + if len(v) > 0 { + copy(b.rawData[b.length:], v) + } + b.builder.unsafeAppendBoolsToBitmap(valid, len(v)) +} + +func (b *baseDecimalBuilder[T]) init(capacity int) { + b.builder.init(capacity) + + b.data = memory.NewResizableBuffer(b.mem) + bytesN := int(reflect.TypeFor[T]().Size()) * capacity + b.data.Resize(bytesN) + b.rawData = arrow.GetData[T](b.data.Bytes()) +} + +func (b *baseDecimalBuilder[T]) Reserve(n int) { + b.builder.reserve(n, b.Resize) +} + +func (b *baseDecimalBuilder[T]) Resize(n int) { + nBuilder := n + if n < minBuilderCapacity { + n = minBuilderCapacity + } + + if b.capacity == 0 { + b.init(n) + } else { + b.builder.resize(nBuilder, b.init) + b.data.Resize(b.traits.BytesRequired(n)) + b.rawData = arrow.GetData[T](b.data.Bytes()) + } +} + +func (b *baseDecimalBuilder[T]) NewDecimalArray() (a *baseDecimal[T]) { + data := b.newData() + a = newDecimalData[T](data) + data.Release() + return +} + +func (b *baseDecimalBuilder[T]) NewArray() arrow.Array { + return b.NewDecimalArray() +} + +func (b *baseDecimalBuilder[T]) newData() (data *Data) { + bytesRequired := b.traits.BytesRequired(b.length) + if bytesRequired > 0 && bytesRequired < b.data.Len() { + // trim buffers + b.data.Resize(bytesRequired) + } + data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap, b.data}, nil, b.nulls, 0) + b.reset() + + if b.data != nil { + b.data.Release() + b.data, b.rawData = nil, nil + } + + return +} + +func (b *baseDecimalBuilder[T]) AppendValueFromString(s string) error { + if s == NullValueStr { + b.AppendNull() + return nil + } + + val, err := b.traits.FromString(s, b.dtype.GetPrecision(), b.dtype.GetScale()) + if err != nil { + b.AppendNull() + return err + } + b.Append(val) + return nil +} + +func (b *baseDecimalBuilder[T]) UnmarshalOne(dec *json.Decoder) error { + t, err := dec.Token() + if err != nil { + return err + } + + var token T + switch v := t.(type) { + case float64: + token, err = b.traits.FromFloat64(v, b.dtype.GetPrecision(), b.dtype.GetScale()) + if err != nil { + return err + } + b.Append(token) + case string: + token, err = b.traits.FromString(v, b.dtype.GetPrecision(), b.dtype.GetScale()) + if err != nil { + return err + } + b.Append(token) + case json.Number: + token, err = b.traits.FromString(v.String(), b.dtype.GetPrecision(), b.dtype.GetScale()) + if err != nil { + return err + } + b.Append(token) + case nil: + b.AppendNull() + default: + return &json.UnmarshalTypeError{ + Value: fmt.Sprint(t), + Type: reflect.TypeFor[T](), + Offset: dec.InputOffset(), + } + } + + return nil +} + +func (b *baseDecimalBuilder[T]) Unmarshal(dec *json.Decoder) error { + for dec.More() { + if err := b.UnmarshalOne(dec); err != nil { + return err + } + } + return nil +} + +func (b *baseDecimalBuilder[T]) UnmarshalJSON(data []byte) error { + dec := json.NewDecoder(bytes.NewReader(data)) + t, err := dec.Token() + if err != nil { + return err + } + + if delim, ok := t.(json.Delim); !ok || delim != '[' { + return fmt.Errorf("decimal builder must unpack from json array, found %s", delim) + } + + return b.Unmarshal(dec) +} + +func NewDecimal32Builder(mem memory.Allocator, dtype *arrow.Decimal32Type) *Decimal32Builder { + b := newDecimalBuilder[decimal.Decimal32](mem, dtype) + b.traits = decimal.Dec32Traits + return b +} + +func NewDecimal64Builder(mem memory.Allocator, dtype *arrow.Decimal64Type) *Decimal64Builder { + b := newDecimalBuilder[decimal.Decimal64](mem, dtype) + b.traits = decimal.Dec64Traits + return b +} + +func NewDecimal128Builder(mem memory.Allocator, dtype *arrow.Decimal128Type) *Decimal128Builder { + b := newDecimalBuilder[decimal.Decimal128](mem, dtype) + b.traits = decimal.Dec128Traits + return &Decimal128Builder{b} +} + +func NewDecimal256Builder(mem memory.Allocator, dtype *arrow.Decimal256Type) *Decimal256Builder { + b := newDecimalBuilder[decimal.Decimal256](mem, dtype) + b.traits = decimal.Dec256Traits + return &Decimal256Builder{b} +} + +var ( + _ arrow.Array = (*Decimal32)(nil) + _ arrow.Array = (*Decimal64)(nil) + _ arrow.Array = (*Decimal128)(nil) + _ arrow.Array = (*Decimal256)(nil) + _ Builder = (*Decimal32Builder)(nil) + _ Builder = (*Decimal64Builder)(nil) + _ Builder = (*Decimal128Builder)(nil) + _ Builder = (*Decimal256Builder)(nil) +) diff --git a/arrow/array/decimal128.go b/arrow/array/decimal128.go deleted file mode 100644 index c5861dce..00000000 --- a/arrow/array/decimal128.go +++ /dev/null @@ -1,368 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package array - -import ( - "bytes" - "fmt" - "math/big" - "reflect" - "strings" - "sync/atomic" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/bitutil" - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/internal/debug" - "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/arrow-go/v18/internal/json" -) - -// A type which represents an immutable sequence of 128-bit decimal values. -type Decimal128 struct { - array - - values []decimal128.Num -} - -func NewDecimal128Data(data arrow.ArrayData) *Decimal128 { - a := &Decimal128{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -func (a *Decimal128) Value(i int) decimal128.Num { return a.values[i] } - -func (a *Decimal128) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return a.GetOneForMarshal(i).(string) -} - -func (a *Decimal128) Values() []decimal128.Num { return a.values } - -func (a *Decimal128) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i := 0; i < a.Len(); i++ { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", a.Value(i)) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Decimal128) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Decimal128Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} -func (a *Decimal128) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - typ := a.DataType().(*arrow.Decimal128Type) - n := a.Value(i) - scale := typ.Scale - f := (&big.Float{}).SetInt(n.BigInt()) - if scale < 0 { - f.SetPrec(128).Mul(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(-scale)).BigInt())) - } else { - f.SetPrec(128).Quo(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(scale)).BigInt())) - } - return f.Text('g', int(typ.Precision)) -} - -// ["1.23", ] -func (a *Decimal128) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) - } - return json.Marshal(vals) -} - -func arrayEqualDecimal128(left, right *Decimal128) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -type Decimal128Builder struct { - builder - - dtype *arrow.Decimal128Type - data *memory.Buffer - rawData []decimal128.Num -} - -func NewDecimal128Builder(mem memory.Allocator, dtype *arrow.Decimal128Type) *Decimal128Builder { - return &Decimal128Builder{ - builder: builder{refCount: 1, mem: mem}, - dtype: dtype, - } -} - -func (b *Decimal128Builder) Type() arrow.DataType { return b.dtype } - -// Release decreases the reference count by 1. -// When the reference count goes to zero, the memory is freed. -func (b *Decimal128Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") - - if atomic.AddInt64(&b.refCount, -1) == 0 { - if b.nullBitmap != nil { - b.nullBitmap.Release() - b.nullBitmap = nil - } - if b.data != nil { - b.data.Release() - b.data = nil - b.rawData = nil - } - } -} - -func (b *Decimal128Builder) Append(v decimal128.Num) { - b.Reserve(1) - b.UnsafeAppend(v) -} - -func (b *Decimal128Builder) UnsafeAppend(v decimal128.Num) { - bitutil.SetBit(b.nullBitmap.Bytes(), b.length) - b.rawData[b.length] = v - b.length++ -} - -func (b *Decimal128Builder) AppendNull() { - b.Reserve(1) - b.UnsafeAppendBoolToBitmap(false) -} - -func (b *Decimal128Builder) AppendNulls(n int) { - for i := 0; i < n; i++ { - b.AppendNull() - } -} - -func (b *Decimal128Builder) AppendEmptyValue() { - b.Append(decimal128.Num{}) -} - -func (b *Decimal128Builder) AppendEmptyValues(n int) { - for i := 0; i < n; i++ { - b.AppendEmptyValue() - } -} - -func (b *Decimal128Builder) UnsafeAppendBoolToBitmap(isValid bool) { - if isValid { - bitutil.SetBit(b.nullBitmap.Bytes(), b.length) - } else { - b.nulls++ - } - b.length++ -} - -// AppendValues will append the values in the v slice. The valid slice determines which values -// in v are valid (not null). The valid slice must either be empty or be equal in length to v. If empty, -// all values in v are appended and considered valid. -func (b *Decimal128Builder) AppendValues(v []decimal128.Num, valid []bool) { - if len(v) != len(valid) && len(valid) != 0 { - panic("len(v) != len(valid) && len(valid) != 0") - } - - if len(v) == 0 { - return - } - - b.Reserve(len(v)) - if len(v) > 0 { - arrow.Decimal128Traits.Copy(b.rawData[b.length:], v) - } - b.builder.unsafeAppendBoolsToBitmap(valid, len(v)) -} - -func (b *Decimal128Builder) init(capacity int) { - b.builder.init(capacity) - - b.data = memory.NewResizableBuffer(b.mem) - bytesN := arrow.Decimal128Traits.BytesRequired(capacity) - b.data.Resize(bytesN) - b.rawData = arrow.Decimal128Traits.CastFromBytes(b.data.Bytes()) -} - -// Reserve ensures there is enough space for appending n elements -// by checking the capacity and calling Resize if necessary. -func (b *Decimal128Builder) Reserve(n int) { - b.builder.reserve(n, b.Resize) -} - -// Resize adjusts the space allocated by b to n elements. If n is greater than b.Cap(), -// additional memory will be allocated. If n is smaller, the allocated memory may reduced. -func (b *Decimal128Builder) Resize(n int) { - nBuilder := n - if n < minBuilderCapacity { - n = minBuilderCapacity - } - - if b.capacity == 0 { - b.init(n) - } else { - b.builder.resize(nBuilder, b.init) - b.data.Resize(arrow.Decimal128Traits.BytesRequired(n)) - b.rawData = arrow.Decimal128Traits.CastFromBytes(b.data.Bytes()) - } -} - -// NewArray creates a Decimal128 array from the memory buffers used by the builder and resets the Decimal128Builder -// so it can be used to build a new array. -func (b *Decimal128Builder) NewArray() arrow.Array { - return b.NewDecimal128Array() -} - -// NewDecimal128Array creates a Decimal128 array from the memory buffers used by the builder and resets the Decimal128Builder -// so it can be used to build a new array. -func (b *Decimal128Builder) NewDecimal128Array() (a *Decimal128) { - data := b.newData() - a = NewDecimal128Data(data) - data.Release() - return -} - -func (b *Decimal128Builder) newData() (data *Data) { - bytesRequired := arrow.Decimal128Traits.BytesRequired(b.length) - if bytesRequired > 0 && bytesRequired < b.data.Len() { - // trim buffers - b.data.Resize(bytesRequired) - } - data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap, b.data}, nil, b.nulls, 0) - b.reset() - - if b.data != nil { - b.data.Release() - b.data = nil - b.rawData = nil - } - - return -} - -func (b *Decimal128Builder) AppendValueFromString(s string) error { - if s == NullValueStr { - b.AppendNull() - return nil - } - val, err := decimal128.FromString(s, b.dtype.Precision, b.dtype.Scale) - if err != nil { - b.AppendNull() - return err - } - b.Append(val) - return nil -} - -func (b *Decimal128Builder) UnmarshalOne(dec *json.Decoder) error { - t, err := dec.Token() - if err != nil { - return err - } - - switch v := t.(type) { - case float64: - val, err := decimal128.FromFloat64(v, b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(val) - case string: - val, err := decimal128.FromString(v, b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(val) - case json.Number: - val, err := decimal128.FromString(v.String(), b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(val) - case nil: - b.AppendNull() - return nil - default: - return &json.UnmarshalTypeError{ - Value: fmt.Sprint(t), - Type: reflect.TypeOf(decimal128.Num{}), - Offset: dec.InputOffset(), - } - } - - return nil -} - -func (b *Decimal128Builder) Unmarshal(dec *json.Decoder) error { - for dec.More() { - if err := b.UnmarshalOne(dec); err != nil { - return err - } - } - return nil -} - -// UnmarshalJSON will add the unmarshalled values to this builder. -// -// If the values are strings, they will get parsed with big.ParseFloat using -// a rounding mode of big.ToNearestAway currently. -func (b *Decimal128Builder) UnmarshalJSON(data []byte) error { - dec := json.NewDecoder(bytes.NewReader(data)) - t, err := dec.Token() - if err != nil { - return err - } - - if delim, ok := t.(json.Delim); !ok || delim != '[' { - return fmt.Errorf("decimal128 builder must unpack from json array, found %s", delim) - } - - return b.Unmarshal(dec) -} - -var ( - _ arrow.Array = (*Decimal128)(nil) - _ Builder = (*Decimal128Builder)(nil) -) diff --git a/arrow/array/decimal256.go b/arrow/array/decimal256.go deleted file mode 100644 index 7f30f897..00000000 --- a/arrow/array/decimal256.go +++ /dev/null @@ -1,368 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package array - -import ( - "bytes" - "fmt" - "math/big" - "reflect" - "strings" - "sync/atomic" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/bitutil" - "github.com/apache/arrow-go/v18/arrow/decimal256" - "github.com/apache/arrow-go/v18/arrow/internal/debug" - "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/arrow-go/v18/internal/json" -) - -// Decimal256 is a type that represents an immutable sequence of 256-bit decimal values. -type Decimal256 struct { - array - - values []decimal256.Num -} - -func NewDecimal256Data(data arrow.ArrayData) *Decimal256 { - a := &Decimal256{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -func (a *Decimal256) Value(i int) decimal256.Num { return a.values[i] } - -func (a *Decimal256) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return a.GetOneForMarshal(i).(string) -} - -func (a *Decimal256) Values() []decimal256.Num { return a.values } - -func (a *Decimal256) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i := 0; i < a.Len(); i++ { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", a.Value(i)) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Decimal256) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Decimal256Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Decimal256) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - typ := a.DataType().(*arrow.Decimal256Type) - n := a.Value(i) - scale := typ.Scale - f := (&big.Float{}).SetInt(n.BigInt()) - if scale < 0 { - f.SetPrec(256).Mul(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(-scale)).BigInt())) - } else { - f.SetPrec(256).Quo(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(scale)).BigInt())) - } - return f.Text('g', int(typ.Precision)) -} - -func (a *Decimal256) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) - } - return json.Marshal(vals) -} - -func arrayEqualDecimal256(left, right *Decimal256) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -type Decimal256Builder struct { - builder - - dtype *arrow.Decimal256Type - data *memory.Buffer - rawData []decimal256.Num -} - -func NewDecimal256Builder(mem memory.Allocator, dtype *arrow.Decimal256Type) *Decimal256Builder { - return &Decimal256Builder{ - builder: builder{refCount: 1, mem: mem}, - dtype: dtype, - } -} - -// Release decreases the reference count by 1. -// When the reference count goes to zero, the memory is freed. -func (b *Decimal256Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") - - if atomic.AddInt64(&b.refCount, -1) == 0 { - if b.nullBitmap != nil { - b.nullBitmap.Release() - b.nullBitmap = nil - } - if b.data != nil { - b.data.Release() - b.data = nil - b.rawData = nil - } - } -} - -func (b *Decimal256Builder) Append(v decimal256.Num) { - b.Reserve(1) - b.UnsafeAppend(v) -} - -func (b *Decimal256Builder) UnsafeAppend(v decimal256.Num) { - bitutil.SetBit(b.nullBitmap.Bytes(), b.length) - b.rawData[b.length] = v - b.length++ -} - -func (b *Decimal256Builder) AppendNull() { - b.Reserve(1) - b.UnsafeAppendBoolToBitmap(false) -} - -func (b *Decimal256Builder) AppendNulls(n int) { - for i := 0; i < n; i++ { - b.AppendNull() - } -} - -func (b *Decimal256Builder) AppendEmptyValue() { - b.Append(decimal256.Num{}) -} - -func (b *Decimal256Builder) AppendEmptyValues(n int) { - for i := 0; i < n; i++ { - b.AppendEmptyValue() - } -} - -func (b *Decimal256Builder) Type() arrow.DataType { return b.dtype } - -func (b *Decimal256Builder) UnsafeAppendBoolToBitmap(isValid bool) { - if isValid { - bitutil.SetBit(b.nullBitmap.Bytes(), b.length) - } else { - b.nulls++ - } - b.length++ -} - -// AppendValues will append the values in the v slice. The valid slice determines which values -// in v are valid (not null). The valid slice must either be empty or be equal in length to v. If empty, -// all values in v are appended and considered valid. -func (b *Decimal256Builder) AppendValues(v []decimal256.Num, valid []bool) { - if len(v) != len(valid) && len(valid) != 0 { - panic("arrow/array: len(v) != len(valid) && len(valid) != 0") - } - - if len(v) == 0 { - return - } - - b.Reserve(len(v)) - if len(v) > 0 { - arrow.Decimal256Traits.Copy(b.rawData[b.length:], v) - } - b.builder.unsafeAppendBoolsToBitmap(valid, len(v)) -} - -func (b *Decimal256Builder) init(capacity int) { - b.builder.init(capacity) - - b.data = memory.NewResizableBuffer(b.mem) - bytesN := arrow.Decimal256Traits.BytesRequired(capacity) - b.data.Resize(bytesN) - b.rawData = arrow.Decimal256Traits.CastFromBytes(b.data.Bytes()) -} - -// Reserve ensures there is enough space for appending n elements -// by checking the capacity and calling Resize if necessary. -func (b *Decimal256Builder) Reserve(n int) { - b.builder.reserve(n, b.Resize) -} - -// Resize adjusts the space allocated by b to n elements. If n is greater than b.Cap(), -// additional memory will be allocated. If n is smaller, the allocated memory may reduced. -func (b *Decimal256Builder) Resize(n int) { - nBuilder := n - if n < minBuilderCapacity { - n = minBuilderCapacity - } - - if b.capacity == 0 { - b.init(n) - } else { - b.builder.resize(nBuilder, b.init) - b.data.Resize(arrow.Decimal256Traits.BytesRequired(n)) - b.rawData = arrow.Decimal256Traits.CastFromBytes(b.data.Bytes()) - } -} - -// NewArray creates a Decimal256 array from the memory buffers used by the builder and resets the Decimal256Builder -// so it can be used to build a new array. -func (b *Decimal256Builder) NewArray() arrow.Array { - return b.NewDecimal256Array() -} - -// NewDecimal256Array creates a Decimal256 array from the memory buffers used by the builder and resets the Decimal256Builder -// so it can be used to build a new array. -func (b *Decimal256Builder) NewDecimal256Array() (a *Decimal256) { - data := b.newData() - a = NewDecimal256Data(data) - data.Release() - return -} - -func (b *Decimal256Builder) newData() (data *Data) { - bytesRequired := arrow.Decimal256Traits.BytesRequired(b.length) - if bytesRequired > 0 && bytesRequired < b.data.Len() { - // trim buffers - b.data.Resize(bytesRequired) - } - data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap, b.data}, nil, b.nulls, 0) - b.reset() - - if b.data != nil { - b.data.Release() - b.data = nil - b.rawData = nil - } - - return -} - -func (b *Decimal256Builder) AppendValueFromString(s string) error { - if s == NullValueStr { - b.AppendNull() - return nil - } - val, err := decimal256.FromString(s, b.dtype.Precision, b.dtype.Scale) - if err != nil { - b.AppendNull() - return err - } - b.Append(val) - return nil -} - -func (b *Decimal256Builder) UnmarshalOne(dec *json.Decoder) error { - t, err := dec.Token() - if err != nil { - return err - } - - switch v := t.(type) { - case float64: - val, err := decimal256.FromFloat64(v, b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(val) - case string: - out, err := decimal256.FromString(v, b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(out) - case json.Number: - out, err := decimal256.FromString(v.String(), b.dtype.Precision, b.dtype.Scale) - if err != nil { - return err - } - b.Append(out) - case nil: - b.AppendNull() - return nil - default: - return &json.UnmarshalTypeError{ - Value: fmt.Sprint(t), - Type: reflect.TypeOf(decimal256.Num{}), - Offset: dec.InputOffset(), - } - } - - return nil -} - -func (b *Decimal256Builder) Unmarshal(dec *json.Decoder) error { - for dec.More() { - if err := b.UnmarshalOne(dec); err != nil { - return err - } - } - return nil -} - -// UnmarshalJSON will add the unmarshalled values to this builder. -// -// If the values are strings, they will get parsed with big.ParseFloat using -// a rounding mode of big.ToNearestAway currently. -func (b *Decimal256Builder) UnmarshalJSON(data []byte) error { - dec := json.NewDecoder(bytes.NewReader(data)) - t, err := dec.Token() - if err != nil { - return err - } - - if delim, ok := t.(json.Delim); !ok || delim != '[' { - return fmt.Errorf("arrow/array: decimal256 builder must unpack from json array, found %s", delim) - } - - return b.Unmarshal(dec) -} - -var ( - _ arrow.Array = (*Decimal256)(nil) - _ Builder = (*Decimal256Builder)(nil) -) diff --git a/arrow/array/dictionary.go b/arrow/array/dictionary.go index 34f0f2b4..0c23934a 100644 --- a/arrow/array/dictionary.go +++ b/arrow/array/dictionary.go @@ -27,6 +27,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/float16" @@ -392,7 +393,8 @@ func createMemoTable(mem memory.Allocator, dt arrow.DataType) (ret hashing.MemoT ret = hashing.NewFloat32MemoTable(0) case arrow.FLOAT64: ret = hashing.NewFloat64MemoTable(0) - case arrow.BINARY, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL128, arrow.DECIMAL256, arrow.INTERVAL_DAY_TIME, arrow.INTERVAL_MONTH_DAY_NANO: + case arrow.BINARY, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL32, arrow.DECIMAL64, + arrow.DECIMAL128, arrow.DECIMAL256, arrow.INTERVAL_DAY_TIME, arrow.INTERVAL_MONTH_DAY_NANO: ret = hashing.NewBinaryMemoTable(0, 0, NewBinaryBuilder(mem, arrow.BinaryTypes.Binary)) case arrow.STRING: ret = hashing.NewBinaryMemoTable(0, 0, NewBinaryBuilder(mem, arrow.BinaryTypes.String)) @@ -623,6 +625,22 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType } } return ret + case arrow.DECIMAL32: + ret := &Decimal32DictionaryBuilder{bldr} + if init != nil { + if err = ret.InsertDictValues(init.(*Decimal32)); err != nil { + panic(err) + } + } + return ret + case arrow.DECIMAL64: + ret := &Decimal64DictionaryBuilder{bldr} + if init != nil { + if err = ret.InsertDictValues(init.(*Decimal64)); err != nil { + panic(err) + } + } + return ret case arrow.DECIMAL128: ret := &Decimal128DictionaryBuilder{bldr} if init != nil { @@ -906,6 +924,16 @@ func getvalFn(arr arrow.Array) func(i int) interface{} { return func(i int) interface{} { return typedarr.Value(i) } case *String: return func(i int) interface{} { return typedarr.Value(i) } + case *Decimal32: + return func(i int) interface{} { + val := typedarr.Value(i) + return (*(*[arrow.Decimal32SizeBytes]byte)(unsafe.Pointer(&val)))[:] + } + case *Decimal64: + return func(i int) interface{} { + val := typedarr.Value(i) + return (*(*[arrow.Decimal64SizeBytes]byte)(unsafe.Pointer(&val)))[:] + } case *Decimal128: return func(i int) interface{} { val := typedarr.Value(i) @@ -1394,6 +1422,42 @@ func (b *FixedSizeBinaryDictionaryBuilder) InsertDictValues(arr *FixedSizeBinary return } +type Decimal32DictionaryBuilder struct { + dictionaryBuilder +} + +func (b *Decimal32DictionaryBuilder) Append(v decimal.Decimal32) error { + return b.appendValue((*(*[arrow.Decimal32SizeBytes]byte)(unsafe.Pointer(&v)))[:]) +} +func (b *Decimal32DictionaryBuilder) InsertDictValues(arr *Decimal32) (err error) { + data := arrow.Decimal32Traits.CastToBytes(arr.values) + for len(data) > 0 { + if err = b.insertDictValue(data[:arrow.Decimal32SizeBytes]); err != nil { + break + } + data = data[arrow.Decimal32SizeBytes:] + } + return +} + +type Decimal64DictionaryBuilder struct { + dictionaryBuilder +} + +func (b *Decimal64DictionaryBuilder) Append(v decimal.Decimal64) error { + return b.appendValue((*(*[arrow.Decimal64SizeBytes]byte)(unsafe.Pointer(&v)))[:]) +} +func (b *Decimal64DictionaryBuilder) InsertDictValues(arr *Decimal64) (err error) { + data := arrow.Decimal64Traits.CastToBytes(arr.values) + for len(data) > 0 { + if err = b.insertDictValue(data[:arrow.Decimal64SizeBytes]); err != nil { + break + } + data = data[arrow.Decimal64SizeBytes:] + } + return +} + type Decimal128DictionaryBuilder struct { dictionaryBuilder } diff --git a/arrow/array/numeric.gen.go b/arrow/array/numeric.gen.go index c6c7b0bd..7e94fe5c 100644 --- a/arrow/array/numeric.gen.go +++ b/arrow/array/numeric.gen.go @@ -101,11 +101,13 @@ func (a *Int64) GetOneForMarshal(i int) interface{} { func (a *Int64) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -196,11 +198,13 @@ func (a *Uint64) GetOneForMarshal(i int) interface{} { func (a *Uint64) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -398,11 +402,13 @@ func (a *Int32) GetOneForMarshal(i int) interface{} { func (a *Int32) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -493,11 +499,13 @@ func (a *Uint32) GetOneForMarshal(i int) interface{} { func (a *Uint32) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -602,6 +610,7 @@ func (a *Float32) MarshalJSON() ([]byte, error) { default: vals[i] = f } + } return json.Marshal(vals) @@ -692,11 +701,13 @@ func (a *Int16) GetOneForMarshal(i int) interface{} { func (a *Int16) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -787,11 +798,13 @@ func (a *Uint16) GetOneForMarshal(i int) interface{} { func (a *Uint16) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = a.values[i] } else { vals[i] = nil } + } return json.Marshal(vals) @@ -882,11 +895,13 @@ func (a *Int8) GetOneForMarshal(i int) interface{} { func (a *Int8) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data } else { vals[i] = nil } + } return json.Marshal(vals) @@ -977,11 +992,13 @@ func (a *Uint8) GetOneForMarshal(i int) interface{} { func (a *Uint8) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data } else { vals[i] = nil } + } return json.Marshal(vals) diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go index 4688a1ef..d5748a35 100644 --- a/arrow/cdata/cdata.go +++ b/arrow/cdata/cdata.go @@ -254,12 +254,17 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) { return ret, xerrors.Errorf("could not parse decimal scale in '%s': %s", f, err.Error()) } - if bitwidth == 128 { + switch bitwidth { + case 32: + dt = &arrow.Decimal32Type{Precision: int32(precision), Scale: int32(scale)} + case 64: + dt = &arrow.Decimal64Type{Precision: int32(precision), Scale: int32(scale)} + case 128: dt = &arrow.Decimal128Type{Precision: int32(precision), Scale: int32(scale)} - } else if bitwidth == 256 { + case 256: dt = &arrow.Decimal256Type{Precision: int32(precision), Scale: int32(scale)} - } else { - return ret, xerrors.Errorf("only decimal128 and decimal256 are supported, got '%s'", f) + default: + return ret, xerrors.Errorf("unsupported decimal bitwidth, got '%s'", f) } } diff --git a/arrow/cdata/cdata_exports.go b/arrow/cdata/cdata_exports.go index 4ed9d0e5..d3673481 100644 --- a/arrow/cdata/cdata_exports.go +++ b/arrow/cdata/cdata_exports.go @@ -154,6 +154,10 @@ func (exp *schemaExporter) exportFormat(dt arrow.DataType) string { return "g" case *arrow.FixedSizeBinaryType: return fmt.Sprintf("w:%d", dt.ByteWidth) + case *arrow.Decimal32Type: + return fmt.Sprintf("d:%d,%d,32", dt.Precision, dt.Scale) + case *arrow.Decimal64Type: + return fmt.Sprintf("d:%d,%d,64", dt.Precision, dt.Scale) case *arrow.Decimal128Type: return fmt.Sprintf("d:%d,%d", dt.Precision, dt.Scale) case *arrow.Decimal256Type: diff --git a/arrow/cdata/cdata_test.go b/arrow/cdata/cdata_test.go index 697a73b3..2a86ea62 100644 --- a/arrow/cdata/cdata_test.go +++ b/arrow/cdata/cdata_test.go @@ -153,7 +153,7 @@ func TestDecimalSchemaErrors(t *testing.T) { {"d:a,2,3", "could not parse decimal precision in 'd:a,2,3':"}, {"d:1,a,3", "could not parse decimal scale in 'd:1,a,3':"}, {"d:1,2,a", "could not parse decimal bitwidth in 'd:1,2,a':"}, - {"d:1,2,384", "only decimal128 and decimal256 are supported, got 'd:1,2,384'"}, + {"d:1,2,384", "unsupported decimal bitwidth, got 'd:1,2,384'"}, } for _, tt := range tests { diff --git a/arrow/datatype.go b/arrow/datatype.go index 2fba6550..95565859 100644 --- a/arrow/datatype.go +++ b/arrow/datatype.go @@ -107,7 +107,7 @@ const ( // parameters. DECIMAL128 - // DECIMAL256 is a precision and scale based decimal type, with 256 bit max. not yet implemented + // DECIMAL256 is a precision and scale based decimal type, with 256 bit max. DECIMAL256 // LIST is a list of some logical data type @@ -116,10 +116,10 @@ const ( // STRUCT of logical types STRUCT - // SPARSE_UNION of logical types. not yet implemented + // SPARSE_UNION of logical types SPARSE_UNION - // DENSE_UNION of logical types. not yet implemented + // DENSE_UNION of logical types DENSE_UNION // DICTIONARY aka Category type @@ -138,13 +138,13 @@ const ( // or nanoseconds. DURATION - // like STRING, but 64-bit offsets. not yet implemented + // like STRING, but 64-bit offsets LARGE_STRING - // like BINARY but with 64-bit offsets, not yet implemented + // like BINARY but with 64-bit offsets LARGE_BINARY - // like LIST but with 64-bit offsets. not yet implemented + // like LIST but with 64-bit offsets LARGE_LIST // calendar interval with three fields @@ -165,6 +165,12 @@ const ( // like LIST but with 64-bit offsets LARGE_LIST_VIEW + // Decimal value with 32-bit representation + DECIMAL32 + + // Decimal value with 64-bit representation + DECIMAL64 + // Alias to ensure we do not break any consumers DECIMAL = DECIMAL128 ) @@ -365,10 +371,10 @@ func IsLargeBinaryLike(t Type) bool { return false } -// IsFixedSizeBinary returns true for Decimal128/256 and FixedSizeBinary +// IsFixedSizeBinary returns true for Decimal32/64/128/256 and FixedSizeBinary func IsFixedSizeBinary(t Type) bool { switch t { - case DECIMAL128, DECIMAL256, FIXED_SIZE_BINARY: + case DECIMAL32, DECIMAL64, DECIMAL128, DECIMAL256, FIXED_SIZE_BINARY: return true } return false @@ -377,7 +383,7 @@ func IsFixedSizeBinary(t Type) bool { // IsDecimal returns true for Decimal128 and Decimal256 func IsDecimal(t Type) bool { switch t { - case DECIMAL128, DECIMAL256: + case DECIMAL32, DECIMAL64, DECIMAL128, DECIMAL256: return true } return false diff --git a/arrow/datatype_fixedwidth.go b/arrow/datatype_fixedwidth.go index 41c7b6f3..5928be3a 100644 --- a/arrow/datatype_fixedwidth.go +++ b/arrow/datatype_fixedwidth.go @@ -22,8 +22,9 @@ import ( "sync" "time" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/internal/json" - "golang.org/x/xerrors" ) @@ -532,19 +533,103 @@ type DecimalType interface { DataType GetPrecision() int32 GetScale() int32 + BitWidth() int +} + +// NarrowestDecimalType constructs the smallest decimal type that can represent +// the requested precision. An error is returned if the requested precision +// cannot be represented (prec <= 0 || prec > 76). +// +// For reference: +// +// prec in [ 1, 9] => Decimal32Type +// prec in [10, 18] => Decimal64Type +// prec in [19, 38] => Decimal128Type +// prec in [39, 76] => Decimal256Type +func NarrowestDecimalType(prec, scale int32) (DecimalType, error) { + switch { + case prec <= 0: + return nil, fmt.Errorf("%w: precision must be > 0 for decimal types, got %d", + ErrInvalid, prec) + case prec <= int32(decimal.MaxPrecision[decimal.Decimal32]()): + return &Decimal32Type{Precision: prec, Scale: scale}, nil + case prec <= int32(decimal.MaxPrecision[decimal.Decimal64]()): + return &Decimal64Type{Precision: prec, Scale: scale}, nil + case prec <= int32(decimal.MaxPrecision[decimal.Decimal128]()): + return &Decimal128Type{Precision: prec, Scale: scale}, nil + case prec <= int32(decimal.MaxPrecision[decimal.Decimal256]()): + return &Decimal256Type{Precision: prec, Scale: scale}, nil + default: + return nil, fmt.Errorf("%w: invalid precision for decimal types, %d", + ErrInvalid, prec) + } } func NewDecimalType(id Type, prec, scale int32) (DecimalType, error) { switch id { + case DECIMAL32: + debug.Assert(prec <= int32(decimal.MaxPrecision[decimal.Decimal32]()), "invalid precision for decimal32") + return &Decimal32Type{Precision: prec, Scale: scale}, nil + case DECIMAL64: + debug.Assert(prec <= int32(decimal.MaxPrecision[decimal.Decimal64]()), "invalid precision for decimal64") + return &Decimal64Type{Precision: prec, Scale: scale}, nil case DECIMAL128: + debug.Assert(prec <= int32(decimal.MaxPrecision[decimal.Decimal128]()), "invalid precision for decimal128") return &Decimal128Type{Precision: prec, Scale: scale}, nil case DECIMAL256: + debug.Assert(prec <= int32(decimal.MaxPrecision[decimal.Decimal256]()), "invalid precision for decimal256") return &Decimal256Type{Precision: prec, Scale: scale}, nil default: - return nil, fmt.Errorf("%w: must use DECIMAL128 or DECIMAL256 to create a DecimalType", ErrInvalid) + return nil, fmt.Errorf("%w: must use one of the DECIMAL IDs to create a DecimalType", ErrInvalid) } } +// Decimal32Type represents a fixed-size 32-bit decimal type. +type Decimal32Type struct { + Precision int32 + Scale int32 +} + +func (*Decimal32Type) ID() Type { return DECIMAL32 } +func (*Decimal32Type) Name() string { return "decimal32" } +func (*Decimal32Type) BitWidth() int { return 32 } +func (*Decimal32Type) Bytes() int { return Decimal32SizeBytes } +func (t *Decimal32Type) String() string { + return fmt.Sprintf("%s(%d, %d)", t.Name(), t.Precision, t.Scale) +} +func (t *Decimal32Type) Fingerprint() string { + return fmt.Sprintf("%s[%d,%d,%d]", typeFingerprint(t), t.BitWidth(), t.Precision, t.Scale) +} +func (t *Decimal32Type) GetPrecision() int32 { return t.Precision } +func (t *Decimal32Type) GetScale() int32 { return t.Scale } + +func (Decimal32Type) Layout() DataTypeLayout { + return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(Decimal32SizeBytes)}} +} + +// Decimal64Type represents a fixed-size 32-bit decimal type. +type Decimal64Type struct { + Precision int32 + Scale int32 +} + +func (*Decimal64Type) ID() Type { return DECIMAL64 } +func (*Decimal64Type) Name() string { return "decimal64" } +func (*Decimal64Type) BitWidth() int { return 64 } +func (*Decimal64Type) Bytes() int { return Decimal64SizeBytes } +func (t *Decimal64Type) String() string { + return fmt.Sprintf("%s(%d, %d)", t.Name(), t.Precision, t.Scale) +} +func (t *Decimal64Type) Fingerprint() string { + return fmt.Sprintf("%s[%d,%d,%d]", typeFingerprint(t), t.BitWidth(), t.Precision, t.Scale) +} +func (t *Decimal64Type) GetPrecision() int32 { return t.Precision } +func (t *Decimal64Type) GetScale() int32 { return t.Scale } + +func (Decimal64Type) Layout() DataTypeLayout { + return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(Decimal64SizeBytes)}} +} + // Decimal128Type represents a fixed-size 128-bit decimal type. type Decimal128Type struct { Precision int32 diff --git a/arrow/datatype_fixedwidth_test.go b/arrow/datatype_fixedwidth_test.go index d60c6b17..bc899f34 100644 --- a/arrow/datatype_fixedwidth_test.go +++ b/arrow/datatype_fixedwidth_test.go @@ -23,6 +23,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestTimeUnit_String verifies each time unit matches its string representation. @@ -43,6 +44,60 @@ func TestTimeUnit_String(t *testing.T) { } } +func TestDecimal32Type(t *testing.T) { + for _, tc := range []struct { + precision int32 + scale int32 + want string + }{ + {1, 9, "decimal32(1, 9)"}, + {9, 9, "decimal32(9, 9)"}, + {9, 1, "decimal32(9, 1)"}, + } { + t.Run(tc.want, func(t *testing.T) { + dt := arrow.Decimal32Type{Precision: tc.precision, Scale: tc.scale} + if got, want := dt.BitWidth(), 32; got != want { + t.Fatalf("invalid bitwidth: got=%d, want=%d", got, want) + } + + if got, want := dt.ID(), arrow.DECIMAL32; got != want { + t.Fatalf("invalid type ID: got=%v, want=%v", got, want) + } + + if got, want := dt.String(), tc.want; got != want { + t.Fatalf("invalid stringer: got=%q, want=%q", got, want) + } + }) + } +} + +func TestDecimal64Type(t *testing.T) { + for _, tc := range []struct { + precision int32 + scale int32 + want string + }{ + {1, 10, "decimal64(1, 10)"}, + {10, 10, "decimal64(10, 10)"}, + {10, 1, "decimal64(10, 1)"}, + } { + t.Run(tc.want, func(t *testing.T) { + dt := arrow.Decimal64Type{Precision: tc.precision, Scale: tc.scale} + if got, want := dt.BitWidth(), 64; got != want { + t.Fatalf("invalid bitwidth: got=%d, want=%d", got, want) + } + + if got, want := dt.ID(), arrow.DECIMAL64; got != want { + t.Fatalf("invalid type ID: got=%v, want=%v", got, want) + } + + if got, want := dt.String(), tc.want; got != want { + t.Fatalf("invalid stringer: got=%q, want=%q", got, want) + } + }) + } +} + func TestDecimal128Type(t *testing.T) { for _, tc := range []struct { precision int32 @@ -438,3 +493,36 @@ func TestDateFromTime(t *testing.T) { assert.EqualValues(t, wantD64, arrow.Date64FromTime(tm)) assert.EqualValues(t, wantD32, arrow.Date32FromTime(tm)) } + +func TestNarrowestDecimalType(t *testing.T) { + tests := []struct { + min, max int32 + expected arrow.Type + }{ + {1, 9, arrow.DECIMAL32}, + {10, 18, arrow.DECIMAL64}, + {19, 38, arrow.DECIMAL128}, + {39, 76, arrow.DECIMAL256}, + } + + for _, tt := range tests { + t.Run(tt.expected.String(), func(t *testing.T) { + for i := tt.min; i <= tt.max; i++ { + typ, err := arrow.NarrowestDecimalType(i, 5) + require.NoError(t, err) + + assert.Equal(t, i, typ.GetPrecision()) + assert.Equal(t, int32(5), typ.GetScale()) + assert.Equal(t, tt.expected, typ.ID()) + } + }) + } + + _, err := arrow.NarrowestDecimalType(-1, 5) + assert.Error(t, err) + assert.ErrorIs(t, err, arrow.ErrInvalid) + + _, err = arrow.NarrowestDecimalType(78, 5) + assert.Error(t, err) + assert.ErrorIs(t, err, arrow.ErrInvalid) +} diff --git a/arrow/decimal/decimal.go b/arrow/decimal/decimal.go new file mode 100644 index 00000000..098a4e0f --- /dev/null +++ b/arrow/decimal/decimal.go @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package decimal + +import ( + "errors" + "fmt" + "math" + "math/big" + "math/bits" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/internal/debug" +) + +// DecimalTypes is a generic constraint representing the implemented decimal types +// in this package, and a single point of update for future additions. Everything +// else is constrained by this. +type DecimalTypes interface { + Decimal32 | Decimal64 | Decimal128 | Decimal256 +} + +// Num is an interface that is useful for building generic types for all decimal +// type implementations. It presents all the methods and interfaces necessary to +// operate on the decimal objects without having to care about the bit width. +type Num[T DecimalTypes] interface { + Negate() T + Add(T) T + Sub(T) T + Mul(T) T + Div(T) (res, rem T) + Pow(T) T + + FitsInPrecision(int32) bool + Abs() T + Sign() int + Rescale(int32, int32) (T, error) + Cmp(T) int + + IncreaseScaleBy(int32) T + ReduceScaleBy(int32, bool) T + + ToFloat32(int32) float32 + ToFloat64(int32) float64 + ToBigFloat(int32) *big.Float + + ToString(int32) string +} + +type ( + Decimal32 int32 + Decimal64 int64 + Decimal128 = decimal128.Num + Decimal256 = decimal256.Num +) + +func MaxPrecision[T DecimalTypes]() int { + // max precision is computed by Floor(log10(2^(nbytes * 8 - 1) - 1)) + var z T + return int(math.Floor(math.Log10(math.Pow(2, float64(unsafe.Sizeof(z))*8-1) - 1))) +} + +func (d Decimal32) Negate() Decimal32 { return -d } +func (d Decimal64) Negate() Decimal64 { return -d } + +func (d Decimal32) Add(rhs Decimal32) Decimal32 { return d + rhs } +func (d Decimal64) Add(rhs Decimal64) Decimal64 { return d + rhs } + +func (d Decimal32) Sub(rhs Decimal32) Decimal32 { return d - rhs } +func (d Decimal64) Sub(rhs Decimal64) Decimal64 { return d - rhs } + +func (d Decimal32) Mul(rhs Decimal32) Decimal32 { return d * rhs } +func (d Decimal64) Mul(rhs Decimal64) Decimal64 { return d * rhs } + +func (d Decimal32) Div(rhs Decimal32) (res, rem Decimal32) { + return d / rhs, d % rhs +} + +func (d Decimal64) Div(rhs Decimal64) (res, rem Decimal64) { + return d / rhs, d % rhs +} + +// about 4-5x faster than using math.Pow which requires converting to float64 +// and back to integers +func intPow[T int32 | int64](base, exp T) T { + result := T(1) + for { + if exp&1 == 1 { + result *= base + } + exp >>= 1 + if exp == 0 { + break + } + base *= base + } + + return result +} + +func (d Decimal32) Pow(rhs Decimal32) Decimal32 { + return Decimal32(intPow(int32(d), int32(rhs))) +} + +func (d Decimal64) Pow(rhs Decimal64) Decimal64 { + return Decimal64(intPow(int64(d), int64(rhs))) +} + +func (d Decimal32) Sign() int { + if d == 0 { + return 0 + } + return int(1 | (d >> 31)) +} + +func (d Decimal64) Sign() int { + if d == 0 { + return 0 + } + return int(1 | (d >> 63)) +} + +func (n Decimal32) Abs() Decimal32 { + if n < 0 { + return -n + } + return n +} + +func (n Decimal64) Abs() Decimal64 { + if n < 0 { + return -n + } + return n +} + +func (n Decimal32) FitsInPrecision(prec int32) bool { + debug.Assert(prec > 0, "precision must be > 0") + debug.Assert(prec <= 9, "precision must be <= 9") + return n.Abs() < Decimal32(math.Pow10(int(prec))) +} + +func (n Decimal64) FitsInPrecision(prec int32) bool { + debug.Assert(prec > 0, "precision must be > 0") + debug.Assert(prec <= 18, "precision must be <= 18") + return n.Abs() < Decimal64(math.Pow10(int(prec))) +} + +func (n Decimal32) ToString(scale int32) string { + return n.ToBigFloat(scale).Text('f', int(scale)) +} + +func (n Decimal64) ToString(scale int32) string { + return n.ToBigFloat(scale).Text('f', int(scale)) +} + +var pt5 = big.NewFloat(0.5) + +func decimalFromString[T interface { + Decimal32 | Decimal64 + FitsInPrecision(int32) bool +}](v string, prec, scale int32) (n T, err error) { + var nbits = uint(unsafe.Sizeof(T(0))) * 8 + + var out *big.Float + out, _, err = big.ParseFloat(v, 10, nbits, big.ToNearestEven) + + if scale < 0 { + var tmp big.Int + val, _ := out.Int(&tmp) + if val.BitLen() > int(nbits) { + return n, fmt.Errorf("bitlen too large for decimal%d", nbits) + } + + n = T(val.Int64() / int64(math.Pow10(int(-scale)))) + } else { + var precInBits = uint(math.Round(float64(prec+scale+1)/math.Log10(2))) + 1 + + p := (&big.Float{}).SetInt(big.NewInt(int64(math.Pow10(int(scale))))) + out.SetPrec(precInBits).Mul(out, p) + if out.Signbit() { + out.Sub(out, pt5) + } else { + out.Add(out, pt5) + } + + var tmp big.Int + val, _ := out.Int(&tmp) + if val.BitLen() > int(nbits) { + return n, fmt.Errorf("bitlen too large for decimal%d", nbits) + } + n = T(val.Int64()) + } + + if !n.FitsInPrecision(prec) { + err = fmt.Errorf("val %v doesn't fit in precision %d", n, prec) + } + return +} + +func Decimal32FromString(v string, prec, scale int32) (n Decimal32, err error) { + return decimalFromString[Decimal32](v, prec, scale) +} + +func Decimal64FromString(v string, prec, scale int32) (n Decimal64, err error) { + return decimalFromString[Decimal64](v, prec, scale) +} + +func Decimal128FromString(v string, prec, scale int32) (n Decimal128, err error) { + return decimal128.FromString(v, prec, scale) +} + +func Decimal256FromString(v string, prec, scale int32) (n Decimal256, err error) { + return decimal256.FromString(v, prec, scale) +} + +func scalePositiveFloat64(v float64, prec, scale int32) (float64, error) { + v *= math.Pow10(int(scale)) + v = math.RoundToEven(v) + + maxabs := math.Pow10(int(prec)) + if v >= maxabs { + return 0, fmt.Errorf("cannot convert %f to decimal(precision=%d, scale=%d)", v, prec, scale) + } + return v, nil +} + +func fromPositiveFloat[T Decimal32 | Decimal64, F float32 | float64](v F, prec, scale int32) (T, error) { + if prec > int32(MaxPrecision[T]()) { + return T(0), fmt.Errorf("invalid precision %d for converting float to Decimal", prec) + } + + val, err := scalePositiveFloat64(float64(v), prec, scale) + if err != nil { + return T(0), err + } + + return T(F(val)), nil +} + +func Decimal32FromFloat[F float32 | float64](v F, prec, scale int32) (Decimal32, error) { + if v < 0 { + dec, err := fromPositiveFloat[Decimal32](-v, prec, scale) + if err != nil { + return dec, err + } + + return -dec, nil + } + + return fromPositiveFloat[Decimal32](v, prec, scale) +} + +func Decimal64FromFloat[F float32 | float64](v F, prec, scale int32) (Decimal64, error) { + if v < 0 { + dec, err := fromPositiveFloat[Decimal64](-v, prec, scale) + if err != nil { + return dec, err + } + + return -dec, nil + } + + return fromPositiveFloat[Decimal64](v, prec, scale) +} + +func Decimal128FromFloat(v float64, prec, scale int32) (Decimal128, error) { + return decimal128.FromFloat64(v, prec, scale) +} + +func Decimal256FromFloat(v float64, prec, scale int32) (Decimal256, error) { + return decimal256.FromFloat64(v, prec, scale) +} + +func (n Decimal32) ToFloat32(scale int32) float32 { + return float32(n.ToFloat64(scale)) +} + +func (n Decimal64) ToFloat32(scale int32) float32 { + return float32(n.ToFloat64(scale)) +} + +func (n Decimal32) ToFloat64(scale int32) float64 { + return float64(n) * math.Pow10(-int(scale)) +} + +func (n Decimal64) ToFloat64(scale int32) float64 { + return float64(n) * math.Pow10(-int(scale)) +} + +func (n Decimal32) ToBigFloat(scale int32) *big.Float { + f := (&big.Float{}).SetInt64(int64(n)) + if scale < 0 { + f.SetPrec(32).Mul(f, (&big.Float{}).SetInt64(intPow(10, -int64(scale)))) + } else { + f.SetPrec(32).Quo(f, (&big.Float{}).SetInt64(intPow(10, int64(scale)))) + } + return f +} + +func (n Decimal64) ToBigFloat(scale int32) *big.Float { + f := (&big.Float{}).SetInt64(int64(n)) + if scale < 0 { + f.SetPrec(64).Mul(f, (&big.Float{}).SetInt64(intPow(10, -int64(scale)))) + } else { + f.SetPrec(64).Quo(f, (&big.Float{}).SetInt64(intPow(10, int64(scale)))) + } + return f +} + +func cmpDec[T Decimal32 | Decimal64](lhs, rhs T) int { + switch { + case lhs > rhs: + return 1 + case lhs < rhs: + return -1 + } + return 0 +} + +func (n Decimal32) Cmp(other Decimal32) int { + return cmpDec(n, other) +} + +func (n Decimal64) Cmp(other Decimal64) int { + return cmpDec(n, other) +} + +func (n Decimal32) IncreaseScaleBy(increase int32) Decimal32 { + debug.Assert(increase >= 0, "invalid increase scale for decimal32") + debug.Assert(increase <= 9, "invalid increase scale for decimal32") + + return n * Decimal32(intPow(10, increase)) +} + +func (n Decimal64) IncreaseScaleBy(increase int32) Decimal64 { + debug.Assert(increase >= 0, "invalid increase scale for decimal64") + debug.Assert(increase <= 18, "invalid increase scale for decimal64") + + return n * Decimal64(intPow(10, int64(increase))) +} + +func reduceScale[T interface { + Decimal32 | Decimal64 + Abs() T +}](n T, reduce int32, round bool) T { + if reduce == 0 { + return n + } + + divisor := T(intPow(10, reduce)) + if !round { + return n / divisor + } + + quo, remainder := n/divisor, n%divisor + divisorHalf := divisor / 2 + if remainder.Abs() >= divisorHalf { + if n > 0 { + quo++ + } else { + quo-- + } + } + + return quo +} + +func (n Decimal32) ReduceScaleBy(reduce int32, round bool) Decimal32 { + debug.Assert(reduce >= 0, "invalid reduce scale for decimal32") + debug.Assert(reduce <= 9, "invalid reduce scale for decimal32") + + return reduceScale(n, reduce, round) +} + +func (n Decimal64) ReduceScaleBy(reduce int32, round bool) Decimal64 { + debug.Assert(reduce >= 0, "invalid reduce scale for decimal32") + debug.Assert(reduce <= 18, "invalid reduce scale for decimal32") + + return reduceScale(n, reduce, round) +} + +//lint:ignore U1000 function is being used, staticcheck seems to not follow generics +func (n Decimal32) rescaleWouldCauseDataLoss(deltaScale int32, multiplier Decimal32) (out Decimal32, loss bool) { + if deltaScale < 0 { + debug.Assert(multiplier != 0, "multiplier must not be zero") + quo, remainder := bits.Div32(0, uint32(n), uint32(multiplier)) + return Decimal32(quo), remainder != 0 + } + + overflow, result := bits.Mul32(uint32(n), uint32(multiplier)) + if overflow != 0 { + return Decimal32(result), true + } + + out = Decimal32(result) + return out, out < n +} + +//lint:ignore U1000 function is being used, staticcheck seems to not follow generics +func (n Decimal64) rescaleWouldCauseDataLoss(deltaScale int32, multiplier Decimal64) (out Decimal64, loss bool) { + if deltaScale < 0 { + debug.Assert(multiplier != 0, "multiplier must not be zero") + quo, remainder := bits.Div32(0, uint32(n), uint32(multiplier)) + return Decimal64(quo), remainder != 0 + } + + overflow, result := bits.Mul32(uint32(n), uint32(multiplier)) + if overflow != 0 { + return Decimal64(result), true + } + + out = Decimal64(result) + return out, out < n +} + +func rescale[T interface { + Decimal32 | Decimal64 + rescaleWouldCauseDataLoss(int32, T) (T, bool) + Sign() int +}](n T, originalScale, newScale int32) (out T, err error) { + if originalScale == newScale { + return n, nil + } + + deltaScale := newScale - originalScale + absDeltaScale := int32(math.Abs(float64(deltaScale))) + + sign := n.Sign() + if n < 0 { + n = -n + } + + multiplier := T(intPow(10, absDeltaScale)) + var wouldHaveLoss bool + out, wouldHaveLoss = n.rescaleWouldCauseDataLoss(deltaScale, multiplier) + if wouldHaveLoss { + err = errors.New("rescale data loss") + } + out *= T(sign) + return +} + +func (n Decimal32) Rescale(originalScale, newScale int32) (out Decimal32, err error) { + return rescale(n, originalScale, newScale) +} + +func (n Decimal64) Rescale(originalScale, newScale int32) (out Decimal64, err error) { + return rescale(n, originalScale, newScale) +} + +var ( + _ Num[Decimal32] = Decimal32(0) + _ Num[Decimal64] = Decimal64(0) + _ Num[Decimal128] = Decimal128{} + _ Num[Decimal256] = Decimal256{} +) diff --git a/arrow/decimal/decimal_test.go b/arrow/decimal/decimal_test.go new file mode 100644 index 00000000..9b4f04f5 --- /dev/null +++ b/arrow/decimal/decimal_test.go @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package decimal_test + +import ( + "fmt" + "math" + "math/big" + "strconv" + "testing" + + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func ulps64(actual, expected float64) int64 { + ulp := math.Nextafter(actual, math.Inf(1)) - actual + return int64(math.Abs((expected - actual) / ulp)) +} + +func ulps32(actual, expected float32) int64 { + ulp := math.Nextafter32(actual, float32(math.Inf(1))) - actual + return int64(math.Abs(float64((expected - actual) / ulp))) +} + +func assertFloat32Approx(t *testing.T, x, y float32) bool { + const maxulps int64 = 4 + ulps := ulps32(x, y) + return assert.LessOrEqualf(t, ulps, maxulps, "%f not equal to %f (%d ulps)", x, y, ulps) +} + +func assertFloat64Approx(t *testing.T, x, y float64) bool { + const maxulps int64 = 4 + ulps := ulps64(x, y) + return assert.LessOrEqualf(t, ulps, maxulps, "%f not equal to %f (%d ulps)", x, y, ulps) +} + +func TestDecimalToReal(t *testing.T) { + tests := []struct { + decimalVal string + scale int32 + exp float64 + }{ + {"0", 0, 0}, + {"0", 10, 0.0}, + {"0", -10, 0.0}, + {"1", 0, 1.0}, + {"12345", 0, 12345.0}, + {"12345", 1, 1234.5}, + {"536870912", 0, math.Pow(2, 29)}, + } + + t.Run("float32", func(t *testing.T) { + checkDecimalToFloat := func(t *testing.T, str string, v float32, scale int32) { + n, err := decimal.Decimal32FromString(str, 9, 0) + require.NoError(t, err) + assert.Equalf(t, v, n.ToFloat32(scale), "Decimal Val: %s, Scale: %d", str, scale) + + n64, err := decimal.Decimal64FromString(str, 18, 0) + require.NoError(t, err) + assert.Equalf(t, v, n64.ToFloat32(scale), "Decimal Val: %s, Scale: %d", str, scale) + } + for _, tt := range tests { + t.Run(tt.decimalVal, func(t *testing.T) { + checkDecimalToFloat(t, tt.decimalVal, float32(tt.exp), tt.scale) + if tt.decimalVal != "0" { + checkDecimalToFloat(t, "-"+tt.decimalVal, float32(-tt.exp), tt.scale) + } + }) + } + + t.Run("large values", func(t *testing.T) { + checkApproxDecimaltoFloat := func(str string, v float32, scale int32) { + n, err := decimal.Decimal32FromString(str, 9, 0) + require.NoError(t, err) + assertFloat32Approx(t, v, n.ToFloat32(scale)) + } + + checkApproxDecimal64toFloat := func(str string, v float32, scale int32) { + n, err := decimal.Decimal64FromString(str, 9, 0) + require.NoError(t, err) + assertFloat32Approx(t, v, n.ToFloat32(scale)) + } + + // exact comparisons would succeed on most platforms, but not all power-of-ten + // factors are exactly representable in binary floating point, so we'll use + // approx and ensure that the values are within 4 ULP (unit of least precision) + for scale := int32(-9); scale <= 9; scale++ { + checkApproxDecimaltoFloat("1", float32(math.Pow10(-int(scale))), scale) + checkApproxDecimaltoFloat("123", float32(123)*float32(math.Pow10(-int(scale))), scale) + } + + for scale := int32(-18); scale <= 18; scale++ { + checkApproxDecimal64toFloat("1", float32(math.Pow10(-int(scale))), scale) + checkApproxDecimal64toFloat("123", float32(123)*float32(math.Pow10(-int(scale))), scale) + } + }) + }) + + t.Run("float64", func(t *testing.T) { + checkDecimalToFloat := func(t *testing.T, str string, v float64, scale int32) { + n, err := decimal.Decimal32FromString(str, 9, 0) + require.NoError(t, err) + assert.Equalf(t, v, n.ToFloat64(scale), "Decimal Val: %s, Scale: %d", str, scale) + + assert.Equalf(t, big.NewFloat(v).SetPrec(32), n.ToBigFloat(scale), + "Decimal Val: %s, Scale: %d", str, scale) + + n64, err := decimal.Decimal64FromString(str, 18, 0) + require.NoError(t, err) + assert.Equalf(t, v, n64.ToFloat64(scale), "Decimal Val: %s, Scale: %d", str, scale) + assert.Equalf(t, big.NewFloat(v).SetPrec(64), n64.ToBigFloat(scale), + "Decimal Val: %s, Scale: %d", str, scale) + } + for _, tt := range tests { + t.Run(tt.decimalVal, func(t *testing.T) { + checkDecimalToFloat(t, tt.decimalVal, tt.exp, tt.scale) + if tt.decimalVal != "0" { + checkDecimalToFloat(t, "-"+tt.decimalVal, -tt.exp, tt.scale) + } + }) + } + + t.Run("large values", func(t *testing.T) { + checkApproxDecimaltoFloat := func(str string, v float64, scale int32) { + n, err := decimal.Decimal32FromString(str, 9, 0) + require.NoError(t, err) + assertFloat64Approx(t, v, n.ToFloat64(scale)) + + assert.Equalf(t, big.NewFloat(v).SetPrec(32), n.ToBigFloat(scale), + "Decimal Val: %s, Scale: %d", str, scale) + } + + checkApproxDecimal64toFloat := func(str string, v float64, scale int32) { + n, err := decimal.Decimal64FromString(str, 9, 0) + require.NoError(t, err) + assertFloat64Approx(t, v, n.ToFloat64(scale)) + + bf, _ := n.ToBigFloat(scale).Float64() + assertFloat64Approx(t, v, bf) + } + + // exact comparisons would succeed on most platforms, but not all power-of-ten + // factors are exactly representable in binary floating point, so we'll use + // approx and ensure that the values are within 4 ULP (unit of least precision) + for scale := int32(-9); scale <= 9; scale++ { + checkApproxDecimaltoFloat("1", math.Pow10(-int(scale)), scale) + checkApproxDecimaltoFloat("123", float64(123)*math.Pow10(-int(scale)), scale) + } + + for scale := int32(-18); scale <= 18; scale++ { + checkApproxDecimal64toFloat("1", math.Pow10(-int(scale)), scale) + checkApproxDecimal64toFloat("123", float64(123)*math.Pow10(-int(scale)), scale) + } + }) + }) +} + +func TestDecimalFromFloat(t *testing.T) { + tests := []struct { + val float64 + precision, scale int32 + expected string + }{ + {0, 1, 0, "0"}, + {-0, 1, 0, "0"}, + {0, 9, 4, "0.0000"}, + {math.Copysign(0.0, -1), 9, 4, "0.0000"}, + {123, 7, 4, "123.0000"}, + {-123, 7, 4, "-123.0000"}, + {456.78, 7, 4, "456.7800"}, + {-456.78, 7, 4, "-456.7800"}, + {456.784, 5, 2, "456.78"}, + {-456.784, 5, 2, "-456.78"}, + {456.786, 5, 2, "456.79"}, + {-456.786, 5, 2, "-456.79"}, + {999.99, 5, 2, "999.99"}, + {-999.99, 5, 2, "-999.99"}, + {123, 9, 0, "123"}, + {-123, 9, 0, "-123"}, + {123.4, 9, 0, "123"}, + {-123.4, 9, 0, "-123"}, + {123.6, 9, 0, "124"}, + {-123.6, 9, 0, "-124"}, + } + + t.Run("float64", func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + n, err := decimal.Decimal32FromFloat(tt.val, tt.precision, tt.scale) + require.NoError(t, err) + assert.Equal(t, tt.expected, fmt.Sprintf("%."+strconv.Itoa(int(tt.scale))+"f", n.ToFloat64(tt.scale))) + }) + } + + t.Run("large values", func(t *testing.T) { + // test entire float64 range + for scale := int32(-308); scale <= 308; scale++ { + val := math.Pow10(int(scale)) + n, err := decimal.Decimal64FromFloat(val, 1, -scale) + require.NoError(t, err) + assert.EqualValues(t, 1, n) + } + + for scale := int32(-307); scale <= 306; scale++ { + val := 123 * math.Pow10(int(scale)) + n, err := decimal.Decimal64FromFloat(val, 2, -scale-1) + require.NoError(t, err) + assert.EqualValues(t, 12, n) + n, err = decimal.Decimal64FromFloat(val, 3, -scale) + require.NoError(t, err) + assert.EqualValues(t, 123, n) + n, err = decimal.Decimal64FromFloat(val, 4, -scale+1) + require.NoError(t, err) + assert.EqualValues(t, 1230, n) + } + }) + }) + + t.Run("float32", func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + n, err := decimal.Decimal32FromFloat(float32(tt.val), tt.precision, tt.scale) + require.NoError(t, err) + assert.Equal(t, tt.expected, fmt.Sprintf("%."+strconv.Itoa(int(tt.scale))+"f", n.ToFloat32(tt.scale))) + }) + } + + t.Run("large values", func(t *testing.T) { + // test entire float32 range + for scale := int32(-38); scale <= 38; scale++ { + val := float32(math.Pow10(int(scale))) + n, err := decimal.Decimal64FromFloat(val, 1, -scale) + require.NoError(t, err) + assert.EqualValues(t, 1, n) + } + + for scale := int32(-37); scale <= 36; scale++ { + val := 123 * float32(math.Pow10(int(scale))) + n, err := decimal.Decimal64FromFloat(val, 2, -scale-1) + require.NoError(t, err) + assert.EqualValues(t, 12, n) + n, err = decimal.Decimal64FromFloat(val, 3, -scale) + require.NoError(t, err) + assert.EqualValues(t, 123, n) + n, err = decimal.Decimal64FromFloat(val, 4, -scale+1) + require.NoError(t, err) + assert.EqualValues(t, 1230, n) + } + }) + }) +} + +func TestFromString(t *testing.T) { + tests := []struct { + s string + expected int64 + expectedScale int32 + }{ + {"12.3", 123, 1}, + {"0.00123", 123, 5}, + {"1.23e-8", 123, 10}, + {"-1.23E-8", -123, 10}, + {"1.23e+3", 1230, 0}, + {"-1.23E+3", -1230, 0}, + {"1.23e+5", 123000, 0}, + {"1.2345E+7", 12345000, 0}, + {"1.23e-8", 123, 10}, + {"-1.23E-8", -123, 10}, + {"0000000", 0, 0}, + {"000.0000", 0, 4}, + {".0000", 0, 5}, + {"1e1", 10, 0}, + {"+234.567", 234567, 3}, + {"1e-8", 1, 8}, + {"2112.33", 211233, 2}, + {"-2112.33", -211233, 2}, + {"12E2", 12, -2}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_%d", tt.s, tt.expectedScale), func(t *testing.T) { + n, err := decimal.Decimal32FromString(tt.s, 8, tt.expectedScale) + require.NoError(t, err) + + ex := decimal.Decimal32(tt.expected) + assert.Equal(t, ex, n) + + n64, err := decimal.Decimal64FromString(tt.s, 8, tt.expectedScale) + require.NoError(t, err) + + ex64 := decimal.Decimal64(tt.expected) + assert.Equal(t, ex64, n64) + }) + } +} + +func TestCmp(t *testing.T) { + for _, tc := range []struct { + n decimal.Decimal32 + rhs decimal.Decimal32 + want int + }{ + {decimal.Decimal32(2), decimal.Decimal32(1), 1}, + {decimal.Decimal32(-1), decimal.Decimal32(-2), 1}, + {decimal.Decimal32(2), decimal.Decimal32(3), -1}, + {decimal.Decimal32(-3), decimal.Decimal32(-2), -1}, + {decimal.Decimal32(2), decimal.Decimal32(2), 0}, + {decimal.Decimal32(-2), decimal.Decimal32(-2), 0}, + } { + t.Run("cmp", func(t *testing.T) { + n := tc.n.Cmp(tc.rhs) + if got, want := n, tc.want; got != want { + t.Fatalf("invalid value. got=%v, want=%v", got, want) + } + }) + } + + for _, tc := range []struct { + n decimal.Decimal64 + rhs decimal.Decimal64 + want int + }{ + {decimal.Decimal64(2), decimal.Decimal64(1), 1}, + {decimal.Decimal64(-1), decimal.Decimal64(-2), 1}, + {decimal.Decimal64(2), decimal.Decimal64(3), -1}, + {decimal.Decimal64(-3), decimal.Decimal64(-2), -1}, + {decimal.Decimal64(2), decimal.Decimal64(2), 0}, + {decimal.Decimal64(-2), decimal.Decimal64(-2), 0}, + } { + t.Run("cmp", func(t *testing.T) { + n := tc.n.Cmp(tc.rhs) + if got, want := n, tc.want; got != want { + t.Fatalf("invalid value. got=%v, want=%v", got, want) + } + }) + } +} + +func TestDecimalRescale(t *testing.T) { + tests := []struct { + orig, exp int32 + oldScale, newScale int32 + }{ + {111, 11100, 0, 2}, + {11100, 111, 2, 0}, + {500000, 5, 6, 1}, + {5, 500000, 1, 6}, + {-111, -11100, 0, 2}, + {-11100, -111, 2, 0}, + {555, 555, 2, 2}, + } + + for _, tt := range tests { + t.Run("decimal32", func(t *testing.T) { + out, err := decimal.Decimal32(tt.orig).Rescale(tt.oldScale, tt.newScale) + require.NoError(t, err) + assert.Equal(t, decimal.Decimal32(tt.exp), out) + }) + t.Run("decimal64", func(t *testing.T) { + out, err := decimal.Decimal64(tt.orig).Rescale(tt.oldScale, tt.newScale) + require.NoError(t, err) + assert.Equal(t, decimal.Decimal64(tt.exp), out) + }) + } + + _, err := decimal.Decimal32(555555).Rescale(6, 1) + assert.Error(t, err) + _, err = decimal.Decimal64(555555).Rescale(6, 1) + assert.Error(t, err) + + _, err = decimal.Decimal32(555555).Rescale(0, 5) + assert.ErrorContains(t, err, "rescale data loss") + _, err = decimal.Decimal64(555555).Rescale(0, 5) + assert.ErrorContains(t, err, "rescale data loss") +} + +func TestDecimalIncreaseScale(t *testing.T) { + assert.Equal(t, decimal.Decimal32(1234), decimal.Decimal32(1234).IncreaseScaleBy(0)) + assert.Equal(t, decimal.Decimal32(1234000), decimal.Decimal32(1234).IncreaseScaleBy(3)) + assert.Equal(t, decimal.Decimal32(-1234000), decimal.Decimal32(-1234).IncreaseScaleBy(3)) + + assert.Equal(t, decimal.Decimal64(1234), decimal.Decimal64(1234).IncreaseScaleBy(0)) + assert.Equal(t, decimal.Decimal64(1234000), decimal.Decimal64(1234).IncreaseScaleBy(3)) + assert.Equal(t, decimal.Decimal64(-1234000), decimal.Decimal64(-1234).IncreaseScaleBy(3)) +} + +func TestDecimalReduceScale(t *testing.T) { + tests := []struct { + value int32 + scale int32 + round bool + expected int32 + }{ + {123456, 0, false, 123456}, + {123456, 1, false, 12345}, + {123456, 1, true, 12346}, + {123451, 1, true, 12345}, + {123789, 2, true, 1238}, + {123749, 2, true, 1237}, + {123750, 2, true, 1238}, + {5, 1, true, 1}, + {0, 1, true, 0}, + } + + for _, tt := range tests { + assert.Equal(t, decimal.Decimal32(tt.expected), + decimal.Decimal32(tt.value).ReduceScaleBy(tt.scale, tt.round), "decimal32") + assert.Equal(t, decimal.Decimal32(tt.expected).Negate(), + decimal.Decimal32(tt.value).Negate().ReduceScaleBy(tt.scale, tt.round), "decimal32") + assert.Equal(t, decimal.Decimal64(tt.expected), + decimal.Decimal64(tt.value).ReduceScaleBy(tt.scale, tt.round), "decimal64") + assert.Equal(t, decimal.Decimal64(tt.expected).Negate(), + decimal.Decimal64(tt.value).Negate().ReduceScaleBy(tt.scale, tt.round), "decimal64") + } +} + +func TestDecimalBasics(t *testing.T) { + tests := []struct { + lhs, rhs int32 + }{ + {100, 3}, + {200, 3}, + {20100, 301}, + {-20100, 301}, + {20100, -301}, + {-20100, -301}, + } + + for _, tt := range tests { + assert.EqualValues(t, tt.lhs+tt.rhs, + decimal.Decimal32(tt.lhs).Add(decimal.Decimal32(tt.rhs))) + assert.EqualValues(t, tt.lhs+tt.rhs, + decimal.Decimal64(tt.lhs).Add(decimal.Decimal64(tt.rhs))) + + assert.EqualValues(t, tt.lhs-tt.rhs, + decimal.Decimal32(tt.lhs).Sub(decimal.Decimal32(tt.rhs))) + assert.EqualValues(t, tt.lhs-tt.rhs, + decimal.Decimal64(tt.lhs).Sub(decimal.Decimal64(tt.rhs))) + + assert.EqualValues(t, tt.lhs*tt.rhs, + decimal.Decimal32(tt.lhs).Mul(decimal.Decimal32(tt.rhs))) + assert.EqualValues(t, tt.lhs*tt.rhs, + decimal.Decimal64(tt.lhs).Mul(decimal.Decimal64(tt.rhs))) + + expdiv, expmod := tt.lhs/tt.rhs, tt.lhs%tt.rhs + div, mod := decimal.Decimal32(tt.lhs).Div(decimal.Decimal32(tt.rhs)) + assert.EqualValues(t, expdiv, div) + assert.EqualValues(t, expmod, mod) + + div64, mod64 := decimal.Decimal64(tt.lhs).Div(decimal.Decimal64(tt.rhs)) + assert.EqualValues(t, expdiv, div64) + assert.EqualValues(t, expmod, mod64) + } +} diff --git a/arrow/decimal/traits.go b/arrow/decimal/traits.go new file mode 100644 index 00000000..0ec0c315 --- /dev/null +++ b/arrow/decimal/traits.go @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package decimal + +// Traits is a convenience for building generic objects for operating on +// Decimal values to get around the limitations of Go generics. By providing this +// interface a generic object can handle producing the proper types to generate +// new decimal values. +type Traits[T DecimalTypes] interface { + BytesRequired(int) int + FromString(string, int32, int32) (T, error) + FromFloat64(float64, int32, int32) (T, error) +} + +var ( + Dec32Traits dec32Traits + Dec64Traits dec64Traits + Dec128Traits dec128Traits + Dec256Traits dec256Traits +) + +type ( + dec32Traits struct{} + dec64Traits struct{} + dec128Traits struct{} + dec256Traits struct{} +) + +func (dec32Traits) BytesRequired(n int) int { return 4 * n } +func (dec64Traits) BytesRequired(n int) int { return 8 * n } +func (dec128Traits) BytesRequired(n int) int { return 16 * n } +func (dec256Traits) BytesRequired(n int) int { return 32 * n } + +func (dec32Traits) FromString(v string, prec, scale int32) (Decimal32, error) { + return Decimal32FromString(v, prec, scale) +} + +func (dec64Traits) FromString(v string, prec, scale int32) (Decimal64, error) { + return Decimal64FromString(v, prec, scale) +} + +func (dec128Traits) FromString(v string, prec, scale int32) (Decimal128, error) { + return Decimal128FromString(v, prec, scale) +} + +func (dec256Traits) FromString(v string, prec, scale int32) (Decimal256, error) { + return Decimal256FromString(v, prec, scale) +} + +func (dec32Traits) FromFloat64(v float64, prec, scale int32) (Decimal32, error) { + return Decimal32FromFloat(v, prec, scale) +} + +func (dec64Traits) FromFloat64(v float64, prec, scale int32) (Decimal64, error) { + return Decimal64FromFloat(v, prec, scale) +} + +func (dec128Traits) FromFloat64(v float64, prec, scale int32) (Decimal128, error) { + return Decimal128FromFloat(v, prec, scale) +} + +func (dec256Traits) FromFloat64(v float64, prec, scale int32) (Decimal256, error) { + return Decimal256FromFloat(v, prec, scale) +} diff --git a/arrow/decimal128/decimal128.go b/arrow/decimal128/decimal128.go index 2e451c1c..660c4131 100644 --- a/arrow/decimal128/decimal128.go +++ b/arrow/decimal128/decimal128.go @@ -327,6 +327,16 @@ func (n Num) ToFloat64(scale int32) float64 { return n.tofloat64Positive(scale) } +func (n Num) ToBigFloat(scale int32) *big.Float { + f := (&big.Float{}).SetInt(n.BigInt()) + if scale < 0 { + f.SetPrec(128).Mul(f, (&big.Float{}).SetInt(scaleMultipliers[-scale].BigInt())) + } else { + f.SetPrec(128).Quo(f, (&big.Float{}).SetInt(scaleMultipliers[scale].BigInt())) + } + return f +} + // LowBits returns the low bits of the two's complement representation of the number. func (n Num) LowBits() uint64 { return n.lo } diff --git a/arrow/decimal256/decimal256.go b/arrow/decimal256/decimal256.go index 76b61853..82c52a65 100644 --- a/arrow/decimal256/decimal256.go +++ b/arrow/decimal256/decimal256.go @@ -339,6 +339,16 @@ func (n Num) ToFloat64(scale int32) float64 { return n.tofloat64Positive(scale) } +func (n Num) ToBigFloat(scale int32) *big.Float { + f := (&big.Float{}).SetInt(n.BigInt()) + if scale < 0 { + f.SetPrec(256).Mul(f, (&big.Float{}).SetInt(scaleMultipliers[-scale].BigInt())) + } else { + f.SetPrec(256).Quo(f, (&big.Float{}).SetInt(scaleMultipliers[scale].BigInt())) + } + return f +} + func (n Num) Sign() int { if n == (Num{}) { return 0 diff --git a/arrow/internal/arrjson/arrjson.go b/arrow/internal/arrjson/arrjson.go index 452809e2..2181ebdf 100644 --- a/arrow/internal/arrjson/arrjson.go +++ b/arrow/internal/arrjson/arrjson.go @@ -29,6 +29,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/float16" @@ -224,6 +225,10 @@ func typeToJSON(arrowType arrow.DataType) (json.RawMessage, error) { typ = listSizeJSON{"fixedsizelist", dt.Len()} case *arrow.FixedSizeBinaryType: typ = byteWidthJSON{"fixedsizebinary", dt.ByteWidth} + case *arrow.Decimal32Type: + typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision), 32} + case *arrow.Decimal64Type: + typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision), 64} case *arrow.Decimal128Type: typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision), 128} case *arrow.Decimal256Type: @@ -491,6 +496,10 @@ func typeFromJSON(typ json.RawMessage, children []FieldWrapper) (arrowType arrow arrowType = &arrow.Decimal256Type{Precision: int32(t.Precision), Scale: int32(t.Scale)} case 128, 0: // default to 128 bits when missing arrowType = &arrow.Decimal128Type{Precision: int32(t.Precision), Scale: int32(t.Scale)} + case 64: + arrowType = &arrow.Decimal64Type{Precision: int32(t.Precision), Scale: int32(t.Scale)} + case 32: + arrowType = &arrow.Decimal32Type{Precision: int32(t.Precision), Scale: int32(t.Scale)} } case "union": t := unionJSON{} @@ -1295,6 +1304,22 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr bldr.AppendValues(data, valids) return returnNewArrayData(bldr) + case *arrow.Decimal32Type: + bldr := array.NewDecimal32Builder(mem, dt) + defer bldr.Release() + data := decimal32FromJSON(arr.Data) + valids := validsFromJSON(arr.Valids) + bldr.AppendValues(data, valids) + return returnNewArrayData(bldr) + + case *arrow.Decimal64Type: + bldr := array.NewDecimal64Builder(mem, dt) + defer bldr.Release() + data := decimal64FromJSON(arr.Data) + valids := validsFromJSON(arr.Valids) + bldr.AppendValues(data, valids) + return returnNewArrayData(bldr) + case *arrow.Decimal128Type: bldr := array.NewDecimal128Builder(mem, dt) defer bldr.Release() @@ -1713,6 +1738,22 @@ func arrayToJSON(field arrow.Field, arr arrow.Array) Array { Valids: validsToJSON(arr), } + case *array.Decimal32: + return Array{ + Name: field.Name, + Count: arr.Len(), + Data: decimal32ToJSON(arr), + Valids: validsToJSON(arr), + } + + case *array.Decimal64: + return Array{ + Name: field.Name, + Count: arr.Len(), + Data: decimal64ToJSON(arr), + Valids: validsToJSON(arr), + } + case *array.Decimal128: return Array{ Name: field.Name, @@ -2038,6 +2079,47 @@ func f64ToJSON(arr *array.Float64) []interface{} { return o } +func decimal32ToJSON(arr *array.Decimal32) []interface{} { + o := make([]interface{}, arr.Len()) + for i := range o { + o[i] = arr.ValueStr(i) + } + return o +} + +func decimal32FromJSON(vs []interface{}) []decimal.Decimal32 { + var tmp big.Int + o := make([]decimal.Decimal32, len(vs)) + for i, v := range vs { + if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil { + panic(fmt.Errorf("could not convert %v (%T) to decimal32: %w", v, v, err)) + } + + o[i] = decimal.Decimal32(tmp.Int64()) + } + return o +} + +func decimal64ToJSON(arr *array.Decimal64) []interface{} { + o := make([]interface{}, arr.Len()) + for i := range o { + o[i] = arr.ValueStr(i) + } + return o +} + +func decimal64FromJSON(vs []interface{}) []decimal.Decimal64 { + var tmp big.Int + o := make([]decimal.Decimal64, len(vs)) + for i, v := range vs { + if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil { + panic(fmt.Errorf("could not convert %v (%T) to decimal64: %w", v, v, err)) + } + + o[i] = decimal.Decimal64(tmp.Int64()) + } + return o +} func decimal128ToJSON(arr *array.Decimal128) []interface{} { o := make([]interface{}, arr.Len()) for i := range o { @@ -2072,7 +2154,7 @@ func decimal256FromJSON(vs []interface{}) []decimal256.Num { o := make([]decimal256.Num, len(vs)) for i, v := range vs { if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil { - panic(fmt.Errorf("could not convert %v (%T) to decimal128: %w", v, v, err)) + panic(fmt.Errorf("could not convert %v (%T) to decimal256: %w", v, v, err)) } o[i] = decimal256.FromBigInt(&tmp) diff --git a/arrow/internal/flatbuf/Decimal.go b/arrow/internal/flatbuf/Decimal.go index 2fc9d5ad..234c3964 100644 --- a/arrow/internal/flatbuf/Decimal.go +++ b/arrow/internal/flatbuf/Decimal.go @@ -22,10 +22,10 @@ import ( flatbuffers "github.com/google/flatbuffers/go" ) -// / Exact decimal value represented as an integer value in two's -// / complement. Currently only 128-bit (16-byte) and 256-bit (32-byte) integers -// / are used. The representation uses the endianness indicated -// / in the Schema. +/// Exact decimal value represented as an integer value in two's +/// complement. Currently 32-bit (4-byte), 64-bit (8-byte), +/// 128-bit (16-byte) and 256-bit (32-byte) integers are used. +/// The representation uses the endianness indicated in the Schema. type Decimal struct { _tab flatbuffers.Table } @@ -46,7 +46,7 @@ func (rcv *Decimal) Table() flatbuffers.Table { return rcv._tab } -// / Total number of decimal digits +/// Total number of decimal digits func (rcv *Decimal) Precision() int32 { o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) if o != 0 { @@ -55,12 +55,12 @@ func (rcv *Decimal) Precision() int32 { return 0 } -// / Total number of decimal digits +/// Total number of decimal digits func (rcv *Decimal) MutatePrecision(n int32) bool { return rcv._tab.MutateInt32Slot(4, n) } -// / Number of digits after the decimal point "." +/// Number of digits after the decimal point "." func (rcv *Decimal) Scale() int32 { o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) if o != 0 { @@ -69,13 +69,13 @@ func (rcv *Decimal) Scale() int32 { return 0 } -// / Number of digits after the decimal point "." +/// Number of digits after the decimal point "." func (rcv *Decimal) MutateScale(n int32) bool { return rcv._tab.MutateInt32Slot(6, n) } -// / Number of bits per value. The only accepted widths are 128 and 256. -// / We use bitWidth for consistency with Int::bitWidth. +/// Number of bits per value. The accepted widths are 32, 64, 128 and 256. +/// We use bitWidth for consistency with Int::bitWidth. func (rcv *Decimal) BitWidth() int32 { o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) if o != 0 { @@ -84,8 +84,8 @@ func (rcv *Decimal) BitWidth() int32 { return 128 } -// / Number of bits per value. The only accepted widths are 128 and 256. -// / We use bitWidth for consistency with Int::bitWidth. +/// Number of bits per value. The accepted widths are 32, 64, 128 and 256. +/// We use bitWidth for consistency with Int::bitWidth. func (rcv *Decimal) MutateBitWidth(n int32) bool { return rcv._tab.MutateInt32Slot(8, n) } diff --git a/arrow/ipc/file_reader.go b/arrow/ipc/file_reader.go index d027db5e..2715831e 100644 --- a/arrow/ipc/file_reader.go +++ b/arrow/ipc/file_reader.go @@ -476,7 +476,7 @@ func (ctx *arrayLoaderContext) loadArray(dt arrow.DataType) arrow.ArrayData { *arrow.Int8Type, *arrow.Int16Type, *arrow.Int32Type, *arrow.Int64Type, *arrow.Uint8Type, *arrow.Uint16Type, *arrow.Uint32Type, *arrow.Uint64Type, *arrow.Float16Type, *arrow.Float32Type, *arrow.Float64Type, - *arrow.Decimal128Type, *arrow.Decimal256Type, + arrow.DecimalType, *arrow.Time32Type, *arrow.Time64Type, *arrow.TimestampType, *arrow.Date32Type, *arrow.Date64Type, diff --git a/arrow/ipc/metadata.go b/arrow/ipc/metadata.go index 228f271b..a5bf1877 100644 --- a/arrow/ipc/metadata.go +++ b/arrow/ipc/metadata.go @@ -281,20 +281,12 @@ func (fv *fieldVisitor) visit(field arrow.Field) { fv.dtype = flatbuf.TypeFloatingPoint fv.offset = floatToFB(fv.b, int32(dt.BitWidth())) - case *arrow.Decimal128Type: + case arrow.DecimalType: fv.dtype = flatbuf.TypeDecimal flatbuf.DecimalStart(fv.b) - flatbuf.DecimalAddPrecision(fv.b, dt.Precision) - flatbuf.DecimalAddScale(fv.b, dt.Scale) - flatbuf.DecimalAddBitWidth(fv.b, 128) - fv.offset = flatbuf.DecimalEnd(fv.b) - - case *arrow.Decimal256Type: - fv.dtype = flatbuf.TypeDecimal - flatbuf.DecimalStart(fv.b) - flatbuf.DecimalAddPrecision(fv.b, dt.Precision) - flatbuf.DecimalAddScale(fv.b, dt.Scale) - flatbuf.DecimalAddBitWidth(fv.b, 256) + flatbuf.DecimalAddPrecision(fv.b, dt.GetPrecision()) + flatbuf.DecimalAddScale(fv.b, dt.GetScale()) + flatbuf.DecimalAddBitWidth(fv.b, int32(dt.BitWidth())) fv.offset = flatbuf.DecimalEnd(fv.b) case *arrow.FixedSizeBinaryType: @@ -947,6 +939,10 @@ func floatToFB(b *flatbuffers.Builder, bw int32) flatbuffers.UOffsetT { func decimalFromFB(data flatbuf.Decimal) (arrow.DataType, error) { switch data.BitWidth() { + case 32: + return &arrow.Decimal32Type{Precision: data.Precision(), Scale: data.Scale()}, nil + case 64: + return &arrow.Decimal64Type{Precision: data.Precision(), Scale: data.Scale()}, nil case 128: return &arrow.Decimal128Type{Precision: data.Precision(), Scale: data.Scale()}, nil case 256: diff --git a/arrow/type_string.go b/arrow/type_string.go index ee3ccb7e..6e5a943d 100644 --- a/arrow/type_string.go +++ b/arrow/type_string.go @@ -51,11 +51,13 @@ func _() { _ = x[BINARY_VIEW-40] _ = x[LIST_VIEW-41] _ = x[LARGE_LIST_VIEW-42] + _ = x[DECIMAL32-43] + _ = x[DECIMAL64-44] } -const _Type_name = "NULLBOOLUINT8INT8UINT16INT16UINT32INT32UINT64INT64FLOAT16FLOAT32FLOAT64STRINGBINARYFIXED_SIZE_BINARYDATE32DATE64TIMESTAMPTIME32TIME64INTERVAL_MONTHSINTERVAL_DAY_TIMEDECIMAL128DECIMAL256LISTSTRUCTSPARSE_UNIONDENSE_UNIONDICTIONARYMAPEXTENSIONFIXED_SIZE_LISTDURATIONLARGE_STRINGLARGE_BINARYLARGE_LISTINTERVAL_MONTH_DAY_NANORUN_END_ENCODEDSTRING_VIEWBINARY_VIEWLIST_VIEWLARGE_LIST_VIEW" +const _Type_name = "NULLBOOLUINT8INT8UINT16INT16UINT32INT32UINT64INT64FLOAT16FLOAT32FLOAT64STRINGBINARYFIXED_SIZE_BINARYDATE32DATE64TIMESTAMPTIME32TIME64INTERVAL_MONTHSINTERVAL_DAY_TIMEDECIMAL128DECIMAL256LISTSTRUCTSPARSE_UNIONDENSE_UNIONDICTIONARYMAPEXTENSIONFIXED_SIZE_LISTDURATIONLARGE_STRINGLARGE_BINARYLARGE_LISTINTERVAL_MONTH_DAY_NANORUN_END_ENCODEDSTRING_VIEWBINARY_VIEWLIST_VIEWLARGE_LIST_VIEWDECIMAL32DECIMAL64" -var _Type_index = [...]uint16{0, 4, 8, 13, 17, 23, 28, 34, 39, 45, 50, 57, 64, 71, 77, 83, 100, 106, 112, 121, 127, 133, 148, 165, 175, 185, 189, 195, 207, 218, 228, 231, 240, 255, 263, 275, 287, 297, 320, 335, 346, 357, 366, 381} +var _Type_index = [...]uint16{0, 4, 8, 13, 17, 23, 28, 34, 39, 45, 50, 57, 64, 71, 77, 83, 100, 106, 112, 121, 127, 133, 148, 165, 175, 185, 189, 195, 207, 218, 228, 231, 240, 255, 263, 275, 287, 297, 320, 335, 346, 357, 366, 381, 390, 399} func (i Type) String() string { if i < 0 || i >= Type(len(_Type_index)-1) { diff --git a/arrow/type_traits.go b/arrow/type_traits.go index 87e2f065..7185ef25 100644 --- a/arrow/type_traits.go +++ b/arrow/type_traits.go @@ -20,8 +20,7 @@ import ( "reflect" "unsafe" - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/float16" "golang.org/x/exp/constraints" ) @@ -68,7 +67,7 @@ type NumericType interface { // as a bitmap and thus the buffer can't be just reinterpreted as a []bool type FixedWidthType interface { IntType | UintType | - FloatType | decimal128.Num | decimal256.Num | + FloatType | decimal.DecimalTypes | DayTimeInterval | MonthDayNanoInterval } diff --git a/arrow/type_traits_decimal128.go b/arrow/type_traits_decimal128.go index 860a7f12..6e416cd6 100644 --- a/arrow/type_traits_decimal128.go +++ b/arrow/type_traits_decimal128.go @@ -19,7 +19,7 @@ package arrow import ( "unsafe" - "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/endian" ) @@ -28,7 +28,7 @@ var Decimal128Traits decimal128Traits const ( // Decimal128SizeBytes specifies the number of bytes required to store a single decimal128 in memory - Decimal128SizeBytes = int(unsafe.Sizeof(decimal128.Num{})) + Decimal128SizeBytes = int(unsafe.Sizeof(decimal.Decimal128{})) ) type decimal128Traits struct{} @@ -37,7 +37,7 @@ type decimal128Traits struct{} func (decimal128Traits) BytesRequired(n int) int { return Decimal128SizeBytes * n } // PutValue -func (decimal128Traits) PutValue(b []byte, v decimal128.Num) { +func (decimal128Traits) PutValue(b []byte, v decimal.Decimal128) { endian.Native.PutUint64(b[:8], uint64(v.LowBits())) endian.Native.PutUint64(b[8:], uint64(v.HighBits())) } @@ -45,14 +45,14 @@ func (decimal128Traits) PutValue(b []byte, v decimal128.Num) { // CastFromBytes reinterprets the slice b to a slice of type uint16. // // NOTE: len(b) must be a multiple of Uint16SizeBytes. -func (decimal128Traits) CastFromBytes(b []byte) []decimal128.Num { - return GetData[decimal128.Num](b) +func (decimal128Traits) CastFromBytes(b []byte) []decimal.Decimal128 { + return GetData[decimal.Decimal128](b) } // CastToBytes reinterprets the slice b to a slice of bytes. -func (decimal128Traits) CastToBytes(b []decimal128.Num) []byte { +func (decimal128Traits) CastToBytes(b []decimal.Decimal128) []byte { return GetBytes(b) } // Copy copies src to dst. -func (decimal128Traits) Copy(dst, src []decimal128.Num) { copy(dst, src) } +func (decimal128Traits) Copy(dst, src []decimal.Decimal128) { copy(dst, src) } diff --git a/arrow/type_traits_decimal256.go b/arrow/type_traits_decimal256.go index f86bd2a6..b196c2e7 100644 --- a/arrow/type_traits_decimal256.go +++ b/arrow/type_traits_decimal256.go @@ -19,7 +19,7 @@ package arrow import ( "unsafe" - "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/endian" ) @@ -27,14 +27,14 @@ import ( var Decimal256Traits decimal256Traits const ( - Decimal256SizeBytes = int(unsafe.Sizeof(decimal256.Num{})) + Decimal256SizeBytes = int(unsafe.Sizeof(decimal.Decimal256{})) ) type decimal256Traits struct{} func (decimal256Traits) BytesRequired(n int) int { return Decimal256SizeBytes * n } -func (decimal256Traits) PutValue(b []byte, v decimal256.Num) { +func (decimal256Traits) PutValue(b []byte, v decimal.Decimal256) { for i, a := range v.Array() { start := i * 8 endian.Native.PutUint64(b[start:], a) @@ -42,12 +42,12 @@ func (decimal256Traits) PutValue(b []byte, v decimal256.Num) { } // CastFromBytes reinterprets the slice b to a slice of decimal256 -func (decimal256Traits) CastFromBytes(b []byte) []decimal256.Num { - return GetData[decimal256.Num](b) +func (decimal256Traits) CastFromBytes(b []byte) []decimal.Decimal256 { + return GetData[decimal.Decimal256](b) } -func (decimal256Traits) CastToBytes(b []decimal256.Num) []byte { +func (decimal256Traits) CastToBytes(b []decimal.Decimal256) []byte { return GetBytes(b) } -func (decimal256Traits) Copy(dst, src []decimal256.Num) { copy(dst, src) } +func (decimal256Traits) Copy(dst, src []decimal.Decimal256) { copy(dst, src) } diff --git a/arrow/type_traits_decimal32.go b/arrow/type_traits_decimal32.go new file mode 100644 index 00000000..ebca65f6 --- /dev/null +++ b/arrow/type_traits_decimal32.go @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arrow + +import ( + "unsafe" + + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/endian" +) + +// Decimal32 traits +var Decimal32Traits decimal32Traits + +const ( + // Decimal32SizeBytes specifies the number of bytes required to store a single decimal32 in memory + Decimal32SizeBytes = int(unsafe.Sizeof(decimal.Decimal32(0))) +) + +type decimal32Traits struct{} + +// BytesRequired returns the number of bytes required to store n elements in memory. +func (decimal32Traits) BytesRequired(n int) int { return Decimal32SizeBytes * n } + +// PutValue +func (decimal32Traits) PutValue(b []byte, v decimal.Decimal32) { + endian.Native.PutUint32(b[:4], uint32(v)) +} + +// CastFromBytes reinterprets the slice b to a slice of type uint16. +// +// NOTE: len(b) must be a multiple of Uint16SizeBytes. +func (decimal32Traits) CastFromBytes(b []byte) []decimal.Decimal32 { + return GetData[decimal.Decimal32](b) +} + +// CastToBytes reinterprets the slice b to a slice of bytes. +func (decimal32Traits) CastToBytes(b []decimal.Decimal32) []byte { + return GetBytes(b) +} + +// Copy copies src to dst. +func (decimal32Traits) Copy(dst, src []decimal.Decimal32) { copy(dst, src) } diff --git a/arrow/type_traits_decimal64.go b/arrow/type_traits_decimal64.go new file mode 100644 index 00000000..bd07883a --- /dev/null +++ b/arrow/type_traits_decimal64.go @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arrow + +import ( + "unsafe" + + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/endian" +) + +// Decimal64 traits +var Decimal64Traits decimal64Traits + +const ( + // Decimal64SizeBytes specifies the number of bytes required to store a single decimal64 in memory + Decimal64SizeBytes = int(unsafe.Sizeof(decimal.Decimal64(0))) +) + +type decimal64Traits struct{} + +// BytesRequired returns the number of bytes required to store n elements in memory. +func (decimal64Traits) BytesRequired(n int) int { return Decimal64SizeBytes * n } + +// PutValue +func (decimal64Traits) PutValue(b []byte, v decimal.Decimal64) { + endian.Native.PutUint64(b[:8], uint64(v)) +} + +// CastFromBytes reinterprets the slice b to a slice of type uint16. +// +// NOTE: len(b) must be a multiple of Uint16SizeBytes. +func (decimal64Traits) CastFromBytes(b []byte) []decimal.Decimal64 { + return GetData[decimal.Decimal64](b) +} + +// CastToBytes reinterprets the slice b to a slice of bytes. +func (decimal64Traits) CastToBytes(b []decimal.Decimal64) []byte { + return GetBytes(b) +} + +// Copy copies src to dst. +func (decimal64Traits) Copy(dst, src []decimal.Decimal64) { copy(dst, src) } diff --git a/arrow/type_traits_test.go b/arrow/type_traits_test.go index d86a67b1..93d98b95 100644 --- a/arrow/type_traits_test.go +++ b/arrow/type_traits_test.go @@ -23,8 +23,10 @@ import ( "testing" "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/float16" ) @@ -90,6 +92,94 @@ func TestFloat16Traits(t *testing.T) { } } +func TestDecimal32Traits(t *testing.T) { + const N = 10 + nbytes := arrow.Decimal32Traits.BytesRequired(N) + b1 := arrow.Decimal32Traits.CastToBytes([]decimal.Decimal32{ + decimal.Decimal32(0), + decimal.Decimal32(1), + decimal.Decimal32(2), + decimal.Decimal32(3), + decimal.Decimal32(4), + decimal.Decimal32(5), + decimal.Decimal32(6), + decimal.Decimal32(7), + decimal.Decimal32(8), + decimal.Decimal32(9), + }) + + b2 := make([]byte, nbytes) + for i := 0; i < N; i++ { + beg := i * arrow.Decimal32SizeBytes + end := (i + 1) * arrow.Decimal32SizeBytes + arrow.Decimal32Traits.PutValue(b2[beg:end], decimal.Decimal32(i)) + } + + if !reflect.DeepEqual(b1, b2) { + v1 := arrow.Decimal32Traits.CastFromBytes(b1) + v2 := arrow.Decimal32Traits.CastFromBytes(b2) + t.Fatalf("invalid values:\nb1=%v\nb2=%v\nv1=%v\nv2=%v\n", b1, b2, v1, v2) + } + + v1 := arrow.Decimal32Traits.CastFromBytes(b1) + for i, v := range v1 { + if got, want := v, decimal.Decimal32(i); got != want { + t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) + } + } + + v2 := make([]decimal.Decimal32, N) + arrow.Decimal32Traits.Copy(v2, v1) + + if !reflect.DeepEqual(v1, v2) { + t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) + } +} + +func TestDecimal64Traits(t *testing.T) { + const N = 10 + nbytes := arrow.Decimal64Traits.BytesRequired(N) + b1 := arrow.Decimal64Traits.CastToBytes([]decimal.Decimal64{ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + }) + + b2 := make([]byte, nbytes) + for i := 0; i < N; i++ { + beg := i * arrow.Decimal64SizeBytes + end := (i + 1) * arrow.Decimal64SizeBytes + arrow.Decimal64Traits.PutValue(b2[beg:end], decimal.Decimal64(i)) + } + + if !reflect.DeepEqual(b1, b2) { + v1 := arrow.Decimal64Traits.CastFromBytes(b1) + v2 := arrow.Decimal64Traits.CastFromBytes(b2) + t.Fatalf("invalid values:\nb1=%v\nb2=%v\nv1=%v\nv2=%v\n", b1, b2, v1, v2) + } + + v1 := arrow.Decimal64Traits.CastFromBytes(b1) + for i, v := range v1 { + if got, want := v, decimal.Decimal64(i); got != want { + t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) + } + } + + v2 := make([]decimal.Decimal64, N) + arrow.Decimal64Traits.Copy(v2, v1) + + if !reflect.DeepEqual(v1, v2) { + t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) + } +} + func TestDecimal128Traits(t *testing.T) { const N = 10 nbytes := arrow.Decimal128Traits.BytesRequired(N)