Skip to content

Commit

Permalink
feat: ent-gen-proto支持自定义时间类型
Browse files Browse the repository at this point in the history
  • Loading branch information
TBXark committed Nov 9, 2024
1 parent 08abc7a commit 0ec2983
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 20 deletions.
1 change: 1 addition & 0 deletions contrib/ent-gen-proto/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
proto
66 changes: 46 additions & 20 deletions contrib/ent-gen-proto/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@ func main() {
var (
schemaPath = flag.String("path", "./schema", "path to schema directory")
protoDir = flag.String("proto", "./proto", "path to proto directory")
ignoreOptional = flag.Bool("ignore-optional", true, "ignore optional, use zero value instead")
autoAddAnnotation = flag.Bool("auto-annotation", true, "auto add annotation to the schema")
enumUseRawType = flag.Bool("enum-raw-type", true, "use string for enum")
ignoreOptional = flag.Bool("ignore_optional", true, "ignore optional, use zero value instead")
autoAddAnnotation = flag.Bool("auto_annotation", true, "auto add annotation to the schema")
enumUseRawType = flag.Bool("enum_raw_type", true, "use string for enum")
timeUseProtoType = flag.String("time_proto_type", "google.protobuf.Timestamp", "use proto type for time.Time, one of int64, string, google.protobuf.Timestamp")
help = flag.Bool("help", false, "show help")
)
flag.Parse()
if *help {
flag.PrintDefaults()
return
}
runProtoGen(*schemaPath, *protoDir, *ignoreOptional, *autoAddAnnotation, *enumUseRawType)
runProtoGen(*schemaPath, *protoDir, *ignoreOptional, *autoAddAnnotation, *enumUseRawType, *timeUseProtoType)
}

func runProtoGen(schemaPath string, protoDir string, ignoreOptional, autoAddAnnotation, enumUseRawType bool) {
func runProtoGen(schemaPath string, protoDir string, ignoreOptional, autoAddAnnotation, enumUseRawType bool, timeUseProtoType string) {
abs, err := filepath.Abs(schemaPath)
if err != nil {
log.Fatalf("entproto: failed getting absolute path: %v", err)
Expand All @@ -46,7 +47,7 @@ func runProtoGen(schemaPath string, protoDir string, ignoreOptional, autoAddAnno
}
if autoAddAnnotation {
for i := 0; i < len(graph.Nodes); i++ {
addAnnotationForNode(graph.Nodes[i], enumUseRawType, ignoreOptional)
addAnnotationForNode(graph.Nodes[i], enumUseRawType, ignoreOptional, timeUseProtoType)
}
}
extension, err := entproto.NewExtension(
Expand All @@ -62,7 +63,7 @@ func runProtoGen(schemaPath string, protoDir string, ignoreOptional, autoAddAnno
}
}

func addAnnotationForNode(node *gen.Type, enumUseRawType bool, ignoreOptional bool) {
func addAnnotationForNode(node *gen.Type, enumUseRawType bool, ignoreOptional bool, timeUseProtoType string) {
if node.Annotations == nil {
node.Annotations = make(map[string]interface{}, 1)
}
Expand All @@ -89,20 +90,9 @@ func addAnnotationForNode(node *gen.Type, enumUseRawType bool, ignoreOptional bo
}
fieldID++
if fd.IsEnum() {
if enumUseRawType {
if fd.HasGoType() {
fd.Type.Type = reflectKind2FieldType[fd.Type.RType.Kind]
} else {
fd.Type.Type = field.TypeString
}
} else {
enums := make(map[string]int32, len(fd.Enums))
for index, enum := range fd.Enums {
enums[enum.Value] = int32(index) + 1
}
fd.Annotations[entproto.EnumAnnotation] = entproto.Enum(enums, entproto.OmitFieldPrefix())
}
fixEnumType(fd, enumUseRawType)
}
fd.Type.Type = fixFieldType(fd.Type.Type, timeUseProtoType)
fd.Annotations[entproto.FieldAnnotation] = entproto.Field(fieldID)
if fd.Optional && ignoreOptional {
fd.Optional = false
Expand All @@ -111,6 +101,42 @@ func addAnnotationForNode(node *gen.Type, enumUseRawType bool, ignoreOptional bo
}
}

func fixEnumType(fd *gen.Field, enumUseRawType bool) {
if enumUseRawType {
if fd.HasGoType() {
fd.Type.Type = reflectKind2FieldType[fd.Type.RType.Kind]
} else {
fd.Type.Type = field.TypeString
}
} else {
enums := make(map[string]int32, len(fd.Enums))
for index, enum := range fd.Enums {
enums[enum.Value] = int32(index) + 1
}
fd.Annotations[entproto.EnumAnnotation] = entproto.Enum(enums, entproto.OmitFieldPrefix())
}
}

func fixFieldType(t field.Type, timeType string) field.Type {
switch t {
case field.TypeJSON, field.TypeOther, field.TypeInvalid:
return field.TypeBytes // JSON and Other types are mapped to bytes.
case field.TypeUUID:
return field.TypeString
case field.TypeTime:
switch timeType {
case "int64":
return field.TypeInt64
case "string":
return field.TypeString
default:
return field.TypeTime
}
default:
return t
}
}

var reflectKind2FieldType = map[reflect.Kind]field.Type{
reflect.Bool: field.TypeBool,
reflect.Int: field.TypeInt,
Expand Down
57 changes: 57 additions & 0 deletions contrib/ent-gen-proto/schema/test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,69 @@
package schema

import (
"database/sql/driver"
"entgo.io/ent"
"entgo.io/ent/schema/field"
"github.com/google/uuid"
"time"
)

type Extra struct {
Key string `json:"key"`
Vals string `json:"vals"`
}

type Level int

const (
Unknown Level = iota
Low
High
)

func (p Level) String() string {
switch p {
case Low:
return "LOW"
case High:
return "HIGH"
default:
return "UNKNOWN"
}
}

// Values provides list valid values for Enum.
func (Level) Values() []string {
return []string{Unknown.String(), Low.String(), High.String()}
}

// Value provides the DB a string from int.
func (p Level) Value() (driver.Value, error) {
return p.String(), nil
}

// Scan tells our code how to read the enum into our type.
func (p *Level) Scan(val any) error {
var s string
switch v := val.(type) {
case nil:
return nil
case string:
s = v
case []uint8:
s = string(v)
}
switch s {
case "LOW":
*p = Low
case "HIGH":
*p = High
default:
*p = Unknown
}
return nil
}

type User struct {
ent.Schema
}
Expand All @@ -23,5 +77,8 @@ func (User) Fields() []ent.Field {
field.String("phone").Optional().Default("").Comment("手机号").MaxLen(20),
field.Uint64("flags").Default(0).Comment("标记位"),
field.JSON("extra", Extra{}).Optional().Comment("额外信息"),
field.UUID("uuid", uuid.UUID{}).Default(uuid.New).Comment("UUID"),
field.Enum("level").GoType(Level(0)),
field.Time("created_at").Immutable().Default(time.Now()).Comment("创建时间"),
}
}

0 comments on commit 0ec2983

Please sign in to comment.