diff --git a/abi.go b/abi.go new file mode 100644 index 00000000..12e4b52b --- /dev/null +++ b/abi.go @@ -0,0 +1,433 @@ +/*- + * + * Hedera Go SDK + * + * Copyright (C) 2020 - 2024 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use q file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package hedera + +import ( + "bytes" + "encoding/json" + "fmt" + "hash" + "io" + "regexp" + "strings" + "sync" + + "golang.org/x/crypto/sha3" +) + +// ABI represents the ethereum abi format +type ABI struct { + Constructor *Method + Methods map[string]*Method + MethodsBySignature map[string]*Method + Events map[string]*Event + Errors map[string]*Error +} + +func (a *ABI) GetMethod(name string) *Method { + m := a.Methods[name] + return m +} + +func (a *ABI) GetMethodBySignature(methodSignature string) *Method { + m := a.MethodsBySignature[methodSignature] + return m +} + +func (a *ABI) addError(e *Error) { + if len(a.Errors) == 0 { + a.Errors = map[string]*Error{} + } + a.Errors[e.Name] = e +} + +func (a *ABI) addEvent(e *Event) { + if len(a.Events) == 0 { + a.Events = map[string]*Event{} + } + name := overloadedName(e.Name, func(s string) bool { + _, ok := a.Events[s] + return ok + }) + a.Events[name] = e +} + +func (a *ABI) addMethod(m *Method) { + if len(a.Methods) == 0 { + a.Methods = map[string]*Method{} + } + if len(a.MethodsBySignature) == 0 { + a.MethodsBySignature = map[string]*Method{} + } + name := overloadedName(m.Name, func(s string) bool { + _, ok := a.Methods[s] + return ok + }) + a.Methods[name] = m + a.MethodsBySignature[m.Sig()] = m +} + +func overloadedName(rawName string, isAvail func(string) bool) string { + name := rawName + ok := isAvail(name) + for idx := 0; ok; idx++ { + name = fmt.Sprintf("%s%d", rawName, idx) + ok = isAvail(name) + } + return name +} + +// NewABI returns a parsed ABI struct +func NewABI(s string) (*ABI, error) { + return NewABIFromReader(bytes.NewReader([]byte(s))) +} + +// NewABIFromReader returns an ABI object from a reader +func NewABIFromReader(r io.Reader) (*ABI, error) { + var abi *ABI + dec := json.NewDecoder(r) + if err := dec.Decode(&abi); err != nil { + return nil, err + } + return abi, nil +} + +// UnmarshalJSON implements json.Unmarshaler interface +// nolint +func (a *ABI) UnmarshalJSON(data []byte) error { + var fields []struct { + Type string + Name string + Anonymous bool + StateMutability string + Inputs []*ArgumentStr + Outputs []*ArgumentStr + } + + if err := json.Unmarshal(data, &fields); err != nil { + return err + } + + for _, field := range fields { + switch field.Type { + case "constructor": + if a.Constructor != nil { + return fmt.Errorf("multiple constructor declaration") + } + input, err := NewTupleTypeFromArgs(field.Inputs) + if err != nil { + return err + } + a.Constructor = &Method{ + Inputs: input, + } + + case "function", "": + c := field.StateMutability == "view" || field.StateMutability == "pure" + + inputs, err := NewTupleTypeFromArgs(field.Inputs) + if err != nil { + return err + } + outputs, err := NewTupleTypeFromArgs(field.Outputs) + if err != nil { + return err + } + method := &Method{ + Name: field.Name, + Const: c, + Inputs: inputs, + Outputs: outputs, + } + a.addMethod(method) + + case "event": + input, err := NewTupleTypeFromArgs(field.Inputs) + if err != nil { + return err + } + event := &Event{ + Name: field.Name, + Anonymous: field.Anonymous, + Inputs: input, + } + a.addEvent(event) + + case "error": + input, err := NewTupleTypeFromArgs(field.Inputs) + if err != nil { + return err + } + errObj := &Error{ + Name: field.Name, + Inputs: input, + } + a.addError(errObj) + + case "fallback": + case "receive": + // do nothing + + default: + return fmt.Errorf("unknown field type '%s'", field.Type) + } + } + return nil +} + +// nolint +func NewABIFromList(humanReadableAbi []string) (*ABI, error) { + res := &ABI{} + for _, c := range humanReadableAbi { + if strings.HasPrefix(c, "constructor") { + typ, err := NewType("tuple" + strings.TrimPrefix(c, "constructor")) + if err != nil { + return nil, err + } + res.Constructor = &Method{ + Inputs: typ, + } + } else if strings.HasPrefix(c, "function ") { + method, err := NewMethod(c) + if err != nil { + return nil, err + } + res.addMethod(method) + } else if strings.HasPrefix(c, "event ") { + evnt, err := NewEvent(c) + if err != nil { + return nil, err + } + res.addEvent(evnt) + } else if strings.HasPrefix(c, "error ") { + errTyp, err := NewError(c) + if err != nil { + return nil, err + } + res.addError(errTyp) + } else { + return nil, fmt.Errorf("either event or function expected") + } + } + return res, nil +} + +// Method is a callable function in the contract +type Method struct { + Name string + Const bool + Inputs *Type + Outputs *Type +} + +// Sig returns the signature of the method +func (m *Method) Sig() string { + return buildSignature(m.Name, m.Inputs) +} + +// ID returns the id of the method +func (m *Method) ID() []byte { + k := acquireKeccak() + k.Write([]byte(m.Sig())) + dst := k.Sum(nil)[:4] + releaseKeccak(k) + return dst +} + +// Encode encodes the inputs with this function +func (m *Method) Encode(args interface{}) ([]byte, error) { + data, err := Encode(args, m.Inputs) + if err != nil { + return nil, err + } + data = append(m.ID(), data...) + return data, nil +} + +// Decode decodes the output with this function +func (m *Method) Decode(data []byte) (map[string]interface{}, error) { + if len(data) == 0 { + return nil, fmt.Errorf("empty response") + } + respInterface, err := Decode(m.Outputs, data) + if err != nil { + return nil, err + } + resp := respInterface.(map[string]interface{}) + return resp, nil +} + +func NewMethod(name string) (*Method, error) { + name, inputs, outputs, err := parseMethodSignature(name) + if err != nil { + return nil, err + } + m := &Method{Name: name, Inputs: inputs, Outputs: outputs} + return m, nil +} + +var ( + funcRegexpWithReturn = regexp.MustCompile(`(\w*)\s*\((.*)\)(.*)\s*returns\s*\((.*)\)`) + funcRegexpWithoutReturn = regexp.MustCompile(`(\w*)\s*\((.*)\)(.*)`) +) + +// Event is a triggered log mechanism +type Event struct { + Name string + Anonymous bool + Inputs *Type +} + +// NewEvent creates a new solidity event object using the signature +func NewEvent(name string) (*Event, error) { + name, typ, err := parseEventOrErrorSignature("event ", name) + if err != nil { + return nil, err + } + return NewEventFromType(name, typ), nil +} + +// NewEventFromType creates a new solidity event object using the name and type +func NewEventFromType(name string, typ *Type) *Event { + return &Event{Name: name, Inputs: typ} +} + +// Error is a solidity error object +type Error struct { + Name string + Inputs *Type +} + +// NewError creates a new solidity error object +func NewError(name string) (*Error, error) { + name, typ, err := parseEventOrErrorSignature("error ", name) + if err != nil { + return nil, err + } + return &Error{Name: name, Inputs: typ}, nil +} + +// ArgumentStr encodes a type object +type ArgumentStr struct { + Name string + Type string + Indexed bool + Components []*ArgumentStr + InternalType string +} + +var keccakPool = sync.Pool{ + New: func() interface{} { + return sha3.NewLegacyKeccak256() + }, +} + +func acquireKeccak() hash.Hash { + return keccakPool.Get().(hash.Hash) +} + +func releaseKeccak(k hash.Hash) { + k.Reset() + keccakPool.Put(k) +} + +type Log struct { + Removed bool + LogIndex uint64 + TransactionIndex uint64 + TransactionHash Hash + BlockHash Hash + BlockNumber uint64 + Address Address + Topics []Hash + Data []byte +} + +// nolint +func parseMethodSignature(name string) (string, *Type, *Type, error) { + name = strings.Replace(name, "\n", " ", -1) + name = strings.Replace(name, "\t", " ", -1) + + name = strings.TrimPrefix(name, "function ") + name = strings.TrimSpace(name) + + var funcName, inputArgs, outputArgs string + + if strings.Contains(name, "returns") { + matches := funcRegexpWithReturn.FindAllStringSubmatch(name, -1) + if len(matches) == 0 { + return "", nil, nil, fmt.Errorf("no matches found") + } + funcName = strings.TrimSpace(matches[0][1]) + inputArgs = strings.TrimSpace(matches[0][2]) + outputArgs = strings.TrimSpace(matches[0][4]) + } else { + matches := funcRegexpWithoutReturn.FindAllStringSubmatch(name, -1) + if len(matches) == 0 { + return "", nil, nil, fmt.Errorf("no matches found") + } + funcName = strings.TrimSpace(matches[0][1]) + inputArgs = strings.TrimSpace(matches[0][2]) + } + + input, err := NewType("tuple(" + inputArgs + ")") + if err != nil { + return "", nil, nil, err + } + output, err := NewType("tuple(" + outputArgs + ")") + if err != nil { + return "", nil, nil, err + } + return funcName, input, output, nil +} + +func buildSignature(name string, typ *Type) string { + types := make([]string, len(typ.tuple)) + for i, input := range typ.tuple { + // nolint + types[i] = strings.Replace(input.Elem.String(), "tuple", "", -1) + } + return fmt.Sprintf("%v(%v)", name, strings.Join(types, ",")) +} + +func parseEventOrErrorSignature(prefix string, name string) (string, *Type, error) { + if !strings.HasPrefix(name, prefix) { + return "", nil, fmt.Errorf("prefix '%s' not found", prefix) + } + name = strings.TrimPrefix(name, prefix) + + if !strings.HasSuffix(name, ")") { + return "", nil, fmt.Errorf("failed to parse input, expected 'name(types)'") + } + indx := strings.Index(name, "(") + if indx == -1 { + return "", nil, fmt.Errorf("failed to parse input, expected 'name(types)'") + } + + funcName, signature := name[:indx], name[indx:] + signature = "tuple" + signature + + typ, err := NewType(signature) + if err != nil { + return "", nil, err + } + return funcName, typ, nil +} diff --git a/abi_decode.go b/abi_decode.go new file mode 100644 index 00000000..0b66b82f --- /dev/null +++ b/abi_decode.go @@ -0,0 +1,422 @@ +/*- + * + * Hedera Go SDK + * + * Copyright (C) 2020 - 2024 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use q file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package hedera + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "math/big" + "reflect" + "strconv" + "strings" + + "github.com/mitchellh/mapstructure" +) + +// Decode decodes the input with a given type +func Decode(t *Type, input []byte) (interface{}, error) { + if len(input) == 0 { + return nil, fmt.Errorf("empty input") + } + val, _, err := decode(t, input) + return val, err +} + +// DecodeStruct decodes the input with a type to a struct +func DecodeStruct(t *Type, input []byte, out interface{}) error { + val, err := Decode(t, input) + if err != nil { + return err + } + + dc := &mapstructure.DecoderConfig{ + Result: out, + WeaklyTypedInput: true, + TagName: "abi", + } + ms, err := mapstructure.NewDecoder(dc) + if err != nil { + return err + } + if err = ms.Decode(val); err != nil { + return err + } + return nil +} + +func decode(t *Type, input []byte) (interface{}, []byte, error) { + var data []byte + var length int + var err error + + // safe check, input should be at least 32 bytes + if len(input) < 32 { + return nil, nil, fmt.Errorf("incorrect length") + } + + if t.isVariableInput() { + length, err = readLength(input) + if err != nil { + return nil, nil, err + } + } else { + data = input[:32] + } + + switch t.kind { + case KindTuple: + return decodeTuple(t, input) + + case KindSlice: + return decodeArraySlice(t, input[32:], length) + + case KindArray: + return decodeArraySlice(t, input, t.size) + } + + var val interface{} + switch t.kind { + case KindBool: + val, err = decodeBool(data) + + case KindInt, KindUInt: + val = readInteger(t, data) + + case KindString: + val = string(input[32 : 32+length]) + + case KindBytes: + val = input[32 : 32+length] + + case KindAddress: + val, err = readAddr(data) + + case KindFixedBytes: + val, err = readFixedBytes(t, data) + + case KindFunction: + val, err = readFunctionType(t, data) + + default: + return nil, nil, fmt.Errorf("decoding not available for type '%s'", t.kind) + } + + return val, input[32:], err +} + +var ( + maxUint256 = big.NewInt(0).Add( + big.NewInt(0).Exp(big.NewInt(2), big.NewInt(256), nil), + big.NewInt(-1)) + maxInt256 = big.NewInt(0).Add( + big.NewInt(0).Exp(big.NewInt(2), big.NewInt(255), nil), + big.NewInt(-1)) +) + +// Address is an Ethereum address +type Address [20]byte + +func min(i, j int) int { + if i < j { + return i + } + return j +} + +// BytesToAddress converts bytes to an address object +func BytesToAddress(b []byte) Address { + var a Address + + size := len(b) + min := min(size, 20) + + copy(a[20-min:], b[len(b)-min:]) + return a +} + +// Address implements the ethgo.Key interface Address method. +func (a Address) Address() Address { + return a +} + +// UnmarshalText implements the unmarshal interface +func (a *Address) UnmarshalText(b []byte) error { + return unmarshalTextByte(a[:], b, 20) +} + +// MarshalText implements the marshal interface +func (a Address) MarshalText() ([]byte, error) { + return []byte(a.String()), nil +} + +// Bytes returns the bytes of the Address +func (a Address) Bytes() []byte { + return a[:] +} + +func (a Address) String() string { + return a.checksumEncode() +} + +func unmarshalTextByte(dst, src []byte, size int) error { + str := string(src) + + str = strings.Trim(str, "\"") + if !strings.HasPrefix(str, "0x") { + return fmt.Errorf("0x prefix not found") + } + str = str[2:] + b, err := hex.DecodeString(str) + if err != nil { + return err + } + if len(b) != size { + return fmt.Errorf("length %d is not correct, expected %d", len(b), size) + } + copy(dst, b) + return nil +} + +func (a Address) checksumEncode() string { + address := strings.ToLower(hex.EncodeToString(a[:])) + + hash := hex.EncodeToString(Keccak256Hash([]byte(address)).Bytes()) + + ret := "0x" + for i := 0; i < len(address); i++ { + character := string(address[i]) + + num, _ := strconv.ParseInt(string(hash[i]), 16, 64) + if num > 7 { + ret += strings.ToUpper(character) + } else { + ret += character + } + } + + return ret +} + +func readAddr(b []byte) (Address, error) { + res := Address{} + if len(b) != 32 { + return res, fmt.Errorf("len is not correct") + } + copy(res[:], b[12:]) + return res, nil +} + +func readInteger(t *Type, b []byte) interface{} { + switch t.t.Kind() { + case reflect.Uint8: + return b[len(b)-1] + + case reflect.Uint16: + return binary.BigEndian.Uint16(b[len(b)-2:]) + + case reflect.Uint32: + return binary.BigEndian.Uint32(b[len(b)-4:]) + + case reflect.Uint64: + return binary.BigEndian.Uint64(b[len(b)-8:]) + + case reflect.Int8: + return int8(b[len(b)-1]) + + case reflect.Int16: + return int16(binary.BigEndian.Uint16(b[len(b)-2:])) + + case reflect.Int32: + return int32(binary.BigEndian.Uint32(b[len(b)-4:])) + + case reflect.Int64: + return int64(binary.BigEndian.Uint64(b[len(b)-8:])) + + default: + ret := new(big.Int).SetBytes(b) + if t.kind == KindUInt { + return ret + } + + if ret.Cmp(maxInt256) > 0 { + ret.Add(maxUint256, big.NewInt(0).Neg(ret)) + ret.Add(ret, big.NewInt(1)) + ret.Neg(ret) + } + return ret + } +} + +// nolint +func readFunctionType(t *Type, word []byte) ([24]byte, error) { + res := [24]byte{} + if !allZeros(word[24:32]) { + return res, fmt.Errorf("function type expects the last 8 bytes to be empty but found: %b", word[24:32]) + } + copy(res[:], word[0:24]) + return res, nil +} + +// nolint +func readFixedBytes(t *Type, word []byte) (interface{}, error) { + array := reflect.New(t.t).Elem() + reflect.Copy(array, reflect.ValueOf(word[0:t.size])) + return array.Interface(), nil +} + +func decodeTuple(t *Type, data []byte) (interface{}, []byte, error) { + res := make(map[string]interface{}) + + orig := data + origLen := len(orig) + for indx, arg := range t.tuple { + if len(data) < 32 { + return nil, nil, fmt.Errorf("incorrect length") + } + + entry := data + if arg.Elem.isDynamicType() { + offset, err := readOffset(data, origLen) + if err != nil { + return nil, nil, err + } + entry = orig[offset:] + } + + val, tail, err := decode(arg.Elem, entry) + if err != nil { + return nil, nil, err + } + + if !arg.Elem.isDynamicType() { + data = tail + } else { + data = data[32:] + } + + name := arg.Name + if name == "" { + name = strconv.Itoa(indx) + } + if _, ok := res[name]; !ok { + res[name] = val + } else { + return nil, nil, fmt.Errorf("tuple with repeated values") + } + } + return res, data, nil +} + +func decodeArraySlice(t *Type, data []byte, size int) (interface{}, []byte, error) { + if size < 0 { + return nil, nil, fmt.Errorf("size is lower than zero") + } + if 32*size > len(data) { + return nil, nil, fmt.Errorf("size is too big") + } + + var res reflect.Value + if t.kind == KindSlice { + res = reflect.MakeSlice(t.t, size, size) + } else if t.kind == KindArray { + res = reflect.New(t.t).Elem() + } + + orig := data + origLen := len(orig) + for indx := 0; indx < size; indx++ { + isDynamic := t.elem.isDynamicType() + + if len(data) < 32 { + return nil, nil, fmt.Errorf("incorrect length") + } + + entry := data + if isDynamic { + offset, err := readOffset(data, origLen) + if err != nil { + return nil, nil, err + } + entry = orig[offset:] + } + + val, tail, err := decode(t.elem, entry) + if err != nil { + return nil, nil, err + } + + if !isDynamic { + data = tail + } else { + data = data[32:] + } + res.Index(indx).Set(reflect.ValueOf(val)) + } + return res.Interface(), data, nil +} + +func decodeBool(data []byte) (interface{}, error) { + switch data[31] { + case 0: + return false, nil + case 1: + return true, nil + default: + return false, fmt.Errorf("bad boolean") + } +} + +func readOffset(data []byte, len int) (int, error) { + offsetBig := big.NewInt(0).SetBytes(data[0:32]) + if offsetBig.BitLen() > 63 { + return 0, fmt.Errorf("offset larger than int64: %v", offsetBig.Int64()) + } + offset := int(offsetBig.Int64()) + if offset > len { + return 0, fmt.Errorf("offset insufficient %v require %v", len, offset) + } + return offset, nil +} + +func readLength(data []byte) (int, error) { + lengthBig := big.NewInt(0).SetBytes(data[0:32]) + if lengthBig.BitLen() > 63 { + return 0, fmt.Errorf("length larger than int64: %v", lengthBig.Int64()) + } + length := int(lengthBig.Uint64()) + + // if we trim the length in the data there should be enough + // bytes to cover the length + if length > len(data)-32 { + return 0, fmt.Errorf("length insufficient %v require %v", len(data), length) + } + return length, nil +} + +func allZeros(b []byte) bool { + for _, i := range b { + if i != 0 { + return false + } + } + return true +} diff --git a/abi_decode_unit_test.go b/abi_decode_unit_test.go new file mode 100644 index 00000000..4beae8ae --- /dev/null +++ b/abi_decode_unit_test.go @@ -0,0 +1,43 @@ +//go:build all || unit +// +build all unit + +package hedera + +/*- + * + * Hedera Go SDK + * + * Copyright (C) 2020 - 2024 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecodeBytesBound(t *testing.T) { + typ, _ := NewType("tuple(string)") + decodeTuple(typ, nil) // it should not panic +} + +func TestDecodeDynamicLengthOutOfBounds(t *testing.T) { + input := []byte("00000000000000000000000000000000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 00000000000000000000000000") + typ, _ := NewType("tuple(bytes32, bytes, bytes)") + + _, err := Decode(typ, input) + require.Error(t, err) +} diff --git a/abi_encode.go b/abi_encode.go new file mode 100644 index 00000000..60bcf893 --- /dev/null +++ b/abi_encode.go @@ -0,0 +1,377 @@ +/*- + * + * Hedera Go SDK + * + * Copyright (C) 2020 - 2024 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use q file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package hedera + +import ( + "encoding/hex" + "fmt" + "math/big" + "reflect" + "strconv" + "strings" +) + +var ( + zero = big.NewInt(0) + one = big.NewInt(1) +) + +// Encode encodes a value +func Encode(v interface{}, t *Type) ([]byte, error) { + return encode(reflect.ValueOf(v), t) +} + +func encode(v reflect.Value, t *Type) ([]byte, error) { + if v.Kind() == reflect.Interface { + v = v.Elem() + } + + switch t.kind { + case KindSlice, KindArray: + return encodeSliceAndArray(v, t) + + case KindTuple: + return encodeTuple(v, t) + + case KindString: + return encodeString(v) + + case KindBool: + return encodeBool(v) + + case KindAddress: + return encodeAddress(v) + + case KindInt, KindUInt: + return encodeNum(v) + + case KindBytes: + return encodeBytes(v) + + case KindFixedBytes, KindFunction: + return encodeFixedBytes(v) + + default: + return nil, fmt.Errorf("encoding not available for type '%s'", t.kind) + } +} + +func encodeSliceAndArray(v reflect.Value, t *Type) ([]byte, error) { + if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { + return nil, encodeErr(v, t.kind.String()) + } + + if v.Kind() == reflect.Array && t.kind != KindArray { + return nil, fmt.Errorf("expected array") + } else if v.Kind() == reflect.Slice && t.kind != KindSlice { + return nil, fmt.Errorf("expected slice") + } + + if t.kind == KindArray && t.size != v.Len() { + return nil, fmt.Errorf("array len incompatible") + } + + var ret, tail []byte + if t.isVariableInput() { + ret = append(ret, packNum(v.Len())...) + } + + offset := 0 + isDynamic := t.elem.isDynamicType() + if isDynamic { + offset = getTypeSize(t.elem) * v.Len() + } + + for i := 0; i < v.Len(); i++ { + val, err := encode(v.Index(i), t.elem) + if err != nil { + return nil, err + } + if !isDynamic { + ret = append(ret, val...) + } else { + ret = append(ret, packNum(offset)...) + offset += len(val) + tail = append(tail, val...) + } + } + return append(ret, tail...), nil +} + +func encodeTuple(v reflect.Value, t *Type) ([]byte, error) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + var err error + isList := true + + switch v.Kind() { + case reflect.Slice, reflect.Array: + case reflect.Map: + isList = false + + case reflect.Struct: + isList = false + v, err = mapFromStruct(v) + if err != nil { + return nil, err + } + + default: + return nil, encodeErr(v, "tuple") + } + + if v.Len() < len(t.tuple) { + return nil, fmt.Errorf("expected at least the same length") + } + + offset := 0 + for _, elem := range t.tuple { + offset += getTypeSize(elem.Elem) + } + + var ret, tail []byte + var aux reflect.Value + + for i, elem := range t.tuple { + if isList { + aux = v.Index(i) + } else { + name := elem.Name + if name == "" { + name = strconv.Itoa(i) + } + aux = v.MapIndex(reflect.ValueOf(name)) + } + if aux.Kind() == reflect.Invalid { + return nil, fmt.Errorf("cannot get key %s", elem.Name) + } + + val, err := encode(aux, elem.Elem) + if err != nil { + return nil, err + } + if elem.Elem.isDynamicType() { + ret = append(ret, packNum(offset)...) + tail = append(tail, val...) + offset += len(val) + } else { + ret = append(ret, val...) + } + } + + return append(ret, tail...), nil +} + +func convertArrayToBytes(value reflect.Value) reflect.Value { + slice := reflect.MakeSlice(reflect.TypeOf([]byte{}), value.Len(), value.Len()) + reflect.Copy(slice, value) + return slice +} + +func encodeFixedBytes(v reflect.Value) ([]byte, error) { + if v.Kind() == reflect.Array { + v = convertArrayToBytes(v) + } + if v.Kind() == reflect.String { + value, err := decodeHex(v.String()) + if err != nil { + return nil, err + } + + v = reflect.ValueOf(value) + } + return rightPad(v.Bytes(), 32), nil +} + +func encodeAddress(v reflect.Value) ([]byte, error) { + if v.Kind() == reflect.Array { + v = convertArrayToBytes(v) + } + if v.Kind() == reflect.String { + var addr Address + if err := addr.UnmarshalText([]byte(v.String())); err != nil { + return nil, err + } + v = reflect.ValueOf(addr.Bytes()) + } + return leftPad(v.Bytes(), 32), nil +} + +func encodeBytes(v reflect.Value) ([]byte, error) { + if v.Kind() == reflect.Array { + v = convertArrayToBytes(v) + } + if v.Kind() == reflect.String { + value, err := decodeHex(v.String()) + if err != nil { + return nil, err + } + + v = reflect.ValueOf(value) + } + return packBytesSlice(v.Bytes(), v.Len()) +} + +func encodeString(v reflect.Value) ([]byte, error) { + if v.Kind() != reflect.String { + return nil, encodeErr(v, "string") + } + return packBytesSlice([]byte(v.String()), v.Len()) +} + +func packBytesSlice(buf []byte, l int) ([]byte, error) { + len, err := encodeNum(reflect.ValueOf(l)) + if err != nil { + return nil, err + } + return append(len, rightPad(buf, (l+31)/32*32)...), nil +} + +func packNum(offset int) []byte { + n, _ := encodeNum(reflect.ValueOf(offset)) + return n +} + +func encodeNum(v reflect.Value) ([]byte, error) { + switch v.Kind() { + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return toU256(new(big.Int).SetUint64(v.Uint())), nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return toU256(big.NewInt(v.Int())), nil + + case reflect.Ptr: + if v.Type() != bigIntT { + return nil, encodeErr(v.Elem(), "number") + } + return toU256(v.Interface().(*big.Int)), nil + + case reflect.Float64: + return encodeNum(reflect.ValueOf(int64(v.Float()))) + + case reflect.String: + n, ok := new(big.Int).SetString(v.String(), 10) + if !ok { + n, ok = new(big.Int).SetString(v.String()[2:], 16) + if !ok { + return nil, encodeErr(v, "number") + } + } + return encodeNum(reflect.ValueOf(n)) + + default: + return nil, encodeErr(v, "number") + } +} + +func encodeBool(v reflect.Value) ([]byte, error) { + if v.Kind() != reflect.Bool { + return nil, encodeErr(v, "bool") + } + if v.Bool() { + return leftPad(one.Bytes(), 32), nil + } + return leftPad(zero.Bytes(), 32), nil +} + +func encodeErr(v reflect.Value, t string) error { + return fmt.Errorf("failed to encode %s as %s", v.Kind().String(), t) +} + +// nolint +func mapFromStruct(v reflect.Value) (reflect.Value, error) { + res := map[string]interface{}{} + typ := v.Type() + for i := 0; i < v.NumField(); i++ { + f := typ.Field(i) + if f.PkgPath != "" { + continue + } + + tagValue := f.Tag.Get("abi") + if tagValue == "-" { + continue + } + + name := strings.ToLower(f.Name) + if tagValue != "" { + name = tagValue + } + if _, ok := res[name]; !ok { + res[name] = v.Field(i).Interface() + } + } + return reflect.ValueOf(res), nil +} + +var ( + tt256 = new(big.Int).Lsh(big.NewInt(1), 256) // 2 ** 256 + tt256m1 = new(big.Int).Sub(tt256, big.NewInt(1)) // 2 ** 256 - 1 +) + +// U256 converts a big Int into a 256bit EVM number. +func toU256(n *big.Int) []byte { + b := new(big.Int) + b = b.Set(n) + + if b.Sign() < 0 || b.BitLen() > 256 { + b.And(b, tt256m1) + } + + return leftPad(b.Bytes(), 32) +} + +func padBytes(b []byte, size int, left bool) []byte { + l := len(b) + if l == size { + return b + } + if l > size { + return b[l-size:] + } + tmp := make([]byte, size) + if left { + copy(tmp[size-l:], b) + } else { + copy(tmp, b) + } + return tmp +} + +// nolint +func leftPad(b []byte, size int) []byte { + return padBytes(b, size, true) +} + +func rightPad(b []byte, size int) []byte { + return padBytes(b, size, false) +} + +func decodeHex(str string) ([]byte, error) { + str = strings.TrimPrefix(str, "0x") + buf, err := hex.DecodeString(str) + if err != nil { + return nil, fmt.Errorf("could not decode hex: %v", err) + } + return buf, nil +} diff --git a/abi_encode_unit_test.go b/abi_encode_unit_test.go new file mode 100644 index 00000000..3ac062ff --- /dev/null +++ b/abi_encode_unit_test.go @@ -0,0 +1,413 @@ +//go:build all || unit +// +build all unit + +package hedera + +/*- + * + * Hedera Go SDK + * + * Copyright (C) 2020 - 2024 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +import ( + "fmt" + "math/big" + "reflect" + "testing" +) + +func mustDecodeHex(str string) []byte { + buf, err := decodeHex(str) + if err != nil { + panic(fmt.Errorf("could not decode hex: %v", err)) + } + return buf +} + +func TestEncoding(t *testing.T) { + cases := []struct { + Type string + Input interface{} + }{ + { + "uint40", + big.NewInt(50), + }, + { + "int256", + big.NewInt(2), + }, + { + "int256[]", + []*big.Int{big.NewInt(1), big.NewInt(2)}, + }, + { + "int256", + big.NewInt(-10), + }, + { + "bytes5", + [5]byte{0x1, 0x2, 0x3, 0x4, 0x5}, + }, + { + "bytes", + mustDecodeHex("0x12345678911121314151617181920211"), + }, + { + "string", + "foobar", + }, + { + "uint8[][2]", + [2][]uint8{{1}, {1}}, + }, + { + "address[]", + []Address{{1}, {2}}, + }, + { + "bytes10[]", + [][10]byte{ + {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0x10}, + {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0x10}, + }, + }, + { + "bytes[]", + [][]byte{ + mustDecodeHex("0x11"), + mustDecodeHex("0x22"), + }, + }, + { + "uint32[2][3][4]", + [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}}, + }, + { + "uint8[]", + []uint8{1, 2}, + }, + { + "string[]", + []string{"hello", "foobar"}, + }, + { + "string[2]", + [2]string{"hello", "foobar"}, + }, + { + "bytes32[][]", + [][][32]uint8{{{1}, {2}}, {{3}, {4}, {5}}}, + }, + { + "bytes32[][2]", + [2][][32]uint8{{{1}, {2}}, {{3}, {4}, {5}}}, + }, + { + "bytes32[3][2]", + [2][3][32]uint8{{{1}, {2}, {3}}, {{3}, {4}, {5}}}, + }, + { + "uint16[][2][]", + [][2][]uint16{ + {{0, 1}, {2, 3}}, + {{4, 5}, {6, 7}}, + }, + }, + { + "tuple(bytes[] a)", + map[string]interface{}{ + "a": [][]byte{{0xf0, 0xf0, 0xf0}, {0xf0, 0xf0, 0xf0}}, + }, + }, + { + "tuple(uint32[2][][] a)", + // `[{"type": "uint32[2][][]"}]`, + map[string]interface{}{ + "a": [][][2]uint32{{{uint32(1), uint32(200)}, {uint32(1), uint32(1000)}}, {{uint32(1), uint32(200)}, {uint32(1), uint32(1000)}}}, + }, + }, + { + "tuple(uint64[2] a)", + map[string]interface{}{ + "a": [2]uint64{1, 2}, + }, + }, + { + "tuple(uint32[2][3][4] a)", + map[string]interface{}{ + "a": [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}}, + }, + }, + { + "tuple(int32[] a)", + map[string]interface{}{ + "a": []int32{1, 2}, + }, + }, + { + "tuple(int32 a, int32 b)", + map[string]interface{}{ + "a": int32(1), + "b": int32(2), + }, + }, + { + "tuple(string a, int32 b)", + map[string]interface{}{ + "a": "Hello Worldxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "b": int32(2), + }, + }, + { + "tuple(int32[2] a, int32[] b)", + map[string]interface{}{ + "a": [2]int32{1, 2}, + "b": []int32{4, 5, 6}, + }, + }, + { + // tuple with array slice + "tuple(address[] a)", + map[string]interface{}{ + "a": []Address{ + {0x1}, + }, + }, + }, + { + // First dynamic second static + "tuple(int32[] a, int32[2] b)", + map[string]interface{}{ + "a": []int32{1, 2, 3}, + "b": [2]int32{4, 5}, + }, + }, + { + // Both dynamic + "tuple(int32[] a, int32[] b)", + map[string]interface{}{ + "a": []int32{1, 2, 3}, + "b": []int32{4, 5, 6}, + }, + }, + { + "tuple(string a, int64 b)", + map[string]interface{}{ + "a": "hello World", + "b": int64(266), + }, + }, + { + // tuple array + "tuple(int32 a, int32 b)[2]", + [2]map[string]interface{}{ + { + "a": int32(1), + "b": int32(2), + }, + { + "a": int32(3), + "b": int32(4), + }, + }, + }, + + { + // tuple array with dynamic content + "tuple(int32[] a)[2]", + [2]map[string]interface{}{ + { + "a": []int32{1, 2, 3}, + }, + { + "a": []int32{4, 5, 6}, + }, + }, + }, + { + // tuple slice + "tuple(int32 a, int32[] b)[]", + []map[string]interface{}{ + { + "a": int32(1), + "b": []int32{2, 3}, + }, + { + "a": int32(4), + "b": []int32{5, 6}, + }, + }, + }, + { + // nested tuple + "tuple(tuple(int32 c, int32[] d) a, int32[] b)", + map[string]interface{}{ + "a": map[string]interface{}{ + "c": int32(5), + "d": []int32{3, 4}, + }, + "b": []int32{1, 2}, + }, + }, + { + "tuple(uint8[2] a, tuple(uint8 e, uint32 f)[2] b, uint16 c, uint64[2][1] d)", + map[string]interface{}{ + "a": [2]uint8{uint8(1), uint8(2)}, + "b": [2]map[string]interface{}{ + { + "e": uint8(10), + "f": uint32(11), + }, + { + "e": uint8(20), + "f": uint32(21), + }, + }, + "c": uint16(3), + "d": [1][2]uint64{{uint64(4), uint64(5)}}, + }, + }, + { + "tuple(uint16 a, uint16 b)[1][]", + [][1]map[string]interface{}{ + { + { + "a": uint16(1), + "b": uint16(2), + }, + }, + { + { + "a": uint16(3), + "b": uint16(4), + }, + }, + { + { + "a": uint16(5), + "b": uint16(6), + }, + }, + { + { + "a": uint16(7), + "b": uint16(8), + }, + }, + }, + }, + { + "tuple(uint64[][] a, tuple(uint8 a, uint32 b)[1] b, uint64 c)", + map[string]interface{}{ + "a": [][]uint64{ + {3, 4}, + }, + "b": [1]map[string]interface{}{ + { + "a": uint8(1), + "b": uint32(2), + }, + }, + "c": uint64(10), + }, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + t.Parallel() + + tt, err := NewType(c.Type) + if err != nil { + t.Fatal(err) + } + + if err := testEncodeDecode(t, tt, c.Input); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestEncodeStruct(t *testing.T) { + typ, _ := NewType("tuple(address aa, uint256 b)") + + type Obj struct { + A Address `abi:"aa"` + B *big.Int + } + obj := Obj{ + A: Address{0x1}, + B: big.NewInt(1), + } + + encoded, err := typ.Encode(&obj) + if err != nil { + t.Fatal(err) + } + + var obj2 Obj + if err := typ.DecodeStruct(encoded, &obj2); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(obj, obj2) { + t.Fatal("bad") + } +} + +func TestEncodeStructCamcelCase(t *testing.T) { + typ, _ := NewType("tuple(address aA, uint256 b)") + + type Obj struct { + A Address `abi:"aA"` + B *big.Int + } + obj := Obj{ + A: Address{0x1}, + B: big.NewInt(1), + } + + encoded, err := typ.Encode(&obj) + if err != nil { + t.Fatal(err) + } + + var obj2 Obj + if err := typ.DecodeStruct(encoded, &obj2); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(obj, obj2) { + t.Fatal("bad") + } +} + +func testEncodeDecode(t *testing.T, tt *Type, input interface{}) error { + res1, err := Encode(input, tt) + if err != nil { + return err + } + res2, err := Decode(tt, res1) + if err != nil { + return err + } + + if !reflect.DeepEqual(res2, input) { + return fmt.Errorf("bad") + } + return nil +} diff --git a/abi_type.go b/abi_type.go new file mode 100644 index 00000000..001b3ce9 --- /dev/null +++ b/abi_type.go @@ -0,0 +1,727 @@ +/*- + * + * Hedera Go SDK + * + * Copyright (C) 2020 - 2024 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use q file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package hedera + +import ( + "fmt" + "math/big" + "reflect" + "regexp" + "strconv" + "strings" +) + +// batch of predefined reflect types +var ( + boolT = reflect.TypeOf(bool(false)) + uint8T = reflect.TypeOf(uint8(0)) + uint16T = reflect.TypeOf(uint16(0)) + uint32T = reflect.TypeOf(uint32(0)) + uint64T = reflect.TypeOf(uint64(0)) + int8T = reflect.TypeOf(int8(0)) + int16T = reflect.TypeOf(int16(0)) + int32T = reflect.TypeOf(int32(0)) + int64T = reflect.TypeOf(int64(0)) + addressT = reflect.TypeOf(Address{}) + stringT = reflect.TypeOf("") + dynamicBytesT = reflect.SliceOf(reflect.TypeOf(byte(0))) + functionT = reflect.ArrayOf(24, reflect.TypeOf(byte(0))) + tupleT = reflect.TypeOf(map[string]interface{}{}) + bigIntT = reflect.TypeOf(new(big.Int)) +) + +// AbiTypeKind represents the kind of abi type +type AbiTypeKind int + +const ( + // KindBool is a boolean + KindBool AbiTypeKind = iota + + // KindUInt is an uint + KindUInt + + // KindInt is an int + KindInt + + // KindString is a string + KindString + + // KindArray is an array + KindArray + + // KindSlice is a slice + KindSlice + + // KindAddress is an address + KindAddress + + // KindBytes is a bytes array + KindBytes + + // KindFixedBytes is a fixed bytes + KindFixedBytes + + // KindFixedPoint is a fixed point + KindFixedPoint + + // KindTuple is a tuple + KindTuple + + // KindFunction is a function + KindFunction +) + +func (k AbiTypeKind) String() string { + names := [...]string{ + "Bool", + "Uint", + "Int", + "String", + "Array", + "Slice", + "Address", + "Bytes", + "FixedBytes", + "FixedPoint", + "Tuple", + "Function", + } + + return names[k] +} + +// TupleElem is an element of a tuple +type TupleElem struct { + Name string + Elem *Type + Indexed bool +} + +// Type is an ABI type +type Type struct { + kind AbiTypeKind + size int + elem *Type + tuple []*TupleElem + t reflect.Type + itype string +} + +func NewTupleType(inputs []*TupleElem) *Type { + return &Type{ + kind: KindTuple, + tuple: inputs, + t: tupleT, + } +} + +func NewTupleTypeFromArgs(inputs []*ArgumentStr) (*Type, error) { + elems := []*TupleElem{} + for _, i := range inputs { + typ, err := NewTypeFromArgument(i) + if err != nil { + return nil, err + } + elems = append(elems, &TupleElem{ + Name: i.Name, + Elem: typ, + Indexed: i.Indexed, + }) + } + return NewTupleType(elems), nil +} + +// Decode decodes an object using this type +func (t *Type) Decode(input []byte) (interface{}, error) { + return Decode(t, input) +} + +// DecodeStruct decodes an object using this type to the out param +func (t *Type) DecodeStruct(input []byte, out interface{}) error { + return DecodeStruct(t, input, out) +} + +// InternalType returns the internal type +func (t *Type) InternalType() string { + return t.itype +} + +// Encode encodes an object using this type +func (t *Type) Encode(v interface{}) ([]byte, error) { + return Encode(v, t) +} + +// String returns the raw representation of the type +func (t *Type) String() string { + return t.Format(false) +} + +// nolint +func (t *Type) Format(includeArgs bool) string { + switch t.kind { + case KindTuple: + rawAux := []string{} + for _, i := range t.TupleElems() { + name := i.Elem.Format(includeArgs) + if i.Indexed { + name += " indexed" + } + if includeArgs { + if i.Name != "" { + name += " " + i.Name + } + } + rawAux = append(rawAux, name) + } + return fmt.Sprintf("tuple(%s)", strings.Join(rawAux, ",")) + + case KindArray: + return fmt.Sprintf("%s[%d]", t.elem.Format(includeArgs), t.size) + + case KindSlice: + return fmt.Sprintf("%s[]", t.elem.Format(includeArgs)) + + case KindBytes: + return "bytes" + + case KindFixedBytes: + return fmt.Sprintf("bytes%d", t.size) + + case KindString: + return "string" + + case KindBool: + return "bool" + + case KindAddress: + return "address" + + case KindFunction: + return "function" + + case KindUInt: + return fmt.Sprintf("uint%d", t.size) + + case KindInt: + return fmt.Sprintf("int%d", t.size) + + default: + panic(fmt.Errorf("BUG: abi type not found %s", t.kind.String())) + } +} + +// Elem returns the elem value for slice and arrays +func (t *Type) Elem() *Type { + return t.elem +} + +// Size returns the size of the type +func (t *Type) Size() int { + return t.size +} + +// TupleElems returns the elems of the tuple +func (t *Type) TupleElems() []*TupleElem { + return t.tuple +} + +// GoType returns the go type +func (t *Type) GoType() reflect.Type { + return t.t +} + +// Kind returns the kind of the type +func (t *Type) Kind() AbiTypeKind { + return t.kind +} + +func (t *Type) isVariableInput() bool { + return t.kind == KindSlice || t.kind == KindBytes || t.kind == KindString +} + +func (t *Type) isDynamicType() bool { + if t.kind == KindTuple { + for _, elem := range t.tuple { + if elem.Elem.isDynamicType() { + return true + } + } + return false + } + return t.kind == KindString || t.kind == KindBytes || t.kind == KindSlice || (t.kind == KindArray && t.elem.isDynamicType()) +} + +func parseType(arg *ArgumentStr) (string, error) { + if !strings.HasPrefix(arg.Type, "tuple") { + return arg.Type, nil + } + + if len(arg.Components) == 0 { + return "tuple()", nil + } + + // parse the arg components from the tuple + str := []string{} + for _, i := range arg.Components { + aux, err := parseType(i) + if err != nil { + return "", err + } + if i.Indexed { + str = append(str, aux+" indexed "+i.Name) + } else { + str = append(str, aux+" "+i.Name) + } + } + return fmt.Sprintf("tuple(%s)%s", strings.Join(str, ","), strings.TrimPrefix(arg.Type, "tuple")), nil +} + +// NewTypeFromArgument parses an abi type from an argument +func NewTypeFromArgument(arg *ArgumentStr) (*Type, error) { + str, err := parseType(arg) + if err != nil { + return nil, err + } + typ, err := NewType(str) + if err != nil { + return nil, err + } + + // fill-in the `internalType` field into the type elems + err = fillIn(typ, arg) + if err != nil { + return nil, err + } + + return typ, nil +} + +func fillIn(typ *Type, arg *ArgumentStr) error { + typ.itype = arg.InternalType + + if len(arg.Components) == 0 { + // no more items, nothing else to do + return nil + } + + // tuple types in the ABI with slices are represented as + // tuple()[] or tuple()[2]. Thus, there might be element in the components + // section of the abi but the next item not be a tuple. + for { + kind := typ.kind + if kind == KindTuple { + break + } + if kind != KindArray && kind != KindSlice { + // error + return fmt.Errorf("array or slice not found") + } + typ = typ.Elem() + } + + if len(arg.Components) != len(typ.tuple) { + // incorrect length + return fmt.Errorf("incorrect size") + } + + for indx, i := range arg.Components { + err := fillIn(typ.tuple[indx].Elem, i) + if err != nil { + return err + } + } + + return nil +} + +// NewType parses a type in string format +func NewType(s string) (*Type, error) { + l := newLexer(s) + l.nextToken() + + return readType(l) +} + +func getTypeSize(t *Type) int { + if t.kind == KindArray && !t.elem.isDynamicType() { + if t.elem.kind == KindArray || t.elem.kind == KindTuple { + return t.size * getTypeSize(t.elem) + } + return t.size * 32 + } else if t.kind == KindTuple && !t.isDynamicType() { + total := 0 + for _, elem := range t.tuple { + total += getTypeSize(elem.Elem) + } + return total + } + return 32 +} + +var typeRegexp = regexp.MustCompile("^([[:alpha:]]+)([[:digit:]]*)$") + +func expectedToken(t tokenType) error { + return fmt.Errorf("expected token %s", t.String()) +} + +func notExpectedToken(t tokenType) error { + return fmt.Errorf("token '%s' not expected", t.String()) +} + +// nolint +func readType(l *lexer) (*Type, error) { + var tt *Type + + tok := l.nextToken() + + isTuple := false + if tok.typ == tupleToken { + if l.nextToken().typ != lparenToken { + return nil, expectedToken(lparenToken) + } + isTuple = true + } else if tok.typ == lparenToken { + isTuple = true + } + if isTuple { + var next token + elems := []*TupleElem{} + for { + name := "" + indexed := false + + elem, err := readType(l) + if err != nil { + if l.current.typ == rparenToken && len(elems) == 0 { + // empty tuple 'tuple()' + break + } + return nil, fmt.Errorf("failed to decode type: %v", err) + } + + switch l.peek.typ { + case strToken: + l.nextToken() + name = l.current.literal + + case indexedToken: + l.nextToken() + indexed = true + if l.peek.typ == strToken { + l.nextToken() + name = l.current.literal + } + } + + elems = append(elems, &TupleElem{ + Name: name, + Elem: elem, + Indexed: indexed, + }) + + next = l.nextToken() + if next.typ == commaToken { + continue + } else if next.typ == rparenToken { + break + } else { + return nil, notExpectedToken(next.typ) + } + } + tt = &Type{kind: KindTuple, tuple: elems, t: tupleT} + } else if tok.typ != strToken { + return nil, expectedToken(strToken) + } else { + // Check normal types + elem, err := decodeSimpleType(tok.literal) + if err != nil { + return nil, err + } + tt = elem + } + + // check for arrays at the end of the type + for { + if l.peek.typ != lbracketToken { + break + } + + l.nextToken() + n := l.nextToken() + + var tAux *Type + if n.typ == rbracketToken { + tAux = &Type{kind: KindSlice, elem: tt, t: reflect.SliceOf(tt.t)} + } else if n.typ == numberToken { + size, err := strconv.ParseUint(n.literal, 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to read array size '%s': %v", n.literal, err) + } + + tAux = &Type{kind: KindArray, elem: tt, size: int(size), t: reflect.ArrayOf(int(size), tt.t)} + if l.nextToken().typ != rbracketToken { + return nil, expectedToken(rbracketToken) + } + } else { + return nil, notExpectedToken(n.typ) + } + + tt = tAux + } + return tt, nil +} + +func decodeSimpleType(str string) (*Type, error) { + match := typeRegexp.FindStringSubmatch(str) + if len(match) == 0 { + return nil, fmt.Errorf("type format is incorrect. Expected 'type''bytes' but found '%s'", str) + } + match = match[1:] + + var err error + t := match[0] + + bytes := 0 + ok := false + + if bytesStr := match[1]; bytesStr != "" { + bytes, err = strconv.Atoi(bytesStr) + if err != nil { + return nil, fmt.Errorf("failed to parse bytes '%s': %v", bytesStr, err) + } + ok = true + } + + // int and uint without bytes default to 256, 'bytes' may + // have or not, the rest dont have bytes + if t == "int" || t == "uint" { + if !ok { + bytes = 256 + } + } else if t != "bytes" && ok { + return nil, fmt.Errorf("type %s does not expect bytes", t) + } + + switch t { + case "uint": + var k reflect.Type + switch bytes { + case 8: + k = uint8T + case 16: + k = uint16T + case 32: + k = uint32T + case 64: + k = uint64T + default: + if bytes%8 != 0 { + panic(fmt.Errorf("number of bytes has to be M mod 8")) + } + k = bigIntT + } + return &Type{kind: KindUInt, size: bytes, t: k}, nil + + case "int": + var k reflect.Type + switch bytes { + case 8: + k = int8T + case 16: + k = int16T + case 32: + k = int32T + case 64: + k = int64T + default: + if bytes%8 != 0 { + panic(fmt.Errorf("number of bytes has to be M mod 8")) + } + k = bigIntT + } + return &Type{kind: KindInt, size: bytes, t: k}, nil + + case "byte": + bytes = 1 + fallthrough + + case "bytes": + if bytes == 0 { + return &Type{kind: KindBytes, t: dynamicBytesT}, nil + } + return &Type{kind: KindFixedBytes, size: bytes, t: reflect.ArrayOf(bytes, reflect.TypeOf(byte(0)))}, nil + + case "string": + return &Type{kind: KindString, t: stringT}, nil + + case "bool": + return &Type{kind: KindBool, t: boolT}, nil + + case "address": + return &Type{kind: KindAddress, t: addressT, size: 20}, nil + + case "function": + return &Type{kind: KindFunction, size: 24, t: functionT}, nil + + default: + return nil, fmt.Errorf("unknown type '%s'", t) + } +} + +type tokenType int + +const ( + eofToken tokenType = iota + strToken + numberToken + tupleToken + lparenToken + rparenToken + lbracketToken + rbracketToken + commaToken + indexedToken + invalidToken +) + +func (t tokenType) String() string { + names := [...]string{ + "eof", + "string", + "number", + "tuple", + "(", + ")", + "[", + "]", + ",", + "indexed", + "", + } + return names[t] +} + +type token struct { + typ tokenType + literal string +} + +type lexer struct { + input string + current token + peek token + position int + readPosition int + ch byte +} + +func newLexer(input string) *lexer { + l := &lexer{input: input} + l.readChar() + return l +} + +func (l *lexer) readChar() { + if l.readPosition >= len(l.input) { + l.ch = 0 + } else { + l.ch = l.input[l.readPosition] + } + + l.position = l.readPosition + l.readPosition++ +} + +func (l *lexer) nextToken() token { + l.current = l.peek + l.peek = l.nextTokenImpl() + return l.current +} + +// nolint +func (l *lexer) nextTokenImpl() token { + var tok token + + // skip whitespace + for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' { + l.readChar() + } + + switch l.ch { + case ',': + tok.typ = commaToken + case '(': + tok.typ = lparenToken + case ')': + tok.typ = rparenToken + case '[': + tok.typ = lbracketToken + case ']': + tok.typ = rbracketToken + case 0: + tok.typ = eofToken + default: + if isLetter(l.ch) { + tok.literal = l.readIdentifier() + if tok.literal == "tuple" { + tok.typ = tupleToken + } else if tok.literal == "indexed" { + tok.typ = indexedToken + } else { + tok.typ = strToken + } + + return tok + } else if isDigit(l.ch) { + return token{numberToken, l.readNumber()} + } else { + tok.typ = invalidToken + } + } + + l.readChar() + return tok +} + +func (l *lexer) readIdentifier() string { + pos := l.position + for isLetter(l.ch) || isDigit(l.ch) { + l.readChar() + } + + return l.input[pos:l.position] +} + +func (l *lexer) readNumber() string { + position := l.position + for isDigit(l.ch) { + l.readChar() + } + return l.input[position:l.position] +} + +func isDigit(ch byte) bool { + return '0' <= ch && ch <= '9' +} + +func isLetter(ch byte) bool { + return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' +} diff --git a/abi_type_unit_test.go b/abi_type_unit_test.go new file mode 100644 index 00000000..453ce137 --- /dev/null +++ b/abi_type_unit_test.go @@ -0,0 +1,462 @@ +//go:build all || unit +// +build all unit + +package hedera + +/*- + * + * Hedera Go SDK + * + * Copyright (C) 2020 - 2024 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestType(t *testing.T) { + cases := []struct { + s string + a *ArgumentStr + t *Type + r string + err bool + }{ + { + s: "bool", + a: simpleType("bool"), + t: &Type{kind: KindBool, t: boolT}, + }, + { + s: "uint32", + a: simpleType("uint32"), + t: &Type{kind: KindUInt, size: 32, t: uint32T}, + }, + { + s: "int32", + a: simpleType("int32"), + t: &Type{kind: KindInt, size: 32, t: int32T}, + }, + { + s: "int32[]", + a: simpleType("int32[]"), + t: &Type{kind: KindSlice, t: reflect.SliceOf(int32T), elem: &Type{kind: KindInt, size: 32, t: int32T}}, + }, + { + s: "int", + a: simpleType("int"), + t: &Type{kind: KindInt, size: 256, t: bigIntT}, + r: "int256", + }, + { + s: "int[]", + a: simpleType("int[]"), + t: &Type{kind: KindSlice, t: reflect.SliceOf(bigIntT), elem: &Type{kind: KindInt, size: 256, t: bigIntT}}, + r: "int256[]", + }, + { + s: "bytes[2]", + a: simpleType("bytes[2]"), + t: &Type{ + kind: KindArray, + t: reflect.ArrayOf(2, dynamicBytesT), + size: 2, + elem: &Type{ + kind: KindBytes, + t: dynamicBytesT, + }, + }, + }, + { + s: "address[]", + a: simpleType("address[]"), + t: &Type{kind: KindSlice, t: reflect.SliceOf(addressT), elem: &Type{kind: KindAddress, size: 20, t: addressT}}, + }, + { + s: "string[]", + a: simpleType("string[]"), + t: &Type{ + kind: KindSlice, + t: reflect.SliceOf(stringT), + elem: &Type{ + kind: KindString, + t: stringT, + }, + }, + }, + { + s: "string[2]", + a: simpleType("string[2]"), + t: &Type{ + kind: KindArray, + size: 2, + t: reflect.ArrayOf(2, stringT), + elem: &Type{ + kind: KindString, + t: stringT, + }, + }, + }, + + { + s: "string[2][]", + a: simpleType("string[2][]"), + t: &Type{ + kind: KindSlice, + t: reflect.SliceOf(reflect.ArrayOf(2, stringT)), + elem: &Type{ + kind: KindArray, + size: 2, + t: reflect.ArrayOf(2, stringT), + elem: &Type{ + kind: KindString, + t: stringT, + }, + }, + }, + }, + { + s: "tuple(int64 indexed arg0)", + a: &ArgumentStr{ + Type: "tuple", + Components: []*ArgumentStr{ + { + Name: "arg0", + Type: "int64", + Indexed: true, + }, + }, + }, + t: &Type{ + kind: KindTuple, + t: tupleT, + tuple: []*TupleElem{ + { + Name: "arg0", + Elem: &Type{ + kind: KindInt, + size: 64, + t: int64T, + }, + Indexed: true, + }, + }, + }, + }, + { + s: "tuple(int64 arg_0)[2]", + a: &ArgumentStr{ + Type: "tuple[2]", + Components: []*ArgumentStr{ + { + Name: "arg_0", + Type: "int64", + }, + }, + }, + t: &Type{ + kind: KindArray, + size: 2, + t: reflect.ArrayOf(2, tupleT), + elem: &Type{ + kind: KindTuple, + t: tupleT, + tuple: []*TupleElem{ + { + Name: "arg_0", + Elem: &Type{ + kind: KindInt, + size: 64, + t: int64T, + }, + }, + }, + }, + }, + }, + { + s: "tuple(int64 a)[]", + a: &ArgumentStr{ + Type: "tuple[]", + Components: []*ArgumentStr{ + { + Name: "a", + Type: "int64", + }, + }, + }, + t: &Type{ + kind: KindSlice, + t: reflect.SliceOf(tupleT), + elem: &Type{ + kind: KindTuple, + t: tupleT, + tuple: []*TupleElem{ + { + Name: "a", + Elem: &Type{ + kind: KindInt, + size: 64, + t: int64T, + }, + }, + }, + }, + }, + }, + { + s: "tuple(int32 indexed arg0,tuple(int32 c) b_2)", + a: &ArgumentStr{ + Type: "tuple", + Components: []*ArgumentStr{ + { + Name: "arg0", + Type: "int32", + Indexed: true, + }, + { + Name: "b_2", + Type: "tuple", + Components: []*ArgumentStr{ + { + Name: "c", + Type: "int32", + }, + }, + }, + }, + }, + t: &Type{ + kind: KindTuple, + t: tupleT, + tuple: []*TupleElem{ + { + Name: "arg0", + Elem: &Type{ + kind: KindInt, + size: 32, + t: int32T, + }, + Indexed: true, + }, + { + Name: "b_2", + Elem: &Type{ + kind: KindTuple, + t: tupleT, + tuple: []*TupleElem{ + { + Name: "c", + Elem: &Type{ + kind: KindInt, + size: 32, + t: int32T, + }, + }, + }, + }, + }, + }, + }, + }, + { + s: "tuple()", + a: &ArgumentStr{ + Type: "tuple", + Components: []*ArgumentStr{}, + }, + t: &Type{ + kind: KindTuple, + t: tupleT, + tuple: []*TupleElem{}, + }, + }, + { + // hidden tuple token + s: "tuple((int32))", + a: &ArgumentStr{ + Type: "tuple", + Components: []*ArgumentStr{ + { + Type: "tuple", + Components: []*ArgumentStr{ + { + Type: "int32", + }, + }, + }, + }, + }, + t: &Type{ + kind: KindTuple, + t: tupleT, + tuple: []*TupleElem{ + { + Elem: &Type{ + kind: KindTuple, + t: tupleT, + tuple: []*TupleElem{ + { + Elem: &Type{ + kind: KindInt, + size: 32, + t: int32T, + }, + }, + }, + }, + }, + }, + }, + r: "tuple(tuple(int32))", + }, + { + s: "int[[", + err: true, + }, + { + s: "tuple[](a int32)", + err: true, + }, + { + s: "int32[a]", + err: true, + }, + { + s: "tuple(a int32", + err: true, + }, + { + s: "tuple(a int32,", + err: true, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + e0, err := NewType(c.s) + if err != nil && !c.err { + t.Fatal(err) + } + if err == nil && c.err { + t.Fatal("it should have failed") + } + + if !c.err { + // compare the string + expected := c.s + if c.r != "" { + expected = c.r + } + assert.Equal(t, expected, e0.Format(true)) + + e1, err := NewTypeFromArgument(c.a) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(c.t, e0) { + + // fmt.Println(c.t.t) + // fmt.Println(e0.t) + + t.Fatal("bad new type") + } + if !reflect.DeepEqual(c.t, e1) { + t.Fatal("bad") + } + } + }) + } +} + +func TestTypeArgumentInternalFields(t *testing.T) { + arg := &ArgumentStr{ + Type: "tuple", + Components: []*ArgumentStr{ + { + Type: "tuple[]", + Components: []*ArgumentStr{ + { + Type: "int32", + InternalType: "c", + }, + }, + InternalType: "b", + }, + }, + } + + res, err := NewTypeFromArgument(arg) + require.NoError(t, err) + + require.Equal(t, res.tuple[0].Elem.itype, "b") + require.Equal(t, res.tuple[0].Elem.elem.tuple[0].Elem.itype, "c") +} + +func TestSize(t *testing.T) { + cases := []struct { + Input string + Size int + }{ + { + "int32", 32, + }, + { + "int32[]", 32, + }, + { + "int32[2]", 32 * 2, + }, + { + "int32[2][2]", 32 * 2 * 2, + }, + { + "string", 32, + }, + { + "string[]", 32, + }, + { + "tuple(uint8 a, uint32 b)[1]", + 64, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + tt, err := NewType(c.Input) + if err != nil { + t.Fatal(err) + } + + size := getTypeSize(tt) + if size != c.Size { + t.Fatalf("expected size %d but found %d", c.Size, size) + } + }) + } +} + +func simpleType(s string) *ArgumentStr { + return &ArgumentStr{ + Type: s, + } +} diff --git a/abi_unit_test.go b/abi_unit_test.go new file mode 100644 index 00000000..8dee2957 --- /dev/null +++ b/abi_unit_test.go @@ -0,0 +1,413 @@ +//go:build all || unit +// +build all unit + +package hedera + +/*- + * + * Hedera Go SDK + * + * Copyright (C) 2020 - 2024 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAbi(t *testing.T) { + inputs, _ := NewType("tuple()") + outputs, _ := NewType("tuple()") + + methodOutput := &Method{ + Name: "abc", + Inputs: inputs, + Outputs: outputs, + } + + inputs, _ = NewType("tuple(address owner)") + outputs, _ = NewType("tuple(uint256 balance)") + balanceFunc := &Method{ + Name: "balanceOf", + Const: true, + Inputs: inputs, + Outputs: outputs, + } + + errorInput, _ := NewType("tuple(address indexed a)") + eventInput, _ := NewType("tuple(address indexed a)") + + cases := []struct { + Input string + Output *ABI + }{ + { + Input: `[ + { + "name": "abc", + "type": "function" + }, + { + "name": "cde", + "type": "event", + "inputs": [ + { + "indexed": true, + "name": "a", + "type": "address" + } + ] + }, + { + "name": "def", + "type": "error", + "inputs": [ + { + "indexed": true, + "name": "a", + "type": "address" + } + ] + }, + { + "type": "function", + "name": "balanceOf", + "constant": true, + "stateMutability": "view", + "payable": false, + "inputs": [ + { + "type": "address", + "name": "owner" + } + ], + "outputs": [ + { + "type": "uint256", + "name": "balance" + } + ] + } + ]`, + Output: &ABI{ + Events: map[string]*Event{ + "cde": { + Name: "cde", + Inputs: eventInput, + }, + }, + Methods: map[string]*Method{ + "abc": methodOutput, + "balanceOf": balanceFunc, + }, + MethodsBySignature: map[string]*Method{ + "abc()": methodOutput, + "balanceOf(address)": balanceFunc, + }, + Errors: map[string]*Error{ + "def": { + Name: "def", + Inputs: errorInput, + }, + }, + }, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + abi, err := NewABI(c.Input) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(abi, c.Output) { + t.Fatal("bad") + } + }) + } +} + +func TestAbiInternalType(t *testing.T) { + const abiStr = `[ + { + "inputs": [ + { + "components": [ + { + "internalType": "address", + "type": "address" + }, + { + "internalType": "uint256[4]", + "type": "uint256[4]" + } + ], + "internalType": "struct X", + "name": "newSet", + "type": "tuple[]" + }, + { + "internalType": "custom_address", + "name": "_to", + "type": "address" + } + ], + "outputs": [], + "name": "transfer", + "type": "function" + } + ]` + + abi, err := NewABI(abiStr) + require.NoError(t, err) + + typ := abi.GetMethod("transfer").Inputs + require.Equal(t, typ.tuple[0].Elem.InternalType(), "struct X") + require.Equal(t, typ.tuple[0].Elem.elem.tuple[0].Elem.InternalType(), "address") + require.Equal(t, typ.tuple[0].Elem.elem.tuple[1].Elem.InternalType(), "uint256[4]") + require.Equal(t, typ.tuple[1].Elem.InternalType(), "custom_address") +} + +func TestAbiPolymorphism(t *testing.T) { + // This ABI contains 2 "transfer" functions (polymorphism) + const polymorphicABI = `[ + { + "inputs": [ + { + "internalType": "address", + "name": "_to", + "type": "address" + }, + { + "internalType": "address", + "name": "_token", + "type": "address" + }, + { + "internalType": "uint256", + "name": "_amount", + "type": "uint256" + } + ], + "name": "transfer", + "outputs": [ + { + "internalType": "bool", + "name": "", + "type": "bool" + } + ], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "address", + "name": "_to", + "type": "address" + }, + { + "internalType": "uint256", + "name": "_amount", + "type": "uint256" + } + ], + "name": "transfer", + "outputs": [ + { + "internalType": "bool", + "name": "", + "type": "bool" + } + ], + "stateMutability": "nonpayable", + "type": "function" + } + ]` + + abi, err := NewABI(polymorphicABI) + if err != nil { + t.Fatal(err) + } + + assert.Len(t, abi.Methods, 2) + assert.Equal(t, abi.GetMethod("transfer").Sig(), "transfer(address,address,uint256)") + assert.Equal(t, abi.GetMethod("transfer0").Sig(), "transfer(address,uint256)") + assert.NotEmpty(t, abi.GetMethodBySignature("transfer(address,address,uint256)")) + assert.NotEmpty(t, abi.GetMethodBySignature("transfer(address,uint256)")) +} + +func TestAbiHumanReadable(t *testing.T) { + cases := []string{ + "constructor(string symbol, string name)", + "function transferFrom(address from, address to, uint256 value)", + "function balanceOf(address owner) view returns (uint256 balance)", + "function balanceOf() view returns ()", + "event Transfer(address indexed from, address indexed to, address value)", + "error InsufficientBalance(address owner, uint256 balance)", + "function addPerson(tuple(string name, uint16 age) person)", + "function addPeople(tuple(string name, uint16 age)[] person)", + "function getPerson(uint256 id) view returns (tuple(string name, uint16 age))", + "event PersonAdded(uint256 indexed id, tuple(string name, uint16 age) person)", + } + vv, err := NewABIFromList(cases) + assert.NoError(t, err) + + // make it nil to not compare it and avoid writing each method twice for the test + vv.MethodsBySignature = nil + + constructorInputs, _ := NewType("tuple(string symbol, string name)") + transferFromInputs, _ := NewType("tuple(address from, address to, uint256 value)") + transferFromOutputs, _ := NewType("tuple()") + balanceOfInputs, _ := NewType("tuple(address owner)") + balanceOfOutputs, _ := NewType("tuple(uint256 balance)") + balanceOf0Inputs, _ := NewType("tuple()") + balanceOf0Outputs, _ := NewType("tuple()") + addPersonInputs, _ := NewType("tuple(tuple(string name, uint16 age) person)") + addPersonOutputs, _ := NewType("tuple()") + addPeopleInputs, _ := NewType("tuple(tuple(string name, uint16 age)[] person)") + addPeopleOutputs, _ := NewType("tuple()") + getPersonInputs, _ := NewType("tuple(uint256 id)") + getPersonOutputs, _ := NewType("tuple(tuple(string name, uint16 age))") + transferEventInputs, _ := NewType("tuple(address indexed from, address indexed to, address value)") + personAddedEventInputs, _ := NewType("tuple(uint256 indexed id, tuple(string name, uint16 age) person)") + errorInputs, _ := NewType("tuple(address owner, uint256 balance)") + + expect := &ABI{ + Constructor: &Method{ + Inputs: constructorInputs, + }, + Methods: map[string]*Method{ + "transferFrom": { + Name: "transferFrom", + Inputs: transferFromInputs, + Outputs: transferFromOutputs, + }, + "balanceOf": { + Name: "balanceOf", + Inputs: balanceOfInputs, + Outputs: balanceOfOutputs, + }, + "balanceOf0": { + Name: "balanceOf", + Inputs: balanceOf0Inputs, + Outputs: balanceOf0Outputs, + }, + "addPerson": { + Name: "addPerson", + Inputs: addPersonInputs, + Outputs: addPersonOutputs, + }, + "addPeople": { + Name: "addPeople", + Inputs: addPeopleInputs, + Outputs: addPeopleOutputs, + }, + "getPerson": { + Name: "getPerson", + Inputs: getPersonInputs, + Outputs: getPersonOutputs, + }, + }, + Events: map[string]*Event{ + "Transfer": { + Name: "Transfer", + Inputs: transferEventInputs, + }, + "PersonAdded": { + Name: "PersonAdded", + Inputs: personAddedEventInputs, + }, + }, + Errors: map[string]*Error{ + "InsufficientBalance": { + Name: "InsufficientBalance", + Inputs: errorInputs, + }, + }, + } + assert.Equal(t, expect, vv) +} + +func TestAbiParseMethodSignature(t *testing.T) { + cases := []struct { + signature string + name string + input string + output string + }{ + { + // both input and output + signature: "function approve(address to) returns (address)", + name: "approve", + input: "tuple(address)", + output: "tuple(address)", + }, + { + // no input + signature: "function approve() returns (address)", + name: "approve", + input: "tuple()", + output: "tuple(address)", + }, + { + // no output + signature: "function approve(address)", + name: "approve", + input: "tuple(address)", + output: "tuple()", + }, + { + // multiline + signature: `function a( + uint256 b, + address[] c + ) + returns + ( + uint256[] d + )`, + name: "a", + input: "tuple(uint256,address[])", + output: "tuple(uint256[])", + }, + } + + for _, c := range cases { + name, input, output, err := parseMethodSignature(c.signature) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, name, c.name) + + if input != nil { + assert.Equal(t, c.input, input.String()) + } else { + assert.Equal(t, c.input, "") + } + + if input != nil { + assert.Equal(t, c.output, output.String()) + } else { + assert.Equal(t, c.output, "") + } + } +} diff --git a/contract_function_parameters.go b/contract_function_parameters.go index acc75cab..c5866126 100644 --- a/contract_function_parameters.go +++ b/contract_function_parameters.go @@ -32,13 +32,6 @@ import ( // Use the builder methods `Add()` to add a parameter. Not all solidity types // are supported out of the box, but the most common types are. The larger variants // of number types require the parameter to be `[]byte`. -// ``` -// AddUint88(math.PaddedBigBytes(n, 88 / 8)) -// ``` -// If you're using `Uint256` specifically you can opt into using -// ``` -// AddUin256(math.PaddedBigBytes(math.U256(n), 32)) -// ``` type ContractFunctionParameters struct { function ContractFunctionSelector arguments []Argument diff --git a/contract_function_parameters_e2e_test.go b/contract_function_parameters_e2e_test.go index b97c08a5..327eda79 100644 --- a/contract_function_parameters_e2e_test.go +++ b/contract_function_parameters_e2e_test.go @@ -31,8 +31,6 @@ import ( "sync" "testing" - "github.com/ethereum/go-ethereum/common" - "github.com/stretchr/testify/require" ) @@ -1662,7 +1660,7 @@ func TestStringArray(t *testing.T) { result, err := contractCal.Execute(env.Client) require.NoError(t, err) parsedResult, _ := result.GetResult("string[]") - strArr := parsedResult.([]interface{})[0].([]string) + strArr := parsedResult.([]string) require.Equal(t, value[0], strArr[0]) require.Equal(t, value[1], strArr[1]) @@ -1697,7 +1695,7 @@ func TestAddressArray(t *testing.T) { require.NoError(t, err) addArr, err := result.GetResult("address[]") require.NoError(t, err) - addresses := addArr.([]interface{})[0].([]common.Address) + addresses := addArr.([]Address) require.Equal(t, value[0], strings.TrimPrefix(addresses[0].String(), "0x")) require.Equal(t, value[1], strings.TrimPrefix(addresses[1].String(), "0x")) @@ -1746,7 +1744,7 @@ func TestBytesArray(t *testing.T) { require.NoError(t, err) bytesArrInterface, err := result.GetResult("bytes[]") require.NoError(t, err) - require.Equal(t, value, bytesArrInterface.([]interface{})[0]) + require.Equal(t, value, bytesArrInterface.([][]uint8)) } @@ -1787,8 +1785,8 @@ func TestBytes32Array(t *testing.T) { require.NoError(t, err) bytes32ArrInterface, err := result.GetResult("bytes32[]") require.NoError(t, err) - require.Equal(t, expected1, bytes32ArrInterface.([]interface{})[0].([][32]byte)[0]) - require.Equal(t, expected2, bytes32ArrInterface.([]interface{})[0].([][32]byte)[1]) + require.Equal(t, expected1, bytes32ArrInterface.([][32]byte)[0]) + require.Equal(t, expected2, bytes32ArrInterface.([][32]byte)[1]) } diff --git a/contract_function_result.go b/contract_function_result.go index fde37fb1..1f488d4c 100644 --- a/contract_function_result.go +++ b/contract_function_result.go @@ -24,9 +24,6 @@ import ( "encoding/binary" "fmt" "math/big" - "strings" - - "github.com/ethereum/go-ethereum/accounts/abi" "github.com/hashgraph/hedera-sdk-go/v2/proto/services" protobuf "google.golang.org/protobuf/proto" @@ -36,7 +33,8 @@ import ( // ContractFunctionResult is a struct which allows users to convert between solidity and Go types, and is typically // returned by `ContractCallQuery` and is present in the transaction records of `ContractExecuteTransaction`. // Use the methods `Get()` to get a parameter. Not all solidity types -// are supported out of the box, but the most common types are. +// are supported out of the box, but the most common types are. The larger variants +// of number types return just the bytes for the integer instead of converting to a big int type. // ``` // contractFunctionResult.GetUint256() // bInt := new(big.Int) @@ -437,15 +435,18 @@ func (result ContractFunctionResult) AsBytes() []byte { // to convert it into the appropriate go type. func (result ContractFunctionResult) GetResult(types string) (interface{}, error) { def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": [{ "type": "%s" }]}]`, types) - abi, err := abi.JSON(strings.NewReader(def)) + abi := ABI{} + err := abi.UnmarshalJSON([]byte(def)) if err != nil { return nil, err } - parsedResult, err := abi.Unpack("method", result.ContractCallResult) + + parsedResult, err := abi.Methods["method"].Decode(result.ContractCallResult) + if err != nil { return nil, err } - return parsedResult, nil + return parsedResult["0"], nil } func extractInt64OrZero(pb *services.ContractFunctionResult) int64 { diff --git a/crypto.go b/crypto.go index 47edc22a..2ee63a4c 100644 --- a/crypto.go +++ b/crypto.go @@ -29,10 +29,12 @@ import ( "math/big" "strings" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" "github.com/hashgraph/hedera-sdk-go/v2/proto/services" "github.com/pkg/errors" "golang.org/x/crypto/pbkdf2" + "golang.org/x/crypto/sha3" protobuf "google.golang.org/protobuf/proto" ) @@ -909,3 +911,21 @@ func (pk PublicKey) VerifyTransaction(tx TransactionInterface) bool { return false } + +func Keccak256Hash(data []byte) (h Hash) { + hash := sha3.NewLegacyKeccak256() + hash.Write(data) + copy(h[:], hash.Sum(nil)) + return h +} + +// Hash represents the 32 byte Keccak256 hash of arbitrary data. +type Hash [32]byte + +func (h Hash) Hex() string { return hexutil.Encode(h[:]) } + +func (h Hash) String() string { + return h.Hex() +} + +func (h Hash) Bytes() []byte { return h[:] } diff --git a/ethereum_transaction_e2e_test.go b/ethereum_transaction_e2e_test.go index 6527c10c..a1e5adac 100644 --- a/ethereum_transaction_e2e_test.go +++ b/ethereum_transaction_e2e_test.go @@ -32,7 +32,7 @@ import ( ) // Testing the signer nonce defined in HIP-844 -// This test should be reworked +// TODO This test should be reworked func TestIntegrationEthereumTransaction(t *testing.T) { // Skip this test because it is flaky with newest version of Local Node t.Skip() @@ -82,9 +82,6 @@ func TestIntegrationEthereumTransaction(t *testing.T) { contractID := *receipt.ContractID // Call data for the smart contract - - // signedBytes := ecdsaPrivateKey.Sign(bytesToSign) - // dummy signed data until test is revisitted rlp, err := hex.DecodeString("02f87082012a022f2f83018000947e3a9eaf9bcc39e2ffa38eb30bf7a93feacbc181880de0b6b3a764000083123456c001a0df48f2efd10421811de2bfb125ab75b2d3c44139c4642837fb1fccce911fd479a01aaf7ae92bee896651dfc9d99ae422a296bf5d9f1ca49b2d96d82b79eb112d66") require.NoError(t, err) @@ -94,10 +91,6 @@ func TestIntegrationEthereumTransaction(t *testing.T) { txDataBytes, err := txData.ToBytes() - // Add signature data to the RLP list for EthereumTransaction submition - // Populate rlp fields - - // 02 is the type of the transaction EIP1559 and should be concatenated to the RLP by service requirement resp, err = NewEthereumTransaction().SetEthereumData(txDataBytes).Execute(env.Client) require.NoError(t, err) diff --git a/go.mod b/go.mod index 5f6168fa..41ef96b5 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/cenkalti/backoff/v4 v4.3.0 github.com/ethereum/go-ethereum v1.13.15 github.com/json-iterator/go v1.1.12 + github.com/mitchellh/mapstructure v1.5.0 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index d94d01de..dc97bcbc 100644 --- a/go.sum +++ b/go.sum @@ -67,7 +67,6 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/holiman/uint256 v1.2.4 h1:jUc4Nk8fm9jZabQuqr2JzednajVmBpC+oiTiXZJEApU= @@ -91,6 +90,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU= github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU=