Skip to content

Commit

Permalink
fix: 不覆盖原有entproto.EnumAnnotation
Browse files Browse the repository at this point in the history
  • Loading branch information
TBXark committed Nov 11, 2024
1 parent 9a81066 commit a875b1a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
40 changes: 19 additions & 21 deletions contrib/ent-gen-proto/entgenproto/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,11 @@ import (
"github.com/mitchellh/mapstructure"
"google.golang.org/protobuf/types/descriptorpb"
"log"
"path/filepath"
"reflect"
"sort"
_ "unsafe"
)

//go:linkname generate entgo.io/contrib/entproto.(*Extension).generate
func generate(extension *entproto.Extension, g *gen.Graph) error

//go:linkname wktsPaths entgo.io/contrib/entproto.wktsPaths
var wktsPaths map[string]string

type Options struct {
SchemaPath string
ProtoDir string
Expand All @@ -42,14 +35,13 @@ type ProtoPackage struct {
Types []string
}

//go:linkname generate entgo.io/contrib/entproto.(*Extension).generate
func generate(extension *entproto.Extension, g *gen.Graph) error

func Generate(options *Options) {
injectProtoPackages(options.ProtoPackages)
abs, err := filepath.Abs(options.SchemaPath)
if err != nil {
log.Fatalf("entproto: failed getting absolute path: %v", err)
}
graph, err := entc.LoadGraph(options.SchemaPath, &gen.Config{
Target: filepath.Dir(abs),
Target: options.SchemaPath,
})
if err != nil {
log.Fatalf("entproto: failed loading ent graph: %v", err)
Expand Down Expand Up @@ -81,7 +73,7 @@ func addAnnotationForNode(node *gen.Type, options *Options) {
}
// If the node does not have the message annotation, add it.
node.Annotations[entproto.MessageAnnotation] = entproto.Message()
idGenerator := &fileIDGenerator{exist: extractExistFieldID(node)}
idGenerator := &fieldIDGenerator{exist: extractExistFieldID(node)}
sort.Slice(node.Fields, func(i, j int) bool {
if node.Fields[i].Position.MixedIn != node.Fields[j].Position.MixedIn {
// MixedIn fields should be at the end of the list.
Expand All @@ -95,7 +87,7 @@ func addAnnotationForNode(node *gen.Type, options *Options) {
}
}

func addAnnotationForField(fd *gen.Field, idGenerator *fileIDGenerator, options *Options) {
func addAnnotationForField(fd *gen.Field, idGenerator *fieldIDGenerator, options *Options) {
if fd.Annotations == nil {
fd.Annotations = make(map[string]interface{}, 1)
}
Expand All @@ -111,12 +103,12 @@ func addAnnotationForField(fd *gen.Field, idGenerator *fileIDGenerator, options
fixEnumType(fd, options.EnumUseRawType)
case field.TypeJSON:
if _, ok := entprotoSupportJSONType[fd.Type.RType.Ident]; !ok {
nt, opts := fixUnsupportedType(fd.Type.Type, options)
nt, opts := fixUnsupportedType(fd.Type.Type, options.UnsupportedProtoType)
fd.Type.Type = nt
fieldOptions = append(fieldOptions, opts...)
}
case field.TypeInvalid:
nt, opts := fixUnsupportedType(fd.Type.Type, options)
nt, opts := fixUnsupportedType(fd.Type.Type, options.UnsupportedProtoType)
fd.Type.Type = nt
fieldOptions = append(fieldOptions, opts...)
case field.TypeTime:
Expand Down Expand Up @@ -144,8 +136,8 @@ func addAnnotationForField(fd *gen.Field, idGenerator *fileIDGenerator, options
}
}

func fixUnsupportedType(t field.Type, options *Options) (field.Type, []entproto.FieldOption) {
switch options.UnsupportedProtoType {
func fixUnsupportedType(t field.Type, unsupportedProtoType string) (field.Type, []entproto.FieldOption) {
switch unsupportedProtoType {
case "google.protobuf.Any":
return t, []entproto.FieldOption{
entproto.Type(descriptorpb.FieldDescriptorProto_TYPE_MESSAGE),
Expand All @@ -165,6 +157,9 @@ func fixUnsupportedType(t field.Type, options *Options) (field.Type, []entproto.
}

func fixEnumType(fd *gen.Field, enumUseRawType bool) {
if fd.Annotations[entproto.EnumAnnotation] != nil {
return
}
if enumUseRawType {
if fd.HasGoType() {
fd.Type.Type = reflectKind2FieldType[fd.Type.RType.Kind]
Expand Down Expand Up @@ -241,12 +236,12 @@ var entprotoSupportJSONType = map[string]struct{}{
"[]string": {},
}

type fileIDGenerator struct {
type fieldIDGenerator struct {
current int
exist map[int]struct{}
}

func (f *fileIDGenerator) Next() (int, error) {
func (f *fieldIDGenerator) Next() (int, error) {
f.current++
for {
if _, ok := f.exist[f.current]; ok {
Expand All @@ -261,14 +256,17 @@ func (f *fileIDGenerator) Next() (int, error) {
return f.current, nil
}

func (f *fileIDGenerator) MustNext() int {
func (f *fieldIDGenerator) MustNext() int {
num, err := f.Next()
if err != nil {
panic(err)
}
return num
}

//go:linkname wktsPaths entgo.io/contrib/entproto.wktsPaths
var wktsPaths map[string]string

func injectProtoPackages(pkg []ProtoPackage) {
wktsPaths["google.protobuf.Any"] = "google/protobuf/any.proto"
wktsPaths["google.protobuf.Struct"] = "google/protobuf/struct.proto"
Expand Down
2 changes: 1 addition & 1 deletion contrib/ent-gen-proto/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func main() {
autoAddAnnotation = flag.Bool("auto_annotation", true, "auto add annotation to the schema")
enumUseRawType = flag.Bool("enum_raw_type", true, "use string for enum")

importProto = flag.String("import_proto", "google/protobuf/any.proto,google.protobuf,Any;", "import proto, format: path,package,type1,type2;")
importProto = flag.String("import_proto", "google/protobuf/any.proto,google.protobuf,Any;", "import proto, format: path1,package1,type1,type2;path2,package2,type3,type4;")

help = flag.Bool("help", false, "show help")
)
Expand Down

0 comments on commit a875b1a

Please sign in to comment.