1
1
package forward
2
2
3
3
import (
4
- "errors"
4
+ "context"
5
+ "golang.org/x/time/rate"
5
6
"net"
6
7
"time"
7
8
)
@@ -21,7 +22,11 @@ func NewUDPDataForwarder() *UDPDataForwarder {
21
22
}
22
23
23
24
func (f * UDPDataForwarder ) Forward (sourceConn , destinationConn net.Conn ) error {
24
- return f .ForwardWithNormal (sourceConn , destinationConn )
25
+ if f .BandwidthLimit != DefaultBandwidthLimit {
26
+ return f .ForwardWithTrafficControl (sourceConn , destinationConn )
27
+ } else {
28
+ return f .ForwardWithNormal (sourceConn , destinationConn )
29
+ }
25
30
}
26
31
27
32
func (f * UDPDataForwarder ) ForwardWithNormal (sourceConn , destinationConn net.Conn ) error {
@@ -47,8 +52,7 @@ func (f *UDPDataForwarder) ForwardWithNormal(sourceConn, destinationConn net.Con
47
52
destinationConnBuffer := make ([]byte , f .BufferSize )
48
53
destinationConn .SetReadDeadline (time .Now ().Add (f .DeadlineSecond * time .Second ))
49
54
m , _ , err := destinationUDPConn .ReadFromUDP (destinationConnBuffer )
50
- var netErr net.Error
51
- if errors .As (err , & netErr ) && netErr .Timeout () {
55
+ if err != nil {
52
56
return
53
57
}
54
58
@@ -62,7 +66,52 @@ func (f *UDPDataForwarder) ForwardWithNormal(sourceConn, destinationConn net.Con
62
66
}
63
67
64
68
func (f * UDPDataForwarder ) ForwardWithTrafficControl (sourceConn , destinationConn net.Conn ) error {
65
- return f .ForwardWithNormal (sourceConn , destinationConn )
69
+ sourceUDPConn , _ := sourceConn .(* net.UDPConn )
70
+ destinationUDPConn , _ := destinationConn .(* net.UDPConn )
71
+
72
+ limiter := rate .NewLimiter (rate .Limit (f .BandwidthLimit * 1024 / 8 ), int (f .BandwidthLimit * 1024 / 8 ))
73
+
74
+ sourceConnBuffer := make ([]byte , f .BufferSize )
75
+ for {
76
+ sourceConn .SetReadDeadline (time .Now ().Add (f .DeadlineSecond * time .Second ))
77
+ n , sourceConnAddr , err := sourceUDPConn .ReadFromUDP (sourceConnBuffer )
78
+ if err != nil {
79
+ continue
80
+ }
81
+
82
+ data := make ([]byte , n )
83
+ copy (data , sourceConnBuffer [:n ])
84
+
85
+ go func (data []byte , sourceConnAddr * net.UDPAddr ) {
86
+ err := limiter .WaitN (context .Background (), n )
87
+ if err != nil {
88
+ return
89
+ }
90
+
91
+ _ , err = destinationConn .Write (data )
92
+ if err != nil {
93
+ return
94
+ }
95
+
96
+ destinationConnBuffer := make ([]byte , f .BufferSize )
97
+ destinationConn .SetReadDeadline (time .Now ().Add (f .DeadlineSecond * time .Second ))
98
+ m , _ , err := destinationUDPConn .ReadFromUDP (destinationConnBuffer )
99
+ if err != nil {
100
+ return
101
+ }
102
+
103
+ err = limiter .WaitN (context .Background (), m )
104
+ if err != nil {
105
+ return
106
+ }
107
+
108
+ _ , err = sourceUDPConn .WriteToUDP (destinationConnBuffer [:m ], sourceConnAddr )
109
+ if err != nil {
110
+ return
111
+ }
112
+
113
+ }(data , sourceConnAddr )
114
+ }
66
115
}
67
116
68
117
func (f * UDPDataForwarder ) SetBandwidthLimit (bandwidthLimit uint64 ) * UDPDataForwarder {
0 commit comments