Skip to content

Commit 9d2a50b

Browse files
Da-Huanglangfun authors
authored and
langfun authors
committed
Track query retry so that we can count the time impact of the failed rpc calls
Also fuse retry logics into concurrent.Job Add the retry entries to LMSamplingUsage. PiperOrigin-RevId: 715098479
1 parent 8c1f445 commit 9d2a50b

8 files changed

+397
-105
lines changed

langfun/core/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272

7373
# Concurrent execute a function with parallel inputs with inheriting current
7474
# context's defaults and overrides.
75+
from langfun.core.concurrent import RetryEntry
7576
from langfun.core.concurrent import concurrent_execute
7677
from langfun.core.concurrent import concurrent_map
7778
from langfun.core.concurrent import with_context_access

langfun/core/concurrent.py

+184-64
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515

1616
import abc
1717
import collections
18+
from collections.abc import Mapping
1819
import concurrent.futures
1920
import dataclasses
2021
import io
2122
import random
2223
import sys
2324
import threading
2425
import time
25-
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
26+
from typing import Annotated, Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
2627

2728
from langfun.core import component
2829
import pyglove as pg
@@ -143,43 +144,33 @@ def with_retry(
143144
A function with the same signature of the input function, with the retry
144145
capability.
145146
"""
146-
rand = random if seed is None else random.Random(seed)
147147

148-
def _func(*args, **kwargs) -> Any:
149-
def base_interval() -> int:
150-
if isinstance(retry_interval, tuple):
151-
return rand.randint(retry_interval[0], retry_interval[1])
152-
else:
153-
assert isinstance(retry_interval, int)
154-
return retry_interval
155-
156-
def next_wait_interval(attempt: int) -> float:
157-
if not exponential_backoff:
158-
attempt = 1
159-
return min(max_retry_interval, base_interval() * (2 ** (attempt - 1)))
160-
161-
wait_intervals = []
162-
errors = []
163-
while True:
164-
with pg.catch_errors(retry_on_errors) as error_context:
165-
return func(*args, **kwargs)
148+
def _func(*args, **kwargs):
149+
job = Job(
150+
func,
151+
args,
152+
kwargs,
153+
retry_on_errors=retry_on_errors,
154+
max_attempts=max_attempts,
155+
retry_interval=retry_interval,
156+
exponential_backoff=exponential_backoff,
157+
max_retry_interval=max_retry_interval,
158+
seed=seed,
159+
)
160+
job()
161+
if job.error:
162+
raise job.error
163+
return job.result
166164

167-
# Branch when errors are met for retry.
168-
errors.append(error_context.error)
169-
if len(errors) < max_attempts:
170-
wait_interval = next_wait_interval(len(errors))
171-
wait_intervals.append(wait_interval)
165+
return _func
172166

173-
pg.logging.warning(
174-
f'Calling {func!r} encountered {error_context.error!r} '
175-
f'(attempts={len(errors)}), retrying in {wait_interval} seconds...'
176-
)
177167

178-
time.sleep(wait_interval)
179-
else:
180-
raise RetryError(func, errors, wait_intervals)
168+
class RetryEntry(pg.Object):
169+
"""Retry entry."""
181170

182-
return _func
171+
call_interval: float
172+
error: BaseException | None = None
173+
wait_interval: float = 0.
183174

184175

185176
def concurrent_execute(
@@ -197,6 +188,7 @@ def concurrent_execute(
197188
retry_interval: int | tuple[int, int] = (5, 60),
198189
exponential_backoff: bool = True,
199190
max_retry_interval: int = 300,
191+
return_jobs: bool = False,
200192
) -> list[Any]:
201193
"""Executes a function concurrently under current component context.
202194
@@ -220,32 +212,52 @@ def concurrent_execute(
220212
max_retry_interval: The max retry interval in seconds. This is useful when
221213
the retry interval is exponential, to avoid the wait time to grow
222214
exponentially.
215+
return_jobs: If True, return a list of `Job` objects. Otherwise, return a
216+
list of outputs.
223217
224218
Returns:
225219
A list of ouputs. Each is the return value of `func` based on the input
226220
value. Order is preserved.
227221
"""
228-
if retry_on_errors is not None:
229-
func = with_retry(
230-
func,
231-
retry_on_errors,
232-
max_attempts=max_attempts,
233-
retry_interval=retry_interval,
234-
exponential_backoff=exponential_backoff,
235-
max_retry_interval=max_retry_interval,
222+
jobs = []
223+
for inputs in parallel_inputs:
224+
jobs.append(
225+
Job(
226+
func,
227+
(inputs,),
228+
retry_on_errors=retry_on_errors,
229+
max_attempts=max_attempts,
230+
retry_interval=retry_interval,
231+
exponential_backoff=exponential_backoff,
232+
max_retry_interval=max_retry_interval,
233+
)
236234
)
237235

238236
# NOTE(daiyip): when executor is not specified and max_worker is 1,
239237
# we don't need to create a executor pool. Instead, the inputs will be
240238
# processed by the user function in sequence within the current thread.
241239
if executor is None and max_workers == 1:
242-
return [func(i) for i in parallel_inputs]
240+
for job in jobs:
241+
job()
242+
if job.error:
243+
raise job.error
244+
return jobs if return_jobs else [job.result for job in jobs]
243245

244246
shutdown_after_finish = executor is None
245247
executor = _executor_pool.executor_from(executor, max_workers=max_workers)
246248

247249
try:
248-
return list(executor.map(with_context_access(func), parallel_inputs))
250+
executed_jobs = list(
251+
executor.map(
252+
lambda job: job(), [with_context_access(job) for job in jobs]
253+
)
254+
)
255+
for job in executed_jobs:
256+
if job.error:
257+
raise job.error
258+
return (
259+
executed_jobs if return_jobs else [job.result for job in executed_jobs]
260+
)
249261
finally:
250262
if shutdown_after_finish:
251263
# Do not wait threads to finish if they are timed out.
@@ -257,9 +269,61 @@ class Job:
257269
"""Thread pool job."""
258270

259271
func: Callable[[Any], Any]
260-
arg: Any
272+
args: Sequence[Any] = ()
273+
kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
274+
_: dataclasses.KW_ONLY
275+
261276
result: Any = pg.MISSING_VALUE
262-
error: BaseException | None = None
277+
error: Annotated[
278+
BaseException | None,
279+
'The non-retryable error encountered during the job execution.',
280+
] = None
281+
retry_entries: Annotated[
282+
Sequence[RetryEntry], 'Records of retry attempts.'
283+
] = dataclasses.field(default_factory=list)
284+
285+
retry_on_errors: Annotated[
286+
Sequence[Type[BaseException] | str],
287+
(
288+
'A sequence of exception types or tuples of exception type and error '
289+
'messages (described in regular expression) as the desired exception '
290+
'types to retry.'
291+
),
292+
] = ()
293+
max_attempts: Annotated[
294+
int, 'Max number of attempts if an error to retry is encountered.'
295+
] = 5
296+
retry_interval: Annotated[
297+
int | tuple[int, int],
298+
(
299+
'The (base) retry interval in seconds. If a tuple, the retry '
300+
'interval will be randomly chosen between the first and the second '
301+
'element of the tuple.'
302+
),
303+
] = (5, 60)
304+
exponential_backoff: Annotated[
305+
bool,
306+
(
307+
'If True, exponential wait time will be applied on top of the base '
308+
'retry interval.'
309+
),
310+
] = True
311+
max_retry_interval: Annotated[
312+
int,
313+
(
314+
'The max retry interval in seconds. This is useful when the retry '
315+
'interval is exponential, to avoid the wait time to grow '
316+
'exponentially.'
317+
),
318+
] = 300
319+
seed: Annotated[
320+
int | None,
321+
(
322+
'Random seed to generate retry interval. If None, the seed will be'
323+
' determined based on current time.'
324+
),
325+
] = None
326+
263327
timeit: pg.object_utils.TimeIt = dataclasses.field(
264328
default_factory=lambda: pg.object_utils.TimeIt('job')
265329
)
@@ -269,14 +333,70 @@ def elapse(self) -> float:
269333
"""Returns the running time in seconds since the job get started."""
270334
return self.timeit.elapse
271335

272-
def __call__(self) -> Any:
336+
def _retry_call(self) -> 'Job':
337+
"""Retries func call on args."""
338+
rand = random if self.seed is None else random.Random(self.seed)
339+
340+
def base_interval() -> int:
341+
if isinstance(self.retry_interval, tuple):
342+
return rand.randint(*self.retry_interval)
343+
else:
344+
assert isinstance(self.retry_interval, int)
345+
return self.retry_interval
346+
347+
def next_wait_interval(attempt: int) -> float:
348+
if not self.exponential_backoff:
349+
attempt = 1
350+
return min(
351+
self.max_retry_interval, base_interval() * (2 ** (attempt - 1))
352+
)
353+
354+
retry_entries = []
355+
wait_interval = 0
356+
while True:
357+
with pg.catch_errors(self.retry_on_errors) as error_context:
358+
begin_time = time.time()
359+
self.result = self.func(*self.args, **self.kwargs)
360+
361+
end_time = time.time()
362+
retry_entries.append(RetryEntry(
363+
call_interval=end_time - begin_time,
364+
wait_interval=wait_interval,
365+
error=error_context.error,
366+
))
367+
if error_context.error is None:
368+
self.retry_entries = retry_entries
369+
return self
370+
371+
# Branch when errors are met for retry.
372+
if len(retry_entries) < self.max_attempts:
373+
wait_interval = next_wait_interval(len(retry_entries))
374+
375+
pg.logging.warning(
376+
f'Calling {self.func!r} encountered {error_context.error!r} '
377+
f'(attempts={len(retry_entries)}), retrying in '
378+
f'{wait_interval} seconds...'
379+
)
380+
381+
time.sleep(wait_interval)
382+
else:
383+
errors = [e.error for e in retry_entries]
384+
# First wait interval is 0.
385+
wait_intervals = [e.wait_interval for e in retry_entries[1:]]
386+
raise RetryError(self.func, errors, wait_intervals)
387+
388+
def __call__(self) -> 'Job':
389+
if getattr(self, '_has_call', False):
390+
raise ValueError('Job can only be called once.')
391+
self._has_call = True
273392
try:
274393
with self.timeit:
275-
self.result = self.func(self.arg)
276-
return self.result
394+
if self.retry_on_errors:
395+
return self._retry_call()
396+
self.result = self.func(*self.args, **self.kwargs)
277397
except BaseException as e: # pylint: disable=broad-exception-caught
278398
self.error = e
279-
return e
399+
return self
280400

281401
def mark_canceled(self, error: BaseException) -> None:
282402
"""Marks the job as canceled."""
@@ -537,7 +657,8 @@ def concurrent_map(
537657
max_attempts: int = 5,
538658
retry_interval: int | tuple[int, int] = (5, 60),
539659
exponential_backoff: bool = True,
540-
) -> Iterator[tuple[Any, Any, BaseException | None]]:
660+
return_jobs: bool = False,
661+
) -> Iterator[Any]:
541662
"""Maps inputs to outptus via func concurrently under current context.
542663
543664
Args:
@@ -580,9 +701,10 @@ def concurrent_map(
580701
of the tuple.
581702
exponential_backoff: If True, exponential wait time will be applied on top
582703
of the base retry interval.
704+
return_jobs: If True, the returned iterator will emit `Job` objects.
583705
584706
Yields:
585-
An iterator of (input, output, error).
707+
An iterator of (input, output, error) or Job object.
586708
587709
Raises:
588710
Exception: Errors that are not in `silence_on_errors` or `retry_on_errors`,
@@ -592,15 +714,6 @@ def concurrent_map(
592714
"""
593715
# Internal usage logging.
594716

595-
if retry_on_errors:
596-
func = with_retry(
597-
func,
598-
retry_on_errors,
599-
max_attempts=max_attempts,
600-
retry_interval=retry_interval,
601-
exponential_backoff=exponential_backoff,
602-
)
603-
604717
status_fn = status_fn or (lambda p: { # pylint: disable=g-long-lambda
605718
'Succeeded': '%.2f%% (%d/%d)' % (
606719
p.success_rate * 100, p.succeeded, p.completed),
@@ -615,7 +728,14 @@ def concurrent_map(
615728
pending_futures = []
616729
total = 0
617730
for inputs in parallel_inputs:
618-
job = Job(func, inputs)
731+
job = Job(
732+
func,
733+
(inputs,),
734+
retry_on_errors=retry_on_errors,
735+
max_attempts=max_attempts,
736+
retry_interval=retry_interval,
737+
exponential_backoff=exponential_backoff,
738+
)
619739
future = executor.submit(
620740
with_context_access(job),
621741
)
@@ -668,7 +788,7 @@ def update_progress_bar(progress: Progress) -> None:
668788
silence_on_errors and isinstance(job.error, silence_on_errors)):
669789
raise job.error # pylint: disable=g-doc-exception
670790

671-
yield job.arg, job.result, job.error
791+
yield job if return_jobs else job.args[0], job.result, job.error
672792
progress.update(job)
673793
update_progress_bar(progress)
674794
ProgressBar.refresh()
@@ -689,7 +809,7 @@ def update_progress_bar(progress: Progress) -> None:
689809
if job.error is not None and not (
690810
silence_on_errors and isinstance(job.error, silence_on_errors)):
691811
raise job.error # pylint: disable=g-doc-exception
692-
yield job.arg, job.result, job.error
812+
yield job if return_jobs else job.args[0], job.result, job.error
693813
progress.update(job)
694814
update_progress_bar(progress)
695815
completed_batch.add(future)
@@ -712,7 +832,7 @@ def update_progress_bar(progress: Progress) -> None:
712832
and isinstance(job.error, silence_on_errors)):
713833
raise job.error # pylint: disable=g-doc-exception
714834

715-
yield job.arg, job.result, job.error
835+
yield job.args[0], job.result, job.error
716836
progress.update(job)
717837
update_progress_bar(progress)
718838
else:

0 commit comments

Comments
 (0)