Skip to content

Commit aa716e4

Browse files
Merge pull request #1727 from OceanParcels/v/api
API changes: `particlefile.py` and other touchups
2 parents 9907c84 + f03f80e commit aa716e4

File tree

9 files changed

+186
-72
lines changed

9 files changed

+186
-72
lines changed

parcels/grid.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -279,19 +279,6 @@ class CStructuredGrid(Structure):
279279
)
280280
return self._cstruct
281281

282-
def lon_grid_to_target(self):
283-
if self.lon_remapping:
284-
self._lon = self.lon_remapping.to_target(self.lon)
285-
286-
def lon_grid_to_source(self):
287-
if self.lon_remapping:
288-
self._lon = self.lon_remapping.to_source(self.lon)
289-
290-
def lon_particle_to_target(self, lon):
291-
if self.lon_remapping:
292-
return self.lon_remapping.particle_to_target(lon)
293-
return lon
294-
295282
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
296283
def check_zonal_periodic(self, *args, **kwargs):
297284
return self._check_zonal_periodic(*args, **kwargs)

parcels/kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def check_fieldsets_in_kernels(self, pyfunc):
367367
)
368368
elif pyfunc is AdvectionAnalytical:
369369
if self.fieldset.particlefile is not None:
370-
self.fieldset.particlefile.analytical = True
370+
self.fieldset.particlefile._is_analytical = True
371371
if self._ptype.uses_jit:
372372
raise NotImplementedError("Analytical Advection only works in Scipy mode")
373373
if self._fieldset.U.interp_method != "cgrid_velocity":

parcels/particle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@ class Variable:
2929
"""
3030

3131
def __init__(self, name, dtype=np.float32, initial=0, to_write: bool | Literal["once"] = True):
32-
self.name = name
32+
self._name = name
3333
self.dtype = dtype
3434
self.initial = initial
3535
self.to_write = to_write
3636

37+
@property
38+
def name(self):
39+
return self._name
40+
3741
def __get__(self, instance, cls):
3842
if instance is None:
3943
return self

parcels/particledata.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from parcels._compat import MPI, KMeans
8+
from parcels.tools._helpers import deprecated
89
from parcels.tools.statuscodes import StatusCode
910

1011

@@ -228,12 +229,15 @@ def __len__(self):
228229
"""Return the length, in terms of 'number of elements, of a ParticleData instance."""
229230
return self._ncount
230231

232+
@deprecated(
233+
"Use iter(...) instead, or just use the object in an iterator context (e.g. for p in particledata: ...)."
234+
) # TODO: Remove 6 months after v3.1.0 (or 9 months; doesn't contribute to code debt)
231235
def iterator(self):
232-
return ParticleDataIterator(self)
236+
return iter(self)
233237

234238
def __iter__(self):
235239
"""Return an Iterator that allows for forward iteration over the elements in the ParticleData (e.g. `for p in pset:`)."""
236-
return self.iterator()
240+
return ParticleDataIterator(self)
237241

238242
def __getitem__(self, index):
239243
"""Get a particle object from the ParticleData instance based on its index."""

parcels/particlefile.py

Lines changed: 106 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import parcels
1212
from parcels._compat import MPI
13+
from parcels.tools._helpers import deprecated, deprecated_made_private
1314
from parcels.tools.warnings import FileWarning
1415

