Skip to content

Commit

Permalink
Fix deadline usage for websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 1, 2023
1 parent ed3ddda commit c6074c7
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 23 deletions.
10 changes: 1 addition & 9 deletions common/dialer/detour.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"sync"

"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/bufio/deadline"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
Expand Down Expand Up @@ -45,14 +44,7 @@ func (d *DetourDialer) DialContext(ctx context.Context, network string, destinat
if err != nil {
return nil, err
}
conn, err := dialer.DialContext(ctx, network, destination)
if err != nil {
return nil, err
}
if deadline.NeedAdditionalReadDeadline(conn) {
conn = deadline.NewConn(conn)
}
return conn, nil
return dialer.DialContext(ctx, network, destination)
}

func (d *DetourDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
Expand Down
26 changes: 23 additions & 3 deletions transport/v2raywebsocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/bufio/deadline"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
Expand Down Expand Up @@ -87,18 +90,35 @@ func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers h
return nil, err
}
}
conn.SetDeadline(time.Now().Add(C.TCPTimeout))
var deadlineConn net.Conn
if deadline.NeedAdditionalReadDeadline(conn) {
deadlineConn = deadline.NewConn(conn)
} else {
deadlineConn = conn
}
err = deadlineConn.SetDeadline(time.Now().Add(C.TCPTimeout))
if err != nil {
return nil, E.Cause(err, "set read deadline")
}
var protocols []string
if protocolHeader := headers.Get("Sec-WebSocket-Protocol"); protocolHeader != "" {
protocols = []string{protocolHeader}
headers.Del("Sec-WebSocket-Protocol")
}
reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(conn, requestURL)
reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(deadlineConn, requestURL)
conn.SetDeadline(time.Time{})
if err != nil {
return nil, err
}
return NewConn(conn, reader, nil, ws.StateClientSide), nil
if reader.Buffered() > 0 {
buffer := buf.NewSize(reader.Buffered())
_, err = buffer.ReadFullFrom(reader, buffer.Len())
if err != nil {
return nil, err
}
conn = bufio.NewCachedConn(conn, buffer)
}
return NewConn(conn, nil, ws.StateClientSide), nil
}

func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
Expand Down
11 changes: 2 additions & 9 deletions transport/v2raywebsocket/conn.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package v2raywebsocket

import (
"bufio"
"context"
"encoding/base64"
"io"
Expand All @@ -28,19 +27,13 @@ type WebsocketConn struct {
remoteAddr net.Addr
}

func NewConn(conn net.Conn, br *bufio.Reader, remoteAddr net.Addr, state ws.State) *WebsocketConn {
func NewConn(conn net.Conn, remoteAddr net.Addr, state ws.State) *WebsocketConn {
controlHandler := wsutil.ControlFrameHandler(conn, state)
var reader io.Reader
if br != nil && br.Buffered() > 0 {
reader = br
} else {
reader = conn
}
return &WebsocketConn{
Conn: conn,
state: state,
reader: &wsutil.Reader{
Source: reader,
Source: conn,
State: state,
SkipHeaderCheck: !debug.Enabled,
OnIntermediate: controlHandler,
Expand Down
4 changes: 2 additions & 2 deletions transport/v2raywebsocket/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
s.invalidRequest(writer, request, http.StatusBadRequest, E.Cause(err, "decode early data"))
return
}
wsConn, reader, _, err := ws.UpgradeHTTP(request, writer)
wsConn, _, _, err := ws.UpgradeHTTP(request, writer)
if err != nil {
s.invalidRequest(writer, request, 0, E.Cause(err, "upgrade websocket connection"))
return
}
var metadata M.Metadata
metadata.Source = sHttp.SourceAddress(request)
conn = NewConn(wsConn, reader.Reader, metadata.Source.TCPAddr(), ws.StateServerSide)
conn = NewConn(wsConn, metadata.Source.TCPAddr(), ws.StateServerSide)
if len(earlyData) > 0 {
conn = bufio.NewCachedConn(conn, buf.As(earlyData))
}
Expand Down

0 comments on commit c6074c7

Please sign in to comment.