Skip to content

Commit

Permalink
delete entries from the cache when the TTL expires
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jun 8, 2021
1 parent 2c53f0f commit 5877433
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ See what needs to be done and submit a pull request :)
* [x] Browse / Lookup / Register services
* [x] Multiple IPv6 / IPv4 addresses support
* [x] Send multiple probes (exp. back-off) if no service answers (*)
* [ ] Timestamp entries for TTL checks
* [x] Timestamp entries for TTL checks
* [ ] Compare new multicasts with already received services

_Notes:_
Expand Down
53 changes: 44 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"strings"
"sync"
"time"

"github.com/cenkalti/backoff"
Expand Down Expand Up @@ -143,6 +144,9 @@ type client struct {
ipv4conn *ipv4.PacketConn
ipv6conn *ipv6.PacketConn
ifaces []net.Interface

mutex sync.Mutex
sentEntries map[string]*ServiceEntry
}

// Client structure constructor
Expand Down Expand Up @@ -177,6 +181,28 @@ func newClient(opts clientOpts) (*client, error) {
}, nil
}

var cleanupFreq = 10 * time.Second

// clean up entries whose TTL expired
func (c *client) cleanupSentEntries(ctx context.Context) {
ticker := time.NewTicker(cleanupFreq)
defer ticker.Stop()
for {
select {
case t := <-ticker.C:
c.mutex.Lock()
for k, e := range c.sentEntries {
if t.After(e.Expiry) {
delete(c.sentEntries, k)
}
}
c.mutex.Unlock()
case <-ctx.Done():
return
}
}
}

// Start listeners and waits for the shutdown signal from exit channel
func (c *client) mainloop(ctx context.Context, params *lookupParams) {
// start listening for responses
Expand All @@ -189,16 +215,20 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
}

// Iterate through channels from listeners goroutines
var entries, sentEntries map[string]*ServiceEntry
sentEntries = make(map[string]*ServiceEntry)
var entries map[string]*ServiceEntry
c.sentEntries = make(map[string]*ServiceEntry)
go c.cleanupSentEntries(ctx)

for {
var now time.Time
select {
case <-ctx.Done():
// Context expired. Notify subscriber that we are done here.
params.done()
c.shutdown()
return
case msg := <-msgCh:
now = time.Now()
entries = make(map[string]*ServiceEntry)
sections := append(msg.Answer, msg.Ns...)
sections = append(sections, msg.Extra...)
Expand All @@ -218,7 +248,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
params.Service,
params.Domain)
}
entries[rr.Ptr].TTL = rr.Hdr.Ttl
entries[rr.Ptr].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
case *dns.SRV:
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name {
continue
Expand All @@ -233,7 +263,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
}
entries[rr.Hdr.Name].HostName = rr.Target
entries[rr.Hdr.Name].Port = int(rr.Port)
entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl
entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
case *dns.TXT:
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name {
continue
Expand All @@ -247,7 +277,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
params.Domain)
}
entries[rr.Hdr.Name].Text = rr.Txt
entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl
entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
}
}
// Associate IPs in a second round as other fields should be filled by now.
Expand All @@ -271,12 +301,15 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {

if len(entries) > 0 {
for k, e := range entries {
if e.TTL == 0 {
c.mutex.Lock()
if !e.Expiry.After(now) {
delete(entries, k)
delete(sentEntries, k)
delete(c.sentEntries, k)
c.mutex.Unlock()
continue
}
if _, ok := sentEntries[k]; ok {
if _, ok := c.sentEntries[k]; ok {
c.mutex.Unlock()
continue
}

Expand All @@ -286,14 +319,16 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
// Require at least one resolved IP address for ServiceEntry
// TODO: wait some more time as chances are high both will arrive.
if len(e.AddrIPv4) == 0 && len(e.AddrIPv6) == 0 {
c.mutex.Unlock()
continue
}
}
// Submit entry to subscriber and cache it.
// This is also a point to possibly stop probing actively for a
// service entry.
c.sentEntries[k] = e
c.mutex.Unlock()
params.Entries <- e
sentEntries[k] = e
if !params.isBrowsing {
params.disableProbing()
}
Expand Down
4 changes: 3 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const (
multicastRepetitions = 2
)

var defaultTTL uint32 = 3200

// Register a service by given arguments. This call will take the system's hostname
// and lookup IP by that hostname.
func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) {
Expand Down Expand Up @@ -173,7 +175,7 @@ func newServer(ifaces []net.Interface) (*Server, error) {
ipv4conn: ipv4conn,
ipv6conn: ipv6conn,
ifaces: ifaces,
ttl: 3200,
ttl: defaultTTL,
shouldShutdown: make(chan struct{}),
}

Expand Down
13 changes: 7 additions & 6 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"sync"
"time"
)

// ServiceRecord contains the basic description of a service, which contains instance name, service type & domain
Expand Down Expand Up @@ -103,12 +104,12 @@ func (l *lookupParams) disableProbing() {
// used to answer multicast queries.
type ServiceEntry struct {
ServiceRecord
HostName string `json:"hostname"` // Host machine DNS name
Port int `json:"port"` // Service Port
Text []string `json:"text"` // Service info served as a TXT record
TTL uint32 `json:"ttl"` // TTL of the service record
AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address
AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address
HostName string `json:"hostname"` // Host machine DNS name
Port int `json:"port"` // Service Port
Text []string `json:"text"` // Service info served as a TXT record
Expiry time.Time `json:"expiry"` // Expiry of the service entry, will be converted to a TTL value
AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address
AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address
}

// NewServiceEntry constructs a ServiceEntry.
Expand Down
44 changes: 44 additions & 0 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,48 @@ func TestSubtype(t *testing.T) {
t.Fatalf("Expected port is %d, but got %d", mdnsPort, expectedResult[0].Port)
}
})

t.Run("ttl", func(t *testing.T) {
origTTL := defaultTTL
origCleanupFreq := cleanupFreq
defer func() {
defaultTTL = origTTL
cleanupFreq = origCleanupFreq
}()
defaultTTL = 2 // 2 seconds
cleanupFreq = 100 * time.Millisecond

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain)

entries := make(chan *ServiceEntry)
var expectedResult []*ServiceEntry
go func() {
for {
select {
case s := <-entries:
expectedResult = append(expectedResult, s)
case <-ctx.Done():
return
}
}
}()

resolver, err := NewResolver(nil)
if err != nil {
t.Fatalf("Expected create resolver success, but got %v", err)
}
if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil {
t.Fatalf("Expected browse success, but got %v", err)
}

<-ctx.Done()
if len(expectedResult) != 2 {
t.Fatalf("Expected to have received 2 entries, but got %d", len(expectedResult))
}
if expectedResult[0].ServiceInstanceName() != expectedResult[1].ServiceInstanceName() {
t.Fatalf("expected the two entries to be identical")
}
})
}

0 comments on commit 5877433

Please sign in to comment.