@@ -137,7 +137,7 @@ def compute(self) -> int:
137
137
# _method: TaskFunc = pydantic.PrivateAttr()
138
138
_cache : tp .Any = pydantic .PrivateAttr (base .Sentinel ())
139
139
140
- def __getstate__ (self ) -> tp . Dict [str , tp .Any ]:
140
+ def __getstate__ (self ) -> dict [str , tp .Any ]:
141
141
out = super ().__getstate__ ()
142
142
out ["__pydantic_private__" ]["_cache" ] = base .Sentinel ()
143
143
return out
@@ -180,23 +180,27 @@ def clear_job(self) -> None:
180
180
(xpfolder / name ).unlink (missing_ok = True )
181
181
182
182
@contextlib .contextmanager
183
- def job_array (self , max_workers : int = 256 ) -> tp .Iterator [tp .List [tp .Any ]]:
183
+ def job_array (
184
+ self , max_workers : int = 256 , allow_empty : bool = False
185
+ ) -> tp .Iterator [list [tp .Any ]]:
184
186
"""Creates a list object to populate
185
187
The tasks in the list will be sent as a job array when exiting the context
186
188
187
189
Parameter
188
190
---------
189
191
max_workers: int
190
192
maximum number of jobs in the array that can be running at a given time
193
+ allow_empty: bool
194
+ if False, an exeption will be raised at the end of the context if the array is still empty
191
195
"""
192
196
executor = self .executor ()
193
- tasks : tp . List [tp .Any ] = []
197
+ tasks : list [tp .Any ] = []
194
198
yield tasks
195
- if not tasks :
199
+ if not tasks and not allow_empty :
196
200
raise RuntimeError (f"Nothing added to job array for { self .uid ()} " )
197
201
# verify unicity
198
202
uid_index : dict [str , int ] = {}
199
- infras : tp . List [TaskInfra ] = [getattr (t , self ._infra_name ) for t in tasks ]
203
+ infras : list [TaskInfra ] = [getattr (t , self ._infra_name ) for t in tasks ]
200
204
folder = self .uid_folder ()
201
205
for k , infra in enumerate (infras ):
202
206
uid = infra .uid ()
@@ -216,12 +220,12 @@ def job_array(self, max_workers: int = 256) -> tp.Iterator[tp.List[tp.Any]]:
216
220
self ._set_permissions (executor .folder )
217
221
name = self .uid ().split ("/" , maxsplit = 1 )[0 ]
218
222
# select jobs to run
219
- statuses : tp . Dict [Status , tp . List [TaskInfra ]] = collections .defaultdict (list )
223
+ statuses : dict [Status , list [TaskInfra ]] = collections .defaultdict (list )
220
224
for i in infras :
221
225
statuses [i .status ()].append (i )
222
226
i ._computed = True
223
227
missing = list (statuses ["not submitted" ])
224
- to_clear : tp . List [Status ] = []
228
+ to_clear : list [Status ] = []
225
229
if self ._effective_mode != "cached" :
226
230
to_clear .append ("failed" )
227
231
if self ._effective_mode == "force" :
@@ -496,7 +500,7 @@ def compute(self, y: int) -> int:
496
500
497
501
_array_executor : submitit .Executor | None = pydantic .PrivateAttr (None )
498
502
499
- def _exclude_from_cls_uid (self ) -> tp . List [str ]:
503
+ def _exclude_from_cls_uid (self ) -> list [str ]:
500
504
return ["." ] # not taken into accound for uid
501
505
502
506
# pylint: disable=unused-argument
0 commit comments