Skip to content

Commit c13f17e

Browse files
committed
gopls/internal/golang: add extract interface code action
1 parent 7240af8 commit c13f17e

File tree

5 files changed

+380
-23
lines changed

5 files changed

+380
-23
lines changed

gopls/internal/golang/codeaction.go

+42-23
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,13 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle,
8080
}
8181
}
8282

83+
pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
84+
if err != nil {
85+
return nil, err
86+
}
87+
8388
if want[protocol.RefactorExtract] {
84-
extractions, err := getExtractCodeActions(pgf, rng, snapshot.Options())
89+
extractions, err := getExtractCodeActions(pkg, pgf, rng, snapshot.Options())
8590
if err != nil {
8691
return nil, err
8792
}
@@ -179,20 +184,18 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic)
179184
}
180185

181186
// getExtractCodeActions returns any refactor.extract code actions for the selection.
182-
func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
183-
if rng.Start == rng.End {
184-
return nil, nil
185-
}
186-
187+
func getExtractCodeActions(pkg *cache.Package, pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
187188
start, end, err := pgf.RangePos(rng)
188189
if err != nil {
189190
return nil, err
190191
}
192+
191193
puri := pgf.URI
192194
var commands []protocol.Command
193-
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
194-
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
195-
Fix: fixExtractFunction,
195+
196+
if _, _, ok, _ := CanExtractInterface(pkg, start, end, pgf.File); ok {
197+
cmd, err := command.NewApplyFixCommand("Extract interface", command.ApplyFixArgs{
198+
Fix: fixExtractInterface,
196199
URI: puri,
197200
Range: rng,
198201
ResolveEdits: supportsResolveEdits(options),
@@ -201,9 +204,12 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
201204
return nil, err
202205
}
203206
commands = append(commands, cmd)
204-
if methodOk {
205-
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
206-
Fix: fixExtractMethod,
207+
}
208+
209+
if rng.Start != rng.End {
210+
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
211+
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
212+
Fix: fixExtractFunction,
207213
URI: puri,
208214
Range: rng,
209215
ResolveEdits: supportsResolveEdits(options),
@@ -212,20 +218,33 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
212218
return nil, err
213219
}
214220
commands = append(commands, cmd)
221+
if methodOk {
222+
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
223+
Fix: fixExtractMethod,
224+
URI: puri,
225+
Range: rng,
226+
ResolveEdits: supportsResolveEdits(options),
227+
})
228+
if err != nil {
229+
return nil, err
230+
}
231+
commands = append(commands, cmd)
232+
}
215233
}
216-
}
217-
if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
218-
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
219-
Fix: fixExtractVariable,
220-
URI: puri,
221-
Range: rng,
222-
ResolveEdits: supportsResolveEdits(options),
223-
})
224-
if err != nil {
225-
return nil, err
234+
if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
235+
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
236+
Fix: fixExtractVariable,
237+
URI: puri,
238+
Range: rng,
239+
ResolveEdits: supportsResolveEdits(options),
240+
})
241+
if err != nil {
242+
return nil, err
243+
}
244+
commands = append(commands, cmd)
226245
}
227-
commands = append(commands, cmd)
228246
}
247+
229248
var actions []protocol.CodeAction
230249
for i := range commands {
231250
actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options))

gopls/internal/golang/extract.go

+34
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
"golang.org/x/tools/go/analysis"
2020
"golang.org/x/tools/go/ast/astutil"
21+
"golang.org/x/tools/gopls/internal/cache"
2122
"golang.org/x/tools/gopls/internal/util/bug"
2223
"golang.org/x/tools/gopls/internal/util/safetoken"
2324
"golang.org/x/tools/internal/analysisinternal"
@@ -127,6 +128,39 @@ func CanExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.N
127128
return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
128129
}
129130

