Skip to content

Commit b1942ea

Browse files
committed
add shutdown rpc method
1 parent 48a9711 commit b1942ea

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

rpc/launch_tpch_queries.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from workload import JobGraph
1212
from utils import EventTime
1313
from data.tpch_loader import make_release_policy
14+
from rpc import erdos_scheduler_pb2
15+
from rpc import erdos_scheduler_pb2_grpc
16+
17+
import grpc
1418

1519

1620
def map_dataset_to_deadline(dataset_size):
@@ -49,13 +53,14 @@ def launch_query(query_number, args):
4953
# )
5054

5155
try:
52-
cmd = ' '.join(cmd)
56+
cmd = " ".join(cmd)
5357
print("Launching:", cmd)
54-
subprocess.Popen(
58+
p = subprocess.Popen(
5559
cmd,
5660
shell=True,
5761
)
5862
print("Query launched successfully.")
63+
return p
5964
except Exception as e:
6065
print(f"Error launching query: {e}")
6166

@@ -187,7 +192,9 @@ def main():
187192
default=1234,
188193
help="RNG seed for generating inter-arrival periods and picking DAGs (default: 1234)",
189194
)
190-
parser.add_argument("--queries", type=int, nargs='+', help="Launch specific queries")
195+
parser.add_argument(
196+
"--queries", type=int, nargs="+", help="Launch specific queries"
197+
)
191198

192199
args = parser.parse_args()
193200

@@ -197,7 +204,7 @@ def main():
197204
os.environ["TPCH_INPUT_DATA_DIR"] = str(args.tpch_spark_path.resolve() / "dbgen")
198205

199206
if args.queries:
200-
assert(len(args.queries) == args.num_queries)
207+
assert len(queries) == args.num_queries
201208

202209
rng = random.Random(args.rng_seed)
203210

@@ -206,6 +213,7 @@ def main():
206213
print("Release times:", release_times)
207214

208215
# Launch queries
216+
ps = []
209217
inter_arrival_times = [release_times[0].time]
210218
for i in range(len(release_times) - 1):
211219
inter_arrival_times.append(release_times[i + 1].time - release_times[i].time)
@@ -215,14 +223,22 @@ def main():
215223
query_number = args.queries[i]
216224
else:
217225
query_number = rng.randint(1, 22)
218-
launch_query(query_number, args)
226+
ps.append(launch_query(query_number, args))
219227
print(
220228
"Current time: ",
221229
time.strftime("%Y-%m-%d %H:%M:%S"),
222230
" launching query: ",
223231
query_number,
224232
)
225233

234+
for p in ps:
235+
p.wait()
236+
237+
channel = grpc.insecure_channel("localhost:50051")
238+
stub = erdos_scheduler_pb2_grpc.SchedulerServiceStub(channel)
239+
response = stub.Shutdown(erdos_scheduler_pb2.Empty())
240+
channel.close()
241+
226242

227243
if __name__ == "__main__":
228244
main()

rpc/protos/rpc/erdos_scheduler.proto

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ service SchedulerService {
4747

4848
/// Notifies the Scheduler that a Task from a particular TaskGraph has completed.option
4949
rpc NotifyTaskCompletion(NotifyTaskCompletionRequest) returns (NotifyTaskCompletionResponse) {}
50+
51+
rpc Shutdown(Empty) returns (Empty) {}
5052
}
5153

5254

@@ -201,3 +203,5 @@ message GetPlacementsResponse {
201203
string message = 3;
202204
bool terminate = 4; // terminate the task graph
203205
}
206+
207+
message Empty {}

rpc/service.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def canonical_task_id(self, stage_id: int):
136136

137137

138138
class Servicer(erdos_scheduler_pb2_grpc.SchedulerServiceServicer):
139-
def __init__(self) -> None:
139+
def __init__(self, server) -> None:
140+
self._server = server
141+
140142
# Override some flags
141143

142144
# Enable orchestrated mode
@@ -230,6 +232,7 @@ def __init__(self) -> None:
230232
self._registered_app_drivers = (
231233
{}
232234
) # Spark driver id differs from taskgraph name (application id)
235+
self._shutdown = False
233236
self._lock = threading.Lock()
234237

235238
super().__init__()
@@ -357,6 +360,10 @@ async def DeregisterDriver(self, request, context):
357360

358361
msg = f"[{stime}] Successfully de-registered driver with id {request.id} for task graph {task_graph_name}"
359362
self._logger.info(msg)
363+
364+
if len(self._registered_app_drivers) == 0 and self._shutdown:
365+
await self._server.stop(0)
366+
360367
return erdos_scheduler_pb2.DeregisterDriverResponse(
361368
success=True,
362369
message=msg,
@@ -755,6 +762,10 @@ async def NotifyTaskCompletion(self, request, context):
755762
message=msg,
756763
)
757764

765+
async def Shutdown(self, request, context):
766+
self._shutdown = True
767+
return erdos_scheduler_pb2.Empty()
768+
758769
async def _tick_simulator(self):
759770
while True:
760771
with self._lock:
@@ -819,7 +830,7 @@ def main(_argv):
819830
loop = asyncio.get_event_loop()
820831

821832
server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=FLAGS.max_workers))
822-
servicer = Servicer()
833+
servicer = Servicer(server)
823834
erdos_scheduler_pb2_grpc.add_SchedulerServiceServicer_to_server(servicer, server)
824835
server.add_insecure_port(f"[::]:{FLAGS.port}")
825836

0 commit comments

Comments
 (0)