Skip to content

Commit 2ece707

Browse files
committed
fix(entproto): allow specifying import path
Basically a merge of: ent#616
1 parent a1d02d8 commit 2ece707

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

entproto/extension.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ type Extension struct {
5858
entc.DefaultExtension
5959
protoDir string
6060
skipGenFile bool
61+
goPkg string // inspired from: https://github.com/ent/contrib/pull/616
6162
}
6263

6364
// WithProtoDir sets the directory where the generated .proto files will be written.
@@ -67,6 +68,13 @@ func WithProtoDir(dir string) ExtensionOption {
6768
}
6869
}
6970

71+
// WithProtoDir sets the directory where the generated .proto files will be written.
72+
func WithGoPkg(pkg string) ExtensionOption {
73+
return func(e *Extension) {
74+
e.protoDir = pkg
75+
}
76+
}
77+
7078
// SkipGenFile skips the generation of a generate.go file next to each .proto file.
7179
func SkipGenFile() ExtensionOption {
7280
return func(e *Extension) {
@@ -197,7 +205,7 @@ func (e *Extension) generate(g *gen.Graph) error {
197205
if err != nil {
198206
return err
199207
}
200-
contents := protocGenerateGo(fd, toSchema, toEnt, g.Config.Package)
208+
contents := e.protocGenerateGo(fd, toSchema, toEnt, g.Config.Package)
201209
if err := os.WriteFile(genGoPath, []byte(contents), 0600); err != nil {
202210
return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err)
203211
}
@@ -216,7 +224,7 @@ func fileExists(fpath string) bool {
216224
return true
217225
}
218226

219-
func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir, entPath, entPackage string) string {
227+
func (e *Extension) protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir, entPath, entPackage string) string {
220228
levelsUp := len(strings.Split(fd.GetPackage(), "."))
221229
toProtoBase := ""
222230
for i := 0; i < levelsUp; i++ {
@@ -235,5 +243,8 @@ func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir, entPath, entPackage
235243
}
236244
goGen := fmt.Sprintf("//go:generate %s", strings.Join(protocCmd, " "))
237245
goPkgName := extractLastFqnPart(fd.GetPackage())
246+
if e.goPkg != "" {
247+
goPkgName = e.goPkg
248+
}
238249
return fmt.Sprintf("package %s\n%s\n", goPkgName, goGen)
239250
}

0 commit comments

Comments
 (0)