131+
// CanExtractInterface reports whether the code in the given position is for a
132+
// type which can be represented as an interface.
133+
func CanExtractInterface(pkg *cache.Package, start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
134+
path, _ := astutil.PathEnclosingInterval(file, start, end)
135+
if len(path) == 0 {
136+
return nil, nil, false, fmt.Errorf("no path enclosing interval")
137+
}
138+
139+
node := path[0]
140+
expr, ok := node.(ast.Expr)
141+
if !ok {
142+
return nil, nil, false, fmt.Errorf("node is not an expression")
143+
}
144+
145+
switch e := expr.(type) {
146+
case *ast.Ident:
147+
o, ok := pkg.GetTypesInfo().ObjectOf(e).(*types.TypeName)
148+
if !ok {
149+
return nil, nil, false, fmt.Errorf("cannot extract a %T to a variable", expr)
150+
}
151+
152+
if _, ok := o.Type().(*types.Basic); ok {
153+
return nil, nil, false, fmt.Errorf("cannot extract a basic type to an interface")
154+
}
155+
156+
return expr, path, true, nil
157+
case *ast.StarExpr, *ast.SelectorExpr:
158+
return expr, path, true, nil
159+
default:
160+
return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
161+
}
162+
}
163+
130164
// Calculate indentation for insertion.
131165
// When inserting lines of code, we must ensure that the lines have consistent
132166
// formatting (i.e. the proper indentation). To do so, we observe the indentation on the

gopls/internal/golang/fix.go

+141
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
package golang
66

