Skip to content

Commit d2499e1

Browse files
committed
add context to connection handlers
1 parent 3b36b7a commit d2499e1

File tree

10 files changed

+36
-14
lines changed

10 files changed

+36
-14
lines changed

internal/client/transport/tcp.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"time"
1111

1212
"github.com/musix/backhaul/internal/utils"
13+
"github.com/musix/backhaul/internal/utils/handlers"
1314
"github.com/musix/backhaul/internal/utils/network"
1415
"github.com/musix/backhaul/internal/web"
1516

@@ -389,5 +390,5 @@ func (c *TcpTransport) localDialer(tcpConn net.Conn, resolvedAddr string, port i
389390

390391
c.logger.Debugf("connected to local address %s successfully", resolvedAddr)
391392

392-
utils.TCPConnectionHandler(tcpConn, localConnection, c.logger, c.usageMonitor, port, c.config.Sniffer)
393+
handlers.TCPConnectionHandler(c.ctx, tcpConn, localConnection, c.logger, c.usageMonitor, port, c.config.Sniffer)
393394
}

internal/client/transport/tcpmux.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"time"
1111

1212
"github.com/musix/backhaul/internal/utils"
13+
"github.com/musix/backhaul/internal/utils/handlers"
1314
"github.com/musix/backhaul/internal/utils/network"
1415
"github.com/musix/backhaul/internal/web"
1516

@@ -410,5 +411,5 @@ func (c *TcpMuxTransport) localDialer(stream *smux.Stream, remoteAddr string) {
410411

411412
c.logger.Debugf("connected to local address %s successfully", remoteAddr)
412413

413-
utils.TCPConnectionHandler(stream, localConnection, c.logger, c.usageMonitor, int(port), c.config.Sniffer)
414+
handlers.TCPConnectionHandler(c.ctx, stream, localConnection, c.logger, c.usageMonitor, int(port), c.config.Sniffer)
414415
}

internal/client/transport/ws.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/musix/backhaul/config"
1313
"github.com/musix/backhaul/internal/utils"
14+
"github.com/musix/backhaul/internal/utils/handlers"
1415
"github.com/musix/backhaul/internal/utils/network"
1516
"github.com/musix/backhaul/internal/web"
1617

@@ -358,5 +359,5 @@ func (c *WsTransport) localDialer(tunnelCon *websocket.Conn, remoteAddr string,
358359
}
359360
c.logger.Debugf("connected to local address %s successfully", remoteAddr)
360361

361-
utils.WSConnectionHandler(tunnelCon, localConnection, c.logger, c.usageMonitor, int(port), c.config.Sniffer)
362+
handlers.WSConnectionHandler(c.ctx, tunnelCon, localConnection, c.logger, c.usageMonitor, int(port), c.config.Sniffer)
362363
}

internal/client/transport/wsmux.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/musix/backhaul/config"
1212
"github.com/musix/backhaul/internal/utils"
13+
"github.com/musix/backhaul/internal/utils/handlers"
1314
"github.com/musix/backhaul/internal/utils/network"
1415
"github.com/musix/backhaul/internal/web"
1516
"github.com/xtaci/smux"
@@ -377,5 +378,5 @@ func (c *WsMuxTransport) localDialer(stream *smux.Stream, remoteAddr string) {
377378

378379
c.logger.Debugf("connected to local address %s successfully", remoteAddr)
379380

380-
utils.TCPConnectionHandler(stream, localConnection, c.logger, c.usageMonitor, int(port), c.config.Sniffer)
381+
handlers.TCPConnectionHandler(c.ctx, stream, localConnection, c.logger, c.usageMonitor, int(port), c.config.Sniffer)
381382
}

internal/server/transport/tcp.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/musix/backhaul/internal/utils"
14+
"github.com/musix/backhaul/internal/utils/handlers"
1415
"github.com/musix/backhaul/internal/utils/network"
1516
"github.com/musix/backhaul/internal/web"
1617

@@ -552,7 +553,7 @@ func (s *TcpTransport) handleLoop() {
552553
}
553554

554555
// Handle data exchange between connections
555-
go utils.TCPConnectionHandler(localConn.conn, tunnelConn, s.logger, s.usageMonitor, localConn.conn.LocalAddr().(*net.TCPAddr).Port, s.config.Sniffer)
556+
go handlers.TCPConnectionHandler(s.ctx, localConn.conn, tunnelConn, s.logger, s.usageMonitor, localConn.conn.LocalAddr().(*net.TCPAddr).Port, s.config.Sniffer)
556557
break loop
557558

558559
}

internal/server/transport/tcpmux.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/musix/backhaul/internal/utils"
15+
"github.com/musix/backhaul/internal/utils/handlers"
1516
"github.com/musix/backhaul/internal/utils/network"
1617
"github.com/musix/backhaul/internal/web"
1718

@@ -604,7 +605,7 @@ func (s *TcpMuxTransport) handleSession(session *smux.Session) {
604605

605606
// Handle data exchange between connections
606607
go func() {
607-
utils.TCPConnectionHandler(stream, incomingConn.conn, s.logger, s.usageMonitor, incomingConn.conn.LocalAddr().(*net.TCPAddr).Port, s.config.Sniffer)
608+
handlers.TCPConnectionHandler(s.ctx, stream, incomingConn.conn, s.logger, s.usageMonitor, incomingConn.conn.LocalAddr().(*net.TCPAddr).Port, s.config.Sniffer)
608609
atomic.AddInt32(&s.streamCounter, -1)
609610
<-counter // read signal from the channel
610611
}()

internal/server/transport/ws.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313

1414
"github.com/musix/backhaul/config"
1515
"github.com/musix/backhaul/internal/utils"
16+
"github.com/musix/backhaul/internal/utils/handlers"
1617
"github.com/musix/backhaul/internal/web"
1718

1819
"github.com/gorilla/websocket"
@@ -509,7 +510,7 @@ func (s *WsTransport) handleLoop() {
509510
continue loop
510511
}
511512
// Handle data exchange between connections
512-
go utils.WSConnectionHandler(tunnelConnection.conn, localConn.conn, s.logger, s.usageMonitor, localConn.conn.LocalAddr().(*net.TCPAddr).Port, s.config.Sniffer)
513+
go handlers.WSConnectionHandler(s.ctx, tunnelConnection.conn, localConn.conn, s.logger, s.usageMonitor, localConn.conn.LocalAddr().(*net.TCPAddr).Port, s.config.Sniffer)
513514
break loop
514515
}
515516
}

internal/server/transport/wsmux.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/musix/backhaul/config" // for mode
1616
"github.com/musix/backhaul/internal/utils"
17+
"github.com/musix/backhaul/internal/utils/handlers"
1718
"github.com/musix/backhaul/internal/web"
1819
"github.com/xtaci/smux"
1920

@@ -566,7 +567,7 @@ func (s *WsMuxTransport) handleSession(session *smux.Session) {
566567

567568
// Handle data exchange between connections
568569
go func() {
569-
utils.TCPConnectionHandler(stream, incomingConn.conn, s.logger, s.usageMonitor, incomingConn.conn.LocalAddr().(*net.TCPAddr).Port, s.config.Sniffer)
570+
handlers.TCPConnectionHandler(s.ctx, stream, incomingConn.conn, s.logger, s.usageMonitor, incomingConn.conn.LocalAddr().(*net.TCPAddr).Port, s.config.Sniffer)
570571
atomic.AddInt32(&s.streamCounter, -1)
571572
<-counter // read signal from the channel
572573
}()
Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
package utils
1+
package handlers
22

33
import (
4+
"context"
45
"errors"
56
"io"
67
"net"
@@ -9,7 +10,7 @@ import (
910
"github.com/sirupsen/logrus"
1011
)
1112

12-
func TCPConnectionHandler(from net.Conn, to net.Conn, logger *logrus.Logger, usage *web.Usage, remotePort int, sniffer bool) {
13+
func TCPConnectionHandler(ctx context.Context, from net.Conn, to net.Conn, logger *logrus.Logger, usage *web.Usage, remotePort int, sniffer bool) {
1314
done := make(chan struct{})
1415

1516
go func() {
@@ -19,7 +20,13 @@ func TCPConnectionHandler(from net.Conn, to net.Conn, logger *logrus.Logger, usa
1920

2021
transferData(to, from, logger, usage, remotePort, sniffer)
2122

22-
<-done
23+
select {
24+
case <-ctx.Done():
25+
from.Close()
26+
to.Close()
27+
return
28+
case <-done:
29+
}
2330
}
2431

2532
// Using direct Read and Write for transferring data
Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
package utils
1+
package handlers
22

33
import (
4+
"context"
45
"errors"
56
"io"
67
"net"
@@ -11,7 +12,7 @@ import (
1112
)
1213

1314
// WebSocketToTCPConnectionHandler handles data transfer between a WebSocket and a TCP connection
14-
func WSConnectionHandler(wsConn *websocket.Conn, tcpConn net.Conn, logger *logrus.Logger, usage *web.Usage, remotePort int, sniffer bool) {
15+
func WSConnectionHandler(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn, logger *logrus.Logger, usage *web.Usage, remotePort int, sniffer bool) {
1516
done := make(chan struct{})
1617

1718
go func() {
@@ -21,7 +22,13 @@ func WSConnectionHandler(wsConn *websocket.Conn, tcpConn net.Conn, logger *logru
2122

2223
transferTCPToWebSocket(tcpConn, wsConn, logger, usage, remotePort, sniffer)
2324

24-
<-done
25+
select {
26+
case <-ctx.Done():
27+
wsConn.Close()
28+
tcpConn.Close()
29+
return
30+
case <-done:
31+
}
2532
}
2633

2734
// transferWebSocketToTCP transfers data from a WebSocket connection to a TCP connection

0 commit comments

Comments
 (0)