From 69008e86eadef4fbf4f70bff3782c63c525f2332 Mon Sep 17 00:00:00 2001
From: Seunghyun Hwang <lesomnus@gmail.com>
Date: Sat, 11 May 2024 00:00:37 +0900
Subject: [PATCH] support separate package

---
 entproto/cmd/protoc-gen-entgrpc/converter.go  |  4 +-
 entproto/cmd/protoc-gen-entgrpc/main.go       | 57 +++++++++++++------
 .../protoc-gen-entgrpc/template/enums.tmpl    |  6 +-
 .../template/method_get.tmpl                  |  9 ++-
 .../template/method_list.tmpl                 |  9 ++-
 .../template/method_mutate.tmpl               |  7 ++-
 .../protoc-gen-entgrpc/template/service.tmpl  |  6 +-
 .../protoc-gen-entgrpc/template/to_proto.tmpl | 19 +++++--
 8 files changed, 77 insertions(+), 40 deletions(-)

diff --git a/entproto/cmd/protoc-gen-entgrpc/converter.go b/entproto/cmd/protoc-gen-entgrpc/converter.go
index 1f12078b7..d3adc8e8b 100644
--- a/entproto/cmd/protoc-gen-entgrpc/converter.go
+++ b/entproto/cmd/protoc-gen-entgrpc/converter.go
@@ -65,7 +65,7 @@ func (g *serviceGenerator) newConverter(fld *entproto.FieldMappingDescriptor) (*
 	case dpb.FieldDescriptorProto_TYPE_ENUM:
 		enumName := fld.PbFieldDescriptor.GetEnumType().GetName()
 		method := fmt.Sprintf("toProto%s_%s", g.EntType.Name, enumName)
-		out.ToProtoConstructor = g.File.GoImportPath.Ident(method)
+		out.ToProtoConstructor = g.EntgrpcPackage.Ident(method)
 	case dpb.FieldDescriptorProto_TYPE_MESSAGE:
 		if fld.IsEdgeField {
 			if err := basicTypeConversion(fld.EdgeIDPbStructFieldDesc(), fld.EntEdge.Type.ID, out); err != nil {
@@ -110,7 +110,7 @@ func (g *serviceGenerator) newConverter(fld *entproto.FieldMappingDescriptor) (*
 	case efld.IsEnum():
 		enumName := fld.PbFieldDescriptor.GetEnumType().GetName()
 		method := fmt.Sprintf("toEnt%s_%s", g.EntType.Name, enumName)
-		out.ToEntConstructor = g.File.GoImportPath.Ident(method)
+		out.ToEntConstructor = g.EntgrpcPackage.Ident(method)
 	case efld.IsJSON():
 		switch efld.Type.Ident {
 		case "[]string":
diff --git a/entproto/cmd/protoc-gen-entgrpc/main.go b/entproto/cmd/protoc-gen-entgrpc/main.go
index 5e98698ed..ee39e7646 100644
--- a/entproto/cmd/protoc-gen-entgrpc/main.go
+++ b/entproto/cmd/protoc-gen-entgrpc/main.go
@@ -30,19 +30,29 @@ import (
 )
 
 var (
-	entSchemaPath *string
-	snake         = gen.Funcs["snake"].(func(string) string)
-	status        = protogen.GoImportPath("google.golang.org/grpc/status")
-	codes         = protogen.GoImportPath("google.golang.org/grpc/codes")
+	entgrpcPackage *string
+	entSchemaPath  *string
+	entPackagePath *string
+
+	snake  = gen.Funcs["snake"].(func(string) string)
+	status = protogen.GoImportPath("google.golang.org/grpc/status")
+	codes  = protogen.GoImportPath("google.golang.org/grpc/codes")
 )
 
 func main() {
 	var flags flag.FlagSet
+	entgrpcPackage = flags.String("package", "", "package path to be generated")
 	entSchemaPath = flags.String("schema_path", "", "ent schema path")
+	entPackagePath = flags.String("entity_package", "", "ent entity package path")
 	protogen.Options{
 		ParamFunc: flags.Set,
 	}.Run(func(plg *protogen.Plugin) error {
-		g, err := entc.LoadGraph(*entSchemaPath, &gen.Config{})
+		conf := gen.Config{}
+		if entPackagePath != nil {
+			conf.Package = *entPackagePath
+		}
+
+		g, err := entc.LoadGraph(*entSchemaPath, &conf)
 		if err != nil {
 			return err
 		}
@@ -99,19 +109,30 @@ func newServiceGenerator(plugin *protogen.Plugin, file *protogen.File, graph *ge
 	if err != nil {
 		return nil, err
 	}
+
+	entgrpcImportPath := file.GoImportPath
+	entgrpcPackageName := file.GoPackageName
+	if entgrpcPackage != nil {
+		entgrpcImportPath = protogen.GoImportPath(*entgrpcPackage)
+		entgrpcPackageName = protogen.GoPackageName(path.Base(*entgrpcPackage))
+	}
+
 	filename := file.GeneratedFilenamePrefix + "_" + snake(service.GoName) + ".go"
-	g := plugin.NewGeneratedFile(filename, file.GoImportPath)
+	g := plugin.NewGeneratedFile(filename, entgrpcImportPath)
 	fieldMap, err := adapter.FieldMap(typ.Name)
 	if err != nil {
 		return nil, err
 	}
+
 	return &serviceGenerator{
-		GeneratedFile: g,
-		EntPackage:    protogen.GoImportPath(graph.Config.Package),
-		File:          file,
-		Service:       service,
-		EntType:       typ,
-		FieldMap:      fieldMap,
+		GeneratedFile:      g,
+		EntgrpcPackageName: entgrpcPackageName,
+		EntgrpcPackage:     entgrpcImportPath,
+		EntPackage:         protogen.GoImportPath(graph.Config.Package),
+		File:               file,
+		Service:            service,
+		EntType:            typ,
+		FieldMap:           fieldMap,
 	}, nil
 }
 
@@ -161,11 +182,13 @@ func (g *serviceGenerator) generate() error {
 type (
 	serviceGenerator struct {
 		*protogen.GeneratedFile
-		EntPackage protogen.GoImportPath
-		File       *protogen.File
-		Service    *protogen.Service
-		EntType    *gen.Type
-		FieldMap   entproto.FieldMap
+		EntgrpcPackageName protogen.GoPackageName
+		EntgrpcPackage     protogen.GoImportPath
+		EntPackage         protogen.GoImportPath
+		File               *protogen.File
+		Service            *protogen.Service
+		EntType            *gen.Type
+		FieldMap           entproto.FieldMap
 	}
 	methodInput struct {
 		G      *serviceGenerator
diff --git a/entproto/cmd/protoc-gen-entgrpc/template/enums.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/enums.tmpl
index 40a2c3d6b..7831bab7f 100644
--- a/entproto/cmd/protoc-gen-entgrpc/template/enums.tmpl
+++ b/entproto/cmd/protoc-gen-entgrpc/template/enums.tmpl
@@ -14,14 +14,14 @@
         }
 
         func toProto{{ $pbEnumIdent.GoName }} (e {{ ident $entEnumIdent }}) {{ ident $pbEnumIdent }} {
-            if v, ok := {{ $pbEnumIdent.GoName }}_value[{{ qualify "strings" "ToUpper" }}({{ if not $omitPrefix }}"{{ $enumFieldPrefix }}" + {{ end }}protoIdentNormalize{{ $pbEnumIdent.GoName }}(string(e)))]; ok {
+            if v, ok := {{ ident $pbEnumIdent }}_value[{{ qualify "strings" "ToUpper" }}({{ if not $omitPrefix }}"{{ $enumFieldPrefix }}" + {{ end }}protoIdentNormalize{{ $pbEnumIdent.GoName }}(string(e)))]; ok {
                 return {{ $pbEnumIdent | ident }}(v)
             }
             return {{ $pbEnumIdent | ident }}(0)
         }
 
         func toEnt{{ $pbEnumIdent.GoName }}(e {{ ident $pbEnumIdent }}) {{ ident $entEnumIdent  }} {
-            if v, ok := {{ $pbEnumIdent.GoName }}_name[int32(e)]; ok {
+            if v, ok := {{ ident $pbEnumIdent }}_name[int32(e)]; ok {
                 entVal := map[string]string{
                 {{- range .EntField.Enums }}
                     "{{ if not $omitPrefix }}{{ $enumFieldPrefix }}{{ end }}{{ protoIdentNormalize .Value }}": "{{ .Value }}",
@@ -32,4 +32,4 @@
             return ""
         }
     {{ end}}
-{{ end }}
\ No newline at end of file
+{{ end }}
diff --git a/entproto/cmd/protoc-gen-entgrpc/template/method_get.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/method_get.tmpl
index 4191e83e1..533829f0a 100644
--- a/entproto/cmd/protoc-gen-entgrpc/template/method_get.tmpl
+++ b/entproto/cmd/protoc-gen-entgrpc/template/method_get.tmpl
@@ -1,5 +1,6 @@
 {{- /*gotype: entgo.io/contrib/entproto/cmd/protoc-gen-entgrpc.methodInput*/ -}}
 {{ define "method_get" }}
+    {{ $importPath := .G.File.GoImportPath }}
     {{- $idField := .G.FieldMap.ID -}}
     {{- $varName := $idField.EntField.Name -}}
     {{- $inputName := .Method.Input.GoIdent.GoName -}}
@@ -9,9 +10,11 @@
     )
     {{- template "field_to_ent" dict "Field" $idField "VarName" $idField.EntField.Name "Ident" (print "req.Get" $idField.PbStructField "()") }}
     switch req.GetView() {
-        case {{ $inputName }}_VIEW_UNSPECIFIED, {{ $inputName }}_BASIC:
+        case {{ $importPath.Ident ( print $inputName "_VIEW_UNSPECIFIED" ) | ident }}:
+            fallthrough
+        case {{ $importPath.Ident ( print $inputName "_BASIC" ) | ident }}:
             get, err = svc.client.{{ .G.EntType.Name }}.Get(ctx, {{ $varName }})
-        case {{ $inputName }}_WITH_EDGE_IDS:
+        case {{ $importPath.Ident ( print $inputName "_WITH_EDGE_IDS" ) | ident }}:
             get, err = svc.client.{{ .G.EntType.Name }}.Query().
             Where({{ qualify (print (unquote .G.EntPackage.String) "/" .G.EntType.Package) "ID" }}({{ $varName }})).
             {{ range .G.FieldMap.Edges }}
@@ -32,4 +35,4 @@
         default:
             return nil, {{ statusErrf "Internal" "internal error: %s" "err" }}
     }
-{{ end }}
\ No newline at end of file
+{{ end }}
diff --git a/entproto/cmd/protoc-gen-entgrpc/template/method_list.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/method_list.tmpl
index 053233f89..79764208e 100644
--- a/entproto/cmd/protoc-gen-entgrpc/template/method_list.tmpl
+++ b/entproto/cmd/protoc-gen-entgrpc/template/method_list.tmpl
@@ -1,5 +1,6 @@
 {{- /*gotype: entgo.io/contrib/entproto/cmd/protoc-gen-entgrpc.methodInput*/ -}}
 {{ define "method_list" }}
+    {{ $importPath := .G.File.GoImportPath }}
     {{- $inputName := .Method.Input.GoIdent.GoName -}}
     var (
         err error
@@ -40,9 +41,11 @@
             Where({{ qualify (print (unquote .G.EntPackage.String) "/" .G.EntType.Package) "IDLTE" }}(pageToken))
     }
     switch req.GetView() {
-    case {{ $inputName }}_VIEW_UNSPECIFIED, {{ $inputName }}_BASIC:
+    case {{ $importPath.Ident ( print $inputName "_VIEW_UNSPECIFIED" ) | ident }}:
+        fallthrough
+    case {{ $importPath.Ident ( print $inputName "_BASIC" ) | ident }}:
         entList, err = listQuery.All(ctx)
-    case {{ $inputName }}_WITH_EDGE_IDS:
+    case {{ $importPath.Ident ( print $inputName "_WITH_EDGE_IDS" ) | ident }}:
         entList, err = listQuery.
             {{ range .G.FieldMap.Edges }}
                 {{- $et := .EntEdge.Type -}}
@@ -64,7 +67,7 @@
         if err != nil {
             return nil, {{ statusErrf "Internal" "internal error: %s" "err" }}
         }
-        return &List{{ .G.EntType.Name }}Response{
+        return &{{ $importPath.Ident ( print "List" .G.EntType.Name "Response" ) | ident }}{
             {{ .G.EntType.Name }}List: protoList,
             NextPageToken: nextPageToken,
         }, nil
diff --git a/entproto/cmd/protoc-gen-entgrpc/template/method_mutate.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/method_mutate.tmpl
index b68f8a86f..9a3ac8330 100644
--- a/entproto/cmd/protoc-gen-entgrpc/template/method_mutate.tmpl
+++ b/entproto/cmd/protoc-gen-entgrpc/template/method_mutate.tmpl
@@ -38,10 +38,11 @@
 
 {{ define "create_builder_func" }}
     {{- $entType  := .Method.G.EntType.Name -}}
+    {{- $pbType := .Method.G.File.GoImportPath.Ident $entType | ident -}}
     {{- $inputVar := camel $entType -}}
-    {{- $outputType := printf "%s%s" $entType "Create" -}}
+    {{- $outputType := .Method.G.EntPackage.Ident ( printf "%s%s" $entType "Create" ) | ident  -}}
 
-    func (svc *{{ .ServiceName }}) createBuilder({{ $inputVar }} *{{ $entType }}) (*ent.{{ $outputType }}, error) {
+    func (svc *{{ .ServiceName }}) createBuilder({{ $inputVar }} *{{ $pbType }}) (*{{ $outputType }}, error) {
         m := svc.client.{{ $entType }}.Create()
         {{- template "mutate_helper" .Method -}}
         return m, nil
@@ -85,4 +86,4 @@
             }
         {{- end }}
     {{- end }}
-{{ end }}
\ No newline at end of file
+{{ end }}
diff --git a/entproto/cmd/protoc-gen-entgrpc/template/service.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/service.tmpl
index d4b916672..aef1645e7 100644
--- a/entproto/cmd/protoc-gen-entgrpc/template/service.tmpl
+++ b/entproto/cmd/protoc-gen-entgrpc/template/service.tmpl
@@ -1,12 +1,12 @@
 {{- /*gotype: entgo.io/contrib/entproto/cmd/protoc-gen-entgrpc.serviceGenerator*/ -}}
 {{ define "service" }}
 // Code generated by protoc-gen-entgrpc. DO NOT EDIT.
-package {{ .File.GoPackageName }}
+package {{ .EntgrpcPackageName }}
 
 // {{ .Service.GoName }} implements {{ .Service.GoName }}Server
 type {{ .Service.GoName }} struct {
-    client *{{ .EntPackage.Ident "Client" | ident }}
-    Unimplemented{{ .Service.GoName }}Server
+    client *{{ .EntPackage.Ident "Client" | ident }}    
+    {{ .File.GoImportPath.Ident (print "Unimplemented" .Service.GoName "Server") | ident }}
 }
 
 // New{{ .Service.GoName }} returns a new {{ .Service.GoName }}
diff --git a/entproto/cmd/protoc-gen-entgrpc/template/to_proto.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/to_proto.tmpl
index 64e093280..ac79d80ae 100644
--- a/entproto/cmd/protoc-gen-entgrpc/template/to_proto.tmpl
+++ b/entproto/cmd/protoc-gen-entgrpc/template/to_proto.tmpl
@@ -1,8 +1,11 @@
 {{- /*gotype: entgo.io/contrib/entproto/cmd/protoc-gen-entgrpc.serviceGenerator*/ -}}
 {{ define "to_proto_func" }}
+    {{- $importPath := .File.GoImportPath -}}
+    {{- $pbTypeName := $importPath.Ident .EntType.Name | ident -}}
+
     // toProto{{ .EntType.Name }} transforms the ent type to the pb type
-    func toProto{{ .EntType.Name }}(e *{{ .EntPackage.Ident .EntType.Name | ident }}) (*{{ .EntType.Name }}, error) {
-        v := &{{ .EntType.Name }}{}
+    func toProto{{ .EntType.Name }}(e *{{ .EntPackage.Ident .EntType.Name | ident }}) (*{{ $pbTypeName }}, error) {
+        v := &{{ $pbTypeName }}{}
         {{- range .FieldMap.Fields }}
             {{- $varName := .EntField.BuilderField -}}
             {{- $f := print "e." .EntField.StructField -}}
@@ -17,20 +20,21 @@
             {{- end }}
         {{- end }}
         {{- range .FieldMap.Edges }}
+            {{ $edgeTypeName := $importPath.Ident .EntEdge.Type.Name | ident }}
             {{- $varName := camel .EntEdge.Type.ID.StructField -}}
             {{- $id := print "edg." .EntEdge.Type.ID.StructField -}}
             {{- $name := .EntEdge.StructField -}}
             {{- if .EntEdge.Unique }}
                 if edg := e.Edges.{{ $name }}; edg != nil {
                     {{- template "field_to_proto" dict "Field" . "VarName" $varName "Ident" $id }}
-                    v.{{ .PbStructField }} = &{{ .EntEdge.Type.Name }}{
+                    v.{{ .PbStructField }} = &{{ $edgeTypeName }}{
                         {{ .EdgeIDPbStructField }}: {{ $varName }},
                     }
                 }
             {{- else }}
                 for _, edg := range e.Edges.{{ $name }} {
                     {{- template "field_to_proto" dict "Field" . "VarName" $varName "Ident" $id }}
-                    v.{{ .PbStructField }} = append(v.{{ .PbStructField }}, &{{ .EntEdge.Type.Name }}{
+                    v.{{ .PbStructField }} = append(v.{{ .PbStructField }}, &{{ $edgeTypeName }}{
                         {{ .EdgeIDPbStructField }}: {{ $varName }},
                     })
                 }
@@ -41,9 +45,12 @@
 {{ end }}
 
 {{ define "to_proto_list_func" }}
+    {{- $importPath := .File.GoImportPath -}}
+    {{- $pbTypeName := $importPath.Ident .EntType.Name | ident -}}
+
     // toProto{{ .EntType.Name }}List transforms a list of ent type to a list of pb type
-    func toProto{{ .EntType.Name }}List(e []*{{ .EntPackage.Ident .EntType.Name | ident }}) ([]*{{ .EntType.Name }}, error) {
-        var pbList []*{{ .EntType.Name }}
+    func toProto{{ .EntType.Name }}List(e []*{{ .EntPackage.Ident .EntType.Name | ident }}) ([]*{{ $pbTypeName }}, error) {
+        var pbList []*{{ $pbTypeName }}
         for _, entEntity := range e {
             pbEntity, err := toProto{{ .EntType.Name }}(entEntity)
             if err != nil {