Skip to content

Commit 3e7e0fd

Browse files
committed
refactor: organize data forwarder interface
1 parent 742afa0 commit 3e7e0fd

File tree

6 files changed

+78
-66
lines changed

6 files changed

+78
-66
lines changed

cmd/main.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ func parseOptionsToRules() []rules.Rule {
8585

8686
func startForwarding(r rules.Rule) error {
8787
fc := forward.NewForwardingConfig().WithSourceAddr(r.SourceAddr).
88-
WithDestinationAddr(r.DestinationAddr).WithProtocol(r.Protocol).
89-
WithUDPDataForwarder(forward.NewSimpleUDPDataForwarder().
90-
SetBufferSize(r.UDPBufferSize).SetDeadlineSecond(r.UDPTimeoutSecond))
91-
if r.BandwidthLimit == forward.DefaultTCPBandwidthLimit {
92-
fc.WithTCPDataForwarder(forward.NewSimpleTCPDataForwarder())
93-
} else {
94-
fc.WithTCPDataForwarder(forward.NewTrafficControlTCPDataForwarder().SetBandwidthLimit(r.BandwidthLimit))
88+
WithDestinationAddr(r.DestinationAddr).WithProtocol(r.Protocol)
89+
switch r.Protocol {
90+
case "tcp":
91+
fc.WithDataForwarder(forward.NewTCPDataForwarder().SetBandwidthLimit(r.BandwidthLimit))
92+
case "udp":
93+
fc.WithDataForwarder(forward.NewUDPDataForwarder().SetBandwidthLimit(r.BandwidthLimit).
94+
SetDeadlineSecond(r.UDPTimeoutSecond).SetBufferSize(r.UDPBufferSize))
9595
}
9696
return fc.StartPortForwarding()
9797
}

pkg/forward/constants.go

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package forward
2+
3+
const DefaultBandwidthLimit uint64 = 0
4+
5+
const DefaultUDPBufferSize uint64 = 1024
6+
const DefaultUDPDeadlineSecond uint64 = 5

pkg/forward/forward.go

+8-14
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ import (
88
)
99

1010
type ForwardingConfig struct {
11-
SourceAddr string
12-
DestinationAddr string
13-
Protocol string
14-
TCPDataForwarder TCPDataForwarder
15-
UDPDataForwarder UDPDataForwarder
11+
SourceAddr string
12+
DestinationAddr string
13+
Protocol string
14+
DataForwarder DataForwarder
1615
}
1716

1817
func NewForwardingConfig() *ForwardingConfig {
@@ -34,13 +33,8 @@ func (f *ForwardingConfig) WithProtocol(protocol string) *ForwardingConfig {
3433
return f
3534
}
3635

37-
func (f *ForwardingConfig) WithTCPDataForwarder(tcpDataForwarder TCPDataForwarder) *ForwardingConfig {
38-
f.TCPDataForwarder = tcpDataForwarder
39-
return f
40-
}
41-
42-
func (f *ForwardingConfig) WithUDPDataForwarder(udpDataForwarder UDPDataForwarder) *ForwardingConfig {
43-
f.UDPDataForwarder = udpDataForwarder
36+
func (f *ForwardingConfig) WithDataForwarder(dataForwarder DataForwarder) *ForwardingConfig {
37+
f.DataForwarder = dataForwarder
4438
return f
4539
}
4640

@@ -83,7 +77,7 @@ func (f *ForwardingConfig) startTCPPortForwarding() error {
8377
continue
8478
}
8579
go func() {
86-
f.TCPDataForwarder.Forward(localConn, remoteConn)
80+
f.DataForwarder.Forward(localConn, remoteConn)
8781
localConn.Close()
8882
log.Printf("TCP connection disconnected from %s\n", localConn.RemoteAddr())
8983
remoteConn.Close()
@@ -114,7 +108,7 @@ func (f *ForwardingConfig) startUDPPortForwarding() error {
114108
}
115109
defer remoteConn.Close()
116110

117-
f.UDPDataForwarder.Forward(*localConn, *remoteConn)
111+
f.DataForwarder.Forward(&*localConn, &*remoteConn)
118112

119113
return nil
120114
}

pkg/forward/tcp.go

+22-22
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,23 @@ import (
66
"net"
77
)
88

9-
type SimpleTCPDataForwarder struct{}
9+
type TCPDataForwarder struct {
10+
BandwidthLimit uint64
11+
}
12+
13+
func NewTCPDataForwarder() *TCPDataForwarder {
14+
return &TCPDataForwarder{BandwidthLimit: DefaultBandwidthLimit}
15+
}
16+
17+
func (f *TCPDataForwarder) Forward(sourceConn, destinationConn net.Conn) error {
18+
if f.BandwidthLimit != DefaultBandwidthLimit {
19+
return f.ForwardWithTrafficControl(sourceConn, destinationConn)
20+
} else {
21+
return f.ForwardWithNormal(sourceConn, destinationConn)
22+
}
23+
}
1024

11-
func (f *SimpleTCPDataForwarder) Forward(sourceConn, destinationConn net.Conn) error {
25+
func (f *TCPDataForwarder) ForwardWithNormal(sourceConn, destinationConn net.Conn) error {
1226
done := make(chan *ForwardingError, 2)
1327

1428
go func() {
@@ -31,26 +45,7 @@ func (f *SimpleTCPDataForwarder) Forward(sourceConn, destinationConn net.Conn) e
3145
return nil
3246
}
3347

34-
func NewSimpleTCPDataForwarder() *SimpleTCPDataForwarder {
35-
return &SimpleTCPDataForwarder{}
36-
}
37-
38-
func NewTrafficControlTCPDataForwarder() *TrafficControlTCPDataForwarder {
39-
return &TrafficControlTCPDataForwarder{BandwidthLimit: DefaultTCPBandwidthLimit}
40-
}
41-
42-
const DefaultTCPBandwidthLimit uint64 = 0
43-
44-
type TrafficControlTCPDataForwarder struct {
45-
BandwidthLimit uint64
46-
}
47-
48-
func (f *TrafficControlTCPDataForwarder) SetBandwidthLimit(bandwidthLimit uint64) *TrafficControlTCPDataForwarder {
49-
f.BandwidthLimit = bandwidthLimit
50-
return f
51-
}
52-
53-
func (f *TrafficControlTCPDataForwarder) Forward(sourceConn, destinationConn net.Conn) error {
48+
func (f *TCPDataForwarder) ForwardWithTrafficControl(sourceConn, destinationConn net.Conn) error {
5449
done := make(chan *ForwardingError, 2)
5550

5651
go func() {
@@ -76,3 +71,8 @@ func (f *TrafficControlTCPDataForwarder) Forward(sourceConn, destinationConn net
7671

7772
return nil
7873
}
74+
75+
func (f *TCPDataForwarder) SetBandwidthLimit(bandwidthLimit uint64) *TCPDataForwarder {
76+
f.BandwidthLimit = bandwidthLimit
77+
return f
78+
}

pkg/forward/types.go

+3-5
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@ package forward
22

33
import "net"
44

5-
type TCPDataForwarder interface {
5+
type DataForwarder interface {
66
Forward(sourceConn, destinationConn net.Conn) error
7-
}
8-
9-
type UDPDataForwarder interface {
10-
Forward(sourceConn, destinationConn net.UDPConn)
7+
ForwardWithNormal(sourceConn, destinationConn net.Conn) error
8+
ForwardWithTrafficControl(sourceConn, destinationConn net.Conn) error
119
}

pkg/forward/udp.go

+32-18
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,31 @@ import (
66
"time"
77
)
88

9-
const DefaultUDPBufferSize uint64 = 1024
10-
const DefaultUDPDeadlineSecond uint64 = 5
11-
12-
type SimpleUDPDataForwarder struct {
9+
type UDPDataForwarder struct {
1310
BufferSize uint64
11+
BandwidthLimit uint64
1412
DeadlineSecond time.Duration
1513
}
1614

17-
func NewSimpleUDPDataForwarder() *SimpleUDPDataForwarder {
18-
return &SimpleUDPDataForwarder{
15+
func NewUDPDataForwarder() *UDPDataForwarder {
16+
return &UDPDataForwarder{
17+
BandwidthLimit: DefaultBandwidthLimit,
1918
BufferSize: DefaultUDPBufferSize,
2019
DeadlineSecond: time.Duration(DefaultUDPDeadlineSecond),
2120
}
2221
}
2322

24-
func (f *SimpleUDPDataForwarder) SetBufferSize(size uint64) *SimpleUDPDataForwarder {
25-
f.BufferSize = size
26-
return f
27-
}
28-
29-
func (f *SimpleUDPDataForwarder) SetDeadlineSecond(second uint64) *SimpleUDPDataForwarder {
30-
f.DeadlineSecond = time.Duration(second)
31-
return f
23+
func (f *UDPDataForwarder) Forward(sourceConn, destinationConn net.Conn) error {
24+
return f.ForwardWithNormal(sourceConn, destinationConn)
3225
}
3326

34-
func (f *SimpleUDPDataForwarder) Forward(sourceConn, destinationConn net.UDPConn) {
27+
func (f *UDPDataForwarder) ForwardWithNormal(sourceConn, destinationConn net.Conn) error {
28+
sourceUDPConn, _ := sourceConn.(*net.UDPConn)
29+
destinationUDPConn, _ := destinationConn.(*net.UDPConn)
3530
sourceConnBuffer := make([]byte, f.BufferSize)
3631
for {
3732
sourceConn.SetReadDeadline(time.Now().Add(f.DeadlineSecond * time.Second))
38-
n, sourceConnAddr, err := sourceConn.ReadFromUDP(sourceConnBuffer)
33+
n, sourceConnAddr, err := sourceUDPConn.ReadFromUDP(sourceConnBuffer)
3934
if err != nil {
4035
continue
4136
}
@@ -51,17 +46,36 @@ func (f *SimpleUDPDataForwarder) Forward(sourceConn, destinationConn net.UDPConn
5146

5247
destinationConnBuffer := make([]byte, f.BufferSize)
5348
destinationConn.SetReadDeadline(time.Now().Add(f.DeadlineSecond * time.Second))
54-
m, _, err := destinationConn.ReadFromUDP(destinationConnBuffer)
49+
m, _, err := destinationUDPConn.ReadFromUDP(destinationConnBuffer)
5550
var netErr net.Error
5651
if errors.As(err, &netErr) && netErr.Timeout() {
5752
return
5853
}
5954

60-
_, err = sourceConn.WriteToUDP(destinationConnBuffer[:m], sourceConnAddr)
55+
_, err = sourceUDPConn.WriteToUDP(destinationConnBuffer[:m], sourceConnAddr)
6156
if err != nil {
6257
return
6358
}
6459

6560
}(data, sourceConnAddr)
6661
}
6762
}
63+
64+
func (f *UDPDataForwarder) ForwardWithTrafficControl(sourceConn, destinationConn net.Conn) error {
65+
return f.ForwardWithNormal(sourceConn, destinationConn)
66+
}
67+
68+
func (f *UDPDataForwarder) SetBandwidthLimit(bandwidthLimit uint64) *UDPDataForwarder {
69+
f.BandwidthLimit = bandwidthLimit
70+
return f
71+
}
72+
73+
func (f *UDPDataForwarder) SetBufferSize(size uint64) *UDPDataForwarder {
74+
f.BufferSize = size
75+
return f
76+
}
77+
78+
func (f *UDPDataForwarder) SetDeadlineSecond(second uint64) *UDPDataForwarder {
79+
f.DeadlineSecond = time.Duration(second)
80+
return f
81+
}

0 commit comments

Comments
 (0)