Skip to content

Commit 82f8e71

Browse files
authored
Merge pull request #110 from dnstapir/use_dnstapir_module
Use shared Python module
2 parents 707e560 + 1c1e225 commit 82f8e71

15 files changed

+409
-1283
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
- id: check-yaml
1212
- repo: https://github.com/astral-sh/ruff-pre-commit
1313
# Ruff version.
14-
rev: v0.6.5
14+
rev: v0.7.2
1515
hooks:
1616
# Run the linter.
1717
- id: ruff

aggrec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__version__ = version("aggrec")
44

55
try:
6-
from aggrec.buildinfo import __commit__, __timestamp__
6+
from .buildinfo import __commit__, __timestamp__
77

88
__verbose_version__ = f"{__version__} ({__commit__})"
99
except ModuleNotFoundError:

aggrec/aggregates.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,7 @@ async def create_aggregate(
238238
span = trace.get_current_span()
239239

240240
with tracer.start_as_current_span("http_request_verifier"):
241-
http_request_verifier = RequestVerifier(
242-
client_database=request.app.settings.clients_database, key_cache=request.app.key_cache
243-
)
241+
http_request_verifier = RequestVerifier(key_resolver=request.app.key_resolver)
244242
res = await http_request_verifier.verify(request)
245243

246244
creator = res.parameters.get("keyid")

aggrec/helpers.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414
)
1515
from http_message_signatures.algorithms import signature_algorithms as supported_signature_algorithms
1616
from http_message_signatures.exceptions import InvalidSignature
17-
from pydantic import AnyHttpUrl, DirectoryPath
1817

19-
from .key_cache import KeyCache
20-
from .key_resolver import FileKeyResolver, UrlKeyResolver
18+
from dnstapir.key_resolver import KeyResolver, PublicKey
2119

2220
DEFAULT_SIGNATURE_ALGORITHM = algorithms.ECDSA_P256_SHA256
2321
HASH_ALGORITHMS = {"sha-256": hashlib.sha256, "sha-512": hashlib.sha512}
@@ -39,25 +37,22 @@ class ContentDigestMissing(ContentDigestException):
3937
pass
4038

4139

40+
class CustomHTTPSignatureKeyResolver(HTTPSignatureKeyResolver):
41+
def __init__(self, key_resolver: KeyResolver):
42+
self.key_resolver = key_resolver
43+
44+
def resolve_public_key(self, key_id: str) -> PublicKey:
45+
return self.key_resolver.resolve_public_key(key_id=key_id)
46+
47+
4248
class RequestVerifier:
4349
def __init__(
4450
self,
51+
key_resolver: KeyResolver,
4552
algorithm: HTTPSignatureAlgorithm | None = None,
46-
key_resolver: HTTPSignatureKeyResolver | None = None,
47-
client_database: AnyHttpUrl | DirectoryPath | None = None,
48-
key_cache: KeyCache | None = None,
4953
):
5054
self.algorithm = algorithm or DEFAULT_SIGNATURE_ALGORITHM
51-
if key_resolver:
52-
self.key_resolver = key_resolver
53-
elif client_database and (
54-
str(client_database).startswith("http://") or str(client_database).startswith("https://")
55-
):
56-
self.key_resolver = UrlKeyResolver(client_database_base_url=str(client_database), key_cache=key_cache)
57-
elif client_database:
58-
self.key_resolver = FileKeyResolver(client_database_directory=str(client_database), key_cache=key_cache)
59-
else:
60-
raise ValueError("No key resolver nor client database specified")
55+
self.http_key_resolver = CustomHTTPSignatureKeyResolver(key_resolver)
6156
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
6257

6358
async def verify_content_digest(self, result: VerifyResult, request: Request):
@@ -87,7 +82,7 @@ async def verify(self, request: Request) -> VerifyResult:
8782
signature_algorithm = supported_signature_algorithms[alg]
8883
verifier = HTTPMessageVerifier(
8984
signature_algorithm=signature_algorithm,
90-
key_resolver=self.key_resolver,
85+
key_resolver=self.http_key_resolver,
9186
)
9287
try:
9388
results = verifier.verify(request)

aggrec/key_cache.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

aggrec/key_resolver.py

Lines changed: 0 additions & 72 deletions
This file was deleted.

aggrec/logging.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

aggrec/server.py

Lines changed: 16 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,55 +5,22 @@
55
import aiomqtt
66
import boto3
77
import mongoengine
8-
import redis
98
import uvicorn
109
from fastapi import FastAPI
1110
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
1211