1516
__all__ = ["ParticleFile"]
@@ -46,31 +47,24 @@ class ParticleFile:
4647
ParticleFile object that can be used to write particle data to file
4748
"""
4849

49-
outputdt = None
50-
particleset = None
51-
parcels_mesh = None
52-
time_origin = None
53-
lonlatdepth_dtype = None
54-
5550
def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True):
56-
self.outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
57-
self.chunks = chunks
58-
self.particleset = particleset
59-
self.parcels_mesh = "spherical"
51+
self._outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
52+
self._chunks = chunks
53+
self._particleset = particleset
54+
self._parcels_mesh = "spherical"
6055
if self.particleset.fieldset is not None:
61-
self.parcels_mesh = self.particleset.fieldset.gridset.grids[0].mesh
62-
self.time_origin = self.particleset.time_origin
56+
self._parcels_mesh = self.particleset.fieldset.gridset.grids[0].mesh
6357
self.lonlatdepth_dtype = self.particleset.particledata.lonlatdepth_dtype
64-
self.maxids = 0
65-
self.pids_written = {}
66-
self.create_new_zarrfile = create_new_zarrfile
67-
self.vars_to_write = {}
58+
self._maxids = 0
59+
self._pids_written = {}
60+
self._create_new_zarrfile = create_new_zarrfile
61+
self._vars_to_write = {}
6862
for var in self.particleset.particledata.ptype.variables:
6963
if var.to_write:
7064
self.vars_to_write[var.name] = var.dtype
71-
self.mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
65+
self._mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
7266
self.particleset.fieldset._particlefile = self
73-
self.analytical = False # Flag to indicate if ParticleFile is used for analytical trajectories
67+
self._is_analytical = False # Flag to indicate if ParticleFile is used for analytical trajectories
7468

7569
# Reset obs_written of each particle, in case new ParticleFile created for a ParticleSet
7670
particleset.particledata.setallvardata("obs_written", 0)
@@ -80,11 +74,11 @@ def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_z
8074
"Conventions": "CF-1.6/CF-1.7",
8175
"ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0",
8276
"parcels_version": parcels.__version__,
83-
"parcels_mesh": self.parcels_mesh,
77+
"parcels_mesh": self._parcels_mesh,
8478
}
8579

8680
# Create dictionary to translate datatypes and fill_values
87-
self.fill_value_map = {
81+
self._fill_value_map = {
8882
np.float16: np.nan,
8983
np.float32: np.nan,
9084
np.float64: np.nan,
@@ -103,23 +97,82 @@ def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_z
10397
# But we need to handle incompatibility with MPI mode for now:
10498
if MPI and MPI.COMM_WORLD.Get_size() > 1:
10599
raise ValueError("Currently, MPI mode is not compatible with directly passing a Zarr store.")
106-
self.fname = name
100+
fname = name
107101
else:
108102
extension = os.path.splitext(str(name))[1]
109103
if extension in [".nc", ".nc4"]:
110104
raise RuntimeError(
111105
"Output in NetCDF is not supported anymore. Use .zarr extension for ParticleFile name."
112106
)
113107
if MPI and MPI.COMM_WORLD.Get_size() > 1:
114-
self.fname = os.path.join(name, f"proc{self.mpi_rank:02d}.zarr")
108+
fname = os.path.join(name, f"proc{self._mpi_rank:02d}.zarr")
115109
if extension in [".zarr"]:
116110
warnings.warn(
117-
f"The ParticleFile name contains .zarr extension, but zarr files will be written per processor in MPI mode at {self.fname}",
111+
f"The ParticleFile name contains .zarr extension, but zarr files will be written per processor in MPI mode at {fname}",
118112
FileWarning,
119113
stacklevel=2,
120114
)
121115
else:
122-
self.fname = name if extension in [".zarr"] else f"{name}.zarr"
116+
fname = name if extension in [".zarr"] else f"{name}.zarr"
117+
self._fname = fname
118+
119+
@property
120+
def create_new_zarrfile(self):
121+
return self._create_new_zarrfile
122+
123+
@property
124+
def outputdt(self):
125+
return self._outputdt
126+
127+
@property
128+
def chunks(self):
129+
return self._chunks
130+
131+
@property
132+
def particleset(self):
133+
return self._particleset
134+
135+
@property
136+
def fname(self):
137+
return self._fname
138+
139+
@property
140+
def vars_to_write(self):
141+
return self._vars_to_write
142+
143+
@property
144+
def time_origin(self):
145+
return self.particleset.time_origin
146+
147+
@property
148+
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
149+
def parcels_mesh(self):
150+
return self._parcels_mesh
151+
152+
@property
153+
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
154+
def maxids(self):
155+
return self._maxids
156+
157+
@property
158+
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
159+
def pids_written(self):
160+
return self._pids_written
161+
162+
@property
163+
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
164+
def mpi_rank(self):
165+
return self._mpi_rank
166+
167+
@property
168+
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
169+
def fill_value_map(self):
170+
return self._fill_value_map
171+
172+
@property
173+
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
174+
def analytical(self):
175+
return self._is_analytical
123176

124177
def _create_variables_attribute_dict(self):
125178
"""Creates the dictionary with variable attributes.
@@ -133,7 +186,7 @@ def _create_variables_attribute_dict(self):
133186
"trajectory": {
134187
"long_name": "Unique identifier for each particle",
135188
"cf_role": "trajectory_id",
136-
"_FillValue": self.fill_value_map[np.int64],
189+
"_FillValue": self._fill_value_map[np.int64],
137190
},
138191
"time": {"long_name": "", "standard_name": "time", "units": "seconds", "axis": "T"},
139192
"lon": {"long_name": "", "standard_name": "longitude", "units": "degrees_east", "axis": "X"},
@@ -147,14 +200,17 @@ def _create_variables_attribute_dict(self):
147200
for vname in self.vars_to_write:
148201
if vname not in ["time", "lat", "lon", "depth", "id"]:
149202
attrs[vname] = {
150-
"_FillValue": self.fill_value_map[self.vars_to_write[vname]],
203+
"_FillValue": self._fill_value_map[self.vars_to_write[vname]],
151204
"long_name": "",
152205
"standard_name": vname,
153206
"units": "unknown",
154207
}
155208

156209
return attrs
157210

