15
15
16
16
import abc
17
17
import collections
18
+ from collections .abc import Mapping
18
19
import concurrent .futures
19
20
import dataclasses
20
21
import io
21
22
import random
22
23
import sys
23
24
import threading
24
25
import time
25
- from typing import Any , Callable , Iterable , Iterator , Literal , Sequence , Tuple , Type , Union
26
+ from typing import Annotated , Any , Callable , Iterable , Iterator , Literal , Sequence , Tuple , Type , Union
26
27
27
28
from langfun .core import component
28
29
import pyglove as pg
@@ -143,43 +144,33 @@ def with_retry(
143
144
A function with the same signature of the input function, with the retry
144
145
capability.
145
146
"""
146
- rand = random if seed is None else random .Random (seed )
147
147
148
- def _func (* args , ** kwargs ) -> Any :
149
- def base_interval () -> int :
150
- if isinstance (retry_interval , tuple ):
151
- return rand .randint (retry_interval [0 ], retry_interval [1 ])
152
- else :
153
- assert isinstance (retry_interval , int )
154
- return retry_interval
155
-
156
- def next_wait_interval (attempt : int ) -> float :
157
- if not exponential_backoff :
158
- attempt = 1
159
- return min (max_retry_interval , base_interval () * (2 ** (attempt - 1 )))
160
-
161
- wait_intervals = []
162
- errors = []
163
- while True :
164
- with pg .catch_errors (retry_on_errors ) as error_context :
165
- return func (* args , ** kwargs )
148
+ def _func (* args , ** kwargs ):
149
+ job = Job (
150
+ func ,
151
+ args ,
152
+ kwargs ,
153
+ retry_on_errors = retry_on_errors ,
154
+ max_attempts = max_attempts ,
155
+ retry_interval = retry_interval ,
156
+ exponential_backoff = exponential_backoff ,
157
+ max_retry_interval = max_retry_interval ,
158
+ seed = seed ,
159
+ )
160
+ job ()
161
+ if job .error :
162
+ raise job .error
163
+ return job .result
166
164
167
- # Branch when errors are met for retry.
168
- errors .append (error_context .error )
169
- if len (errors ) < max_attempts :
170
- wait_interval = next_wait_interval (len (errors ))
171
- wait_intervals .append (wait_interval )
165
+ return _func
172
166
173
- pg .logging .warning (
174
- f'Calling { func !r} encountered { error_context .error !r} '
175
- f'(attempts={ len (errors )} ), retrying in { wait_interval } seconds...'
176
- )
177
167
178
- time .sleep (wait_interval )
179
- else :
180
- raise RetryError (func , errors , wait_intervals )
168
+ class RetryEntry (pg .Object ):
169
+ """Retry entry."""
181
170
182
- return _func
171
+ call_interval : float
172
+ error : BaseException | None = None
173
+ wait_interval : float = 0.
183
174
184
175
185
176
def concurrent_execute (
@@ -197,6 +188,7 @@ def concurrent_execute(
197
188
retry_interval : int | tuple [int , int ] = (5 , 60 ),
198
189
exponential_backoff : bool = True ,
199
190
max_retry_interval : int = 300 ,
191
+ return_jobs : bool = False ,
200
192
) -> list [Any ]:
201
193
"""Executes a function concurrently under current component context.
202
194
@@ -220,32 +212,52 @@ def concurrent_execute(
220
212
max_retry_interval: The max retry interval in seconds. This is useful when
221
213
the retry interval is exponential, to avoid the wait time to grow
222
214
exponentially.
215
+ return_jobs: If True, return a list of `Job` objects. Otherwise, return a
216
+ list of outputs.
223
217
224
218
Returns:
225
219
A list of ouputs. Each is the return value of `func` based on the input
226
220
value. Order is preserved.
227
221
"""
228
- if retry_on_errors is not None :
229
- func = with_retry (
230
- func ,
231
- retry_on_errors ,
232
- max_attempts = max_attempts ,
233
- retry_interval = retry_interval ,
234
- exponential_backoff = exponential_backoff ,
235
- max_retry_interval = max_retry_interval ,
222
+ jobs = []
223
+ for inputs in parallel_inputs :
224
+ jobs .append (
225
+ Job (
226
+ func ,
227
+ (inputs ,),
228
+ retry_on_errors = retry_on_errors ,
229
+ max_attempts = max_attempts ,
230
+ retry_interval = retry_interval ,
231
+ exponential_backoff = exponential_backoff ,
232
+ max_retry_interval = max_retry_interval ,
233
+ )
236
234
)
237
235
238
236
# NOTE(daiyip): when executor is not specified and max_worker is 1,
239
237
# we don't need to create a executor pool. Instead, the inputs will be
240
238
# processed by the user function in sequence within the current thread.
241
239
if executor is None and max_workers == 1 :
242
- return [func (i ) for i in parallel_inputs ]
240
+ for job in jobs :
241
+ job ()
242
+ if job .error :
243
+ raise job .error
244
+ return jobs if return_jobs else [job .result for job in jobs ]
243
245
244
246
shutdown_after_finish = executor is None
245
247
executor = _executor_pool .executor_from (executor , max_workers = max_workers )
246
248
247
249
try :
248
- return list (executor .map (with_context_access (func ), parallel_inputs ))
250
+ executed_jobs = list (
251
+ executor .map (
252
+ lambda job : job (), [with_context_access (job ) for job in jobs ]
253
+ )
254
+ )
255
+ for job in executed_jobs :
256
+ if job .error :
257
+ raise job .error
258
+ return (
259
+ executed_jobs if return_jobs else [job .result for job in executed_jobs ]
260
+ )
249
261
finally :
250
262
if shutdown_after_finish :
251
263
# Do not wait threads to finish if they are timed out.
@@ -257,9 +269,61 @@ class Job:
257
269
"""Thread pool job."""
258
270
259
271
func : Callable [[Any ], Any ]
260
- arg : Any
272
+ args : Sequence [Any ] = ()
273
+ kwargs : Mapping [str , Any ] = dataclasses .field (default_factory = dict )
274
+ _ : dataclasses .KW_ONLY
275
+
261
276
result : Any = pg .MISSING_VALUE
262
- error : BaseException | None = None
277
+ error : Annotated [
278
+ BaseException | None ,
279
+ 'The non-retryable error encountered during the job execution.' ,
280
+ ] = None
281
+ retry_entries : Annotated [
282
+ Sequence [RetryEntry ], 'Records of retry attempts.'
283
+ ] = dataclasses .field (default_factory = list )
284
+
285
+ retry_on_errors : Annotated [
286
+ Sequence [Type [BaseException ] | str ],
287
+ (
288
+ 'A sequence of exception types or tuples of exception type and error '
289
+ 'messages (described in regular expression) as the desired exception '
290
+ 'types to retry.'
291
+ ),
292
+ ] = ()
293
+ max_attempts : Annotated [
294
+ int , 'Max number of attempts if an error to retry is encountered.'
295
+ ] = 5
296
+ retry_interval : Annotated [
297
+ int | tuple [int , int ],
298
+ (
299
+ 'The (base) retry interval in seconds. If a tuple, the retry '
300
+ 'interval will be randomly chosen between the first and the second '
301
+ 'element of the tuple.'
302
+ ),
303
+ ] = (5 , 60 )
304
+ exponential_backoff : Annotated [
305
+ bool ,
306
+ (
307
+ 'If True, exponential wait time will be applied on top of the base '
308
+ 'retry interval.'
309
+ ),
310
+ ] = True
311
+ max_retry_interval : Annotated [
312
+ int ,
313
+ (
314
+ 'The max retry interval in seconds. This is useful when the retry '
315
+ 'interval is exponential, to avoid the wait time to grow '
316
+ 'exponentially.'
317
+ ),
318
+ ] = 300
319
+ seed : Annotated [
320
+ int | None ,
321
+ (
322
+ 'Random seed to generate retry interval. If None, the seed will be'
323
+ ' determined based on current time.'
324
+ ),
325
+ ] = None
326
+
263
327
timeit : pg .object_utils .TimeIt = dataclasses .field (
264
328
default_factory = lambda : pg .object_utils .TimeIt ('job' )
265
329
)
@@ -269,14 +333,70 @@ def elapse(self) -> float:
269
333
"""Returns the running time in seconds since the job get started."""
270
334
return self .timeit .elapse
271
335
272
- def __call__ (self ) -> Any :
336
+ def _retry_call (self ) -> 'Job' :
337
+ """Retries func call on args."""
338
+ rand = random if self .seed is None else random .Random (self .seed )
339
+
340
+ def base_interval () -> int :
341
+ if isinstance (self .retry_interval , tuple ):
342
+ return rand .randint (* self .retry_interval )
343
+ else :
344
+ assert isinstance (self .retry_interval , int )
345
+ return self .retry_interval
346
+
347
+ def next_wait_interval (attempt : int ) -> float :
348
+ if not self .exponential_backoff :
349
+ attempt = 1
350
+ return min (
351
+ self .max_retry_interval , base_interval () * (2 ** (attempt - 1 ))
352
+ )
353
+
354
+ retry_entries = []
355
+ wait_interval = 0
356
+ while True :
357
+ with pg .catch_errors (self .retry_on_errors ) as error_context :
358
+ begin_time = time .time ()
359
+ self .result = self .func (* self .args , ** self .kwargs )
360
+
361
+ end_time = time .time ()
362
+ retry_entries .append (RetryEntry (
363
+ call_interval = end_time - begin_time ,
364
+ wait_interval = wait_interval ,
365
+ error = error_context .error ,
366
+ ))
367
+ if error_context .error is None :
368
+ self .retry_entries = retry_entries
369
+ return self
370
+
371
+ # Branch when errors are met for retry.
372
+ if len (retry_entries ) < self .max_attempts :
373
+ wait_interval = next_wait_interval (len (retry_entries ))
374
+
375
+ pg .logging .warning (
376
+ f'Calling { self .func !r} encountered { error_context .error !r} '
377
+ f'(attempts={ len (retry_entries )} ), retrying in '
378
+ f'{ wait_interval } seconds...'
379
+ )
380
+
381
+ time .sleep (wait_interval )
382
+ else :
383
+ errors = [e .error for e in retry_entries ]
384
+ # First wait interval is 0.
385
+ wait_intervals = [e .wait_interval for e in retry_entries [1 :]]
386
+ raise RetryError (self .func , errors , wait_intervals )
387
+
388
+ def __call__ (self ) -> 'Job' :
389
+ if getattr (self , '_has_call' , False ):
390
+ raise ValueError ('Job can only be called once.' )
391
+ self ._has_call = True
273
392
try :
274
393
with self .timeit :
275
- self .result = self .func (self .arg )
276
- return self .result
394
+ if self .retry_on_errors :
395
+ return self ._retry_call ()
396
+ self .result = self .func (* self .args , ** self .kwargs )
277
397
except BaseException as e : # pylint: disable=broad-exception-caught
278
398
self .error = e
279
- return e
399
+ return self
280
400
281
401
def mark_canceled (self , error : BaseException ) -> None :
282
402
"""Marks the job as canceled."""
@@ -537,7 +657,8 @@ def concurrent_map(
537
657
max_attempts : int = 5 ,
538
658
retry_interval : int | tuple [int , int ] = (5 , 60 ),
539
659
exponential_backoff : bool = True ,
540
- ) -> Iterator [tuple [Any , Any , BaseException | None ]]:
660
+ return_jobs : bool = False ,
661
+ ) -> Iterator [Any ]:
541
662
"""Maps inputs to outptus via func concurrently under current context.
542
663
543
664
Args:
@@ -580,9 +701,10 @@ def concurrent_map(
580
701
of the tuple.
581
702
exponential_backoff: If True, exponential wait time will be applied on top
582
703
of the base retry interval.
704
+ return_jobs: If True, the returned iterator will emit `Job` objects.
583
705
584
706
Yields:
585
- An iterator of (input, output, error).
707
+ An iterator of (input, output, error) or Job object .
586
708
587
709
Raises:
588
710
Exception: Errors that are not in `silence_on_errors` or `retry_on_errors`,
@@ -592,15 +714,6 @@ def concurrent_map(
592
714
"""
593
715
# Internal usage logging.
594
716
595
- if retry_on_errors :
596
- func = with_retry (
597
- func ,
598
- retry_on_errors ,
599
- max_attempts = max_attempts ,
600
- retry_interval = retry_interval ,
601
- exponential_backoff = exponential_backoff ,
602
- )
603
-
604
717
status_fn = status_fn or (lambda p : { # pylint: disable=g-long-lambda
605
718
'Succeeded' : '%.2f%% (%d/%d)' % (
606
719
p .success_rate * 100 , p .succeeded , p .completed ),
@@ -615,7 +728,14 @@ def concurrent_map(
615
728
pending_futures = []
616
729
total = 0
617
730
for inputs in parallel_inputs :
618
- job = Job (func , inputs )
731
+ job = Job (
732
+ func ,
733
+ (inputs ,),
734
+ retry_on_errors = retry_on_errors ,
735
+ max_attempts = max_attempts ,
736
+ retry_interval = retry_interval ,
737
+ exponential_backoff = exponential_backoff ,
738
+ )
619
739
future = executor .submit (
620
740
with_context_access (job ),
621
741
)
@@ -668,7 +788,7 @@ def update_progress_bar(progress: Progress) -> None:
668
788
silence_on_errors and isinstance (job .error , silence_on_errors )):
669
789
raise job .error # pylint: disable=g-doc-exception
670
790
671
- yield job . arg , job .result , job .error
791
+ yield job if return_jobs else job . args [ 0 ] , job .result , job .error
672
792
progress .update (job )
673
793
update_progress_bar (progress )
674
794
ProgressBar .refresh ()
@@ -689,7 +809,7 @@ def update_progress_bar(progress: Progress) -> None:
689
809
if job .error is not None and not (
690
810
silence_on_errors and isinstance (job .error , silence_on_errors )):
691
811
raise job .error # pylint: disable=g-doc-exception
692
- yield job . arg , job .result , job .error
812
+ yield job if return_jobs else job . args [ 0 ] , job .result , job .error
693
813
progress .update (job )
694
814
update_progress_bar (progress )
695
815
completed_batch .add (future )
@@ -712,7 +832,7 @@ def update_progress_bar(progress: Progress) -> None:
712
832
and isinstance (job .error , silence_on_errors )):
713
833
raise job .error # pylint: disable=g-doc-exception
714
834
715
- yield job .arg , job .result , job .error
835
+ yield job .args [ 0 ] , job .result , job .error
716
836
progress .update (job )
717
837
update_progress_bar (progress )
718
838
else :
0 commit comments