1312
import aggrec.aggregates
1413
import aggrec.extras
14+
from dnstapir.key_cache import key_cache_from_settings
15+
from dnstapir.key_resolver import key_resolver_from_client_database
16+
from dnstapir.logging import configure_json_logging
17+
from dnstapir.opentelemetry import configure_opentelemetry
1518

1619
from . import OPENAPI_METADATA, __verbose_version__
17-
from .key_cache import CombinedKeyCache, KeyCache, MemoryKeyCache, RedisKeyCache
18-
from .logging import JsonFormatter # noqa
1920
from .settings import Settings
20-
from .telemetry import configure_opentelemetry
2121

2222
logger = logging.getLogger(__name__)
2323

24-
LOGGING_RECORD_CUSTOM_FORMAT = {
25-
"time": "asctime",
26-
# "Created": "created",
27-
# "RelativeCreated": "relativeCreated",
28-
"name": "name",
29-
# "Levelno": "levelno",
30-
"levelname": "levelname",
31-
"process": "process",
32-
"thread": "thread",
33-
# "threadName": "threadName",
34-
# "Pathname": "pathname",
35-
# "Filename": "filename",
36-
# "Module": "module",
37-
# "Lineno": "lineno",
38-
# "FuncName": "funcName",
39-
"message": "message",
40-
}
41-
42-
LOGGING_CONFIG_JSON = {
43-
"version": 1,
44-
"disable_existing_loggers": False,
45-
"formatters": {
46-
"json": {
47-
"class": "aggrec.logging.JsonFormatter",
48-
"format": LOGGING_RECORD_CUSTOM_FORMAT,
49-
},
50-
},
51-
"handlers": {
52-
"json": {"class": "logging.StreamHandler", "formatter": "json"},
53-
},
54-
"root": {"handlers": ["json"], "level": "DEBUG"},
55-
}
56-
5724

5825
class AggrecServer(FastAPI):
5926
def __init__(self, settings: Settings):
@@ -63,26 +30,18 @@ def __init__(self, settings: Settings):
6330
self.add_middleware(ProxyHeadersMiddleware)
6431
self.include_router(aggrec.aggregates.router)
6532
self.include_router(aggrec.extras.router)
66-
configure_opentelemetry(
67-
self,
68-
service_name="aggrec",
69-
spans_endpoint=str(settings.otlp.spans_endpoint),
70-
metrics_endpoint=str(settings.otlp.metrics_endpoint),
71-
insecure=settings.otlp.insecure,
33+
if self.settings.otlp:
34+
configure_opentelemetry(
35+
service_name="aggrec",
36+
settings=self.settings.otlp,
37+
fastapi_app=self,
38+
)
39+
else:
40+
self.logger.info("Configured without OpenTelemetry")
41+
key_cache = key_cache_from_settings(self.settings.key_cache) if self.settings.key_cache else None
42+
self.key_resolver = key_resolver_from_client_database(
43+
client_database=str(self.settings.clients_database), key_cache=key_cache
7244
)
73-
self.key_cache: KeyCache | None = None
74-
if self.settings.key_cache:
75-
memory_key_cache = MemoryKeyCache(size=self.settings.key_cache.size, ttl=self.settings.key_cache.ttl)
76-
if redis_settings := self.settings.key_cache.redis:
77-
redis_client = redis.StrictRedis(host=redis_settings.host, port=redis_settings.port)
78-
self.logger.debug("Using REDIS at %s:%d", redis_settings.host, redis_settings.port)
79-
redis_key_cache = RedisKeyCache(redis_client=redis_client, ttl=self.settings.key_cache.ttl)
80-
if self.settings.key_cache.size:
81-
self.key_cache = CombinedKeyCache([memory_key_cache, redis_key_cache])
82-
else:
83-
self.key_cache = redis_key_cache
84-
elif self.settings.key_cache.size:
85-
self.key_cache = memory_key_cache
8645

8746
@staticmethod
8847
def connect_mongodb(settings: Settings):
@@ -143,8 +102,7 @@ def main() -> None:
143102
print(f"Aggregate Receiver version {__verbose_version__}")
144103
return
145104

146-
logging_config = LOGGING_CONFIG_JSON
147-
logging.config.dictConfig(logging_config)
105+
logging_config = configure_json_logging()
148106

149107
if args.debug:
150108
logging.basicConfig(level=logging.DEBUG)

0 commit comments

Comments
 (0)