Skip to content

Commit f0bc5e0

Browse files
committed
Allow using goImport as source for moqs
1 parent 0bf2e8a commit f0bc5e0

File tree

2 files changed

+132
-24
lines changed

2 files changed

+132
-24
lines changed

internal/registry/registry.go

+12
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ func (r *Registry) AddImport(pkg *types.Package) *Package {
9292
return nil
9393
}
9494

95+
if strings.Contains(pkg.Path(), "moq_force_local_interface_mode_") {
96+
path = pkg.Imports()[0].Path() // moq_force_local_interface_mode only have 1 import dependency and is the one we are looking for
97+
}
98+
9599
if imprt, ok := r.imports[path]; ok {
96100
return imprt
97101
}
@@ -111,6 +115,14 @@ func (r *Registry) AddImport(pkg *types.Package) *Package {
111115
func (r Registry) Imports() []*Package {
112116
imports := make([]*Package, 0, len(r.imports))
113117
for _, imprt := range r.imports {
118+
if strings.Contains(imprt.Path(), "moq_force_local_interface_mode_") {
119+
imports = append(imports, &Package{
120+
pkg: imprt.pkg.Imports()[0], // moq_force_local_interface_mode only have 1 dependency
121+
Alias: imprt.pkg.Imports()[0].Name(),
122+
})
123+
124+
continue
125+
}
114126
imports = append(imports, imprt)
115127
}
116128
sort.Slice(imports, func(i, j int) bool {

main.go

+120-24
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ import (
66
"flag"
77
"fmt"
88
"io"
9-
"io/ioutil"
109
"os"
10+
"os/exec"
1111
"path/filepath"
12+
"text/template"
1213

1314
"github.com/matryer/moq/pkg/moq"
1415
)
1516

16-
// Version is the command version, injected at build time.
1717
var Version string = "dev"
1818

1919
type userFlags struct {
@@ -64,50 +64,146 @@ func main() {
6464
}
6565

6666
func run(flags userFlags) error {
67-
if len(flags.args) < 2 {
67+
var (
68+
buf bytes.Buffer
69+
out io.Writer = os.Stdout
70+
srcDir string
71+
)
72+
73+
if NotEnoughArguments(flags) {
6874
return errors.New("not enough arguments")
6975
}
7076

71-
if flags.remove && flags.outFile != "" {
72-
if err := os.Remove(flags.outFile); err != nil {
73-
if !errors.Is(err, os.ErrNotExist) {
74-
return err
75-
}
77+
if ShouldRemoveFile(flags) {
78+
if err := RemoveFile(flags.outFile); err != nil && !os.IsNotExist(err) {
79+
return err
7680
}
7781
}
7882

79-
var buf bytes.Buffer
80-
var out io.Writer = os.Stdout
81-
if flags.outFile != "" {
83+
if ShouldWriteToFile(flags) {
8284
out = &buf
8385
}
8486

85-
srcDir, args := flags.args[0], flags.args[1:]
86-
m, err := moq.New(moq.Config{
87-
SrcDir: srcDir,
88-
PkgName: flags.pkgName,
89-
Formatter: flags.formatter,
90-
StubImpl: flags.stubImpl,
91-
SkipEnsure: flags.skipEnsure,
92-
WithResets: flags.withResets,
93-
})
87+
srcDir, cleanUp, err := getSourceDirectory(flags)
9488
if err != nil {
9589
return err
9690
}
91+
defer cleanUp()
9792

98-
if err = m.Mock(out, args...); err != nil {
93+
m, err := CreateMoq(flags, srcDir)
94+
if err != nil {
95+
return err
96+
}
97+
98+
if err := m.Mock(out, flags.args[1:]...); err != nil {
9999
return err
100100
}
101101

102102
if flags.outFile == "" {
103103
return nil
104104
}
105105

106-
// create the file
107-
err = os.MkdirAll(filepath.Dir(flags.outFile), 0o750)
106+
if err := CreateOutputFile(flags.outFile, buf.Bytes()); err != nil {
107+
return err
108+
}
109+
110+
return nil
111+
}
112+
113+
func getSourceDirectory(flags userFlags) (string, func(), error) {
114+
if DirectoryExists(flags.args[0]) {
115+
return flags.args[0], func() {}, nil
116+
}
117+
118+
cmd := exec.Command("go", "list", flag.Args()[0])
119+
output, err := cmd.CombinedOutput()
120+
if err != nil {
121+
return "", func() {}, fmt.Errorf("%s", output)
122+
}
123+
124+
pwd, err := os.Getwd()
125+
if err != nil {
126+
return "", func() {}, err
127+
}
128+
129+
tempDir, err := os.MkdirTemp(pwd, "moq_force_local_interface_mode_")
130+
if err != nil {
131+
return "", func() {}, err
132+
}
133+
134+
if err := GenerateMoqForceLocalInterface(flags, tempDir); err != nil {
135+
return "", func() {}, err
136+
}
137+
138+
return tempDir, func() { os.RemoveAll(tempDir) }, nil
139+
140+
}
141+
142+
func NotEnoughArguments(flags userFlags) bool {
143+
return len(flags.args) < 2
144+
}
145+
146+
func ShouldRemoveFile(flags userFlags) bool {
147+
return flags.remove && flags.outFile != ""
148+
}
149+
150+
func RemoveFile(filePath string) error {
151+
return os.Remove(filePath)
152+
}
153+
154+
func ShouldWriteToFile(flags userFlags) bool {
155+
return flags.outFile != ""
156+
}
157+
158+
func DirectoryExists(directoryPath string) bool {
159+
_, err := os.Stat(directoryPath)
160+
return err == nil
161+
}
162+
163+
func GenerateMoqForceLocalInterface(flags userFlags, tempDir string) error {
164+
tmpl, err := template.New("force_local_interface").Parse(moqForceLocalInterface)
165+
if err != nil {
166+
return err
167+
}
168+
169+
var buf bytes.Buffer
170+
171+
err = tmpl.Execute(&buf, map[string]interface{}{
172+
"SrcPkgQualifier": filepath.Base(flags.args[0]),
173+
"Import": flags.args[0],
174+
"InterfaceName": flags.args[1],
175+
})
108176
if err != nil {
109177
return err
110178
}
111179

112-
return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0o600)
180+
return os.WriteFile(filepath.Join(tempDir, "moq_force_local_interface.go"), buf.Bytes(), 0600)
113181
}
182+
183+
func CreateMoq(flags userFlags, srcDir string) (*moq.Mocker, error) {
184+
return moq.New(moq.Config{
185+
SrcDir: srcDir,
186+
PkgName: flags.pkgName,
187+
Formatter: flags.formatter,
188+
StubImpl: flags.stubImpl,
189+
SkipEnsure: flags.skipEnsure,
190+
WithResets: flags.withResets,
191+
})
192+
}
193+
194+
func CreateOutputFile(filePath string, data []byte) error {
195+
if err := os.MkdirAll(filepath.Dir(filePath), 0750); err != nil {
196+
return err
197+
}
198+
return os.WriteFile(filePath, data, 0600)
199+
}
200+
201+
const moqForceLocalInterface = `// Code generated by moq; DO NOT EDIT
202+
// github.com/matryer/moq
203+
204+
package {{.SrcPkgQualifier}}
205+
206+
import "{{.Import}}"
207+
208+
type {{.InterfaceName}} {{$.SrcPkgQualifier}}.{{.InterfaceName}}
209+
`

0 commit comments

Comments
 (0)