Skip to content

Commit ed87c2a

Browse files
MateoLostanlenfe51
andauthored
Split detections (#539)
* compute cone during seq creation * round values * drop seq with cones * create alerts * add overlap from triangulation pr * new alerts strat * missing init * missings deps * use preset variable * updates routes * update poetry * update loc * error management * fix on seq case * use started_at and last_seens_at * clean output * missing READ * add test * fix style * test overlap * mypy * ruff on test overlap * import issue * fix style * fix deletions to respect fk * cast * recompyte alerts after seq annotation * fix alert update * style * style * adapt test * ruff on test * add tests on detections * increase test on seq * ruff * limit lat and lon * rename fonction * rename to sequence_azimuth * rename sequence camera azimuth * add AlertBase * adapt e2e * new headers * header fix * header fix * add triangulation test * drop dupicate * imprrove test for label_sequence * style * add seq delete test * fix(api): resolve merge conflict and fix router * style * test(api): add test coverage for organization deletion of alerts and alert sequences * put back old delete * style org * style test * fix test org * reduce complexity * style * use pose_id * style * adapt test * split seq based on bbox * adapt test * prevent UnboundLocalError * fix mypy client * complete client * style * ruff * rename camera_azimuth * fix test * client/ * fix test * reorder router * srtyle * mypy * fix merge error in unitest * ruff fix * check to use the pose azimuth * split bboxes * allow cam to get pose * allow camera to manage poses * add split detections test * increase test on client * increase coverage on create det * fix codacy warning adding if * increase coverage * simplify * bboxes to bbox in db field and endpoints read * added correct constraints on bbox and other_bboxes length --------- Co-authored-by: fe51 <[email protected]>
1 parent 345c9b2 commit ed87c2a

File tree

18 files changed

+1215
-176
lines changed

18 files changed

+1215
-176
lines changed

client/pyroclient/client.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
55

66
from enum import Enum
7-
from typing import Dict, List, Optional, Tuple
7+
from typing import Dict, List, Tuple
88
from urllib.parse import urljoin
99

1010
import requests
@@ -22,6 +22,7 @@ class ClientRoute(str, Enum):
2222
CAMERAS_HEARTBEAT = "cameras/heartbeat"
2323
CAMERAS_IMAGE = "cameras/image"
2424
CAMERAS_FETCH = "cameras/"
25+
CAMERAS_BY_ID = "cameras/{camera_id}"
2526
# POSES
2627
POSES_CREATE = "poses/"
2728
POSES_BY_ID = "poses/{pose_id}"
@@ -180,17 +181,36 @@ def create_pose(
180181
timeout=self.timeout,
181182
)
182183

183-
def patch_pose(
184+
def get_current_poses(self, camera_id: int | None = None) -> Response:
185+
"""Fetch poses for the authenticated camera.
186+
187+
For admin/agent tokens, provide camera_id to retrieve poses via the camera endpoint.
188+
"""
189+
if camera_id is not None:
190+
return requests.get(
191+
urljoin(self._route_prefix, ClientRoute.CAMERAS_BY_ID.format(camera_id=camera_id)),
192+
headers=self.headers,
193+
timeout=self.timeout,
194+
)
195+
return requests.get(
196+
urljoin(self._route_prefix, ClientRoute.POSES_CREATE),
197+
headers=self.headers,
198+
timeout=self.timeout,
199+
)
200+
201+
def update_pose(
184202
self,
185203
pose_id: int,
186204
azimuth: float | None = None,
187205
patrol_id: int | None = None,
188206
) -> Response:
189207
"""Update a pose
190208
191-
>>> api_client.patch_pose(pose_id=1, azimuth=90.0)
209+
>>> api_client.update_pose(pose_id=1, azimuth=90.0)
192210
"""
193-
payload = {}
211+
if azimuth is None and patrol_id is None:
212+
raise ValueError("Either azimuth or patrol_id must be provided")
213+
payload: Dict[str, float | int] = {}
194214
if azimuth is not None:
195215
payload["azimuth"] = azimuth
196216
if patrol_id is not None:
@@ -203,6 +223,9 @@ def patch_pose(
203223
timeout=self.timeout,
204224
)
205225

226+
# Backward compatibility alias
227+
patch_pose = update_pose
228+
206229
def delete_pose(self, pose_id: int) -> Response:
207230
"""Delete a pose
208231
@@ -296,34 +319,30 @@ def delete_occlusion_mask(self, mask_id: int) -> Response:
296319
def create_detection(
297320
self,
298321
media: bytes,
299-
azimuth: float,
300322
bboxes: List[Tuple[float, float, float, float, float]],
301-
pose_id: Optional[int] = None,
323+
pose_id: int,
302324
) -> Response:
303325
"""Notify the detection of a wildfire on the picture taken by a camera.
304326
305327
>>> from pyroclient import Client
306328
>>> api_client = Client("MY_CAM_TOKEN")
307329
>>> with open("path/to/my/file.ext", "rb") as f: data = f.read()
308-
>>> response = api_client.create_detection(data, azimuth=124.2, bboxes=[(.1,.1,.5,.8,.5)], pose_id=12)
330+
>>> response = api_client.create_detection(data, bboxes=[(.1,.1,.5,.8,.5)], pose_id=12)
309331
310332
Args:
311333
media: byte data of the picture
312-
azimuth: the azimuth of the camera when the picture was taken
313334
bboxes: list of tuples where each tuple is a relative coordinate in order xmin, ymin, xmax, ymax, conf
314-
pose_id: optional, pose_id of the detection
335+
pose_id: pose_id of the detection
315336
316337
Returns:
317338
HTTP response
318339
"""
319340
if not isinstance(bboxes, (list, tuple)) or len(bboxes) == 0 or len(bboxes) > 5:
320341
raise ValueError("bboxes must be a non-empty list of tuples with a maximum of 5 boxes")
321-
data = {
322-
"azimuth": azimuth,
342+
data: Dict[str, str] = {
323343
"bboxes": _dump_bbox_to_json(bboxes),
324344
}
325-
if pose_id is not None:
326-
data["pose_id"] = pose_id
345+
data["pose_id"] = str(pose_id)
327346
return requests.post(
328347
urljoin(self._route_prefix, ClientRoute.DETECTIONS_CREATE),
329348
headers=self.headers,

client/tests/conftest.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from operator import itemgetter
23
from urllib.parse import urljoin
34

45
import pytest
@@ -26,7 +27,7 @@ def mock_img():
2627

2728

2829
@pytest.fixture(scope="session")
29-
def cam_token():
30+
def cam_setup():
3031
admin_headers = {"Authorization": f"Bearer {SUPERADMIN_TOKEN}"}
3132
payload = {
3233
"name": "pyro-camera-01",
@@ -44,11 +45,27 @@ def cam_token():
4445
payload = {"azimuth": 359, "patrol_id": 1, "camera_id": cam_id}
4546
response = requests.post(urljoin(API_URL, "poses"), json=payload, headers=admin_headers, timeout=5)
4647
assert response.status_code == 201
48+
pose_id = response.json()["id"]
4749

48-
# Create a cam token
49-
return requests.post(urljoin(API_URL, f"cameras/{cam_id}/token"), headers=admin_headers, timeout=5).json()[
50+
cam_token = requests.post(urljoin(API_URL, f"cameras/{cam_id}/token"), headers=admin_headers, timeout=5).json()[
5051
"access_token"
5152
]
53+
return {"token": cam_token, "pose_id": pose_id, "camera_id": cam_id}
54+
55+
56+
@pytest.fixture(scope="session")
57+
def cam_token(cam_setup):
58+
return itemgetter("token")(cam_setup)
59+
60+
61+
@pytest.fixture(scope="session")
62+
def cam_pose_id(cam_setup):
63+
return itemgetter("pose_id")(cam_setup)
64+
65+
66+
@pytest.fixture(scope="session")
67+
def cam_id(cam_setup):
68+
return itemgetter("camera_id")(cam_setup)
5269

5370

5471
@pytest.fixture(scope="session")

client/tests/test_client.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,35 @@ def test_client_constructor(token, host, timeout, expected_error):
2525
Client(token, host, timeout=timeout)
2626

2727

28+
def test_get_current_poses_camera(cam_token, cam_pose_id):
29+
cam_client = Client(cam_token, "http://localhost:5050", timeout=10)
30+
response = cam_client.get_current_poses()
31+
assert response.status_code == 200, response.__dict__
32+
poses = response.json()
33+
assert isinstance(poses, list)
34+
assert any(pose["id"] == cam_pose_id for pose in poses)
35+
36+
37+
def test_get_current_poses_admin(cam_id, cam_pose_id):
38+
admin_client = Client(pytest.admin_token, "http://localhost:5050", timeout=10)
39+
response = admin_client.get_current_poses(camera_id=cam_id)
40+
assert response.status_code == 200, response.__dict__
41+
payload = response.json()
42+
assert isinstance(payload.get("poses"), list)
43+
assert any(pose["id"] == cam_pose_id for pose in payload["poses"])
44+
45+
46+
def test_update_pose_camera(cam_token, cam_pose_id):
47+
cam_client = Client(cam_token, "http://localhost:5050", timeout=10)
48+
with pytest.raises(ValueError, match="Either azimuth or patrol_id must be provided"):
49+
cam_client.update_pose(cam_pose_id)
50+
response = cam_client.update_pose(cam_pose_id, azimuth=123.4)
51+
assert response.status_code == 200, response.__dict__
52+
assert response.json()["azimuth"] == 123.4
53+
54+
2855
@pytest.fixture(scope="session")
29-
def test_cam_workflow(cam_token, mock_img):
56+
def test_cam_workflow(cam_token, cam_pose_id, mock_img):
3057
cam_client = Client(cam_token, "http://localhost:5050", timeout=10)
3158
response = cam_client.heartbeat()
3259
assert response.status_code == 200
@@ -37,14 +64,18 @@ def test_cam_workflow(cam_token, mock_img):
3764
assert isinstance(response.json()["last_image"], str)
3865
# Check that adding bboxes works
3966
with pytest.raises(ValueError, match="bboxes must be a non-empty list of tuples"):
40-
cam_client.create_detection(mock_img, 123.2, None)
67+
cam_client.create_detection(mock_img, None, pose_id=cam_pose_id)
4168
with pytest.raises(ValueError, match="bboxes must be a non-empty list of tuples"):
42-
cam_client.create_detection(mock_img, 123.2, [])
43-
response = cam_client.create_detection(mock_img, 123.2, [(0, 0, 1.0, 0.9, 0.5)], pose_id=1)
69+
cam_client.create_detection(mock_img, [], pose_id=cam_pose_id)
70+
response = cam_client.create_detection(mock_img, [(0, 0, 1.0, 0.9, 0.5)], pose_id=cam_pose_id)
4471
assert response.status_code == 201, response.__dict__
45-
response = cam_client.create_detection(mock_img, 123.2, [(0, 0, 1.0, 0.9, 0.5), (0.2, 0.2, 0.7, 0.7, 0.8)])
72+
response = cam_client.create_detection(
73+
mock_img,
74+
[(0, 0, 1.0, 0.9, 0.5), (0.2, 0.2, 0.7, 0.7, 0.8)],
75+
pose_id=cam_pose_id,
76+
)
4677
assert response.status_code == 201, response.__dict__
47-
response = cam_client.create_detection(mock_img, 123.2, [(0, 0, 1.0, 0.9, 0.5)])
78+
response = cam_client.create_detection(mock_img, [(0, 0, 1.0, 0.9, 0.5)], pose_id=cam_pose_id)
4879
assert response.status_code == 201, response.__dict__
4980
return response.json()["id"]
5081

@@ -81,4 +112,4 @@ def test_user_workflow(test_cam_workflow, user_token):
81112
assert len(response.json()) == 1
82113
response = user_client.fetch_sequences_detections(response.json()[0]["id"])
83114
assert response.status_code == 200, response.__dict__
84-
assert len(response.json()) == 3
115+
assert len(response.json()) == 4

scripts/test_e2e.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def main(args):
9595
"azimuth": 45,
9696
}
9797
pose_id = api_request("post", f"{args.endpoint}/poses/", agent_auth, payload)["id"]
98+
pose_azimuth = payload["azimuth"]
9899

99100
# Take a picture
100101
file_bytes = requests.get("https://pyronear.org/img/logo.png", timeout=5).content
@@ -149,7 +150,7 @@ def main(args):
149150
assert sequence["camera_id"] == cam_id
150151
assert sequence["started_at"] == response.json()["created_at"]
151152
assert sequence["last_seen_at"] > sequence["started_at"]
152-
assert sequence["camera_azimuth"] == response.json()["azimuth"]
153+
assert sequence["camera_azimuth"] == pose_azimuth
153154
# Fetch the latest sequence
154155
assert len(api_request("get", f"{args.endpoint}/sequences/unlabeled/latest", agent_auth)) == 1
155156
# Fetch from date

src/app/api/api_v1/endpoints/alerts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[i
3737
select(AlertSequence.alert_id, Sequence)
3838
.join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id))
3939
.where(AlertSequence.alert_id.in_(alert_ids)) # type: ignore[attr-defined]
40-
.order_by(AlertSequence.alert_id, desc(cast(Any, Sequence.last_seen_at)))
40+
.order_by(cast(Any, AlertSequence.alert_id), desc(cast(Any, Sequence.last_seen_at)))
4141
)
4242
res = await session.exec(seq_stmt)
4343
for alert_id, sequence in res.all():

0 commit comments

Comments
 (0)