From 1137789f75d1c1cdedfc976632817fb8e798f02b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 5 Jun 2021 21:25:01 -0700 Subject: [PATCH] delete entries from the cache when the TTL expires --- README.md | 2 +- client.go | 26 ++++++++++++++++++++------ server.go | 4 +++- service.go | 13 +++++++------ service_test.go | 34 ++++++++++++++++++++++++++++++++++ 5 files changed, 65 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 6d0442ec..7bea963a 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ See what needs to be done and submit a pull request :) * [x] Browse / Lookup / Register services * [x] Multiple IPv6 / IPv4 addresses support * [x] Send multiple probes (exp. back-off) if no service answers (*) -* [ ] Timestamp entries for TTL checks +* [x] Timestamp entries for TTL checks * [ ] Compare new multicasts with already received services _Notes:_ diff --git a/client.go b/client.go index 270394ab..89ac49dc 100644 --- a/client.go +++ b/client.go @@ -177,6 +177,8 @@ func newClient(opts clientOpts) (*client, error) { }, nil } +var cleanupFreq = 10 * time.Second + // Start listeners and waits for the shutdown signal from exit channel func (c *client) mainloop(ctx context.Context, params *lookupParams) { // start listening for responses @@ -189,16 +191,28 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { } // Iterate through channels from listeners goroutines - var entries, sentEntries map[string]*ServiceEntry - sentEntries = make(map[string]*ServiceEntry) + var entries map[string]*ServiceEntry + sentEntries := make(map[string]*ServiceEntry) + + ticker := time.NewTicker(cleanupFreq) + defer ticker.Stop() for { + var now time.Time select { case <-ctx.Done(): // Context expired. Notify subscriber that we are done here. params.done() c.shutdown() return + case t := <-ticker.C: + for k, e := range sentEntries { + if t.After(e.Expiry) { + delete(sentEntries, k) + } + } + continue case msg := <-msgCh: + now = time.Now() entries = make(map[string]*ServiceEntry) sections := append(msg.Answer, msg.Ns...) sections = append(sections, msg.Extra...) @@ -218,7 +232,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { params.Service, params.Domain) } - entries[rr.Ptr].TTL = rr.Hdr.Ttl + entries[rr.Ptr].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second) case *dns.SRV: if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name { continue @@ -233,7 +247,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { } entries[rr.Hdr.Name].HostName = rr.Target entries[rr.Hdr.Name].Port = int(rr.Port) - entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl + entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second) case *dns.TXT: if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name { continue @@ -247,7 +261,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { params.Domain) } entries[rr.Hdr.Name].Text = rr.Txt - entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl + entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second) } } // Associate IPs in a second round as other fields should be filled by now. @@ -271,7 +285,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { if len(entries) > 0 { for k, e := range entries { - if e.TTL == 0 { + if !e.Expiry.After(now) { delete(entries, k) delete(sentEntries, k) continue diff --git a/server.go b/server.go index 70fd11ac..a78c4d65 100644 --- a/server.go +++ b/server.go @@ -21,6 +21,8 @@ const ( multicastRepetitions = 2 ) +var defaultTTL uint32 = 3200 + // Register a service by given arguments. This call will take the system's hostname // and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) { @@ -173,7 +175,7 @@ func newServer(ifaces []net.Interface) (*Server, error) { ipv4conn: ipv4conn, ipv6conn: ipv6conn, ifaces: ifaces, - ttl: 3200, + ttl: defaultTTL, shouldShutdown: make(chan struct{}), } diff --git a/service.go b/service.go index 6253c543..43bbf8aa 100644 --- a/service.go +++ b/service.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "sync" + "time" ) // ServiceRecord contains the basic description of a service, which contains instance name, service type & domain @@ -103,12 +104,12 @@ func (l *lookupParams) disableProbing() { // used to answer multicast queries. type ServiceEntry struct { ServiceRecord - HostName string `json:"hostname"` // Host machine DNS name - Port int `json:"port"` // Service Port - Text []string `json:"text"` // Service info served as a TXT record - TTL uint32 `json:"ttl"` // TTL of the service record - AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address - AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address + HostName string `json:"hostname"` // Host machine DNS name + Port int `json:"port"` // Service Port + Text []string `json:"text"` // Service info served as a TXT record + Expiry time.Time `json:"expiry"` // Expiry of the service entry, will be converted to a TTL value + AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address + AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address } // NewServiceEntry constructs a ServiceEntry. diff --git a/service_test.go b/service_test.go index 2c5a23ed..8c79d498 100644 --- a/service_test.go +++ b/service_test.go @@ -163,4 +163,38 @@ func TestSubtype(t *testing.T) { t.Fatalf("Expected port is %d, but got %d", mdnsPort, result.Port) } }) + + t.Run("ttl", func(t *testing.T) { + origTTL := defaultTTL + origCleanupFreq := cleanupFreq + defer func() { + defaultTTL = origTTL + cleanupFreq = origCleanupFreq + }() + defaultTTL = 2 // 2 seconds + cleanupFreq = 100 * time.Millisecond + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain) + + entries := make(chan *ServiceEntry, 100) + resolver, err := NewResolver(nil) + if err != nil { + t.Fatalf("Expected create resolver success, but got %v", err) + } + if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil { + t.Fatalf("Expected browse success, but got %v", err) + } + + <-ctx.Done() + if len(entries) != 2 { + t.Fatalf("Expected to have received 2 entries, but got %d", len(entries)) + } + res1 := <-entries + res2 := <-entries + if res1.ServiceInstanceName() != res2.ServiceInstanceName() { + t.Fatalf("expected the two entries to be identical") + } + }) }