Skip to content

Commit f949335

Browse files
skrawczelijahbenizzy
authored andcommitted
Adds lazy threadpool DAG parallelization
Taking inspiration from #1263, I implemented a similar adapter to how async works. We get away with this because we don't encounter SERDE boundaries. If you run the example DAG you'll see that: 1. it is parallelized as it should be 2. you can use caching and the tracking adapter Rough edges: - haven't tested this extensively, but seems to just work. - need to add tests for it & docs, etc.
1 parent 0ae54e8 commit f949335

File tree

4 files changed

+172
-0
lines changed

4 files changed

+172
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import time
2+
3+
4+
def a() -> str:
5+
print("a")
6+
time.sleep(3)
7+
return "a"
8+
9+
10+
def b() -> str:
11+
print("b")
12+
time.sleep(3)
13+
return "b"
14+
15+
16+
def c(a: str, b: str) -> str:
17+
print("c")
18+
time.sleep(3)
19+
return a + " " + b
20+
21+
22+
def d() -> str:
23+
print("d")
24+
time.sleep(3)
25+
return "d"
26+
27+
28+
def e(c: str, d: str) -> str:
29+
print("e")
30+
time.sleep(3)
31+
return c + " " + d
32+
33+
34+
def z() -> str:
35+
print("z")
36+
time.sleep(3)
37+
return "z"
38+
39+
40+
def y() -> str:
41+
print("y")
42+
time.sleep(3)
43+
return "y"
44+
45+
46+
def x(z: str, y: str) -> str:
47+
print("x")
48+
time.sleep(3)
49+
return z + " " + y
50+
51+
52+
def s(x: str, e: str) -> str:
53+
print("s")
54+
time.sleep(3)
55+
return x + " " + e
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import time
2+
3+
import my_functions
4+
5+
from hamilton import driver
6+
from hamilton.plugins import h_threadpool
7+
8+
start = time.time()
9+
adapter = h_threadpool.FutureAdapter()
10+
dr = driver.Builder().with_modules(my_functions).with_adapters(adapter).build()
11+
dr.display_all_functions("my_funtions.png")
12+
r = dr.execute("s")
13+
print("got return from dr")
14+
print(r)
15+
print("Time taken with", time.time() - start)
16+
17+
from hamilton_sdk import adapters
18+
19+
tracker = adapters.HamiltonTracker(
20+
project_id=21, # modify this as needed
21+
username="[email protected]",
22+
dag_name="with_caching",
23+
tags={"environment": "DEV", "cached": "False", "team": "MY_TEAM", "version": "1"},
24+
)
25+
26+
start = time.time()
27+
dr = (
28+
driver.Builder().with_modules(my_functions).with_adapters(tracker, adapter).with_cache().build()
29+
)
30+
dr.display_all_functions("a.png")
31+
r = dr.execute("s")
32+
print("got return from dr")
33+
print(r)
34+
print("Time taken with cold cache", time.time() - start)
35+
36+
tracker = adapters.HamiltonTracker(
37+
project_id=21, # modify this as needed
38+
username="[email protected]",
39+
dag_name="with_caching",
40+
tags={"environment": "DEV", "cached": "True", "team": "MY_TEAM", "version": "1"},
41+
)
42+
43+
start = time.time()
44+
dr = (
45+
driver.Builder().with_modules(my_functions).with_adapters(tracker, adapter).with_cache().build()
46+
)
47+
dr.display_all_functions("a.png")
48+
r = dr.execute("s")
49+
print("got return from dr")
50+
print(r)
51+
print("Time taken with warm cache", time.time() - start)
52+
53+
start = time.time()
54+
dr = driver.Builder().with_modules(my_functions).build()
55+
dr.display_all_functions("a.png")
56+
r = dr.execute("s")
57+
print("got return from dr")
58+
print(r)
59+
print("Time taken without", time.time() - start)

hamilton/plugins/h_threadpool.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from concurrent.futures import Future, ThreadPoolExecutor
2+
from typing import Any, Callable, Dict
3+
4+
from hamilton import registry
5+
6+
registry.disable_autoload()
7+
8+
from hamilton import lifecycle, node
9+
from hamilton.lifecycle import base
10+
11+
12+
def _new_fn(fn, **fn_kwargs):
13+
"""Function that runs in the thread.
14+
15+
It can recursively check for Futures because we don't have to worry about
16+
process serialization.
17+
:param fn: Function to run
18+
:param fn_kwargs: Keyword arguments to pass to the function
19+
"""
20+
for k, v in fn_kwargs.items():
21+
if isinstance(v, Future):
22+
while isinstance(v, Future):
23+
v = v.result()
24+
fn_kwargs[k] = v
25+
# execute the function once all the futures are resolved
26+
return fn(**fn_kwargs)
27+
28+
29+
class FutureAdapter(base.BaseDoRemoteExecute, lifecycle.ResultBuilder):
30+
def __init__(self, max_workers: int = None):
31+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
32+
# self.executor = ProcessPoolExecutor(max_workers=max_workers)
33+
34+
def do_remote_execute(
35+
self,
36+
*,
37+
execute_lifecycle_for_node: Callable,
38+
node: node.Node,
39+
**kwargs: Dict[str, Any],
40+
) -> Any:
41+
"""Method that is called to implement correct remote execution of hooks. This makes sure that all the pre-node and post-node hooks get executed in the remote environment which is necessary for some adapters. Node execution is called the same as before through "do_node_execute".
42+
43+
:param node: Node that is being executed
44+
:param kwargs: Keyword arguments that are being passed into the node
45+
:param execute_lifecycle_for_node: Function executing lifecycle_hooks and lifecycle_methods
46+
"""
47+
return self.executor.submit(_new_fn, execute_lifecycle_for_node, **kwargs)
48+
49+
def build_result(self, **outputs: Any) -> Any:
50+
"""Given a set of outputs, build the result.
51+
52+
:param outputs: the outputs from the execution of the graph.
53+
:return: the result of the execution of the graph.
54+
"""
55+
for k, v in outputs.items():
56+
if isinstance(v, Future):
57+
outputs[k] = v.result()
58+
return outputs

0 commit comments

Comments
 (0)