Skip to content

entproto: add option to specify Go package name #616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions entproto/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Extension struct {
entc.DefaultExtension
protoDir string
skipGenFile bool
goPkg string
}

// WithProtoDir sets the directory where the generated .proto files will be written.
Expand All @@ -74,6 +75,14 @@ func SkipGenFile() ExtensionOption {
}
}

// WithGoPkg sets the Go package to be used in the generated .proto files.
// By default, the Go package is derived from the last part of the proto package.
func WithGoPkg(pkg string) ExtensionOption {
return func(e *Extension) {
e.goPkg = pkg
}
}

// Hooks implements entc.Extension.
func (e *Extension) Hooks() []gen.Hook {
return []gen.Hook{e.hook()}
Expand Down Expand Up @@ -165,7 +174,7 @@ func (e *Extension) generate(g *gen.Graph) error {
return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err)
}
toSchema := filepath.Join(toBase, "schema")
contents := protocGenerateGo(fd, toSchema)
contents := e.protocGenerateGo(fd, toSchema)
if err := os.WriteFile(genGoPath, []byte(contents), 0600); err != nil {
return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err)
}
Expand All @@ -184,7 +193,7 @@ func fileExists(fpath string) bool {
return true
}

func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string {
func (e *Extension) protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string {
levelsUp := len(strings.Split(fd.GetPackage(), "."))
toProtoBase := ""
for i := 0; i < levelsUp; i++ {
Expand All @@ -202,6 +211,12 @@ func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string {
fd.GetName(),
}
goGen := fmt.Sprintf("//go:generate %s", strings.Join(protocCmd, " "))
goPkgName := extractLastFqnPart(fd.GetPackage())

// Use the provided Go package name if set, otherwise derive from the protobuf package
goPkgName := e.goPkg
if goPkgName == "" {
goPkgName = extractLastFqnPart(fd.GetPackage())
}

return fmt.Sprintf("package %s\n%s\n", goPkgName, goGen)
}
77 changes: 77 additions & 0 deletions entproto/extension_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package entproto

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestWithGoPkg(t *testing.T) {
tests := []struct {
name string
goPkg string
fdPackage string
wantPkgName string
}{
{
name: "Default behavior",
goPkg: "",
fdPackage: "example.service",
wantPkgName: "service",
},
{
name: "Custom package",
goPkg: "custompkg",
fdPackage: "example.service",
wantPkgName: "custompkg",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test directly with the extractLastFqnPart helper function
// since we can't easily mock desc.FileDescriptor
var pkgName string
if tt.goPkg != "" {
pkgName = tt.goPkg
} else {
pkgName = extractLastFqnPart(tt.fdPackage)
}

require.Equal(t, tt.wantPkgName, pkgName,
"Expected package name to be %q, got %q", tt.wantPkgName, pkgName)
})
}
}

// TestExtensionGoPkg tests that the WithGoPkg option properly sets the goPkg field
func TestExtensionGoPkg(t *testing.T) {
customPkg := "myspecialpkg"
ext, err := NewExtension(WithGoPkg(customPkg))
require.NoError(t, err)

// Check that the goPkg field is set
require.Equal(t, customPkg, ext.goPkg,
"Expected goPkg to be %q, got %q", customPkg, ext.goPkg)
}

// TestExtractLastFqnPart tests the package extraction function
func TestExtractLastFqnPart(t *testing.T) {
tests := []struct {
input string
want string
}{
{"example.service", "service"},
{"service", "service"},
{"com.example.api.v1", "v1"},
{"", ""},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := extractLastFqnPart(tt.input)
require.Equal(t, tt.want, got,
"extractLastFqnPart(%q) = %q, want %q", tt.input, got, tt.want)
})
}
}