Skip to content

Commit cdffb5b

Browse files
committed
Support async put for vineyard client.
Signed-off-by: Ye Cao <[email protected]>
1 parent 0f78867 commit cdffb5b

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

python/vineyard/core/client.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
session: int = None,
169169
username: str = None,
170170
password: str = None,
171+
max_workers: int = 8,
171172
config: str = None,
172173
):
173174
"""Connects to the vineyard IPC socket and RPC socket.
@@ -211,6 +212,8 @@ def __init__(
211212
is enabled.
212213
password: Optional, the required password of vineyardd when authentication
213214
is enabled.
215+
max_workers: Optional, the maximum number of threads that can be used to
216+
asynchronously put objects to vineyard. Default is 8.
214217
config: Optional, can either be a path to a YAML configuration file or
215218
a path to a directory containing the default config file
216219
`vineyard-config.yaml`. Also, the environment variable
@@ -290,6 +293,9 @@ def __init__(
290293
except VineyardException:
291294
continue
292295

296+
self._max_workers = max_workers
297+
self._put_thread_pool = None
298+
293299
self._spread = False
294300
self._compression = True
295301
if self._ipc_client is None and self._rpc_client is None:
@@ -347,6 +353,13 @@ def rpc_client(self) -> RPCClient:
347353
assert self._rpc_client is not None, "RPC client is not available."
348354
return self._rpc_client
349355

356+
@property
357+
def put_thread_pool(self) -> ThreadPoolExecutor:
358+
"""Lazy initialization of the thread pool for asynchronous put."""
359+
if self._put_thread_pool is None:
360+
self._put_thread_pool = ThreadPoolExecutor(max_workers=self._max_workers)
361+
return self._put_thread_pool
362+
350363
def has_ipc_client(self):
351364
return self._ipc_client is not None
352365

@@ -820,8 +833,7 @@ def get(
820833
):
821834
return get(self, object_id, name, resolver, fetch, **kwargs)
822835

823-
@_apply_docstring(put)
824-
def put(
836+
def _put_internal(
825837
self,
826838
value: Any,
827839
builder: Optional[BuilderContext] = None,
@@ -858,6 +870,32 @@ def put(
858870
self.compression = previous_compression_state
859871
return put(self, value, builder, persist, name, **kwargs)
860872

873+
@_apply_docstring(put)
874+
def put(
875+
self,
876+
value: Any,
877+
builder: Optional[BuilderContext] = None,
878+
persist: bool = False,
879+
name: Optional[str] = None,
880+
as_async: bool = False,
881+
**kwargs,
882+
):
883+
if as_async:
884+
def _default_callback(future):
885+
try:
886+
result = future.result()
887+
print(f"Successfully put object {result}", flush=True)
888+
except Exception as e:
889+
print(f"Failed to put object: {e}", flush=True)
890+
891+
thread_pool = self.put_thread_pool
892+
result = thread_pool.submit(
893+
self._put_internal, value, builder, persist, name, **kwargs
894+
)
895+
result.add_done_callback(_default_callback)
896+
return result
897+
return self._put_internal(value, builder, persist, name, **kwargs)
898+
861899
@contextlib.contextmanager
862900
def with_compression(self, enabled: bool = True):
863901
"""Disable compression for the following put operations."""

python/vineyard/core/tests/test_client.py

+39
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import itertools
2020
import multiprocessing
2121
import random
22+
import time
2223
import traceback
2324
from concurrent.futures import ThreadPoolExecutor
25+
from threading import Thread
2426

2527
import numpy as np
2628

@@ -317,3 +319,40 @@ def test_memory_trim(vineyard_client):
317319

318320
# there might be some fragmentation overhead
319321
assert parse_shared_memory_usage() <= original_memory_usage + 2 * data_kbytes
322+
323+
324+
def test_async_put_and_get(vineyard_client):
325+
data = np.ones((100, 100, 16))
326+
object_nums = 100
327+
328+
def producer(vineyard_client):
329+
start_time = time.time()
330+
client = vineyard_client.fork()
331+
for i in range(object_nums):
332+
client.put(data, name="test" + str(i), as_async=True, persist=True)
333+
client.put(data)
334+
end_time = time.time()
335+
print("Producer time: ", end_time - start_time)
336+
337+
def consumer(vineyard_client):
338+
start_time = time.time()
339+
client = vineyard_client.fork()
340+
for i in range(object_nums):
341+
object_id = client.get_name(name="test" + str(i), wait=True)
342+
client.get(object_id)
343+
end_time = time.time()
344+
print("Consumer time: ", end_time - start_time)
345+
346+
producer_thread = Thread(target=producer, args=(vineyard_client,))
347+
consumer_thread = Thread(target=consumer, args=(vineyard_client,))
348+
349+
start_time = time.time()
350+
351+
producer_thread.start()
352+
consumer_thread.start()
353+
354+
producer_thread.join()
355+
consumer_thread.join()
356+
357+
end_time = time.time()
358+
print("Total time: ", end_time - start_time)

0 commit comments

Comments
 (0)