Skip to content

Commit dec43b2

Browse files
authored
Merge pull request #7716 from freedomofpress/7682-stricter-typing
feat(`api2`): validate and freeze dataclasses
2 parents 42cdb9e + 6ce810b commit dec43b2

File tree

2 files changed

+153
-46
lines changed

2 files changed

+153
-46
lines changed

securedrop/journalist_app/api2/types.py

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@
22
from enum import IntEnum, StrEnum, auto
33
from typing import (
44
Any,
5+
Callable,
6+
Iterable,
57
List,
68
Mapping,
79
NewType,
810
Optional,
911
Set,
1012
Tuple,
1113
)
14+
from uuid import UUID
1215

1316
Record = NewType("Record", dict[str, Any])
1417
Version = NewType("Version", str)
1518

19+
VERSION_LEN = 64 # hex digits
20+
1621

1722
# NB. Ideally we'd have a generic UUID[T], but the semantics don't change
1823
# before mypy 1.12, which is incompatible with our use elsewhere of sqlmypy.
@@ -63,7 +68,7 @@ class Index:
6368
journalists: dict[JournalistUUID, Version] = field(default_factory=dict)
6469

6570

66-
@dataclass
71+
@dataclass(frozen=True)
6772
class Target:
6873
"""Base class for `<Resource>Target` dataclasses, to make their union usable
6974
at runtime. Subclass at least with:
@@ -74,63 +79,133 @@ class Target:
7479

7580
version: Version
7681

82+
def __post_init__(self) -> None:
83+
version = str(self.version)
7784

78-
@dataclass
85+
if len(version) != VERSION_LEN:
86+
raise ValueError(f"version must have {VERSION_LEN} hex digits")
87+
88+
try:
89+
int(version, 16)
90+
except ValueError:
91+
raise ValueError("version must be hex-encoded")
92+
93+
94+
@dataclass(frozen=True)
7995
class SourceTarget(Target):
8096
source_uuid: SourceUUID
8197

98+
def __post_init__(self) -> None:
99+
super().__post_init__()
100+
if not self.source_uuid:
101+
raise ValueError("source_uuid must be non-empty")
82102

83-
@dataclass
103+
try:
104+
UUID(str(self.source_uuid))
105+
except ValueError:
106+
raise ValueError(f"invalid source UUID: {self.source_uuid}")
107+
108+
109+
@dataclass(frozen=True)
84110
class ItemTarget(Target):
85111
item_uuid: ItemUUID
86112

113+
def __post_init__(self) -> None:
114+
super().__post_init__()
115+
if not self.item_uuid:
116+
raise ValueError("item_uuid must be non-empty")
87117

88-
@dataclass
118+
try:
119+
UUID(str(self.item_uuid))
120+
except ValueError:
121+
raise ValueError(f"invalid item UUID: {self.item_uuid}")
122+
123+
124+
@dataclass(frozen=True)
89125
class EventData:
90126
"""
91127
Base class for `<EventType>Data dataclasses, to make their union usable at runtime.
92128
For non-empty events, subclass and add to `EVENT_DATA_TYPES`.
93129
"""
94130

95131

96-
@dataclass
132+
@dataclass(frozen=True)
97133
class ReplySentData(EventData):
98134
uuid: ReplyUUID
99135
reply: str
100136

137+
def __post_init__(self) -> None:
138+
try:
139+
UUID(str(self.uuid))
140+
except ValueError:
141+
raise ValueError(f"invalid reply UUID: {self.uuid}")
142+
143+
if not self.reply:
144+
raise ValueError("reply must be a non-empty string")
145+
101146

102147
EVENT_DATA_TYPES = {EventType.REPLY_SENT: ReplySentData}
103148

104149

105-
@dataclass
150+
@dataclass(frozen=True)
106151
class Event:
107152
id: EventID
108-
target: Target | Mapping
153+
target: Target | Mapping[str, Any]
109154
type: EventType
110-
data: Optional[EventData | Mapping] = None
155+
data: Optional[EventData | Mapping[str, Any]] = None
111156

112157
def __post_init__(self) -> None:
113-
if not isinstance(self.type, EventType):
114-
self.type = EventType(self.type) # strict enum
158+
# ID must be usable as an int (for snowflake ordering; see section
159+
# "Snowflake IDs" in `API2.md`):
160+
if not str(self.id).isdigit():
161+
raise ValueError(f"event ID must be an integer string: {self.id}")
115162

116-
if not isinstance(self.target, Target):
117-
if "source_uuid" in self.target:
118-
self.target = SourceTarget(**self.target)
119-
elif "item_uuid" in self.target:
120-
self.target = ItemTarget(**self.target)
163+
# Normalize type:
164+
if not isinstance(self.type, EventType):
165+
object.__setattr__(self, "type", EventType(self.type))
166+
167+
# Normalize target:
168+
target = self.target
169+
if not isinstance(target, Target):
170+
if not isinstance(target, Mapping):
171+
raise TypeError(f"invalid event target: {target!r}")
172+
173+
if "source_uuid" in target:
174+
target = SourceTarget(**target)
175+
elif "item_uuid" in target:
176+
target = ItemTarget(**target)
121177
else:
122-
raise TypeError(f"invalid event target: {self.target}")
178+
raise TypeError(f"invalid event target: {target}")
179+
180+
object.__setattr__(self, "target", target)
123181

124-
if not isinstance(self.data, EventData) and self.data and self.type in EVENT_DATA_TYPES:
182+
# Normalize data:
183+
data = self.data
184+
if data is None:
185+
return
186+
187+
# If it's already a `EventData` dataclass, validate it:
188+
if isinstance(data, EventData):
189+
expected = EVENT_DATA_TYPES.get(self.type)
190+
if expected is not None and not isinstance(data, expected):
191+
raise TypeError(f"invalid event data for type {self.type}")
192+
return
193+
194+
# If it's a mapping for an event type that expects data, instantiate an
195+
# `EventType` dataclass:
196+
if isinstance(data, Mapping) and self.type in EVENT_DATA_TYPES:
125197
try:
126-
self.data = EVENT_DATA_TYPES[self.type](**self.data)
198+
data_obj = EVENT_DATA_TYPES[self.type](**data)
127199
except TypeError:
128200
raise TypeError(f"invalid event data for type {self.type}")
201+
object.__setattr__(self, "data", data_obj)
202+
203+
# Otherwise, discard it.
129204
else:
130-
self.data = None
205+
object.__setattr__(self, "data", None)
131206

132207

133-
@dataclass
208+
@dataclass(frozen=True)
134209
class EventResult:
135210
event_id: EventID
136211
status: EventStatus
@@ -140,7 +215,7 @@ class EventResult:
140215
items: dict[ItemUUID, Optional[Record]] = field(default_factory=dict)
141216

142217

143-
@dataclass
218+
@dataclass(frozen=True)
144219
class BatchRequest:
145220
# Source metadata:
146221
sources: Set[SourceUUID] = field(default_factory=set)
@@ -150,7 +225,26 @@ class BatchRequest:
150225
journalists: Set[JournalistUUID] = field(default_factory=set)
151226

152227
# Events submitted by the client:
153-
events: List[Event] = field(default_factory=list)
228+
events: List[Event | Mapping[str, Any]] = field(default_factory=list)
229+
230+
def __post_init__(self) -> None:
231+
def _normalize_uuids(raw: Iterable[Any], wrap: Callable) -> set:
232+
try:
233+
return {wrap(x) for x in raw}
234+
except TypeError:
235+
raise TypeError("expected an iterable")
236+
237+
object.__setattr__(self, "sources", _normalize_uuids(self.sources, SourceUUID))
238+
object.__setattr__(self, "items", _normalize_uuids(self.items, ItemUUID))
239+
object.__setattr__(self, "journalists", _normalize_uuids(self.journalists, JournalistUUID))
240+
241+
normalized_events: list[Event | Mapping[str, Any]] = []
242+
for e in self.events:
243+
if isinstance(e, (Event, Mapping)):
244+
normalized_events.append(e)
245+
else:
246+
raise TypeError("BatchRequest.events must contain Event or Mapping instances")
247+
object.__setattr__(self, "events", normalized_events)
154248

155249

156250
@dataclass

securedrop/tests/test_journalist_api2.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from flask_sqlalchemy import get_debug_queries
1313
from journalist_app import api2, create_app
1414
from journalist_app.api2.shared import json_version
15-
from journalist_app.api2.types import Event, EventType, ItemTarget, SourceTarget
15+
from journalist_app.api2.types import VERSION_LEN, Event, EventType, ItemTarget, SourceTarget
1616
from models import Reply, Source, SourceStar, Submission, db
1717
from sqlalchemy.orm.exc import MultipleResultsFound
1818
from tests.factories import SecureDropConfigFactory
@@ -550,15 +550,18 @@ def test_api2_item_deleted(
550550
assert Reply.query.filter(Reply.uuid == event.target.item_uuid).one_or_none() is None
551551

552552
# Try to delete something that doesn't exist:
553-
event.id = "345678"
554-
event.target.item_uuid = "does not exist"
553+
nonexistent_event = Event(
554+
id="345678",
555+
target=ItemTarget(item_uuid=str(uuid.uuid4()), version=reply_version),
556+
type=EventType.ITEM_DELETED,
557+
)
555558
response = app.post(
556559
url_for("api2.data"),
557-
json={"events": [asdict(event)]},
560+
json={"events": [asdict(nonexistent_event)]},
558561
headers=get_api_headers(journalist_api_token),
559562
)
560-
assert response.json["events"][event.id] == [410, None]
561-
assert event.target.item_uuid not in response.json["items"]
563+
assert response.json["events"][nonexistent_event.id] == [410, None]
564+
assert nonexistent_event.target.item_uuid not in response.json["items"]
562565

563566

564567
def test_api2_source_deleted(
@@ -574,7 +577,7 @@ def test_api2_source_deleted(
574577
# Try deleting the source with the wrong version
575578
event = Event(
576579
id="394758",
577-
target=SourceTarget(source_uuid=source_uuid, version="wrong-version"),
580+
target=SourceTarget(source_uuid=source_uuid, version="a" * VERSION_LEN),
578581
type=EventType.SOURCE_DELETED,
579582
)
580583
response = app.post(
@@ -624,14 +627,17 @@ def test_api2_source_deleted(
624627
assert Source.query.filter(Source.uuid == source_uuid).one_or_none() is None
625628

626629
# Try to delete a source that doesn't exist
627-
event.id = "234567"
628-
event.target.source_uuid = "does-not-exist"
630+
nonexistent_event = Event(
631+
id="234567",
632+
target=SourceTarget(source_uuid=str(uuid.uuid4()), version=source_version),
633+
type=EventType.SOURCE_DELETED,
634+
)
629635
response = app.post(
630636
url_for("api2.data"),
631-
json={"events": [asdict(event)]},
637+
json={"events": [asdict(nonexistent_event)]},
632638
headers=get_api_headers(journalist_api_token),
633639
)
634-
assert response.json["events"][event.id] == [410, None]
640+
assert response.json["events"][nonexistent_event.id] == [410, None]
635641
assert "does-not-exist" not in response.json["sources"]
636642

637643

@@ -653,7 +659,7 @@ def test_api2_source_conversation_deleted(
653659
# (intentionally not fetching the correct version)
654660
event = Event(
655661
id="498567",
656-
target=SourceTarget(source_uuid=source_uuid, version="wrong-version"),
662+
target=SourceTarget(source_uuid=source_uuid, version="a" * VERSION_LEN),
657663
type=EventType.SOURCE_CONVERSATION_DELETED,
658664
)
659665
response = app.post(
@@ -842,16 +848,19 @@ def test_api2_item_seen(
842848
updated_submission = Submission.query.filter(Submission.uuid == submission_uuid).one()
843849
assert updated_submission.downloaded is True
844850

845-
# Try with invalid item UUID
846-
event.id = "234567"
847-
event.target.item_uuid = "invalid-uuid"
851+
# Try to mark seen an item that doesn't exist
852+
no_such_item_event = Event(
853+
id="234567",
854+
target=ItemTarget(item_uuid=str(uuid.uuid4()), version=item_version),
855+
type=EventType.ITEM_SEEN,
856+
)
848857
response = app.post(
849858
url_for("api2.data"),
850-
json={"events": [asdict(event)]},
859+
json={"events": [asdict(no_such_item_event)]},
851860
headers=get_api_headers(journalist_api_token),
852861
)
853-
assert response.json["events"][event.id][0] == 404
854-
assert "could not find item" in response.json["events"][event.id][1]
862+
assert response.json["events"][no_such_item_event.id][0] == 404
863+
assert "could not find item" in response.json["events"][no_such_item_event.id][1]
855864

856865

857866
def test_api2_idempotence_period(journalist_app):
@@ -924,7 +933,7 @@ def test_api2_source_conversation_deleted_resubmission(
924933
# 1. Submit with the wrong version --> Conflict (409).
925934
event = Event(
926935
id="600100",
927-
target=SourceTarget(source_uuid=source_uuid, version="wrong-version"),
936+
target=SourceTarget(source_uuid=source_uuid, version="a" * VERSION_LEN),
928937
type=EventType.SOURCE_CONVERSATION_DELETED,
929938
)
930939
res1 = app.post(
@@ -955,14 +964,18 @@ def test_api2_source_conversation_deleted_resubmission(
955964
assert index.status_code == 200
956965
correct_version = index.json["sources"][source_uuid]
957966

958-
event.target = SourceTarget(source_uuid=source_uuid, version=correct_version)
967+
corrected_event = Event(
968+
id=event.id,
969+
target=SourceTarget(source_uuid=source_uuid, version=correct_version),
970+
type=event.type,
971+
)
959972
res2 = app.post(
960973
url_for("api2.data"),
961-
json={"events": [asdict(event)]},
974+
json={"events": [asdict(corrected_event)]},
962975
headers=get_api_headers(journalist_api_token),
963976
)
964977
assert res2.status_code == 200
965-
assert res2.json["events"][event.id] == [200, None]
978+
assert res2.json["events"][corrected_event.id] == [200, None]
966979

967980
# Confirm that items are returned as deleted.
968981
assert res2.json["sources"][source_uuid] is not None
@@ -980,11 +993,11 @@ def test_api2_source_conversation_deleted_resubmission(
980993
# 3. Resubmit the same event again --> Already Reported (208).
981994
res3 = app.post(
982995
url_for("api2.data"),
983-
json={"events": [asdict(event)]},
996+
json={"events": [asdict(corrected_event)]},
984997
headers=get_api_headers(journalist_api_token),
985998
)
986999
assert res3.status_code == 200
987-
assert res3.json["events"][event.id][0] == 208
1000+
assert res3.json["events"][corrected_event.id][0] == 208
9881001

9891002

9901003
def test_api2_reply_sent_then_requested_item_is_deduped(

0 commit comments

Comments
 (0)