Skip to content

Commit c6cc492

Browse files
authored
p2p backend: heart beat (#223)
python grpc server doesn't have an easy way to check whether a client is dropped or not. This makes it hard to manage endpoint in the flame channel concept and hence can cause deadlock situations where an endpoint (working as grpc server) waits for data to arrive from an endpoint which is dropped. As a workaround, a heart beat is sent periodically; if it is not received for a certain duration, it is assumed that the client is dropped. The grpc server cleans up resources allocated for the endpoint, which prevents deadlock.
1 parent 9af0480 commit c6cc492

File tree

1 file changed

+104
-36
lines changed
  • lib/python/flame/backend

1 file changed

+104
-36
lines changed

lib/python/flame/backend/p2p.py

Lines changed: 104 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import asyncio
1919
import logging
2020
import socket
21+
import time
2122
from typing import AsyncIterable, Iterable, Tuple
2223

2324
import grpc
@@ -34,7 +35,9 @@
3435
logger = logging.getLogger(__name__)
3536

3637
ENDPOINT_TOKEN_LEN = 2
37-
HEART_BEAT_DURATION = 30
38+
HEART_BEAT_DURATION = 30 # for metaserver
39+
QUEUE_WAIT_TIME = 10 # 10 second
40+
EXTRA_WAIT_TIME = QUEUE_WAIT_TIME / 2
3841

3942

4043
class BackendServicer(msg_pb2_grpc.BackendRouteServicer):
@@ -53,11 +56,12 @@ async def send_data(self, req_iter: AsyncIterable[msg_pb2.Data],
5356
unused_context) -> msg_pb2.BackendID:
5457
"""Implement a method to handle send_data request stream."""
5558
# From server perspective, the server receives data from client.
56-
logger.warn(f"address of req_iter = {hex(id(req_iter))}")
57-
logger.warn(f"address of context = {hex(id(unused_context))}")
58-
5959
async for msg in req_iter:
60-
await self.p2pbe._handle_data(msg)
60+
self.p2pbe._set_heart_beat(msg.end_id)
61+
# if the message is not a heart beat message,
62+
# the message needs to be processed.
63+
if msg.seqno != -1 or msg.eom is False or msg.channel_name != "":
64+
await self.p2pbe._handle_data(msg)
6165

6266
return msg_pb2.BackendID(end_id=self.p2pbe._id)
6367

@@ -67,6 +71,7 @@ async def recv_data(self, req: msg_pb2.BackendID,
6771
# From server perspective, the server sends data to client.
6872
dck_task = asyncio.create_task(self._dummy_context_keeper(context))
6973
self.p2pbe._set_writer(req.end_id, context)
74+
self.p2pbe._set_heart_beat(req.end_id)
7075

7176
await dck_task
7277

@@ -101,18 +106,15 @@ def __init__(self):
101106
self._backend = None
102107

103108
self._endpoints = {}
104-
self.end_to_rwop = {}
105109
self._channels = {}
110+
self._livecheck = {}
106111

107112
with background_thread_loop() as loop:
108113
self._loop = loop
109114

110115
async def _init_loop_stuff():
111116
self._eventq = asyncio.Queue()
112117

113-
coro = self._monitor_end_termination()
114-
_ = asyncio.create_task(coro)
115-
116118
coro = self._setup_server()
117119
_ = asyncio.create_task(coro)
118120

@@ -124,20 +126,6 @@ async def _init_loop_stuff():
124126

125127
self._initialized = True
126128

127-
async def _monitor_end_termination(self):
128-
# TODO: handle how to monitor grpc channel status
129-
# while True:
130-
# for end_id, (reader, _) in list(self._endpoints.items()):
131-
# if not reader.at_eof():
132-
# continue
133-
134-
# # connection is closed
135-
# await self._eventq.put((BackendEvent.DISCONNECT, end_id))
136-
# await self._close(end_id)
137-
138-
# await asyncio.sleep(1)
139-
pass
140-
141129
async def _setup_server(self):
142130
server = grpc.aio.server()
143131
msg_pb2_grpc.add_BackendRouteServicer_to_server(
@@ -222,7 +210,7 @@ async def _register_channel(self, channel) -> None:
222210
raise SystemError('registration failure')
223211

224212
for endpoint in meta_resp.endpoints:
225-
logger.info(f"endpoint: {endpoint}")
213+
logger.debug(f"connecting to endpoint: {endpoint}")
226214
await self._connect_and_notify(endpoint, channel.name())
227215

228216
while True:
@@ -249,7 +237,7 @@ async def notify(self, channel_name, notify_type, stub, grpc_ch) -> bool:
249237
try:
250238
resp = await stub.notify_end(msg)
251239
except grpc.aio.AioRpcError:
252-
logger.warn("can't proceed as grpc channel is unavailable")
240+
logger.debug("can't proceed as grpc channel is unavailable")
253241
return False
254242

255243
logger.debug(f"resp = {resp}")
@@ -370,33 +358,56 @@ async def _broadcast_task(self, channel):
370358
await self.send_chunks(end_id, channel.name(), data)
371359
except Exception as ex:
372360
ex_name = type(ex).__name__
373-
logger.warn(f"An exception of type {ex_name} occurred")
361+
logger.debug(f"An exception of type {ex_name} occurred")
374362

375-
await self._eventq.put((BackendEvent.DISCONNECT, end_id))
376-
del self._endpoints[end_id]
363+
self._cleanup_end(end_id)
377364
txq.task_done()
378365

379366
async def _unicast_task(self, channel, end_id):
380367
txq = channel.get_txq(end_id)
381368

382369
while True:
383-
data = await txq.get()
370+
try:
371+
data = await asyncio.wait_for(txq.get(), QUEUE_WAIT_TIME)
372+
except asyncio.TimeoutError:
373+
if end_id not in self._endpoints:
374+
logger.debug(f"end_id {end_id} not in _endpoints")
375+
break
376+
377+
_, _, clt_writer, _ = self._endpoints[end_id]
378+
if clt_writer is None:
379+
continue
380+
381+
def heart_beat():
382+
# the condition for heart beat message:
383+
# channel_name = ""
384+
# seqno = -1
385+
# eom = True
386+
msg = msg_pb2.Data(end_id=self._id,
387+
channel_name="",
388+
payload=b'',
389+
seqno=-1,
390+
eom=True)
391+
392+
yield msg
393+
394+
await clt_writer.send_data(heart_beat())
395+
continue
384396

385397
try:
386398
await self.send_chunks(end_id, channel.name(), data)
387399
except Exception as ex:
388400
ex_name = type(ex).__name__
389-
logger.warn(f"An exception of type {ex_name} occurred")
401+
logger.debug(f"An exception of type {ex_name} occurred")
390402

391-
await self._eventq.put((BackendEvent.DISCONNECT, end_id))
392-
del self._endpoints[end_id]
403+
self._cleanup_end(end_id)
393404
txq.task_done()
394405
# This break ends a tx_task for end_id
395406
break
396407

397408
txq.task_done()
398409

399-
logger.warn(f"unicast task for {end_id} terminated")
410+
logger.debug(f"unicast task for {end_id} terminated")
400411

401412
async def send_chunks(self, other: str, ch_name: str, data: bytes) -> None:
402413
"""Send data chunks to an end."""
@@ -446,7 +457,7 @@ async def _rx_task(self, end_id: str, reader) -> None:
446457
try:
447458
msg = await reader.read()
448459
except grpc.aio.AioRpcError:
449-
logger.info(f"AioRpcError occurred for {end_id}")
460+
logger.debug(f"AioRpcError occurred for {end_id}")
450461
break
451462

452463
if msg == grpc.aio.EOF:
@@ -456,7 +467,64 @@ async def _rx_task(self, end_id: str, reader) -> None:
456467

457468
# grpc channel is unavailable
458469
# so, clean up an entry for end_id from _endpoints dict
470+
self._cleanup_end(end_id)
471+
472+
logger.debug(f"cleaned up {end_id} info from _endpoints")
473+
474+
async def _cleanup_end(self, end_id):
459475
await self._eventq.put((BackendEvent.DISCONNECT, end_id))
460-
del self._endpoints[end_id]
476+
if end_id in self._endpoints:
477+
del self._endpoints[end_id]
478+
if end_id in self._livecheck:
479+
self._livecheck[end_id].cancel()
480+
del self._livecheck[end_id]
481+
482+
def _set_heart_beat(self, end_id) -> None:
483+
logger.debug(f"heart beat data message for {end_id}")
484+
if end_id not in self._livecheck:
485+
timeout = QUEUE_WAIT_TIME + 5
486+
self._livecheck[end_id] = LiveChecker(self, end_id, timeout)
487+
488+
self._livecheck[end_id].reset()
489+
490+
491+
class LiveChecker:
492+
"""LiveChecker class."""
493+
494+
def __init__(self, p2pbe, end_id, timeout) -> None:
495+
"""Initialize an instance."""
496+
self._p2pbe = p2pbe
497+
self._end_id = end_id
498+
self._timeout = timeout
499+
500+
self._task = None
501+
self._last_reset = time.time()
502+
503+
async def _check(self):
504+
await asyncio.sleep(self._timeout)
505+
await self._p2pbe._cleanup_end(self._end_id)
506+
logger.debug(f"live check timeout occured for {self._end_id}")
507+
508+
def cancel(self) -> None:
509+
"""Cancel a task."""
510+
if self._task is None or self._task.cancelled():
511+
return
512+
513+
self._task.cancel()
514+
logger.debug(f"cancelled task for {self._end_id}")
515+
516+
def reset(self) -> None:
517+
"""Reset a task."""
518+
now = time.time()
519+
if now - self._last_reset < EXTRA_WAIT_TIME / 2:
520+
# this is to prevent too frequent reset
521+
logger.debug("this is to prevent too frequent reset")
522+
return
523+
524+
self._last_reset = now
525+
526+
self.cancel()
527+
528+
self._task = asyncio.ensure_future(self._check())
461529

462-
logger.info(f"cleaned up {end_id} info from _endpoints")
530+
logger.debug(f"set future for {self._end_id}")

0 commit comments

Comments
 (0)