211+
@deprecated(
212+
"ParticleFile.metadata is a dictionary. Use `ParticleFile.metadata['key'] = ...` or other dictionary methods instead."
213+
) # TODO: Remove 6 months after v3.1.0
158214
def add_metadata(self, name, message):
159215
"""Add metadata to :class:`parcels.particleset.ParticleSet`.
160216
@@ -175,21 +231,25 @@ def _convert_varout_name(self, var):
175231
else:
176232
return var
177233

178-
def write_once(self, var):
234+
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
235+
def write_once(self, *args, **kwargs):
236+
return self._write_once(*args, **kwargs)
237+
238+
def _write_once(self, var):
179239
return self.particleset.particledata.ptype[var].to_write == "once"
180240

181241
def _extend_zarr_dims(self, Z, store, dtype, axis):
182242
if axis == 1:
183-
a = np.full((Z.shape[0], self.chunks[1]), self.fill_value_map[dtype], dtype=dtype)
243+
a = np.full((Z.shape[0], self.chunks[1]), self._fill_value_map[dtype], dtype=dtype)
184244
obs = zarr.group(store=store, overwrite=False)["obs"]
185245
if len(obs) == Z.shape[1]:
186246
obs.append(np.arange(self.chunks[1]) + obs[-1] + 1)
187247
else:
188-
extra_trajs = self.maxids - Z.shape[0]
248+
extra_trajs = self._maxids - Z.shape[0]
189249
if len(Z.shape) == 2:
190-
a = np.full((extra_trajs, Z.shape[1]), self.fill_value_map[dtype], dtype=dtype)
250+
a = np.full((extra_trajs, Z.shape[1]), self._fill_value_map[dtype], dtype=dtype)
191251
else:
192-
a = np.full((extra_trajs,), self.fill_value_map[dtype], dtype=dtype)
252+
a = np.full((extra_trajs,), self._fill_value_map[dtype], dtype=dtype)
193253
Z.append(a, axis=axis)
194254
zarr.consolidate_metadata(store)
195255

@@ -221,11 +281,11 @@ def write(self, pset, time, indices=None):
221281

222282
if len(indices_to_write) > 0:
223283
pids = pset.particledata.getvardata("id", indices_to_write)
224-
to_add = sorted(set(pids) - set(self.pids_written.keys()))
284+
to_add = sorted(set(pids) - set(self._pids_written.keys()))
225285
for i, pid in enumerate(to_add):
226-
self.pids_written[pid] = self.maxids + i
227-
ids = np.array([self.pids_written[p] for p in pids], dtype=int)
228-
self.maxids = len(self.pids_written)
286+
self._pids_written[pid] = self._maxids + i
287+
ids = np.array([self._pids_written[p] for p in pids], dtype=int)
288+
self._maxids = len(self._pids_written)
229289

230290
once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0]
231291
if len(once_ids) > 0:
@@ -234,7 +294,7 @@ def write(self, pset, time, indices=None):
234294

235295
if self.create_new_zarrfile:
236296
if self.chunks is None:
237-
self.chunks = (len(ids), 1)
297+
self._chunks = (len(ids), 1)
238298
if pset._repeatpclass is not None and self.chunks[0] < 1e4:
239299
warnings.warn(
240300
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
@@ -243,37 +303,37 @@ def write(self, pset, time, indices=None):
243303
FileWarning,
244304
stacklevel=2,
245305
)
246-
if (self.maxids > len(ids)) or (self.maxids > self.chunks[0]):
247-
arrsize = (self.maxids, self.chunks[1])
306+
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
307+
arrsize = (self._maxids, self.chunks[1])
248308
else:
249309
arrsize = (len(ids), self.chunks[1])
250310
ds = xr.Dataset(
251311
attrs=self.metadata,
252312
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
253313
)
254314
attrs = self._create_variables_attribute_dict()
255-
obs = np.zeros((self.maxids), dtype=np.int32)
315+
obs = np.zeros((self._maxids), dtype=np.int32)
256316
for var in self.vars_to_write:
257317
varout = self._convert_varout_name(var)
258318
if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate
259-
if self.write_once(var):
319+
if self._write_once(var):
260320
data = np.full(
261321
(arrsize[0],),
262-
self.fill_value_map[self.vars_to_write[var]],
322+
self._fill_value_map[self.vars_to_write[var]],
263323
dtype=self.vars_to_write[var],
264324
)
265325
data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
266326
dims = ["trajectory"]
267327
else:
268328
data = np.full(
269-
arrsize, self.fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
329+
arrsize, self._fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
270330
)
271331
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
272332
dims = ["trajectory", "obs"]
273333
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
274-
ds[varout].encoding["chunks"] = self.chunks[0] if self.write_once(var) else self.chunks
334+
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks
275335
ds.to_zarr(self.fname, mode="w")
276-
self.create_new_zarrfile = False
336+
self._create_new_zarrfile = False
277337
else:
278338
# Either use the store that was provided directly or create a DirectoryStore:
279339
if issubclass(type(self.fname), zarr.storage.Store):
@@ -284,9 +344,9 @@ def write(self, pset, time, indices=None):
284344
obs = pset.particledata.getvardata("obs_written", indices_to_write)
285345
for var in self.vars_to_write:
286346
varout = self._convert_varout_name(var)
287-
if self.maxids > Z[varout].shape[0]:
347+
if self._maxids > Z[varout].shape[0]:
288348
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0)
289-
if self.write_once(var):
349+
if self._write_once(var):
290350
if len(once_ids) > 0:
291351
Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
292352
else:

0 commit comments

Comments
 (0)