Skip to content

Commit 6f5f7fd

Browse files
committed
feat: Add source port rule
1 parent 3c45991 commit 6f5f7fd

File tree

7 files changed

+55
-9
lines changed

7 files changed

+55
-9
lines changed

adapter/metadata.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type Metadata struct {
2121
ConnectionID string
2222
ServiceName string
2323
SniffedProtocol Protocol
24-
SourceIP netip.Addr
24+
SourceAddress netip.AddrPort
2525
DestinationHostname string
2626
DestinationPort uint16
2727
Minecraft *MinecraftMetadata

route/router.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ func (r *Router) HandleConnection(conn net.Conn, metadata *adapter.Metadata) {
142142
}
143143
r.access.RUnlock()
144144
err = bufio.CopyConn(destinationConn, cachedConn)
145+
logger := r.logger.Warn().Str("id", metadata.ConnectionID).Str("outbound", outbound.Name())
145146
if err != nil {
146-
r.logger.Warn().Str("id", metadata.ConnectionID).Str("outbound", outbound.Name()).Err(err).Msg("Handled connection")
147-
} else {
148-
r.logger.Info().Str("id", metadata.ConnectionID).Str("outbound", outbound.Name()).Msg("Handled connection")
147+
logger = logger.Err(err)
149148
}
149+
logger.Msg("Handled outbound connection")
150150
cachedConn.Close()
151151
return
152152
}

route/rule.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ func NewRule(logger *log.Logger, config *config.Rule, listMap map[string]set.Str
3939
return NewSourceIPVersionRule(config)
4040
case "SourceIP":
4141
return NewSourceIPRule(config, listMap)
42+
case "SourcePort":
43+
return NewSourcePortRule(config)
4244
case "MinecraftHostname":
4345
return NewMinecraftHostnameRule(config, listMap)
4446
case "MinecraftPlayerName":

route/rule_ip.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (r *RuleSourceIP) Config() *config.Rule {
8383
}
8484

8585
func (r *RuleSourceIP) Match(metadata *adapter.Metadata) (match bool) {
86-
match = r.set.Contains(metadata.SourceIP.WithZone(""))
86+
match = r.set.Contains(metadata.SourceAddress.Addr().WithZone(""))
8787
if r.config.Invert {
8888
match = !match
8989
}

route/rule_ipversion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func (r *RuleSourceIPVersion) Config() *config.Rule {
3232
}
3333

3434
func (r *RuleSourceIPVersion) Match(metadata *adapter.Metadata) (match bool) {
35-
if metadata.SourceIP.Is4() {
35+
if metadata.SourceAddress.Addr().Is4() {
3636
match = r.version == 4
3737
} else {
3838
match = r.version == 6

route/rule_port.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package route
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
7+
"github.com/layou233/zbproxy/v3/adapter"
8+
"github.com/layou233/zbproxy/v3/config"
9+
)
10+
11+
type RuleSourcePort struct {
12+
config *config.Rule
13+
ports map[uint16]struct{}
14+
}
15+
16+
var _ Rule = (*RuleSourcePort)(nil)
17+
18+
func NewSourcePortRule(config *config.Rule) (*RuleSourcePort, error) {
19+
var ports []uint16
20+
err := json.Unmarshal(config.Parameter, &ports)
21+
if err != nil {
22+
return nil, fmt.Errorf("bad port list [%v]: %w", config.Parameter, err)
23+
}
24+
portsMap := make(map[uint16]struct{}, len(ports))
25+
for _, port := range ports {
26+
portsMap[port] = struct{}{}
27+
}
28+
return &RuleSourcePort{
29+
config: config,
30+
ports: portsMap,
31+
}, nil
32+
}
33+
34+
func (r *RuleSourcePort) Config() *config.Rule {
35+
return r.config
36+
}
37+
38+
func (r *RuleSourcePort) Match(metadata *adapter.Metadata) (match bool) {
39+
_, match = r.ports[metadata.SourceAddress.Port()]
40+
if r.config.Invert {
41+
match = !match
42+
}
43+
return
44+
}

service/service.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ func (s *Service) listenLoop() {
5050
if err != nil {
5151
return
5252
}
53-
ip := conn.RemoteAddr().(*net.TCPAddr).IP
54-
ipString := ip.String()
53+
tcpAddress := conn.RemoteAddr().(*net.TCPAddr)
54+
ipString := tcpAddress.IP.String()
5555
if s.ipAccessLists != nil &&
5656
!access.Check(s.ipAccessLists, s.config.IPAccess.Mode, ipString) {
5757
conn.SetLinger(0)
@@ -63,7 +63,7 @@ func (s *Service) listenLoop() {
6363
ServiceName: s.config.Name,
6464
DestinationHostname: s.config.TargetAddress,
6565
DestinationPort: s.config.TargetPort,
66-
SourceIP: common.MustOK(netip.AddrFromSlice(ip)).Unmap(),
66+
SourceAddress: netip.AddrPortFrom(common.MustOK(netip.AddrFromSlice(tcpAddress.IP)).Unmap(), uint16(tcpAddress.Port)),
6767
}
6868
metadata.GenerateID()
6969
s.logger.Info().Str("id", metadata.ConnectionID).Str("service", s.config.Name).

0 commit comments

Comments
 (0)