From 99840ba0f40d68ed4109ee1484567039eafb6360 Mon Sep 17 00:00:00 2001 From: beryll1um Date: Mon, 14 Oct 2024 02:15:29 +0200 Subject: [PATCH] Add high-level boundary management interface I rewrote the router's boundary management part to implement dynamic management from a high-level box interface. This also includes a number of changes I made in the process of rewriting some messy parts, such as the Outbound tree bottom-top starter. --- adapter/router.go | 18 +- box.go | 225 ++++----- box_outbound.go | 85 ---- common/process/searcher.go | 6 +- common/process/searcher_linux.go | 4 +- common/process/searcher_linux_shared.go | 2 +- experimental/libbox/config.go | 2 +- experimental/libbox/internal/procfs/procfs.go | 2 +- experimental/libbox/platform.go | 2 +- experimental/libbox/service.go | 6 +- inbound/default.go | 6 + route/router.go | 437 ++++++++++++------ route/router_outbound_starter.go | 60 +++ 13 files changed, 503 insertions(+), 352 deletions(-) delete mode 100644 box_outbound.go create mode 100644 route/router_outbound_starter.go diff --git a/adapter/router.go b/adapter/router.go index 619c1110cb..4abe38d8cb 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -18,14 +18,28 @@ import ( ) type Router interface { - Service + AddOutbound(outbound Outbound) error + AddInbound(inbound Inbound) error + + RemoveOutbound(tag string) error + RemoveInbound(tag string) error + PreStarter + + StartOutbounds() error + + Service + + StartInbounds() error + PostStarter + Cleanup() error + DefaultOutbound(network string) (Outbound, error) Outbounds() []Outbound Outbound(tag string) (Outbound, bool) - DefaultOutbound(network string) (Outbound, error) + Inbound(tag string) (Inbound, bool) FakeIPStore() FakeIPStore diff --git a/box.go b/box.go index 716b1b093c..46c5fec2fe 100644 --- a/box.go +++ b/box.go @@ -29,16 +29,16 @@ import ( var _ adapter.Service = (*Box)(nil) type Box struct { - createdAt time.Time - router adapter.Router - inbounds []adapter.Inbound - outbounds []adapter.Outbound - logFactory log.Factory - logger log.ContextLogger - preServices1 map[string]adapter.Service - preServices2 map[string]adapter.Service - postServices map[string]adapter.Service - done chan struct{} + createdAt time.Time + router adapter.Router + logFactory log.Factory + logger log.ContextLogger + preServices1 map[string]adapter.Service + preServices2 map[string]adapter.Service + postServices map[string]adapter.Service + platformInterface platform.Interface + ctx context.Context + done chan struct{} } type Options struct { @@ -97,57 +97,6 @@ func New(options Options) (*Box, error) { if err != nil { return nil, E.Cause(err, "parse route options") } - inbounds := make([]adapter.Inbound, 0, len(options.Inbounds)) - outbounds := make([]adapter.Outbound, 0, len(options.Outbounds)) - for i, inboundOptions := range options.Inbounds { - var in adapter.Inbound - var tag string - if inboundOptions.Tag != "" { - tag = inboundOptions.Tag - } else { - tag = F.ToString(i) - } - in, err = inbound.New( - ctx, - router, - logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")), - tag, - inboundOptions, - options.PlatformInterface, - ) - if err != nil { - return nil, E.Cause(err, "parse inbound[", i, "]") - } - inbounds = append(inbounds, in) - } - for i, outboundOptions := range options.Outbounds { - var out adapter.Outbound - var tag string - if outboundOptions.Tag != "" { - tag = outboundOptions.Tag - } else { - tag = F.ToString(i) - } - out, err = outbound.New( - ctx, - router, - logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")), - tag, - outboundOptions) - if err != nil { - return nil, E.Cause(err, "parse outbound[", i, "]") - } - outbounds = append(outbounds, out) - } - err = router.Initialize(inbounds, outbounds, func() adapter.Outbound { - out, oErr := outbound.New(ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.Outbound{Type: "direct", Tag: "default"}) - common.Must(oErr) - outbounds = append(outbounds, out) - return out - }) - if err != nil { - return nil, err - } if options.PlatformInterface != nil { err = options.PlatformInterface.Initialize(ctx, router) if err != nil { @@ -183,18 +132,35 @@ func New(options Options) (*Box, error) { router.SetV2RayServer(v2rayServer) preServices2["v2ray api"] = v2rayServer } - return &Box{ - router: router, - inbounds: inbounds, - outbounds: outbounds, - createdAt: createdAt, - logFactory: logFactory, - logger: logFactory.Logger(), - preServices1: preServices1, - preServices2: preServices2, - postServices: postServices, - done: make(chan struct{}), - }, nil + box := &Box{ + router: router, + createdAt: createdAt, + logFactory: logFactory, + logger: logFactory.Logger(), + preServices1: preServices1, + preServices2: preServices2, + postServices: postServices, + platformInterface: options.PlatformInterface, + ctx: ctx, + done: make(chan struct{}), + } + for i, outOpts := range options.Outbounds { + if outOpts.Tag == "" { + outOpts.Tag = F.ToString(i) + } + if err := box.AddOutbound(outOpts); err != nil { + return nil, E.Cause(err, "create outbound") + } + } + for i, inOpts := range options.Inbounds { + if inOpts.Tag == "" { + inOpts.Tag = F.ToString(i) + } + if err := box.AddInbound(inOpts); err != nil { + return nil, E.Cause(err, "create inbound") + } + } + return box, nil } func (s *Box) PreStart() error { @@ -263,12 +229,10 @@ func (s *Box) preStart() error { } } } - err = s.router.PreStart() - if err != nil { + if err := s.router.PreStart(); err != nil { return E.Cause(err, "pre-start router") } - err = s.startOutbounds() - if err != nil { + if err := s.router.StartOutbounds(); err != nil { return err } return s.router.Start() @@ -291,20 +255,10 @@ func (s *Box) start() error { return E.Cause(err, "start ", serviceName) } } - for i, in := range s.inbounds { - var tag string - if in.Tag() == "" { - tag = F.ToString(i) - } else { - tag = in.Tag() - } - err = in.Start() - if err != nil { - return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]") - } + if err := s.router.StartInbounds(); err != nil { + return E.Cause(err, "start inbounds") } - err = s.postStart() - if err != nil { + if err = s.postStart(); err != nil { return err } return s.router.Cleanup() @@ -317,26 +271,8 @@ func (s *Box) postStart() error { return E.Cause(err, "start ", serviceName) } } - // TODO: reorganize ALL start order - for _, out := range s.outbounds { - if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound { - err := lateOutbound.PostStart() - if err != nil { - return E.Cause(err, "post-start outbound/", out.Tag()) - } - } - } - err := s.router.PostStart() - if err != nil { - return err - } - for _, in := range s.inbounds { - if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound { - err = lateInbound.PostStart() - if err != nil { - return E.Cause(err, "post-start inbound/", in.Tag()) - } - } + if err := s.router.PostStart(); err != nil { + return E.Cause(err, "post-start") } return nil } @@ -357,20 +293,6 @@ func (s *Box) Close() error { }) monitor.Finish() } - for i, in := range s.inbounds { - monitor.Start("close inbound/", in.Type(), "[", i, "]") - errors = E.Append(errors, in.Close(), func(err error) error { - return E.Cause(err, "close inbound/", in.Type(), "[", i, "]") - }) - monitor.Finish() - } - for i, out := range s.outbounds { - monitor.Start("close outbound/", out.Type(), "[", i, "]") - errors = E.Append(errors, common.Close(out), func(err error) error { - return E.Cause(err, "close outbound/", out.Type(), "[", i, "]") - }) - monitor.Finish() - } monitor.Start("close router") if err := common.Close(s.router); err != nil { errors = E.Append(errors, err, func(err error) error { @@ -403,3 +325,58 @@ func (s *Box) Close() error { func (s *Box) Router() adapter.Router { return s.router } + +func (s *Box) AddOutbound(option option.Outbound) error { + if option.Tag == "" { + return E.New("empty tag") + } + out, err := outbound.New( + s.ctx, + s.router, + s.logFactory.NewLogger(F.ToString("outbound/", option.Type, "[", option.Tag, "]")), + option.Tag, + option, + ) + if err != nil { + return E.Cause(err, "parse addited outbound") + } + if err := s.router.AddOutbound(out); err != nil { + return E.Cause(err, "outbound/", option.Type, "[", option.Tag, "]") + } + return nil +} + +func (s *Box) AddInbound(option option.Inbound) error { + if option.Tag == "" { + return E.New("empty tag") + } + in, err := inbound.New( + s.ctx, + s.router, + s.logFactory.NewLogger(F.ToString("inbound/", option.Type, "[", option.Tag, "]")), + option.Tag, + option, + s.platformInterface, + ) + if err != nil { + return E.Cause(err, "parse addited inbound") + } + if err := s.router.AddInbound(in); err != nil { + return E.Cause(err, "inbound/", option.Type, "[", option.Tag, "]") + } + return nil +} + +func (s *Box) RemoveOutbound(tag string) error { + if err := s.router.RemoveOutbound(tag); err != nil { + return E.Cause(err, "outbound[", tag, "]") + } + return nil +} + +func (s *Box) RemoveInbound(tag string) error { + if err := s.router.RemoveInbound(tag); err != nil { + return E.Cause(err, "inbound[", tag, "]") + } + return nil +} diff --git a/box_outbound.go b/box_outbound.go deleted file mode 100644 index f03f3b7d41..0000000000 --- a/box_outbound.go +++ /dev/null @@ -1,85 +0,0 @@ -package box - -import ( - "strings" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/taskmonitor" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" -) - -func (s *Box) startOutbounds() error { - monitor := taskmonitor.New(s.logger, C.StartTimeout) - outboundTags := make(map[adapter.Outbound]string) - outbounds := make(map[string]adapter.Outbound) - for i, outboundToStart := range s.outbounds { - var outboundTag string - if outboundToStart.Tag() == "" { - outboundTag = F.ToString(i) - } else { - outboundTag = outboundToStart.Tag() - } - if _, exists := outbounds[outboundTag]; exists { - return E.New("outbound tag ", outboundTag, " duplicated") - } - outboundTags[outboundToStart] = outboundTag - outbounds[outboundTag] = outboundToStart - } - started := make(map[string]bool) - for { - canContinue := false - startOne: - for _, outboundToStart := range s.outbounds { - outboundTag := outboundTags[outboundToStart] - if started[outboundTag] { - continue - } - dependencies := outboundToStart.Dependencies() - for _, dependency := range dependencies { - if !started[dependency] { - continue startOne - } - } - started[outboundTag] = true - canContinue = true - if starter, isStarter := outboundToStart.(interface { - Start() error - }); isStarter { - monitor.Start("initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]") - err := starter.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]") - } - } - } - if len(started) == len(s.outbounds) { - break - } - if canContinue { - continue - } - currentOutbound := common.Find(s.outbounds, func(it adapter.Outbound) bool { - return !started[outboundTags[it]] - }) - var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error - lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error { - problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool { - return !started[it] - }) - if common.Contains(oTree, problemOutboundTag) { - return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag) - } - problemOutbound := outbounds[problemOutboundTag] - if problemOutbound == nil { - return E.New("dependency[", problemOutboundTag, "] not found for outbound[", outboundTags[oCurrent], "]") - } - return lintOutbound(append(oTree, problemOutboundTag), problemOutbound) - } - return lintOutbound([]string{outboundTags[currentOutbound]}, currentOutbound) - } - return nil -} diff --git a/common/process/searcher.go b/common/process/searcher.go index cee81068ca..97df88c940 100644 --- a/common/process/searcher.go +++ b/common/process/searcher.go @@ -12,7 +12,7 @@ import ( ) type Searcher interface { - FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) + FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*Info, error) } var ErrNotFound = E.New("process not found") @@ -29,8 +29,8 @@ type Info struct { UserId int32 } -func FindProcessInfo(searcher Searcher, ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) { - info, err := searcher.FindProcessInfo(ctx, network, source, destination) +func FindProcessInfo(searcher Searcher, ctx context.Context, network string, source netip.AddrPort) (*Info, error) { + info, err := searcher.FindProcessInfo(ctx, network, source) if err != nil { return nil, err } diff --git a/common/process/searcher_linux.go b/common/process/searcher_linux.go index 39470205a4..037e377c08 100644 --- a/common/process/searcher_linux.go +++ b/common/process/searcher_linux.go @@ -19,8 +19,8 @@ func NewSearcher(config Config) (Searcher, error) { return &linuxSearcher{config.Logger}, nil } -func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) { - inode, uid, err := resolveSocketByNetlink(network, source, destination) +func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*Info, error) { + inode, uid, err := resolveSocketByNetlink(network, source) if err != nil { return nil, err } diff --git a/common/process/searcher_linux_shared.go b/common/process/searcher_linux_shared.go index e75b0b4f9d..220a58615f 100644 --- a/common/process/searcher_linux_shared.go +++ b/common/process/searcher_linux_shared.go @@ -36,7 +36,7 @@ const ( pathProc = "/proc" ) -func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) { +func resolveSocketByNetlink(network string, source netip.AddrPort) (inode, uid uint32, err error) { var family uint8 var protocol uint8 diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index df8b6ee34e..c3d33f0899 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -94,7 +94,7 @@ func (s *platformInterfaceStub) ReadWIFIState() adapter.WIFIState { return adapter.WIFIState{} } -func (s *platformInterfaceStub) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*process.Info, error) { +func (s *platformInterfaceStub) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*process.Info, error) { return nil, os.ErrInvalid } diff --git a/experimental/libbox/internal/procfs/procfs.go b/experimental/libbox/internal/procfs/procfs.go index 8c918a799f..2d8e11e2aa 100644 --- a/experimental/libbox/internal/procfs/procfs.go +++ b/experimental/libbox/internal/procfs/procfs.go @@ -30,7 +30,7 @@ func init() { } } -func ResolveSocketByProcSearch(network string, source, _ netip.AddrPort) int32 { +func ResolveSocketByProcSearch(network string, source netip.AddrPort) int32 { if netIndexOfLocal < 0 || netIndexOfUid < 0 { return -1 } diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index 8306012ac2..8cf150ea77 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -10,7 +10,7 @@ type PlatformInterface interface { OpenTun(options TunOptions) (int32, error) WriteLog(message string) UseProcFS() bool - FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (int32, error) + FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32) (int32, error) PackageNameByUid(uid int32) (string, error) UIDByPackageName(packageName string) (int32, error) UsePlatformDefaultInterfaceMonitor() bool diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index cdfae04ca8..1f2fe6cbd0 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -203,10 +203,10 @@ func (w *platformInterfaceWrapper) ReadWIFIState() adapter.WIFIState { return (adapter.WIFIState)(*wifiState) } -func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*process.Info, error) { +func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*process.Info, error) { var uid int32 if w.useProcFS { - uid = procfs.ResolveSocketByProcSearch(network, source, destination) + uid = procfs.ResolveSocketByProcSearch(network, source) if uid == -1 { return nil, E.New("procfs: not found") } @@ -221,7 +221,7 @@ func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network return nil, E.New("unknown network: ", network) } var err error - uid, err = w.iif.FindConnectionOwner(ipProtocol, source.Addr().String(), int32(source.Port()), destination.Addr().String(), int32(destination.Port())) + uid, err = w.iif.FindConnectionOwner(ipProtocol, source.Addr().String(), int32(source.Port())) if err != nil { return nil, err } diff --git a/inbound/default.go b/inbound/default.go index 44c580deb9..f8490dcad5 100644 --- a/inbound/default.go +++ b/inbound/default.go @@ -115,6 +115,12 @@ func (a *myInboundAdapter) Start() error { func (a *myInboundAdapter) Close() error { a.inShutdown.Store(true) + if a.tcpListener != nil { + a.logger.Info("tcp server closed at ", a.tcpListener.Addr()) + } + if a.udpConn != nil { + a.logger.Info("udp server closed at ", a.udpConn.LocalAddr()) + } var err error if a.systemProxy != nil && a.systemProxy.IsEnabled() { err = a.systemProxy.Disable() diff --git a/route/router.go b/route/router.go index c8fe94be5b..de6a012d37 100644 --- a/route/router.go +++ b/route/router.go @@ -10,6 +10,7 @@ import ( "os/user" "runtime" "strings" + "sync" "time" "github.com/sagernet/sing-box/adapter" @@ -50,11 +51,15 @@ import ( var _ adapter.Router = (*Router)(nil) type Router struct { - ctx context.Context - logger log.ContextLogger - dnsLogger log.ContextLogger + ctx context.Context + logger log.ContextLogger + dnsLogger log.ContextLogger + // Currently this is responsible for protecting inbound and outbound dynamic + // control. I'm not sure if it can be separated because I haven't delved + // into the logic yet to make sure they don't interfere with each other. + // To research, may improve performance on some high-load setups. + boundary sync.RWMutex inboundByTag map[string]adapter.Inbound - outbounds []adapter.Outbound outboundByTag map[string]adapter.Outbound rules []adapter.Rule defaultDetour string @@ -113,6 +118,7 @@ func NewRouter( ctx: ctx, logger: logFactory.NewLogger("router"), dnsLogger: logFactory.NewLogger("dns"), + inboundByTag: make(map[string]adapter.Inbound), outboundByTag: make(map[string]adapter.Outbound), rules: make([]adapter.Rule, 0, len(options.Rules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), @@ -373,76 +379,6 @@ func NewRouter( return router, nil } -func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error { - inboundByTag := make(map[string]adapter.Inbound) - for _, inbound := range inbounds { - inboundByTag[inbound.Tag()] = inbound - } - outboundByTag := make(map[string]adapter.Outbound) - for _, detour := range outbounds { - outboundByTag[detour.Tag()] = detour - } - var defaultOutboundForConnection adapter.Outbound - var defaultOutboundForPacketConnection adapter.Outbound - if r.defaultDetour != "" { - detour, loaded := outboundByTag[r.defaultDetour] - if !loaded { - return E.New("default detour not found: ", r.defaultDetour) - } - if common.Contains(detour.Network(), N.NetworkTCP) { - defaultOutboundForConnection = detour - } - if common.Contains(detour.Network(), N.NetworkUDP) { - defaultOutboundForPacketConnection = detour - } - } - if defaultOutboundForConnection == nil { - for _, detour := range outbounds { - if common.Contains(detour.Network(), N.NetworkTCP) { - defaultOutboundForConnection = detour - break - } - } - } - if defaultOutboundForPacketConnection == nil { - for _, detour := range outbounds { - if common.Contains(detour.Network(), N.NetworkUDP) { - defaultOutboundForPacketConnection = detour - break - } - } - } - if defaultOutboundForConnection == nil || defaultOutboundForPacketConnection == nil { - detour := defaultOutbound() - if defaultOutboundForConnection == nil { - defaultOutboundForConnection = detour - } - if defaultOutboundForPacketConnection == nil { - defaultOutboundForPacketConnection = detour - } - outbounds = append(outbounds, detour) - outboundByTag[detour.Tag()] = detour - } - r.inboundByTag = inboundByTag - r.outbounds = outbounds - r.defaultOutboundForConnection = defaultOutboundForConnection - r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection - r.outboundByTag = outboundByTag - for i, rule := range r.rules { - if _, loaded := outboundByTag[rule.Outbound()]; !loaded { - return E.New("outbound not found for rule[", i, "]: ", rule.Outbound()) - } - } - return nil -} - -func (r *Router) Outbounds() []adapter.Outbound { - if !r.started { - return nil - } - return r.outbounds -} - func (r *Router) PreStart() error { monitor := taskmonitor.New(r.logger, C.StartTimeout) if r.interfaceMonitor != nil { @@ -581,9 +517,191 @@ func (r *Router) Start() error { return nil } +func (r *Router) Cleanup() error { + for _, ruleSet := range r.ruleSetMap { + ruleSet.Cleanup() + } + runtime.GC() + return nil +} + +func (r *Router) AddOutbound(out adapter.Outbound) error { + r.boundary.Lock() + defer r.boundary.Unlock() + + if _, ok := r.outboundByTag[out.Tag()]; ok { + return E.New("duplication of tag") + } + + if r.defaultDetour == "" || r.defaultDetour == out.Tag() { + if r.defaultOutboundForConnection == nil { + if common.Contains(out.Network(), N.NetworkTCP) { + r.defaultOutboundForConnection = out + } + } + if r.defaultOutboundForPacketConnection == nil { + if common.Contains(out.Network(), N.NetworkUDP) { + r.defaultOutboundForPacketConnection = out + } + } + } + + if r.started { + monitor := taskmonitor.New(r.logger, C.StartTimeout) + monitor.Start("initialize outbound/", out.Type(), "[", out.Tag(), "]") + defer monitor.Finish() + + if startable, isStartable := out.(interface{ Start() error }); isStartable { + if err := startable.Start(); err != nil { + return E.Cause(err, "start") + } + } + + if err := postStartOutbound(out); err != nil { + return E.Cause(err, "post start") + } + } + + r.outboundByTag[out.Tag()] = out + return nil +} + +func (r *Router) AddInbound(in adapter.Inbound) error { + r.boundary.Lock() + defer r.boundary.Unlock() + + if _, ok := r.inboundByTag[in.Tag()]; ok { + return E.New("duplication of tag") + } + + if r.started { + monitor := taskmonitor.New(r.logger, C.StartTimeout) + monitor.Start("initialize inbound/", in.Type(), "[", in.Tag(), "]") + defer monitor.Finish() + + if err := in.Start(); err != nil { + return E.Cause(err, "start") + } + + if err := postStartInbound(in); err != nil { + return E.Cause(err, "post-start") + } + } + + r.inboundByTag[in.Tag()] = in + return nil +} + +func (r *Router) RemoveOutbound(tag string) error { + r.boundary.Lock() + defer r.boundary.Unlock() + + out, ok := r.outboundByTag[tag] + if !ok { + return E.New("unknown tag") + } + delete(r.outboundByTag, tag) + + if out == r.defaultOutboundForConnection { + r.defaultOutboundForConnection = nil + } + if out == r.defaultOutboundForPacketConnection { + r.defaultOutboundForPacketConnection = nil + } + if r.defaultDetour == "" { + for _, out := range r.outboundByTag { + if r.defaultOutboundForConnection == nil { + if common.Contains(out.Network(), N.NetworkTCP) { + r.defaultOutboundForConnection = out + } + if common.Contains(out.Network(), N.NetworkUDP) { + r.defaultOutboundForPacketConnection = out + } + if r.defaultOutboundForConnection != nil && r.defaultOutboundForPacketConnection != nil { + break + } + } + } + } + + if r.started { + if err := common.Close(out); err != nil { + return E.Cause(err, "close") + } + } + + return nil +} + +func (r *Router) RemoveInbound(tag string) error { + r.boundary.Lock() + defer r.boundary.Unlock() + + in, ok := r.inboundByTag[tag] + if !ok { + return E.New("unknown tag") + } + delete(r.inboundByTag, tag) + + if r.started { + if err := in.Close(); err != nil { + return E.Cause(err, "close") + } + } + + return nil +} + +func (r *Router) StartOutbounds() error { + monitor := taskmonitor.New(r.logger, C.StartTimeout) + startedTags := make(map[string]struct{}) + + for tag, out := range r.outboundByTag { + if err := (&OutboundStarter{ + outboundByTag: r.outboundByTag, + startedTags: startedTags, + monitor: monitor, + }).Start(tag, make(map[string]struct{})); err != nil { + return E.Cause(err, "start outbound/", out.Type(), "[", tag, "]") + } + } + + return nil +} + +func (r *Router) StartInbounds() error { + for tag, in := range r.inboundByTag { + if err := in.Start(); err != nil { + return E.Cause(err, "start inbound/", in.Type(), "[", tag, "]") + } + } + return nil +} + +func (r *Router) closeBounds(monitor *taskmonitor.Monitor) error { + r.boundary.Lock() + defer r.boundary.Unlock() + var err error + for tag, in := range r.inboundByTag { + monitor.Start("close inbound/", in.Type(), "[", tag, "]") + err = E.Append(err, in.Close(), func(err error) error { + return E.Cause(err, "close inbound/", in.Type(), "[", tag, "]") + }) + monitor.Finish() + } + for tag, out := range r.outboundByTag { + monitor.Start("close outbound/", out.Type(), "[", tag, "]") + err = E.Append(err, common.Close(out), func(err error) error { + return E.Cause(err, "close outbound/", out.Type(), "[", tag, "]") + }) + monitor.Finish() + } + return err +} + func (r *Router) Close() error { monitor := taskmonitor.New(r.logger, C.StopTimeout) - var err error + err := r.closeBounds(monitor) for i, rule := range r.rules { monitor.Start("close rule[", i, "]") err = E.Append(err, rule.Close(), func(err error) error { @@ -654,10 +772,35 @@ func (r *Router) Close() error { }) monitor.Finish() } + r.started = false return err } +func postStartOutbound(out adapter.Outbound) error { + if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound { + if err := lateOutbound.PostStart(); err != nil { + return E.Cause(err, "outbound/", out.Type(), "[", out.Tag(), "]") + } + } + return nil +} + +func postStartInbound(in adapter.Inbound) error { + if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound { + if err := lateInbound.PostStart(); err != nil { + return E.Cause(err, "inbound/", in.Type(), "[", in.Tag(), "]") + } + } + return nil +} + func (r *Router) PostStart() error { + // TODO: reorganize ALL start order + for _, out := range r.outboundByTag { + if err := postStartOutbound(out); err != nil { + return err + } + } monitor := taskmonitor.New(r.logger, C.StopTimeout) if len(r.ruleSets) > 0 { monitor.Start("initialize rule-set") @@ -749,35 +892,58 @@ func (r *Router) PostStart() error { return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]") } } - r.started = true - return nil -} - -func (r *Router) Cleanup() error { - for _, ruleSet := range r.ruleSetMap { - ruleSet.Cleanup() + for _, in := range r.inboundByTag { + if err := postStartInbound(in); err != nil { + return err + } } - runtime.GC() + r.started = true return nil } -func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { - outbound, loaded := r.outboundByTag[tag] - return outbound, loaded -} - func (r *Router) DefaultOutbound(network string) (adapter.Outbound, error) { - if network == N.NetworkTCP { + r.boundary.RLock() + defer r.boundary.RUnlock() + switch network { + case N.NetworkTCP: if r.defaultOutboundForConnection == nil { return nil, E.New("missing default outbound for TCP connections") } return r.defaultOutboundForConnection, nil - } else { + case N.NetworkUDP: if r.defaultOutboundForPacketConnection == nil { return nil, E.New("missing default outbound for UDP connections") } return r.defaultOutboundForPacketConnection, nil } + return nil, E.New("wrong network type provided") +} + +func (r *Router) Outbounds() []adapter.Outbound { + if !r.started { + return nil + } + r.boundary.RLock() + defer r.boundary.RUnlock() + res := make([]adapter.Outbound, 0, len(r.outboundByTag)) + for _, out := range r.outboundByTag { + res = append(res, out) + } + return res +} + +func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { + r.boundary.RLock() + defer r.boundary.RUnlock() + outbound, loaded := r.outboundByTag[tag] + return outbound, loaded +} + +func (r *Router) Inbound(tag string) (adapter.Inbound, bool) { + r.boundary.RLock() + defer r.boundary.RUnlock() + inbound, loaded := r.inboundByTag[tag] + return inbound, loaded } func (r *Router) FakeIPStore() adapter.FakeIPStore { @@ -802,8 +968,8 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad if metadata.LastInbound == metadata.InboundDetour { return E.New("routing loop on detour: ", metadata.InboundDetour) } - detour := r.inboundByTag[metadata.InboundDetour] - if detour == nil { + detour, ok := r.Inbound(metadata.InboundDetour) + if !ok { return E.New("inbound detour not found: ", metadata.InboundDetour) } injectable, isInjectable := detour.(adapter.InjectableInbound) @@ -908,15 +1074,27 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad } else if metadata.Destination.IsIPv6() { metadata.IPVersion = 6 } - ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForConnection) - if err != nil { - return err + + rule, detour := r.ruleByMetadata(ctx, &metadata) + if rule == nil { + var err error + detour, err = r.DefaultOutbound(N.NetworkTCP) + if err != nil { + return E.New("missing supported outbound, closing packet connection") + } + } + if tag, loaded := outbound.TagFromContext(ctx); loaded { + if tag == detour.Tag() { + return E.New("connection loopback in outbound/", detour.Type(), "[", detour.Tag(), "]") + } } if !common.Contains(detour.Network(), N.NetworkTCP) { - return E.New("missing supported outbound, closing connection") + return E.New("missing support of network type by outbound, closing packet connection") } + ctx = outbound.ContextWithTag(ctx, detour.Tag()) + if r.clashServer != nil { - trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, matchedRule) + trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, rule) defer tracker.Leave() conn = trackerConn } @@ -936,8 +1114,8 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m if metadata.LastInbound == metadata.InboundDetour { return E.New("routing loop on detour: ", metadata.InboundDetour) } - detour := r.inboundByTag[metadata.InboundDetour] - if detour == nil { + detour, ok := r.Inbound(metadata.InboundDetour) + if !ok { return E.New("inbound detour not found: ", metadata.InboundDetour) } injectable, isInjectable := detour.(adapter.InjectableInbound) @@ -1082,15 +1260,27 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m } else if metadata.Destination.IsIPv6() { metadata.IPVersion = 6 } - ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection) - if err != nil { - return err + + rule, detour := r.ruleByMetadata(ctx, &metadata) + if rule == nil { + var err error + detour, err = r.DefaultOutbound(N.NetworkUDP) + if err != nil { + return E.New("missing supported outbound, closing packet connection") + } + } + if tag, loaded := outbound.TagFromContext(ctx); loaded { + if tag == detour.Tag() { + return E.New("connection loopback in outbound/", detour.Type(), "[", detour.Tag(), "]") + } } if !common.Contains(detour.Network(), N.NetworkUDP) { - return E.New("missing supported outbound, closing packet connection") + return E.New("missing support of network type by outbound, closing packet connection") } + ctx = outbound.ContextWithTag(ctx, detour.Tag()) + if r.clashServer != nil { - trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, matchedRule) + trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, rule) defer tracker.Leave() conn = trackerConn } @@ -1105,26 +1295,9 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m return detour.NewPacketConnection(ctx, conn, metadata) } -func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (context.Context, adapter.Rule, adapter.Outbound, error) { - matchRule, matchOutbound := r.match0(ctx, metadata, defaultOutbound) - if contextOutbound, loaded := outbound.TagFromContext(ctx); loaded { - if contextOutbound == matchOutbound.Tag() { - return nil, nil, nil, E.New("connection loopback in outbound/", matchOutbound.Type(), "[", matchOutbound.Tag(), "]") - } - } - ctx = outbound.ContextWithTag(ctx, matchOutbound.Tag()) - return ctx, matchRule, matchOutbound, nil -} - -func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) { +func (r *Router) processInfoByMetadata(ctx context.Context, metadata *adapter.InboundContext) *process.Info { if r.processSearcher != nil { - var originDestination netip.AddrPort - if metadata.OriginDestination.IsValid() { - originDestination = metadata.OriginDestination.AddrPort() - } else if metadata.Destination.IsIP() { - originDestination = metadata.Destination.AddrPort() - } - processInfo, err := process.FindProcessInfo(r.processSearcher, ctx, metadata.Network, metadata.Source.AddrPort(), originDestination) + processInfo, err := process.FindProcessInfo(r.processSearcher, ctx, metadata.Network, metadata.Source.AddrPort()) if err != nil { r.logger.InfoContext(ctx, "failed to search process: ", err) } else { @@ -1145,21 +1318,26 @@ func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, d r.logger.InfoContext(ctx, "found user id: ", processInfo.UserId) } } - metadata.ProcessInfo = processInfo + return processInfo } } + return nil +} + +func (r *Router) ruleByMetadata(ctx context.Context, metadata *adapter.InboundContext) (adapter.Rule, adapter.Outbound) { + metadata.ProcessInfo = r.processInfoByMetadata(ctx, metadata) for i, rule := range r.rules { metadata.ResetRuleCache() if rule.Match(metadata) { detour := rule.Outbound() - r.logger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) + r.logger.DebugContext(ctx, "rule[", i, "] ", rule.String(), " => ", detour) if outbound, loaded := r.Outbound(detour); loaded { return rule, outbound } - r.logger.ErrorContext(ctx, "outbound not found: ", detour) + r.logger.ErrorContext(ctx, "not found outbound[", detour, "]") } } - return nil, defaultOutbound + return nil, nil } func (r *Router) InterfaceFinder() control.InterfaceFinder { @@ -1306,8 +1484,8 @@ func (r *Router) notifyNetworkUpdate(event int) { func (r *Router) ResetNetwork() error { conntrack.Close() - for _, outbound := range r.outbounds { - listener, isListener := outbound.(adapter.InterfaceUpdateListener) + for _, out := range r.Outbounds() { + listener, isListener := out.(adapter.InterfaceUpdateListener) if isListener { listener.InterfaceUpdated() } @@ -1316,6 +1494,7 @@ func (r *Router) ResetNetwork() error { for _, transport := range r.transports { transport.Reset() } + return nil } diff --git a/route/router_outbound_starter.go b/route/router_outbound_starter.go new file mode 100644 index 0000000000..72cd242fb7 --- /dev/null +++ b/route/router_outbound_starter.go @@ -0,0 +1,60 @@ +package route + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + E "github.com/sagernet/sing/common/exceptions" +) + +type OutboundStarter struct { + outboundByTag map[string]adapter.Outbound + startedTags map[string]struct{} + monitor *taskmonitor.Monitor +} + +func (s *OutboundStarter) Start(tag string, pathIncludesTags map[string]struct{}) error { + adapter := s.outboundByTag[tag] + if adapter == nil { + return E.New("dependency[", tag, "] is not found") + } + + // The outbound may have been started by another subtree in the previous, + // we don't need to start it again. + if _, ok := s.startedTags[tag]; ok { + return nil + } + + // If we detected the repetition of the tags in scope of tree evaluation, + // the circular dependency is found, as it grows from bottom to top. + if _, ok := pathIncludesTags[tag]; ok { + return E.New("circular dependency related with outbound/", adapter.Type(), "[", tag, "]") + } + + // This required to be done only if that outbound isn't already started, + // because some dependencies may come to the same root, + // but they aren't circular. + pathIncludesTags[tag] = struct{}{} + + // Next, we are recursively starting all dependencies of the current + // outbound and repeating the cycle. + for _, dependencyTag := range adapter.Dependencies() { + if err := s.Start(dependencyTag, pathIncludesTags); err != nil { + return err + } + } + + // Anyway, it will be finished soon, nothing will happen if I'll include + // Startable interface typecasting too. + s.monitor.Start("initialize outbound/", adapter.Type(), "[", tag, "]") + defer s.monitor.Finish() + + // After the evaluation of entire tree let's begin to start all + // the outbounds! + if startable, isStartable := adapter.(interface{ Start() error }); isStartable { + if err := startable.Start(); err != nil { + return E.Cause(err, "initialize outbound/", adapter.Type(), "[", tag, "]") + } + } + + return nil +}