|
6 | 6 | from datetime import datetime, timedelta |
7 | 7 | from functools import partial |
8 | 8 | from pathlib import Path |
9 | | -from typing import Any, List, Optional, Tuple |
| 9 | +from typing import Any, List, Optional |
10 | 10 |
|
11 | 11 | import anyio |
12 | 12 | from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Security |
|
55 | 55 | get_query_registry, |
56 | 56 | get_serialization_registry, |
57 | 57 | get_validation_registry, |
| 58 | + shape_param, |
58 | 59 | slice_, |
59 | 60 | ) |
60 | 61 | from .file_response_with_range import FileResponseWithRange |
@@ -1290,39 +1291,32 @@ async def put_array_block( |
1290 | 1291 | return json_or_msgpack(request, None) |
1291 | 1292 |
|
1292 | 1293 |
|
1293 | | -@router.patch("/array/append/{path:path}") |
1294 | | -async def patch_array_block( |
| 1294 | +@router.patch("/array/full/{path:path}") |
| 1295 | +async def patch_array_full( |
1295 | 1296 | request: Request, |
1296 | | - shape: str, |
1297 | | - axis: int, |
| 1297 | + slice=Depends(slice_), |
| 1298 | + shape=Depends(shape_param), |
| 1299 | + grow: bool = False, |
1298 | 1300 | entry=SecureEntry( |
1299 | 1301 | scopes=["write:data"], |
1300 | | - structure_families={StructureFamily.array, StructureFamily.sparse}, |
| 1302 | + structure_families={StructureFamily.array}, |
1301 | 1303 | ), |
1302 | 1304 | deserialization_registry=Depends(get_deserialization_registry), |
1303 | 1305 | ): |
1304 | | - if not hasattr(entry, "append_block"): |
| 1306 | + if slice is None: |
| 1307 | + slice = ... |
| 1308 | + if not hasattr(entry, "patch"): |
1305 | 1309 | raise HTTPException( |
1306 | 1310 | status_code=HTTP_405_METHOD_NOT_ALLOWED, |
1307 | 1311 | detail="This node cannot accept array data.", |
1308 | 1312 | ) |
1309 | 1313 |
|
1310 | 1314 | dtype = entry.structure().data_type.to_numpy_dtype() |
1311 | | - shape_tuple: Tuple[int, ...] = tuple(map(int, shape.split(","))) |
1312 | 1315 | body = await request.body() |
1313 | 1316 | media_type = request.headers["content-type"] |
1314 | | - if entry.structure_family == "array": |
1315 | | - # dtype = entry.structure().data_type.to_numpy_dtype() |
1316 | | - # _, shape = slice_and_shape_from_block_and_chunks( |
1317 | | - # block, entry.structure().chunks |
1318 | | - # ) |
1319 | | - deserializer = deserialization_registry.dispatch("array", media_type) |
1320 | | - data = await ensure_awaitable(deserializer, body, dtype, shape_tuple) |
1321 | | - elif entry.structure_family == "sparse": # TODO: Handle sparse |
1322 | | - raise NotImplementedError(entry.structure_family) |
1323 | | - else: |
1324 | | - raise NotImplementedError(entry.structure_family) |
1325 | | - await ensure_awaitable(entry.append_block, data, axis) |
| 1317 | + deserializer = deserialization_registry.dispatch("array", media_type) |
| 1318 | + data = await ensure_awaitable(deserializer, body, dtype, shape) |
| 1319 | + await ensure_awaitable(entry.patch, data, slice, grow) |
1326 | 1320 | return json_or_msgpack(request, None) |
1327 | 1321 |
|
1328 | 1322 |
|
|
0 commit comments