From 533c14ac0ead521764b8466632eab852e0e891ac Mon Sep 17 00:00:00 2001 From: Gene Huang Date: Thu, 13 Feb 2025 18:59:22 -0800 Subject: [PATCH] Create and properly shutdown a threadpool per transform iterator. PiperOrigin-RevId: 726723702 --- ml_metrics/_src/chainables/transform.py | 15 ++++---- ml_metrics/_src/chainables/transform_test.py | 10 +++--- ml_metrics/_src/utils/iter_utils.py | 37 ++++++++++---------- ml_metrics/_src/utils/iter_utils_test.py | 23 ++++++------ 4 files changed, 47 insertions(+), 38 deletions(-) diff --git a/ml_metrics/_src/chainables/transform.py b/ml_metrics/_src/chainables/transform.py index abe775ad..1a1262ae 100644 --- a/ml_metrics/_src/chainables/transform.py +++ b/ml_metrics/_src/chainables/transform.py @@ -152,6 +152,7 @@ def __init__( self.agg_state = state self._with_agg_state = state and (with_agg_state or with_agg_result) self._with_agg_result = with_agg_result + self._thread_pool = None def iterate_fn( input_iterator: Iterable[tree.MapLikeTree | None] = (), @@ -168,13 +169,18 @@ def iterate_fn( iterator = iter_utils.piter( iterate_fn, input_iterators=[self.input_iterator], - thread_pool=_get_thread_pool(), + thread_pool=self._get_thread_pool(), max_parallism=tree_fn.num_threads, ) self._iterator = iter(iterator) self._prev_ticker = time.time() self.batch_index = 0 + def _get_thread_pool(self) -> futures.ThreadPoolExecutor: + if self._thread_pool is None: + self._thread_pool = futures.ThreadPoolExecutor() + return self._thread_pool + @property def name(self) -> str: return self._tree_fn.name @@ -231,6 +237,8 @@ def __next__(self) -> _ValueT: f' {self.batch_index} batches, returned a type of' f' {type(returned)}.', ) + if self._thread_pool is not None: + self._thread_pool.shutdown(wait=False) raise StopIteration(returned) if returned else e def __iter__(self): @@ -256,11 +264,6 @@ def clear_cache(): lazy_fns.clear_cache() -@functools.lru_cache(maxsize=1) -def _get_thread_pool(): - return futures.ThreadPoolExecutor(thread_name_prefix='chainable_mt') - - @dataclasses.dataclass(frozen=True, kw_only=True) class CombinedTreeFn: """Combining multiple transforms into concrete functions. diff --git a/ml_metrics/_src/chainables/transform_test.py b/ml_metrics/_src/chainables/transform_test.py index 2a529473..1206f975 100644 --- a/ml_metrics/_src/chainables/transform_test.py +++ b/ml_metrics/_src/chainables/transform_test.py @@ -199,11 +199,13 @@ def test_sharded_sequence_data_source_make(self): def test_sharded_sequence_data_source_multithread(self): ds = io.SequenceDataSource(range(3)) num_threads = 2 - p = transform.TreeTransform(num_threads=num_threads).data_source( - ds - ) + p = transform.TreeTransform(num_threads=num_threads).data_source(ds) expected = [0, 1, 2] - self.assertEqual(expected, list(p.make())) + it = p.make().iterate() + self.assertEqual(expected, list(it)) + assert it._thread_pool is not None + self.assertNotEmpty(it._thread_pool._threads) + self.assertTrue(all(not t.is_alive() for t in it._thread_pool._threads)) def test_sharded_sequence_data_source_resume(self): ds = io.SequenceDataSource(range(3)) diff --git a/ml_metrics/_src/utils/iter_utils.py b/ml_metrics/_src/utils/iter_utils.py index 4949884d..e9aec4a4 100644 --- a/ml_metrics/_src/utils/iter_utils.py +++ b/ml_metrics/_src/utils/iter_utils.py @@ -356,8 +356,8 @@ class IteratorQueue(IterableQueue[_ValueT]): _dequeue_lock: threading.Condition # The lock here protects the access to all states below. _states_lock: threading.RLock - _enqueue_cnt: int - _dequeue_cnt: int + _enqueue_start: int + _enqueue_stop: int _parallelism: int _returned: list[Any] _exception: Exception | None @@ -390,8 +390,8 @@ def __init__( self._exception = None self._exhausted = False self._parallelism = parallelism - self._enqueue_cnt = 0 - self._dequeue_cnt = 0 + self._enqueue_start = 0 + self._enqueue_stop = 0 self._run_enqueue = True self.ignore_error = ignore_error @@ -419,7 +419,7 @@ def _set_exhausted(self): 'chainable: "%s" dequeue exhausted with_exception=%s, remaining %d', self.name, self.exception is not None, - (self._enqueue_cnt - self._dequeue_cnt), + (self._enqueue_start - self._enqueue_stop), ) self._exhausted = True self._dequeue_lock.notify_all() @@ -439,7 +439,7 @@ def enqueue_done(self) -> bool: # If parallelism is not set, it means the no enqueuer has started yet. if not self._parallelism: return False - return self._enqueue_cnt == self._dequeue_cnt == self._parallelism + return self._enqueue_start == self._enqueue_stop == self._parallelism @property def returned(self) -> list[Any]: @@ -563,18 +563,18 @@ def put(self, value: _ValueT) -> None: def _start_enqueue(self): with self._states_lock: - self._enqueue_cnt += 1 - self._parallelism = max(self._parallelism, self._enqueue_cnt) + self._enqueue_start += 1 + self._parallelism = max(self._parallelism, self._enqueue_start) def _stop_enqueue(self, *values): """Stops enqueueing and records the returned values.""" with self._states_lock: - self._dequeue_cnt += 1 + self._enqueue_stop += 1 self._returned.extend(values) logging.debug( 'chainable: "%s" enqueue stop, remaining %d, with_exception=%s', self.name, - self._dequeue_cnt, + (self._enqueue_start - self._enqueue_stop), self.exception is not None, ) if self.enqueue_done: @@ -586,7 +586,7 @@ def _stop_enqueue(self, *values): def stop_enqueue(self): self._run_enqueue = False with self._states_lock: - self._dequeue_cnt = self._enqueue_cnt + self._enqueue_stop = self._enqueue_start with self._enqueue_lock: self._enqueue_lock.notify_all() with self._dequeue_lock: @@ -602,6 +602,7 @@ def enqueue_from_iterator(self, iterator: Iterable[_ValueT]): self.put(next(iterator)) except StopIteration as e: self._stop_enqueue(*e.args) + logging.debug('chainable: "%s" enqueue stop', self.name) return except Exception as e: # pylint: disable=broad-exception-caught e.add_note(f'Exception during enqueueing "{self.name}".') @@ -759,9 +760,9 @@ def __aiter__(self): def _get_thread_pool( thread_pool: futures.ThreadPoolExecutor | None = None, ) -> futures.ThreadPoolExecutor: - if isinstance(thread_pool, futures.ThreadPoolExecutor): - return thread_pool - return futures.ThreadPoolExecutor(thread_name_prefix='piter') + if thread_pool is None: + thread_pool = futures.ThreadPoolExecutor(thread_name_prefix='piter') + return thread_pool def _get_iterate_fn( @@ -815,9 +816,9 @@ def piter( name='piter_input_q', parallelism=len(input_iterators), ) - pool = _get_thread_pool(thread_pool) + thread_pool = _get_thread_pool(thread_pool) for iterator in input_iterators: - pool.submit(input_iterable.enqueue_from_iterator, iterator) + thread_pool.submit(input_iterable.enqueue_from_iterator, iterator) if iterator_fn is None: assert input_iterable is not None return input_iterable @@ -826,10 +827,10 @@ def piter( buffer_size, name='piter_output_q', parallelism=max_parallism ) if max_parallism: - pool = _get_thread_pool(thread_pool) + thread_pool = _get_thread_pool(thread_pool) for _ in range(max_parallism): it = _get_iterate_fn(iterator_fn, input_iterable) - pool.submit(output_queue.enqueue_from_iterator, it) + thread_pool.submit(output_queue.enqueue_from_iterator, it) return output_queue # In process mode when max_parallism is 0. return _get_iterate_fn(iterator_fn, input_iterable) diff --git a/ml_metrics/_src/utils/iter_utils_test.py b/ml_metrics/_src/utils/iter_utils_test.py index bd2c248e..67d2d3ea 100644 --- a/ml_metrics/_src/utils/iter_utils_test.py +++ b/ml_metrics/_src/utils/iter_utils_test.py @@ -576,17 +576,19 @@ def test_piter_iterate_fn_only(self): def foo(): yield from input_iter - pit = iter_utils.piter(foo, max_parallism=256) - actual = list(pit) - self.assertCountEqual(list(range(n)), actual) + with futures.ThreadPoolExecutor() as thread_pool: + pit = iter_utils.piter(foo, max_parallism=1, thread_pool=thread_pool) + actual = list(pit) + self.assertCountEqual(list(range(n)), actual) self.assertIsInstance(pit, iter_utils.IteratorQueue) def test_piter_multiple_iterators(self): n, m = 256, 2 assert n % m == 0 inputs = [range_with_return(m, 0.3) for _ in range(int(n / m))] - pit = iter_utils.piter(input_iterators=inputs) - actual = list(pit) + with futures.ThreadPoolExecutor() as thread_pool: + pit = iter_utils.piter(input_iterators=inputs, thread_pool=thread_pool) + actual = list(pit) expected = list(itt.chain(*[range(m) for _ in range(int(n / m))])) self.assertCountEqual(expected, actual) self.assertIsInstance(pit, iter_utils.IteratorQueue) @@ -597,9 +599,9 @@ def test_piter_multiple_iterators_concurrent_dequeue(self): inputs = [range(m) for _ in range(int(n / m))] # This is to test when not all enqueuer are started before the dequeuer is, # which could cause premature StopIteration controled by `enqueue_done`. - pool = futures.ThreadPoolExecutor(max_workers=1) - pit = iter_utils.piter(input_iterators=inputs, thread_pool=pool) - actual = list(pit) + with futures.ThreadPoolExecutor(max_workers=1) as pool: + pit = iter_utils.piter(input_iterators=inputs, thread_pool=pool) + actual = list(pit) expected = list(itt.chain(*inputs)) self.assertCountEqual(expected, actual) self.assertIsInstance(pit, iter_utils.IteratorQueue) @@ -626,8 +628,9 @@ def foo(x): return x + 1 n = 256 - it = iter_utils.pmap(foo, range(n), max_parallism=256) - actual = list(it) + with futures.ThreadPoolExecutor() as pool: + it = iter_utils.pmap(foo, range(n), max_parallism=256, thread_pool=pool) + actual = list(it) self.assertCountEqual(list(range(1, n + 1)), actual) self.assertIsInstance(it, iter_utils.IteratorQueue)