Skip to content

Commit

Permalink
fix: 跳过ent proto已设置annotation的field
Browse files Browse the repository at this point in the history
  • Loading branch information
TBXark committed Nov 10, 2024
1 parent 4708e67 commit 3911546
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 29 deletions.
125 changes: 97 additions & 28 deletions contrib/ent-gen-proto/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"entgo.io/ent/entc/gen"
"entgo.io/ent/schema/field"
"flag"
"fmt"
"log"
"path/filepath"
"reflect"
"sort"
"strconv"
_ "unsafe"
)

Expand Down Expand Up @@ -67,38 +69,62 @@ func addAnnotationForNode(node *gen.Type, enumUseRawType bool, ignoreOptional bo
if node.Annotations == nil {
node.Annotations = make(map[string]interface{}, 1)
}
if node.Annotations[entproto.MessageAnnotation] == nil {
// If the node does not have the message annotation, add it.
node.Annotations[entproto.MessageAnnotation] = entproto.Message()
fieldID := 1
if node.ID.Annotations == nil {
node.ID.Annotations = make(map[string]interface{}, 1)
}
node.ID.Annotations[entproto.FieldAnnotation] = entproto.Field(fieldID)
sort.Slice(node.Fields, func(i, j int) bool {
if node.Fields[i].Position.MixedIn != node.Fields[j].Position.MixedIn {
// MixedIn fields should be at the end of the list.
return !node.Fields[i].Position.MixedIn
}
return node.Fields[i].Position.Index < node.Fields[j].Position.Index
})
if node.Annotations[entproto.MessageAnnotation] != nil {
return
}

for j := 0; j < len(node.Fields); j++ {
fd := node.Fields[j]
if fd.Annotations == nil {
fd.Annotations = make(map[string]interface{}, 1)
}
fieldID++
if fd.IsEnum() {
fixEnumType(fd, enumUseRawType)
}
fd.Type.Type = fixFieldType(fd, timeUseProtoType)
fd.Annotations[entproto.FieldAnnotation] = entproto.Field(fieldID)
if fd.Optional && ignoreOptional {
fd.Optional = false
maxExistNum := 0
existNums := map[int]struct{}{}
for _, fd := range node.Fields {
if fd.Annotations != nil {
if obj, exist := fd.Annotations[entproto.FieldAnnotation]; exist {
if dict, ok := obj.(map[string]interface{}); ok {
if num, hasNum := dict["Number"]; hasNum {
if numInt, err := convertToInt(num); err == nil {
existNums[numInt] = struct{}{}
if numInt > maxExistNum {
maxExistNum = numInt
}
}
}
}
}
}
}

idGenerator := &fileIDGenerator{exist: existNums}

// If the node does not have the message annotation, add it.
node.Annotations[entproto.MessageAnnotation] = entproto.Message()
if node.ID.Annotations == nil {
node.ID.Annotations = make(map[string]interface{}, 1)
}
node.ID.Annotations[entproto.FieldAnnotation] = entproto.Field(idGenerator.MustNext())
sort.Slice(node.Fields, func(i, j int) bool {
if node.Fields[i].Position.MixedIn != node.Fields[j].Position.MixedIn {
// MixedIn fields should be at the end of the list.
return !node.Fields[i].Position.MixedIn
}
return node.Fields[i].Position.Index < node.Fields[j].Position.Index
})

for j := 0; j < len(node.Fields); j++ {
fd := node.Fields[j]
if fd.Annotations == nil {
fd.Annotations = make(map[string]interface{}, 1)
}
if fd.IsEnum() {
fixEnumType(fd, enumUseRawType)
}
if fd.Annotations[entproto.FieldAnnotation] != nil {
continue
}
fd.Type.Type = fixFieldType(fd, timeUseProtoType)
fd.Annotations[entproto.FieldAnnotation] = entproto.Field(idGenerator.MustNext())
if fd.Optional && ignoreOptional {
fd.Optional = false
}
}
}

func fixEnumType(fd *gen.Field, enumUseRawType bool) {
Expand Down Expand Up @@ -196,3 +222,46 @@ var buildInTypeSlice = map[string]struct{}{
"[]string": {},
"[]bool": {},
}

type fileIDGenerator struct {
current int
exist map[int]struct{}
}

func (f *fileIDGenerator) Next() (int, error) {
f.current++
for {
if _, ok := f.exist[f.current]; ok {
f.current++
continue
}
if f.current > 536870911 {
return 0, fmt.Errorf("entproto: field number exceed the maximum value 536870911")
}
break
}
return f.current, nil
}

func (f *fileIDGenerator) MustNext() int {
num, err := f.Next()
if err != nil {
panic(err)
}
return num
}

func convertToInt(val any) (int, error) {
switch v := val.(type) {
case int:
return v, nil
case float64:
return int(v), nil
case float32:
return int(v), nil
case string:
return strconv.Atoi(v)
default:
return 0, fmt.Errorf("entproto: unable to convert %v to int", val)
}
}
6 changes: 5 additions & 1 deletion contrib/ent-gen-proto/schema/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package schema

import (
"database/sql/driver"
"entgo.io/contrib/entproto"
"entgo.io/ent"
"entgo.io/ent/schema/field"
"github.com/google/uuid"
"google.golang.org/protobuf/types/descriptorpb"
"time"
)

Expand Down Expand Up @@ -79,7 +81,9 @@ func (User) Fields() []ent.Field {
field.String("phone").Optional().Default("").Comment("手机号").MaxLen(20),
field.Uint64("flags").Default(0).Comment("标记位"),
field.JSON("roles", []string{}).Optional().Comment("角色列表"),
field.JSON("extra", Extra{}).Optional().Comment("额外信息"),
field.JSON("extra", Extra{}).Optional().Comment("额外信息").Annotations(
entproto.Field(6, entproto.Type(descriptorpb.FieldDescriptorProto_TYPE_BYTES)),
),
field.UUID("uuid", uuid.UUID{}).Default(uuid.New).Comment("UUID"),
field.Enum("level").GoType(Level(0)),
field.Time("created_at").Immutable().Default(time.Now()).Comment("创建时间"),
Expand Down

0 comments on commit 3911546

Please sign in to comment.