Skip to content

Commit

Permalink
refactor: merge fieldmap code into one package; (#75)
Browse files Browse the repository at this point in the history
fix: get() panic when no field is set
  • Loading branch information
AsterDY authored Sep 20, 2024
1 parent b8eee87 commit 3c00b4d
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 242 deletions.
55 changes: 38 additions & 17 deletions proto/utils.go → internal/util/fieldmap.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
package proto
/**
* Copyright 2024 ByteDance Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package util

import (
"unsafe"
Expand All @@ -23,7 +39,7 @@ type FieldNameMap struct {
}

// Set sets the field descriptor for the given key
func (ft *FieldNameMap) Set(key string, field *FieldDescriptor) (exist bool) {
func (ft *FieldNameMap) Set(key string, field unsafe.Pointer) (exist bool) {
if len(key) > ft.maxKeyLength {
ft.maxKeyLength = len(key)
}
Expand All @@ -39,32 +55,37 @@ func (ft *FieldNameMap) Set(key string, field *FieldDescriptor) (exist bool) {
}

// Get gets the field descriptor for the given key
func (ft FieldNameMap) Get(k string) *FieldDescriptor {
func (ft FieldNameMap) Get(k string) unsafe.Pointer {
if ft.trie != nil {
return (*FieldDescriptor)(ft.trie.Get(k))
return (unsafe.Pointer)(ft.trie.Get(k))
} else if ft.hash != nil {
return (*FieldDescriptor)(ft.hash.Get(k))
return (unsafe.Pointer)(ft.hash.Get(k))
}
return nil
}

// All returns all field descriptors
func (ft FieldNameMap) All() []*FieldDescriptor {
return *(*[]*FieldDescriptor)(unsafe.Pointer(&ft.all))
func (ft FieldNameMap) All() []caching.Pair {
return ft.all
}

// Size returns the size of the map
func (ft FieldNameMap) Size() int {
if ft.hash != nil {
return ft.hash.Size()
} else {
} else if ft.trie != nil {
return ft.trie.Size()
}
return 0
}

// Build builds the map.
// It will try to build a trie tree if the dispersion of keys is higher enough (min).
func (ft *FieldNameMap) Build() {
if len(ft.all) == 0 {
return
}

var empty unsafe.Pointer

// statistics the distrubution for each position:
Expand Down Expand Up @@ -146,34 +167,34 @@ func (ft *FieldNameMap) Build() {
}

// FieldIDMap is a map from field id to field descriptor
type FieldNumberMap struct {
m []*FieldDescriptor
all []*FieldDescriptor
type FieldIDMap struct {
m []unsafe.Pointer
all []unsafe.Pointer
}

// All returns all field descriptors
func (fd FieldNumberMap) All() (ret []*FieldDescriptor) {
func (fd FieldIDMap) All() (ret []unsafe.Pointer) {
return fd.all
}

// Size returns the size of the map
func (fd FieldNumberMap) Size() int {
func (fd FieldIDMap) Size() int {
return len(fd.m)
}

// Get gets the field descriptor for the given id
func (fd FieldNumberMap) Get(id FieldNumber) *FieldDescriptor {
func (fd FieldIDMap) Get(id int32) unsafe.Pointer {
if int(id) >= len(fd.m) {
return nil
}
return fd.m[id]
}

// Set sets the field descriptor for the given id
func (fd *FieldNumberMap) Set(id FieldNumber, f *FieldDescriptor) {
func (fd *FieldIDMap) Set(id int32, f unsafe.Pointer) {
if int(id) >= len(fd.m) {
len := int(id) + 1
tmp := make([]*FieldDescriptor, len)
tmp := make([]unsafe.Pointer, len)
copy(tmp, fd.m)
fd.m = tmp
}
Expand All @@ -189,4 +210,4 @@ func (fd *FieldNumberMap) Set(id FieldNumber, f *FieldDescriptor) {
}
}
fd.m[id] = f
}
}
35 changes: 35 additions & 0 deletions internal/util/fieldmap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Copyright 2024 ByteDance Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package util

import "testing"

func TestEmptyFieldMap(t *testing.T) {
// empty test
ids := FieldIDMap{}
if ids.Get(1) != nil {
t.Fatalf("expect nil")
}
names := FieldNameMap{}
if names.Get("a") != nil {
t.Fatalf("expect nil")
}
names.Build()
if names.Get("a") != nil {
t.Fatalf("expect nil")
}
}
12 changes: 7 additions & 5 deletions proto/descriptor.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package proto

import "github.com/cloudwego/dynamicgo/internal/util"

type TypeDescriptor struct {
baseId FieldNumber // for LIST/MAP to write field tag by baseId
typ Type
Expand Down Expand Up @@ -113,24 +115,24 @@ func (f *FieldDescriptor) IsList() bool {
type MessageDescriptor struct {
baseId FieldNumber
name string
ids FieldNumberMap
names FieldNameMap // store name and jsonName for FieldDescriptor
ids util.FieldIDMap
names util.FieldNameMap // store name and jsonName for FieldDescriptor
}

func (m *MessageDescriptor) Name() string {
return m.name
}

func (m *MessageDescriptor) ByJSONName(name string) *FieldDescriptor {
return m.names.Get(name)
return (*FieldDescriptor)(m.names.Get(name))
}

func (m *MessageDescriptor) ByName(name string) *FieldDescriptor {
return m.names.Get(name)
return (*FieldDescriptor)(m.names.Get(name))
}

func (m *MessageDescriptor) ByNumber(id FieldNumber) *FieldDescriptor {
return m.ids.Get(id)
return (*FieldDescriptor)(m.ids.Get(int32(id)))
}

func (m *MessageDescriptor) FieldsCount() int {
Expand Down
12 changes: 7 additions & 5 deletions proto/idl.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"errors"
"math"
"unsafe"

"github.com/cloudwego/dynamicgo/internal/util"
"github.com/cloudwego/dynamicgo/meta"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/protoparse"
Expand Down Expand Up @@ -171,8 +173,8 @@ func parseMessage(ctx context.Context, msgDesc *desc.MessageDescriptor, cache co
fields := msgDesc.GetFields()
md := &MessageDescriptor{
baseId: FieldNumber(math.MaxInt32),
ids: FieldNumberMap{},
names: FieldNameMap{},
ids: util.FieldIDMap{},
names: util.FieldNameMap{},
}

ty = &TypeDescriptor{
Expand Down Expand Up @@ -249,9 +251,9 @@ func parseMessage(ctx context.Context, msgDesc *desc.MessageDescriptor, cache co

// add fieldDescriptor to MessageDescriptor
// md.ids[FieldNumber(id)] = fieldDesc
md.ids.Set(FieldNumber(id), fieldDesc)
md.names.Set(name, fieldDesc)
md.names.Set(jsonName, fieldDesc)
md.ids.Set(int32(id), unsafe.Pointer(fieldDesc))
md.names.Set(name, unsafe.Pointer(fieldDesc))
md.names.Set(jsonName, unsafe.Pointer(fieldDesc))
}
md.names.Build()

Expand Down
15 changes: 9 additions & 6 deletions thrift/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package thrift

import (
"fmt"
"unsafe"

"github.com/cloudwego/dynamicgo/http"
"github.com/cloudwego/dynamicgo/internal/util"
"github.com/cloudwego/thriftgo/parser"
)

Expand Down Expand Up @@ -166,8 +168,8 @@ func (d TypeDescriptor) Struct() *StructDescriptor {
type StructDescriptor struct {
baseID FieldID
name string
ids FieldIDMap
names FieldNameMap
ids util.FieldIDMap
names util.FieldNameMap
requires RequiresBitmap
hmFields []*FieldDescriptor
annotations []parser.Annotation
Expand Down Expand Up @@ -212,12 +214,13 @@ func (s StructDescriptor) Name() string {

// Len returns the number of fields in the struct
func (s StructDescriptor) Len() int {
return len(s.ids.all)
return len(s.ids.All())
}

// Fields returns all fields in the struct
func (s StructDescriptor) Fields() []*FieldDescriptor {
return s.ids.All()
ret := s.ids.All()
return *(*[]*FieldDescriptor)(unsafe.Pointer(&ret))
}

// Fields returns requireness bitmap in the struct.
Expand All @@ -232,15 +235,15 @@ func (s StructDescriptor) Annotations() []parser.Annotation {

// FieldById finds the field by field id
func (s StructDescriptor) FieldById(id FieldID) *FieldDescriptor {
return s.ids.Get(id)
return (*FieldDescriptor)(s.ids.Get(int32(id)))
}

// FieldByName finds the field by key
//
// NOTICE: Options.MapFieldWay can influence the behavior of this method.
// ep: if Options.MapFieldWay is MapFieldWayName, then field names should be used as key.
func (s StructDescriptor) FieldByKey(k string) (field *FieldDescriptor) {
return s.names.Get(k)
return (*FieldDescriptor)(s.names.Get(k))
}

// FieldID is used to identify a field in a struct
Expand Down
Loading

0 comments on commit 3c00b4d

Please sign in to comment.