Skip to content

Commit 9c81ec4

Browse files
committed
Allow using goImport as source for moqs
1 parent 0bf2e8a commit 9c81ec4

File tree

2 files changed

+125
-24
lines changed

2 files changed

+125
-24
lines changed

internal/registry/registry.go

+12
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ func (r *Registry) AddImport(pkg *types.Package) *Package {
9292
return nil
9393
}
9494

95+
if strings.Contains(pkg.Path(), "moq_force_local_interface_mode_") {
96+
path = pkg.Imports()[0].Path() // moq_force_local_interface_mode only have 1 import dependency and is the one we are looking for
97+
}
98+
9599
if imprt, ok := r.imports[path]; ok {
96100
return imprt
97101
}
@@ -111,6 +115,14 @@ func (r *Registry) AddImport(pkg *types.Package) *Package {
111115
func (r Registry) Imports() []*Package {
112116
imports := make([]*Package, 0, len(r.imports))
113117
for _, imprt := range r.imports {
118+
if strings.Contains(imprt.Path(), "moq_force_local_interface_mode_") {
119+
imports = append(imports, &Package{
120+
pkg: imprt.pkg.Imports()[0], // moq_force_local_interface_mode only have 1 dependency
121+
Alias: imprt.pkg.Imports()[0].Name(),
122+
})
123+
124+
continue
125+
}
114126
imports = append(imports, imprt)
115127
}
116128
sort.Slice(imports, func(i, j int) bool {

main.go

+113-24
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@ import (
66
"flag"
77
"fmt"
88
"io"
9-
"io/ioutil"
109
"os"
1110
"path/filepath"
11+
"text/template"
1212

1313
"github.com/matryer/moq/pkg/moq"
1414
)
1515

16-
// Version is the command version, injected at build time.
1716
var Version string = "dev"
1817

1918
type userFlags struct {
@@ -64,50 +63,140 @@ func main() {
6463
}
6564

6665
func run(flags userFlags) error {
67-
if len(flags.args) < 2 {
66+
var (
67+
buf bytes.Buffer
68+
out io.Writer = os.Stdout
69+
srcDir string
70+
)
71+
72+
if NotEnoughArguments(flags) {
6873
return errors.New("not enough arguments")
6974
}
7075

71-
if flags.remove && flags.outFile != "" {
72-
if err := os.Remove(flags.outFile); err != nil {
73-
if !errors.Is(err, os.ErrNotExist) {
74-
return err
75-
}
76+
if ShouldRemoveFile(flags) {
77+
if err := RemoveFile(flags.outFile); err != nil && !os.IsNotExist(err) {
78+
return err
7679
}
7780
}
7881

79-
var buf bytes.Buffer
80-
var out io.Writer = os.Stdout
81-
if flags.outFile != "" {
82+
if ShouldWriteToFile(flags) {
8283
out = &buf
8384
}
8485

85-
srcDir, args := flags.args[0], flags.args[1:]
86-
m, err := moq.New(moq.Config{
87-
SrcDir: srcDir,
88-
PkgName: flags.pkgName,
89-
Formatter: flags.formatter,
90-
StubImpl: flags.stubImpl,
91-
SkipEnsure: flags.skipEnsure,
92-
WithResets: flags.withResets,
93-
})
86+
srcDir, cleanUp, err := getSourceDirectory(flags)
9487
if err != nil {
9588
return err
9689
}
90+
defer cleanUp()
9791

98-
if err = m.Mock(out, args...); err != nil {
92+
m, err := CreateMoq(flags, srcDir)
93+
if err != nil {
94+
return err
95+
}
96+
97+
if err := m.Mock(out, flags.args[1:]...); err != nil {
9998
return err
10099
}
101100

102101
if flags.outFile == "" {
103102
return nil
104103
}
105104

106-
// create the file
107-
err = os.MkdirAll(filepath.Dir(flags.outFile), 0o750)
105+
if err := CreateOutputFile(flags.outFile, buf.Bytes()); err != nil {
106+
return err
107+
}
108+
109+
return nil
110+
}
111+
112+
func getSourceDirectory(flags userFlags) (string, func(), error) {
113+
if DirectoryExists(flags.args[0]) {
114+
return flags.args[0], func() {}, nil
115+
}
116+
117+
pwd, err := os.Getwd()
118+
if err != nil {
119+
return "", func() {}, err
120+
}
121+
122+
tempDir, err := os.MkdirTemp(pwd, "moq_force_local_interface_mode_")
123+
if err != nil {
124+
return "", func() {}, err
125+
}
126+
127+
if err := GenerateMoqForceLocalInterface(flags, tempDir); err != nil {
128+
return "", func() {}, err
129+
}
130+
131+
return tempDir, func() { os.RemoveAll(tempDir) }, nil
132+
133+
}
134+
135+
func NotEnoughArguments(flags userFlags) bool {
136+
return len(flags.args) < 2
137+
}
138+
139+
func ShouldRemoveFile(flags userFlags) bool {
140+
return flags.remove && flags.outFile != ""
141+
}
142+
143+
func RemoveFile(filePath string) error {
144+
return os.Remove(filePath)
145+
}
146+
147+
func ShouldWriteToFile(flags userFlags) bool {
148+
return flags.outFile != ""
149+
}
150+
151+
func DirectoryExists(directoryPath string) bool {
152+
_, err := os.Stat(directoryPath)
153+
return err == nil
154+
}
155+
156+
func GenerateMoqForceLocalInterface(flags userFlags, tempDir string) error {
157+
tmpl, err := template.New("force_local_interface").Parse(moqForceLocalInterface)
108158
if err != nil {
109159
return err
110160
}
111161

112-
return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0o600)
162+
var buf bytes.Buffer
163+
164+
err = tmpl.Execute(&buf, map[string]interface{}{
165+
"SrcPkgQualifier": filepath.Base(flags.args[0]),
166+
"Import": flags.args[0],
167+
"InterfaceName": flags.args[1],
168+
})
169+
if err != nil {
170+
return err
171+
}
172+
173+
return os.WriteFile(filepath.Join(tempDir, "moq_force_local_interface.go"), buf.Bytes(), 0600)
174+
}
175+
176+
func CreateMoq(flags userFlags, srcDir string) (*moq.Mocker, error) {
177+
return moq.New(moq.Config{
178+
SrcDir: srcDir,
179+
PkgName: flags.pkgName,
180+
Formatter: flags.formatter,
181+
StubImpl: flags.stubImpl,
182+
SkipEnsure: flags.skipEnsure,
183+
WithResets: flags.withResets,
184+
})
185+
}
186+
187+
func CreateOutputFile(filePath string, data []byte) error {
188+
if err := os.MkdirAll(filepath.Dir(filePath), 0750); err != nil {
189+
return err
190+
}
191+
return os.WriteFile(filePath, data, 0600)
113192
}
193+
194+
const moqForceLocalInterface = `// Code generated by moq; DO NOT EDIT
195+
// github.com/matryer/moq
196+
197+
package {{.SrcPkgQualifier}}
198+
199+
import "{{.Import}}"
200+
201+
type {{.InterfaceName}} {{$.SrcPkgQualifier}}.{{.InterfaceName}}
202+
`

0 commit comments

Comments
 (0)