Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create and properly shutdown a threadpool per transform iterator. #556

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions ml_metrics/_src/chainables/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(
self.agg_state = state
self._with_agg_state = state and (with_agg_state or with_agg_result)
self._with_agg_result = with_agg_result
self._thread_pool = None

def iterate_fn(
input_iterator: Iterable[tree.MapLikeTree | None] = (),
Expand All @@ -168,13 +169,18 @@ def iterate_fn(
iterator = iter_utils.piter(
iterate_fn,
input_iterators=[self.input_iterator],
thread_pool=_get_thread_pool(),
thread_pool=self._get_thread_pool(),
max_parallism=tree_fn.num_threads,
)
self._iterator = iter(iterator)
self._prev_ticker = time.time()
self.batch_index = 0

def _get_thread_pool(self) -> futures.ThreadPoolExecutor:
if self._thread_pool is None:
self._thread_pool = futures.ThreadPoolExecutor()
return self._thread_pool

@property
def name(self) -> str:
return self._tree_fn.name
Expand Down Expand Up @@ -231,6 +237,8 @@ def __next__(self) -> _ValueT:
f' {self.batch_index} batches, returned a type of'
f' {type(returned)}.',
)
if self._thread_pool is not None:
self._thread_pool.shutdown(wait=False)
raise StopIteration(returned) if returned else e

def __iter__(self):
Expand All @@ -256,11 +264,6 @@ def clear_cache():
lazy_fns.clear_cache()


@functools.lru_cache(maxsize=1)
def _get_thread_pool():
return futures.ThreadPoolExecutor(thread_name_prefix='chainable_mt')


@dataclasses.dataclass(frozen=True, kw_only=True)
class CombinedTreeFn:
"""Combining multiple transforms into concrete functions.
Expand Down
10 changes: 6 additions & 4 deletions ml_metrics/_src/chainables/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,13 @@ def test_sharded_sequence_data_source_make(self):
def test_sharded_sequence_data_source_multithread(self):
ds = io.SequenceDataSource(range(3))
num_threads = 2
p = transform.TreeTransform(num_threads=num_threads).data_source(
ds
)
p = transform.TreeTransform(num_threads=num_threads).data_source(ds)
expected = [0, 1, 2]
self.assertEqual(expected, list(p.make()))
it = p.make().iterate()
self.assertEqual(expected, list(it))
assert it._thread_pool is not None
self.assertNotEmpty(it._thread_pool._threads)
self.assertTrue(all(not t.is_alive() for t in it._thread_pool._threads))

def test_sharded_sequence_data_source_resume(self):
ds = io.SequenceDataSource(range(3))
Expand Down
37 changes: 19 additions & 18 deletions ml_metrics/_src/utils/iter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ class IteratorQueue(IterableQueue[_ValueT]):
_dequeue_lock: threading.Condition
# The lock here protects the access to all states below.
_states_lock: threading.RLock
_enqueue_cnt: int
_dequeue_cnt: int
_enqueue_start: int
_enqueue_stop: int
_parallelism: int
_returned: list[Any]
_exception: Exception | None
Expand Down Expand Up @@ -390,8 +390,8 @@ def __init__(
self._exception = None
self._exhausted = False
self._parallelism = parallelism
self._enqueue_cnt = 0
self._dequeue_cnt = 0
self._enqueue_start = 0
self._enqueue_stop = 0
self._run_enqueue = True
self.ignore_error = ignore_error

Expand Down Expand Up @@ -419,7 +419,7 @@ def _set_exhausted(self):
'chainable: "%s" dequeue exhausted with_exception=%s, remaining %d',
self.name,
self.exception is not None,
(self._enqueue_cnt - self._dequeue_cnt),
(self._enqueue_start - self._enqueue_stop),
)
self._exhausted = True
self._dequeue_lock.notify_all()
Expand All @@ -439,7 +439,7 @@ def enqueue_done(self) -> bool:
# If parallelism is not set, it means the no enqueuer has started yet.
if not self._parallelism:
return False
return self._enqueue_cnt == self._dequeue_cnt == self._parallelism
return self._enqueue_start == self._enqueue_stop == self._parallelism

@property
def returned(self) -> list[Any]:
Expand Down Expand Up @@ -563,18 +563,18 @@ def put(self, value: _ValueT) -> None:

def _start_enqueue(self):
with self._states_lock:
self._enqueue_cnt += 1
self._parallelism = max(self._parallelism, self._enqueue_cnt)
self._enqueue_start += 1
self._parallelism = max(self._parallelism, self._enqueue_start)

def _stop_enqueue(self, *values):
"""Stops enqueueing and records the returned values."""
with self._states_lock:
self._dequeue_cnt += 1
self._enqueue_stop += 1
self._returned.extend(values)
logging.debug(
'chainable: "%s" enqueue stop, remaining %d, with_exception=%s',
self.name,
self._dequeue_cnt,
(self._enqueue_start - self._enqueue_stop),
self.exception is not None,
)
if self.enqueue_done:
Expand All @@ -586,7 +586,7 @@ def _stop_enqueue(self, *values):
def stop_enqueue(self):
self._run_enqueue = False
with self._states_lock:
self._dequeue_cnt = self._enqueue_cnt
self._enqueue_stop = self._enqueue_start
with self._enqueue_lock:
self._enqueue_lock.notify_all()
with self._dequeue_lock:
Expand All @@ -602,6 +602,7 @@ def enqueue_from_iterator(self, iterator: Iterable[_ValueT]):
self.put(next(iterator))
except StopIteration as e:
self._stop_enqueue(*e.args)
logging.debug('chainable: "%s" enqueue stop', self.name)
return
except Exception as e: # pylint: disable=broad-exception-caught
e.add_note(f'Exception during enqueueing "{self.name}".')
Expand Down Expand Up @@ -759,9 +760,9 @@ def __aiter__(self):
def _get_thread_pool(
thread_pool: futures.ThreadPoolExecutor | None = None,
) -> futures.ThreadPoolExecutor:
if isinstance(thread_pool, futures.ThreadPoolExecutor):
return thread_pool
return futures.ThreadPoolExecutor(thread_name_prefix='piter')
if thread_pool is None:
thread_pool = futures.ThreadPoolExecutor(thread_name_prefix='piter')
return thread_pool


def _get_iterate_fn(
Expand Down Expand Up @@ -815,9 +816,9 @@ def piter(
name='piter_input_q',
parallelism=len(input_iterators),
)
pool = _get_thread_pool(thread_pool)
thread_pool = _get_thread_pool(thread_pool)
for iterator in input_iterators:
pool.submit(input_iterable.enqueue_from_iterator, iterator)
thread_pool.submit(input_iterable.enqueue_from_iterator, iterator)
if iterator_fn is None:
assert input_iterable is not None
return input_iterable
Expand All @@ -826,10 +827,10 @@ def piter(
buffer_size, name='piter_output_q', parallelism=max_parallism
)
if max_parallism:
pool = _get_thread_pool(thread_pool)
thread_pool = _get_thread_pool(thread_pool)
for _ in range(max_parallism):
it = _get_iterate_fn(iterator_fn, input_iterable)
pool.submit(output_queue.enqueue_from_iterator, it)
thread_pool.submit(output_queue.enqueue_from_iterator, it)
return output_queue
# In process mode when max_parallism is 0.
return _get_iterate_fn(iterator_fn, input_iterable)
Expand Down
23 changes: 13 additions & 10 deletions ml_metrics/_src/utils/iter_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,17 +576,19 @@ def test_piter_iterate_fn_only(self):
def foo():
yield from input_iter

pit = iter_utils.piter(foo, max_parallism=256)
actual = list(pit)
self.assertCountEqual(list(range(n)), actual)
with futures.ThreadPoolExecutor() as thread_pool:
pit = iter_utils.piter(foo, max_parallism=1, thread_pool=thread_pool)
actual = list(pit)
self.assertCountEqual(list(range(n)), actual)
self.assertIsInstance(pit, iter_utils.IteratorQueue)

def test_piter_multiple_iterators(self):
n, m = 256, 2
assert n % m == 0
inputs = [range_with_return(m, 0.3) for _ in range(int(n / m))]
pit = iter_utils.piter(input_iterators=inputs)
actual = list(pit)
with futures.ThreadPoolExecutor() as thread_pool:
pit = iter_utils.piter(input_iterators=inputs, thread_pool=thread_pool)
actual = list(pit)
expected = list(itt.chain(*[range(m) for _ in range(int(n / m))]))
self.assertCountEqual(expected, actual)
self.assertIsInstance(pit, iter_utils.IteratorQueue)
Expand All @@ -597,9 +599,9 @@ def test_piter_multiple_iterators_concurrent_dequeue(self):
inputs = [range(m) for _ in range(int(n / m))]
# This is to test when not all enqueuer are started before the dequeuer is,
# which could cause premature StopIteration controled by `enqueue_done`.
pool = futures.ThreadPoolExecutor(max_workers=1)
pit = iter_utils.piter(input_iterators=inputs, thread_pool=pool)
actual = list(pit)
with futures.ThreadPoolExecutor(max_workers=1) as pool:
pit = iter_utils.piter(input_iterators=inputs, thread_pool=pool)
actual = list(pit)
expected = list(itt.chain(*inputs))
self.assertCountEqual(expected, actual)
self.assertIsInstance(pit, iter_utils.IteratorQueue)
Expand All @@ -626,8 +628,9 @@ def foo(x):
return x + 1

n = 256
it = iter_utils.pmap(foo, range(n), max_parallism=256)
actual = list(it)
with futures.ThreadPoolExecutor() as pool:
it = iter_utils.pmap(foo, range(n), max_parallism=256, thread_pool=pool)
actual = list(it)
self.assertCountEqual(list(range(1, n + 1)), actual)
self.assertIsInstance(it, iter_utils.IteratorQueue)

Expand Down