Skip to content

Commit 3901f7b

Browse files
committed
WIP: Rework to use PATCH /array/full
1 parent 9b88601 commit 3901f7b

File tree

7 files changed

+68
-71
lines changed

7 files changed

+68
-71
lines changed

tiled/adapters/zarr.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -174,41 +174,40 @@ async def write_block(
174174
-------
175175
176176
"""
177-
if slice is not ...:
178-
raise NotImplementedError
179177
block_slice, shape = slice_and_shape_from_block_and_chunks(
180178
block, self.structure().chunks
181179
)
182180
self._array[block_slice] = data
183181

184-
async def append_block(
182+
async def patch(
185183
self,
186184
data: NDArray[Any],
187-
axis: int,
188-
) -> List[int]:
189-
"""
190-
191-
Parameters
192-
----------
193-
data :
194-
block :
195-
slice :
196-
197-
Returns
198-
-------
199-
200-
"""
201-
202-
new_shape = list(self._array.shape)
203-
new_shape[axis] += list(data.shape)[axis] # Extend along axis
204-
205-
# Resize the Zarr array to accommodate new data
206-
self._array.resize(tuple(new_shape))
207-
208-
# Append the new data to the resized array
209-
# Slicing to place data at the end
210-
self._array[-data.shape[0] :] = data # noqa: E203
211-
return new_shape
185+
slice: NDSlice,
186+
grow: bool = False,
187+
) -> Tuple[int, ...]:
188+
"""
189+
Write data into a slice of the array, maybe growing it.
190+
191+
If the specified slice does not fit into the array, and grow=True, the
192+
array will be resize (grown, never shrunk) to fit it. The new shape is
193+
returned.
194+
"""
195+
current_shape = self._array.shape
196+
new_shape = list(current_shape)
197+
for i, (s, dim) in enumerate(zip(slice, current_shape)):
198+
if isinstance(s, int):
199+
new_shape[i] = max(new_shape[i], s)
200+
elif isinstance(s, builtins.slice) and isinstance(s.stop, int):
201+
new_shape[i] = max(new_shape[i], s.stop)
202+
new_shape_tuple = tuple(new_shape)
203+
if new_shape_tuple != current_shape:
204+
if grow:
205+
# Resize the Zarr array to accommodate new data
206+
self._array.resize(new_shape_tuple)
207+
else:
208+
raise ValueError(f"Slice does not fit into array shape {current_shape}")
209+
self._array[slice] = data
210+
return new_shape_tuple
212211

213212

214213
if sys.version_info < (3, 9):

tiled/catalog/adapter.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,15 +1041,13 @@ async def write_block(self, *args, **kwargs):
10411041
(await self.get_adapter()).write_block, *args, **kwargs
10421042
)
10431043

1044-
async def append_block(self, *args, **kwargs):
1044+
async def patch(self, *args, **kwargs):
10451045
# assumes a single DataSource (currently only supporting zarr)
10461046
async with self.context.session() as db:
10471047
try:
10481048
new_shape = await ensure_awaitable(
1049-
(await self.get_adapter()).append_block, *args, **kwargs
1049+
(await self.get_adapter()).patch, *args, **kwargs
10501050
)
1051-
if new_shape is None:
1052-
raise ValueError("No new shape returned from append_block.")
10531051
node = await db.get(orm.Node, self.node.id)
10541052
data_source = node.data_sources[0]
10551053
structure_row = await db.get(orm.Structure, data_source.structure_id)

tiled/client/array.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import dask
44
import dask.array
55
import numpy
6+
from numpy.typing import NDArray
67

8+
from ..type_aliases import NDSlice
79
from .base import BaseClient
810
from .utils import export_util, handle_error, params_from_slice
911

@@ -167,40 +169,39 @@ def write(self, array):
167169
)
168170
)
169171

170-
def write_block(self, array, block):
172+
def write_block(self, array, block, slice=...):
171173
handle_error(
172174
self.context.http_client.put(
173175
self.item["links"]["block"].format(*block),
174176
content=array.tobytes(),
175177
headers={"Content-Type": "application/octet-stream"},
178+
params=params_from_slice(slice),
176179
)
177180
)
178181

179-
def append_block(self, array: numpy.nd, axis):
182+
def patch(self, array: NDArray, slice: NDSlice, grow=False):
180183
"""
181-
Append a block to the array along the given axis. The block must have
182-
the same shape as the existing blocks along that axis.
183-
This method differs from `write_block` as it increases the size
184-
of the array along the given axis, while `write_block` overwrites
185-
the data in the block at the given index. This is useful for
186-
cases where you do not know ahead of time how many blocks you will
187-
eventually receive.
184+
Write data
188185
189186
Parameters
190187
----------
191188
array : array-like
192-
The block to append.
193-
axis : int
194-
The axis along which to append the block.
195-
196-
189+
The data to write
190+
slice : NDSlice
191+
Where to place this data in the array
192+
grow : bool
193+
Grow the array shape to fit the new slice, if necessary
197194
"""
198-
formatted_shape = ",".join(f"{value}" for value in array.shape)
195+
array_ = numpy.ascontiguousarray(array)
196+
params = params_from_slice(slice)
197+
params["shape"] = ",".join(map(str, array_.shape))
198+
params["grow"] = bool(grow)
199199
handle_error(
200200
self.context.http_client.patch(
201-
self.item["links"]["append"].format(formatted_shape, axis),
202-
content=array.tobytes(),
201+
self.item["links"]["full"],
202+
content=array_.tobytes(),
203203
headers={"Content-Type": "application/octet-stream"},
204+
params=params,
204205
)
205206
)
206207

tiled/server/dependencies.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ def expected_shape(
165165
return tuple(map(int, expected_shape.split(",")))
166166

167167

168+
def shape_param(
169+
shape: str = Query(..., min_length=1, pattern="^[0-9]+(,[0-9]+)*$|^scalar$"),
170+
):
171+
"Specify and parse a shape parameter."
172+
return tuple(map(int, shape.split(",")))
173+
174+
168175
def np_style_slicer(indices: tuple):
169176
return indices[0] if len(indices) == 1 else slice_func(*indices)
170177

tiled/server/links.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def links_for_array(structure_family, structure, base_url, path_str):
2020
block_template = ",".join(f"{{{index}}}" for index in range(len(structure.shape)))
2121
links["block"] = f"{base_url}/array/block/{path_str}?block={block_template}"
2222
links["full"] = f"{base_url}/array/full/{path_str}"
23-
links["append"] = f"{base_url}/array/append/{path_str}?shape={{0}}&axis={{1}}"
2423
return links
2524

2625

tiled/server/router.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import datetime, timedelta
77
from functools import partial
88
from pathlib import Path
9-
from typing import Any, List, Optional, Tuple
9+
from typing import Any, List, Optional
1010

1111
import anyio
1212
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Security
@@ -55,6 +55,7 @@
5555
get_query_registry,
5656
get_serialization_registry,
5757
get_validation_registry,
58+
shape_param,
5859
slice_,
5960
)
6061
from .file_response_with_range import FileResponseWithRange
@@ -1290,39 +1291,32 @@ async def put_array_block(
12901291
return json_or_msgpack(request, None)
12911292

12921293

1293-
@router.patch("/array/append/{path:path}")
1294-
async def patch_array_block(
1294+
@router.patch("/array/full/{path:path}")
1295+
async def patch_array_full(
12951296
request: Request,
1296-
shape: str,
1297-
axis: int,
1297+
slice=Depends(slice_),
1298+
shape=Depends(shape_param),
1299+
grow: bool = False,
12981300
entry=SecureEntry(
12991301
scopes=["write:data"],
1300-
structure_families={StructureFamily.array, StructureFamily.sparse},
1302+
structure_families={StructureFamily.array},
13011303
),
13021304
deserialization_registry=Depends(get_deserialization_registry),
13031305
):
1304-
if not hasattr(entry, "append_block"):
1306+
if slice is None:
1307+
slice = ...
1308+
if not hasattr(entry, "patch"):
13051309
raise HTTPException(
13061310
status_code=HTTP_405_METHOD_NOT_ALLOWED,
13071311
detail="This node cannot accept array data.",
13081312
)
13091313

13101314
dtype = entry.structure().data_type.to_numpy_dtype()
1311-
shape_tuple: Tuple[int, ...] = tuple(map(int, shape.split(",")))
13121315
body = await request.body()
13131316
media_type = request.headers["content-type"]
1314-
if entry.structure_family == "array":
1315-
# dtype = entry.structure().data_type.to_numpy_dtype()
1316-
# _, shape = slice_and_shape_from_block_and_chunks(
1317-
# block, entry.structure().chunks
1318-
# )
1319-
deserializer = deserialization_registry.dispatch("array", media_type)
1320-
data = await ensure_awaitable(deserializer, body, dtype, shape_tuple)
1321-
elif entry.structure_family == "sparse": # TODO: Handle sparse
1322-
raise NotImplementedError(entry.structure_family)
1323-
else:
1324-
raise NotImplementedError(entry.structure_family)
1325-
await ensure_awaitable(entry.append_block, data, axis)
1317+
deserializer = deserialization_registry.dispatch("array", media_type)
1318+
data = await ensure_awaitable(deserializer, body, dtype, shape)
1319+
await ensure_awaitable(entry.patch, data, slice, grow)
13261320
return json_or_msgpack(request, None)
13271321

13281322

tiled/server/schemas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ class ArrayLinks(pydantic.BaseModel):
210210
self: str
211211
full: str
212212
block: str
213-
append: str
214213

215214

216215
class AwkwardLinks(pydantic.BaseModel):

0 commit comments

Comments
 (0)