-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrfc2308.py
153 lines (127 loc) · 4.47 KB
/
rfc2308.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Simulate RFC 2308 cache, resolver, and auth.
Cache exact match + negative caching.
"""
from pprint import pformat
import time
import dns.resolver
import dns.zone
class Cache(object):
def __init__(self):
self.storage = {}
self.hit = 0
self.miss = 0
self.start = int(time.monotonic())
self.now = self.start
def __str__(self):
return pformat({'hit': self.hit, 'miss': self.miss})
def set_reltime(self, reltime):
"""
move relative time
"""
self.now = self.start + reltime
def put_name(self, name, ttl):
"""
cache information for whole name
"""
self.storage[name] = self.now + ttl
def get_name(self, name):
"""
Returns:
- a node dict if node is in cache
- KeyError if node is not in cache (or is expired)
- NXDOMAIN if non-existence of node is in cache and valid
"""
node = self.storage[name]
if isinstance(node, int):
# NXDOMAIN, check if it is still valid
expires = node
if expires < self.now:
raise KeyError('expired')
else:
raise dns.resolver.NXDOMAIN()
else:
return node
def put_rrtype(self, name, rrtype, ttl):
"""
cache information for one RR type
"""
assert (name not in self.storage) or (isinstance(self.storage[name], dict)), 'unsupported put operation'
self.storage.setdefault(name, {})[rrtype] = self.now + ttl
def get_rrtype(self, name, rrtype):
try:
node = self.get_name(name)
expires = node[rrtype] # verify RRtype is in cache
if expires < self.now:
raise KeyError('expired')
except KeyError: # not in cache
self.miss += 1
raise
except dns.resolver.NXDOMAIN: # cached NXDOMAIN
pass
self.hit += 1
class Resolver(object):
def __init__(self, auth):
self.cache = Cache()
self.auth = auth
def set_reltime(self, reltime):
self.cache.set_reltime(reltime)
def lookup(self, name, rrtype):
try:
return self.cache.get_rrtype(name, rrtype)
except KeyError:
rcode, answers = self.auth.query(name, rrtype)
if rcode == dns.rcode.NOERROR:
self._store_noerror(answers)
else:
assert rcode == dns.rcode.NXDOMAIN
self._store_nxdomain(answers)
def _store_noerror(self, answers):
for owner, data in answers.items():
name, rrtype = owner
self.cache.put_rrtype(name, rrtype, data["ttl"])
def _store_nxdomain(self, answers):
# answers is just negative TTL
assert len(answers) == 1
for owner, data in answers.items():
pass
name, rrtype = owner
assert rrtype == dns.rdatatype.ANY
self.cache.put_name(name, data["ttl"])
class Authoritative(object):
def __init__(self, rootdb):
self.queries = 0
self.rootzone = dns.zone.from_file(rootdb, origin=dns.name.root, relativize=False)
rootnode = self.rootzone[dns.name.root]
soa_rrs = rootnode.find_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA)
soa = soa_rrs[0]
self.neg_ttl = min(soa_rrs.ttl, soa.minimum) # https://tools.ietf.org/html/rfc2308#section-5
def _gen_nxdomain(self, name):
"""
(qname, type ANY) is used encode negative TTL for the resolver
"""
answer = {(name, dns.rdatatype.ANY): {"ttl": self.neg_ttl}}
return (dns.rcode.NXDOMAIN, answer)
def _gen_nodata(self, name, rrtype):
answers = {(name, rrtype): {"ttl": self.neg_ttl}}
return (dns.rcode.NOERROR, answers)
def _gen_noerror(self, name, rrtype, ttl):
answers = {(name, rrtype): {"ttl": ttl}}
return (dns.rcode.NOERROR, answers)
def query(self, name, rrtype):
"""
Returns: (rcode, {(name, rrtype): ttl})
NXDOMAIN == rrtype ANY + TTL
"""
assert name.is_absolute()
self.queries += 1
try:
node = self.rootzone[name]
except KeyError: # NXDOMAIN
return self._gen_nxdomain(name)
try:
rrs = node.find_rdataset(dns.rdataclass.IN, rrtype)
except KeyError: # NODATA
return self._gen_nodata(name, rrtype)
# NOERROR
return self._gen_noerror(name, rrtype, rrs.ttl)