77
import (
8+
"bytes"
89
"context"
10+
"errors"
911
"fmt"
1012
"go/ast"
1113
"go/token"
1214
"go/types"
15+
"slices"
1316

1417
"golang.org/x/tools/go/analysis"
18+
"golang.org/x/tools/go/ast/astutil"
1519
"golang.org/x/tools/gopls/internal/analysis/embeddirective"
1620
"golang.org/x/tools/gopls/internal/analysis/fillstruct"
1721
"golang.org/x/tools/gopls/internal/analysis/stubmethods"
@@ -22,6 +26,7 @@ import (
2226
"golang.org/x/tools/gopls/internal/file"
2327
"golang.org/x/tools/gopls/internal/protocol"
2428
"golang.org/x/tools/gopls/internal/util/bug"
29+
"golang.org/x/tools/gopls/internal/util/safetoken"
2530
"golang.org/x/tools/internal/imports"
2631
)
2732

@@ -61,6 +66,7 @@ func singleFile(fixer1 singleFileFixer) fixer {
6166
const (
6267
fixExtractVariable = "extract_variable"
6368
fixExtractFunction = "extract_function"
69+
fixExtractInterface = "extract_interface"
6470
fixExtractMethod = "extract_method"
6571
fixInlineCall = "inline_call"
6672
fixInvertIfCondition = "invert_if_condition"
@@ -110,6 +116,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
110116

111117
// Ad-hoc fixers: these are used when the command is
112118
// constructed directly by logic in server/code_action.
119+
fixExtractInterface: extractInterface,
113120
fixExtractFunction: singleFile(extractFunction),
114121
fixExtractMethod: singleFile(extractMethod),
115122
fixExtractVariable: singleFile(extractVariable),
@@ -138,6 +145,140 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
138145
return suggestedFixToEdits(ctx, snapshot, fixFset, suggestion)
139146
}
140147

148+
func extractInterface(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
149+
path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
150+
151+
var field *ast.Field
152+
var decl ast.Decl
153+
for _, node := range path {
154+
if f, ok := node.(*ast.Field); ok {
155+
field = f
156+
continue
157+
}
158+
159+
// Record the node that starts the declaration of the type that contains
160+
// the field we are creating the interface for.
161+
if d, ok := node.(ast.Decl); ok {
162+
decl = d
163+
break // we have both the field and the declaration
164+
}
165+
}
166+
167+
if field == nil || decl == nil {
168+
return nil, nil, nil
169+
}
170+
171+
p := safetoken.StartPosition(pkg.FileSet(), field.Pos())
172+
pos := protocol.Position{
173+
Line: uint32(p.Line - 1), // Line is zero-based
174+
Character: uint32(p.Column - 1), // Character is zero-based
175+
}
176+
177+
fh, err := snapshot.ReadFile(ctx, pgf.URI)
178+
if err != nil {
179+
return nil, nil, err
180+
}
181+
182+
refs, err := references(ctx, snapshot, fh, pos, false)
183+
if err != nil {
184+
return nil, nil, err
185+
}
186+
187+
type method struct {
188+
signature *types.Signature
189+
name string
190+
}
191+
192+
var methods []method
193+
for _, ref := range refs {
194+
locPkg, locPgf, err := NarrowestPackageForFile(ctx, snapshot, ref.location.URI)
195+
if err != nil {
196+
return nil, nil, err
197+
}
198+
199+
_, end, err := locPgf.RangePos(ref.location.Range)
200+
if err != nil {
201+
return nil, nil, err
202+
}
203+
204+
// We are interested in the method call, so we need the node after the dot
205+
rangeEnd := end + token.Pos(len("."))
206+
path, _ := astutil.PathEnclosingInterval(locPgf.File, rangeEnd, rangeEnd)
207+
id, ok := path[0].(*ast.Ident)
208+
if !ok {
209+
continue
210+
}
211+
212+
obj := locPkg.GetTypesInfo().ObjectOf(id)
213+
if obj == nil {
214+
continue
215+
}
216+
217+
sig, ok := obj.Type().(*types.Signature)
218+
if !ok {
219+
return nil, nil, errors.New("cannot extract interface with non-method accesses")
220+
}
221+
222+
fc := method{signature: sig, name: obj.Name()}
223+
if !slices.Contains(methods, fc) {
224+
methods = append(methods, fc)
225+
}
226+
}
227+
228+
interfaceName := "I" + pkg.GetTypesInfo().ObjectOf(field.Names[0]).Name()
229+
var buf bytes.Buffer
230+
buf.WriteString("\ntype ")
231+
buf.WriteString(interfaceName)
232+
buf.WriteString(" interface {\n")
233+
for _, fc := range methods {
234+
buf.WriteString("\t")
235+
buf.WriteString(fc.name)
236+
types.WriteSignature(&buf, fc.signature, relativeTo(pkg.GetTypes()))
237+
buf.WriteByte('\n')
238+
}
239+
buf.WriteByte('}')
240+
buf.WriteByte('\n')
241+
242+
interfacePos := decl.Pos() - 1
243+
// Move the interface above the documentation comment if the type declaration
244+
// includes one.
245+
switch d := decl.(type) {
246+
case *ast.GenDecl:
247+
if d.Doc != nil {
248+
interfacePos = d.Doc.Pos() - 1
249+
}
250+
case *ast.FuncDecl:
251+
if d.Doc != nil {
252+
interfacePos = d.Doc.Pos() - 1
253+
}
254+
}
255+
256+
return pkg.FileSet(), &analysis.SuggestedFix{
257+
Message: "Extract interface",
258+
TextEdits: []analysis.TextEdit{{
259+
Pos: interfacePos,
260+
End: interfacePos,
261+
NewText: buf.Bytes(),
262+
}, {
263+
Pos: field.Type.Pos(),
264+
End: field.Type.End(),
265+
NewText: []byte(interfaceName),
266+
}},
267+
}, nil
268+
}
269+
270+
func relativeTo(pkg *types.Package) types.Qualifier {
271+
if pkg == nil {
272+
return nil
273+
}
274+
return func(other *types.Package) string {
275+
if pkg == other {
276+
return "" // same package; unqualified
277+
}
278+
return other.Name()
279+
}
280+
}
281+
141282
// suggestedFixToEdits converts the suggestion's edits from analysis form into protocol form.
142283
func suggestedFixToEdits(ctx context.Context, snapshot *cache.Snapshot, fset *token.FileSet, suggestion *analysis.SuggestedFix) ([]protocol.TextDocumentEdit, error) {
143284
editsPerFile := map[protocol.DocumentURI]*protocol.TextDocumentEdit{}

0 commit comments

Comments
 (0)