Skip to content

Commit

Permalink
global: use netip where possible now
Browse files Browse the repository at this point in the history
There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <[email protected]>
  • Loading branch information
zx2c4 committed Nov 5, 2021
1 parent 851efb1 commit 0243978
Show file tree
Hide file tree
Showing 22 changed files with 239 additions and 280 deletions.
50 changes: 19 additions & 31 deletions conn/bind_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"unsafe"

"golang.org/x/sys/unix"
"golang.zx2c4.com/go118/netip"
)

type ipv4Source struct {
Expand Down Expand Up @@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)

func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
addr, err := parseEndpoint(s)
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}

ipv4 := addr.IP.To4()
if ipv4 != nil {
if e.Addr().Is4() {
dst := end.dst4()
end.isV6 = false
dst.Port = addr.Port
copy(dst.Addr[:], ipv4)
dst.Port = int(e.Port())
dst.Addr = e.Addr().As4()
end.ClearSrc()
return &end, nil
}

ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if e.Addr().Is6() {
zone, err := zoneToUint32(e.Addr().Zone())
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
dst.Port = addr.Port
dst.Port = int(e.Port())
dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:])
dst.Addr = e.Addr().As16()
end.ClearSrc()
return &end, nil
}
Expand Down Expand Up @@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
}
}

func (end *LinuxSocketEndpoint) SrcIP() net.IP {
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 {
return net.IPv4(
end.src4().Src[0],
end.src4().Src[1],
end.src4().Src[2],
end.src4().Src[3],
)
return netip.AddrFrom4(end.src4().Src)
} else {
return end.src6().src[:]
return netip.AddrFrom16(end.src6().src)
}
}

func (end *LinuxSocketEndpoint) DstIP() net.IP {
func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 {
return net.IPv4(
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
return netip.AddrFrom4(end.src4().Src)
} else {
return end.dst6().Addr[:]
return netip.AddrFrom16(end.dst6().Addr)
}
}

Expand All @@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
}

func (end *LinuxSocketEndpoint) DstToString() string {
var udpAddr net.UDPAddr
udpAddr.IP = end.DstIP()
var port int
if !end.isV6 {
udpAddr.Port = end.dst4().Port
port = end.dst4().Port
} else {
udpAddr.Port = end.dst6().Port
port = end.dst6().Port
}
return udpAddr.String()
return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
}

func (end *LinuxSocketEndpoint) ClearDst() {
Expand Down
18 changes: 12 additions & 6 deletions conn/bind_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"net"
"sync"
"syscall"

"golang.zx2c4.com/go118/netip"
)

// StdNetBind is meant to be a temporary solution on platforms for which
Expand All @@ -32,18 +34,22 @@ var _ Bind = (*StdNetBind)(nil)
var _ Endpoint = (*StdNetEndpoint)(nil)

func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s)
return (*StdNetEndpoint)(addr), err
e, err := netip.ParseAddrPort(s)
return (*StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), err
}

func (*StdNetEndpoint) ClearSrc() {}

func (e *StdNetEndpoint) DstIP() net.IP {
return (*net.UDPAddr)(e).IP
func (e *StdNetEndpoint) DstIP() netip.Addr {
return netip.AddrFromSlice((*net.UDPAddr)(e).IP)
}

func (e *StdNetEndpoint) SrcIP() net.IP {
return nil // not supported
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}

func (e *StdNetEndpoint) DstToBytes() []byte {
Expand Down
19 changes: 9 additions & 10 deletions conn/bind_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"unsafe"

"golang.org/x/sys/windows"
"golang.zx2c4.com/go118/netip"

"golang.zx2c4.com/wireguard/conn/winrio"
)
Expand Down Expand Up @@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {

func (*WinRingEndpoint) ClearSrc() {}

func (e *WinRingEndpoint) DstIP() net.IP {
func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
return append([]byte{}, e.data[2:6]...)
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
return append([]byte{}, e.data[6:22]...)
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
return nil
return netip.Addr{}
}

func (e *WinRingEndpoint) SrcIP() net.IP {
return nil // not supported
func (e *WinRingEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}

func (e *WinRingEndpoint) DstToBytes() []byte {
Expand All @@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
return addr.String()
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
return addr.String()
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
Expand Down
14 changes: 5 additions & 9 deletions conn/bindtest/bindtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"math/rand"
"net"
"os"
"strconv"

"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
)

Expand Down Expand Up @@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d

func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }

func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }

func (c ChannelEndpoint) SrcIP() net.IP { return nil }
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }

func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
Expand Down Expand Up @@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
}

func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
_, port, err := net.SplitHostPort(s)
addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
i, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, err
}
return ChannelEndpoint(i), nil
return ChannelEndpoint(addr.Port()), nil
}
37 changes: 4 additions & 33 deletions conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ package conn
import (
"errors"
"fmt"
"net"
"reflect"
"runtime"
"strings"

"golang.zx2c4.com/go118/netip"
)

// A ReceiveFunc receives a single inbound packet from the network.
Expand Down Expand Up @@ -68,8 +69,8 @@ type Endpoint interface {
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
DstIP() netip.Addr
SrcIP() netip.Addr
}

var (
Expand Down Expand Up @@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
}
return name
}

func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address

host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}

// parse address and port

addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}
Loading

0 comments on commit 0243978

Please sign in to comment.