Skip to content

Commit

Permalink
delete entries from the cache when the TTL expires
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jul 5, 2021
1 parent 1df3d26 commit 1137789
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ See what needs to be done and submit a pull request :)
* [x] Browse / Lookup / Register services
* [x] Multiple IPv6 / IPv4 addresses support
* [x] Send multiple probes (exp. back-off) if no service answers (*)
* [ ] Timestamp entries for TTL checks
* [x] Timestamp entries for TTL checks
* [ ] Compare new multicasts with already received services

_Notes:_
Expand Down
26 changes: 20 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ func newClient(opts clientOpts) (*client, error) {
}, nil
}

var cleanupFreq = 10 * time.Second

// Start listeners and waits for the shutdown signal from exit channel
func (c *client) mainloop(ctx context.Context, params *lookupParams) {
// start listening for responses
Expand All @@ -189,16 +191,28 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
}

// Iterate through channels from listeners goroutines
var entries, sentEntries map[string]*ServiceEntry
sentEntries = make(map[string]*ServiceEntry)
var entries map[string]*ServiceEntry
sentEntries := make(map[string]*ServiceEntry)

ticker := time.NewTicker(cleanupFreq)
defer ticker.Stop()
for {
var now time.Time
select {
case <-ctx.Done():
// Context expired. Notify subscriber that we are done here.
params.done()
c.shutdown()
return
case t := <-ticker.C:
for k, e := range sentEntries {
if t.After(e.Expiry) {
delete(sentEntries, k)
}
}
continue
case msg := <-msgCh:
now = time.Now()
entries = make(map[string]*ServiceEntry)
sections := append(msg.Answer, msg.Ns...)
sections = append(sections, msg.Extra...)
Expand All @@ -218,7 +232,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
params.Service,
params.Domain)
}
entries[rr.Ptr].TTL = rr.Hdr.Ttl
entries[rr.Ptr].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
case *dns.SRV:
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name {
continue
Expand All @@ -233,7 +247,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
}
entries[rr.Hdr.Name].HostName = rr.Target
entries[rr.Hdr.Name].Port = int(rr.Port)
entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl
entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
case *dns.TXT:
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name {
continue
Expand All @@ -247,7 +261,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
params.Domain)
}
entries[rr.Hdr.Name].Text = rr.Txt
entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl
entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
}
}
// Associate IPs in a second round as other fields should be filled by now.
Expand All @@ -271,7 +285,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {

if len(entries) > 0 {
for k, e := range entries {
if e.TTL == 0 {
if !e.Expiry.After(now) {
delete(entries, k)
delete(sentEntries, k)
continue
Expand Down
4 changes: 3 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const (
multicastRepetitions = 2
)

var defaultTTL uint32 = 3200

// Register a service by given arguments. This call will take the system's hostname
// and lookup IP by that hostname.
func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) {
Expand Down Expand Up @@ -173,7 +175,7 @@ func newServer(ifaces []net.Interface) (*Server, error) {
ipv4conn: ipv4conn,
ipv6conn: ipv6conn,
ifaces: ifaces,
ttl: 3200,
ttl: defaultTTL,
shouldShutdown: make(chan struct{}),
}

Expand Down
13 changes: 7 additions & 6 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"sync"
"time"
)

// ServiceRecord contains the basic description of a service, which contains instance name, service type & domain
Expand Down Expand Up @@ -103,12 +104,12 @@ func (l *lookupParams) disableProbing() {
// used to answer multicast queries.
type ServiceEntry struct {
ServiceRecord
HostName string `json:"hostname"` // Host machine DNS name
Port int `json:"port"` // Service Port
Text []string `json:"text"` // Service info served as a TXT record
TTL uint32 `json:"ttl"` // TTL of the service record
AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address
AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address
HostName string `json:"hostname"` // Host machine DNS name
Port int `json:"port"` // Service Port
Text []string `json:"text"` // Service info served as a TXT record
Expiry time.Time `json:"expiry"` // Expiry of the service entry, will be converted to a TTL value
AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address
AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address
}

// NewServiceEntry constructs a ServiceEntry.
Expand Down
34 changes: 34 additions & 0 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,38 @@ func TestSubtype(t *testing.T) {
t.Fatalf("Expected port is %d, but got %d", mdnsPort, result.Port)
}
})

t.Run("ttl", func(t *testing.T) {
origTTL := defaultTTL
origCleanupFreq := cleanupFreq
defer func() {
defaultTTL = origTTL
cleanupFreq = origCleanupFreq
}()
defaultTTL = 2 // 2 seconds
cleanupFreq = 100 * time.Millisecond

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain)

entries := make(chan *ServiceEntry, 100)
resolver, err := NewResolver(nil)
if err != nil {
t.Fatalf("Expected create resolver success, but got %v", err)
}
if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil {
t.Fatalf("Expected browse success, but got %v", err)
}

<-ctx.Done()
if len(entries) != 2 {
t.Fatalf("Expected to have received 2 entries, but got %d", len(entries))
}
res1 := <-entries
res2 := <-entries
if res1.ServiceInstanceName() != res2.ServiceInstanceName() {
t.Fatalf("expected the two entries to be identical")
}
})
}

0 comments on commit 1137789

Please sign in to comment.