Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cmd/compile: devirtualize interface calls with type assertions #71711

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 90 additions & 5 deletions src/cmd/compile/internal/devirtualize/devirtualize.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ func StaticCall(call *ir.CallExpr) {
}

sel := call.Fun.(*ir.SelectorExpr)
r := ir.StaticValue(sel.X)
if r.Op() != ir.OCONVIFACE {
typ := staticType(sel.X)
if typ == nil {
return
}
recv := r.(*ir.ConvExpr)

typ := recv.X.Type()
if typ.IsInterface() {
// Don't try to devirtualize calls that we statically know that would have failed at runtime.
// This can happen in such case: any(0).(interface {A()}).A(), this typechecks without
// any errors, but will cause a runtime panic. We statically know that int(0) does not
// implement that interface, thus we skip the devirtualization, as it is not possible
// to make a type assertion from interface{A()} to int (int does not implement interface{A()}).
if !typecheck.Implements(typ, sel.X.Type()) {
return
}

Expand Down Expand Up @@ -138,3 +141,85 @@ func StaticCall(call *ir.CallExpr) {
// Desugar OCALLMETH, if we created one (#57309).
typecheck.FixMethodCall(call)
}

func staticType(n ir.Node) *types.Type {
for {
switch n1 := n.(type) {
case *ir.ConvExpr:
if n1.Op() == ir.OCONVNOP || n1.Op() == ir.OCONVIFACE {
n = n1.X
continue
}
case *ir.InlinedCallExpr:
if n1.Op() == ir.OINLCALL {
n = n1.SingleResult()
continue
}
case *ir.ParenExpr:
n = n1.X
continue
case *ir.TypeAssertExpr:
n = n1.X
continue
}

n1 := staticValue(n)
if n1 == nil {
if n.Type().IsInterface() {
return nil
}
return n.Type()
}
n = n1
}
}

func staticValue(nn ir.Node) ir.Node {
if nn.Op() != ir.ONAME {
return nil
}

n := nn.(*ir.Name).Canonical()
if n.Class != ir.PAUTO {
return nil
}

defn := n.Defn
if defn == nil {
return nil
}

var rhs ir.Node
FindRHS:
switch defn.Op() {
case ir.OAS:
defn := defn.(*ir.AssignStmt)
rhs = defn.Y
case ir.OAS2:
defn := defn.(*ir.AssignListStmt)
for i, lhs := range defn.Lhs {
if lhs == n {
rhs = defn.Rhs[i]
break FindRHS
}
}
base.Fatalf("%v missing from LHS of %v", n, defn)
case ir.OAS2DOTTYPE:
defn := defn.(*ir.AssignListStmt)
if defn.Lhs[0] == n {
rhs = defn.Rhs[0]
}
default:
return nil
}

if rhs == nil {
base.Fatalf("RHS is nil: %v", defn)
}

if ir.Reassigned(n) {
return nil
}

return rhs
}
1 change: 1 addition & 0 deletions src/cmd/compile/internal/noder/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -2941,6 +2941,7 @@ func (r *reader) multiExpr() []ir.Node {
as.Def = true
for i := range results {
tmp := r.temp(pos, r.typ())
tmp.Defn = as
as.PtrInit().Append(ir.NewDecl(pos, ir.ODCL, tmp))
as.Lhs.Append(tmp)

Expand Down
14 changes: 14 additions & 0 deletions src/crypto/sha256/sha256_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,17 @@ func BenchmarkHash1K(b *testing.B) {
func BenchmarkHash8K(b *testing.B) {
benchmarkSize(b, 8192)
}

func TestAllocatonsWithTypeAsserts(t *testing.T) {
cryptotest.SkipTestAllocations(t)
allocs := testing.AllocsPerRun(100, func() {
h := New()
h.Write([]byte{1, 2, 3})
marshaled, _ := h.(encoding.BinaryMarshaler).MarshalBinary()
marshaled, _ = h.(encoding.BinaryAppender).AppendBinary(marshaled[:0])
h.(encoding.BinaryUnmarshaler).UnmarshalBinary(marshaled)
})
if allocs != 0 {
t.Fatalf("allocs = %v; want = 0", allocs)
}
}
214 changes: 214 additions & 0 deletions test/escape_iface_with_devirt_type_assertions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// errorcheck -0 -m

// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package escape

type M interface{ M() }

type A interface{ A() }

type C interface{ C() }

type Impl struct{}

func (*Impl) M() {} // ERROR "can inline"

func (*Impl) A() {} // ERROR "can inline"

type CImpl struct{}

func (CImpl) C() {} // ERROR "can inline"

func t() {
var a M = &Impl{} // ERROR "&Impl{} does not escape"

a.(M).M() // ERROR "devirtualizing a.\(M\).M" "inlining call"
a.(A).A() // ERROR "devirtualizing a.\(A\).A" "inlining call"
a.(*Impl).M() // ERROR "inlining call"
a.(*Impl).A() // ERROR "inlining call"

v := a.(M)
v.M() // ERROR "devirtualizing v.M" "inlining call"
v.(A).A() // ERROR "devirtualizing v.\(A\).A" "inlining call"
v.(*Impl).A() // ERROR "inlining call"
v.(*Impl).M() // ERROR "inlining call"

v2 := a.(A)
v2.A() // ERROR "devirtualizing v2.A" "inlining call"
v2.(M).M() // ERROR "devirtualizing v2.\(M\).M" "inlining call"
v2.(*Impl).A() // ERROR "inlining call"
v2.(*Impl).M() // ERROR "inlining call"

a.(M).(A).A() // ERROR "devirtualizing a.\(M\).\(A\).A" "inlining call"
a.(A).(M).M() // ERROR "devirtualizing a.\(A\).\(M\).M" "inlining call"

a.(M).(A).(*Impl).A() // ERROR "inlining call"
a.(A).(M).(*Impl).M() // ERROR "inlining call"

any(a).(M).M() // ERROR "devirtualizing" "inlining call"
any(a).(A).A() // ERROR "devirtualizing" "inlining call"
any(a).(M).(any).(A).A() // ERROR "devirtualizing" "inlining call"

c := any(a)
c.(A).A() // ERROR "devirtualizing" "inlining call"
c.(M).M() // ERROR "devirtualizing" "inlining call"

{
var a C = &CImpl{} // ERROR "does not escape"
a.(any).(C).C() // ERROR "devirtualizing" "inlining"
a.(any).(*CImpl).C() // ERROR "inlining"
}
}

func t2() {
{
var a M = &Impl{} // ERROR "does not escape"
if v, ok := a.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
{
var a M = &Impl{} // ERROR "does not escape"
if v, ok := a.(A); ok {
v.A() // ERROR "devirtualizing" "inlining call"
}
}
{
var a M = &Impl{} // ERROR "does not escape"
v, ok := a.(M)
if ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
{
var a M = &Impl{} // ERROR "does not escape"
v, ok := a.(A)
if ok {
v.A() // ERROR "devirtualizing" "inlining call"
}
}
{
var a M = &Impl{} // ERROR "does not escape"
v, ok := a.(*Impl)
if ok {
v.A() // ERROR "inlining"
v.M() // ERROR "inlining"
}
}
{
var a M = &Impl{} // ERROR "does not escape"
v, _ := a.(M)
v.M() // ERROR "devirtualizing" "inlining call"
}
{
var a M = &Impl{} // ERROR "does not escape"
v, _ := a.(A)
v.A() // ERROR "devirtualizing" "inlining call"
}
{
var a M = &Impl{} // ERROR "does not escape"
v, _ := a.(*Impl)
v.A() // ERROR "inlining"
v.M() // ERROR "inlining"
}
{
a := newM() // ERROR "does not escape" "inlining call"
callA(a) // ERROR "devirtualizing" "inlining call"
callIfA(a) // ERROR "devirtualizing" "inlining call"
}

{
var a M = &Impl{} // ERROR "does not escape"
// Note the !ok condition, devirtualizing here is fine.
if v, ok := a.(M); !ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
}

func newM() M { // ERROR "can inline"
return &Impl{} // ERROR "escapes"
}

func callA(m M) { // ERROR "can inline" "leaking param"
m.(A).A()
}

func callIfA(m M) { // ERROR "can inline" "leaking param"
if v, ok := m.(A); ok {
v.A()
}
}

//go:noinline
func newImplNoInline() *Impl {
return &Impl{} // ERROR "escapes"
}

func t3() {
{
var a A = newImplNoInline()
if v, ok := a.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
{
m := make(map[*Impl]struct{}) // ERROR "does not escape"
for v := range m {
var v A = v
if v, ok := v.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
}
{
m := make(map[int]*Impl) // ERROR "does not escape"
for _, v := range m {
var v A = v
if v, ok := v.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
}
{
m := make(map[int]*Impl) // ERROR "does not escape"
var v A = m[0]
if v, ok := v.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
{
m := make(chan *Impl)
var v A = <-m
if v, ok := v.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
}

//go:noinline
func testInvalidAsserts() {
any(0).(interface{ A() }).A() // ERROR "escapes"
{
var a M = &Impl{} // ERROR "escapes"
a.(C).C() // this will panic
a.(any).(C).C() // this will panic
}
{
var a C = &CImpl{} // ERROR "escapes"
a.(M).M() // this will panic
a.(any).(M).M() // this will panic
}
{
var a C = &CImpl{} // ERROR "does not escape"

// this will panic
a.(M).(*Impl).M() // ERROR "inlining"

// this will panic
a.(any).(M).(*Impl).M() // ERROR "inlining"
}
}