Skip to content

Commit 9f6aa8b

Browse files
committed
release time handling on workload
1 parent c756d17 commit 9f6aa8b

File tree

4 files changed

+128
-32
lines changed

4 files changed

+128
-32
lines changed

configs/tpch_test.conf

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,9 @@
1515
--execution_mode=replay
1616
--replay_trace=tpch
1717

18+
# Release time config.
19+
--override_arrival_period=5
20+
--override_num_invocation=10
21+
1822
# TPCH flags
1923
--tpch_query_dag_spec=profiles/workload/tpch/queries.yaml

data/tpch_loader.py

Lines changed: 111 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
import sys
1+
import random
2+
from pathlib import Path
23

34
from typing import Any, Dict, List, Optional
45
from pathlib import Path
56

6-
import absl #noqa: F401
7+
import absl
78
import numpy as np
89
import yaml
910

11+
from more_itertools import before_and_after
12+
1013
from utils import EventTime
1114
from workload import (
1215
Workload,
@@ -22,20 +25,40 @@
2225
from .base_workload_loader import BaseWorkloadLoader
2326

2427

28+
"""
29+
- [ ] Release policy based on workload
30+
- [ ] Fix current time setting
31+
- [ ] Configure deadline variance
32+
- [ ] Configure release policy
33+
"""
34+
35+
2536
class TpchLoader(BaseWorkloadLoader):
2637
"""Loads the TPCH trace from the provided file
27-
38+
2839
Args:
2940
path (`str`): Path to a YAML file specifying the TPC-H query DAGs
30-
_flags (`absl.flags`): The flags used to initialize the app, if any
41+
flags (`absl.flags`): The flags used to initialize the app, if any
3142
"""
32-
def __init__(self, path: str, _flags: Optional["absl.flags"] = None) -> None:
33-
if _flags:
34-
self._loop_timeout = _flags.loop_timeout
35-
self._workload_profile_path = _flags.workload_profile_path
43+
44+
def __init__(self, path: str, flags: "absl.flags") -> None:
45+
self._flags = flags
46+
self._rng_seed = flags.random_seed
47+
self._rng = random.Random(self._rng_seed)
48+
self._loop_timeout = flags.loop_timeout
49+
self._num_queries = flags.tpch_num_queries
50+
self._dataset_size = flags.tpch_dataset_size
51+
if flags.workload_profile_path:
52+
self._workload_profile_path = str(
53+
Path(flags.workload_profile_path) / f"{self._dataset_size}g"
54+
)
3655
else:
37-
self._loop_timeout = EventTime(time=sys.maxsize, unit=EventTime.Unit.US)
3856
self._workload_profile_path = "./profiles/workload/tpch/decima/2g"
57+
self._workload_update_interval = EventTime(10, EventTime.Unit.US)
58+
release_policy = self._get_release_policy()
59+
self._release_times = release_policy.get_release_times(
60+
completion_time=EventTime(self._flags.loop_timeout, EventTime.Unit.US)
61+
)
3962

4063
with open(path, "r") as f:
4164
workload_data = yaml.safe_load(f)
@@ -51,25 +74,72 @@ def __init__(self, path: str, _flags: Optional["absl.flags"] = None) -> None:
5174
)
5275
job_graphs[query_name] = job_graph
5376

54-
workload = Workload.from_job_graphs(job_graphs)
55-
workload.populate_task_graphs(completion_time=self._loop_timeout)
56-
self._workloads = iter([workload])
77+
self._job_graphs = job_graphs
78+
self._workload = Workload.empty(flags)
5779

80+
def _get_release_policy(self):
81+
release_policy_args = {}
82+
if self._flags.override_release_policy == "periodic":
83+
release_policy_args = {
84+
"period": EventTime(
85+
self._flags.override_arrival_period, EventTime.Unit.US
86+
),
87+
}
88+
elif self._flags.override_release_policy == "fixed":
89+
release_policy_args = {
90+
"period": EventTime(
91+
self._flags.override_arrival_period, EventTime.Unit.US
92+
),
93+
"num_invocations": self._flags.override_num_invocation,
94+
}
95+
elif self._flags.override_release_policy == "poisson":
96+
release_policy_args = {
97+
"rate": self._flags.override_poisson_arrival_rate,
98+
"num_invocations": self._flags.override_num_invocation,
99+
}
100+
elif self._flags.override_release_policy == "gamma":
101+
release_policy_args = {
102+
"rate": self._flags.override_poisson_arrival_rate,
103+
"num_invocations": self._flags.override_num_invocation,
104+
"coefficient": self._flags.override_gamma_coefficient,
105+
}
106+
elif self._flags.override_release_policy == "fixed_gamma":
107+
release_policy_args = {
108+
"variable_arrival_rate": self._flags.override_poisson_arrival_rate,
109+
"base_arrival_rate": self._flags.override_base_arrival_rate,
110+
"num_invocations": self._flags.override_num_invocation,
111+
"coefficient": self._flags.override_gamma_coefficient,
112+
}
113+
else:
114+
raise NotImplementedError(
115+
f"Release policy {self._flags.override_release_policy} not implemented."
116+
)
58117

59-
@staticmethod
60-
def make_job_graph(query_name: str, graph: List[Dict[str, Any]], profile_path: str) -> JobGraph:
61-
job_graph = JobGraph(
62-
name=query_name,
118+
# Check that none of the arg values are None
119+
assert all([val is not None for val in release_policy_args.values()])
63120

64-
# TODO: make configurable
65-
release_policy=JobGraph.ReleasePolicy.fixed(
66-
period=EventTime(30, EventTime.Unit.US),
67-
num_invocations=10,
68-
start=EventTime(0, EventTime.Unit.US),
121+
# Construct the release policy
122+
start_time = EventTime(
123+
time=self._rng.randint(
124+
self._flags.randomize_start_time_min,
125+
self._flags.randomize_start_time_max,
69126
),
127+
unit=EventTime.Unit.US,
128+
)
129+
release_policy = getattr(
130+
JobGraph.ReleasePolicy, self._flags.override_release_policy
131+
)(start=start_time, rng_seed=self._rng_seed, **release_policy_args)
132+
133+
return release_policy
70134

135+
@staticmethod
136+
def make_job_graph(
137+
query_name: str, graph: List[Dict[str, Any]], profile_path: str
138+
) -> JobGraph:
139+
job_graph = JobGraph(
140+
name=query_name,
71141
# TODO: make configurable
72-
deadline_variance=(0,0),
142+
deadline_variance=(10, 50),
73143
)
74144

75145
query_num = int(query_name[1:])
@@ -100,9 +170,10 @@ def make_job_graph(query_name: str, graph: List[Dict[str, Any]], profile_path: s
100170

101171
return job_graph
102172

103-
104173
@staticmethod
105-
def load_query_profile(profiler_data: Dict[int, Dict[str, Any]], query_name: str, node_name: str) -> WorkProfile:
174+
def load_query_profile(
175+
profiler_data: Dict[int, Dict[str, Any]], query_name: str, node_name: str
176+
) -> WorkProfile:
106177
profile = profiler_data[int(node_name)]
107178
resources = Resources(
108179
resource_vector={
@@ -122,9 +193,10 @@ def load_query_profile(profiler_data: Dict[int, Dict[str, Any]], query_name: str
122193
execution_strategies=execution_strategies,
123194
)
124195

125-
126196
@staticmethod
127-
def get_profiler_data_for_query(profile_path: str, query_num: int) -> Dict[int, Dict[str, Any]]:
197+
def get_profiler_data_for_query(
198+
profile_path: str, query_num: int
199+
) -> Dict[int, Dict[str, Any]]:
128200
def pre_process_task_duration(task_duration):
129201
# remove fresh durations from first wave
130202
clean_first_wave = {}
@@ -152,7 +224,6 @@ def pre_process_task_duration(task_duration):
152224
for n in range(num_nodes):
153225
task_duration = task_durations[n]
154226
e = next(iter(task_duration["first_wave"]))
155-
# NOTE: somehow only picks the first element {2: [n_tasks_in_ms]}
156227

157228
num_tasks = len(task_duration["first_wave"][e]) + len(
158229
task_duration["rest_wave"][e]
@@ -176,12 +247,21 @@ def pre_process_task_duration(task_duration):
176247

177248
return stage_info
178249

179-
180250
def get_next_workload(self, current_time: EventTime) -> Optional[Workload]:
181-
try:
182-
return next(self._workloads)
183-
except StopIteration:
251+
if len(self._release_times) == 0:
184252
return None
253+
to_release, self._release_times = before_and_after(lambda t: t <= current_time + self._workload_update_interval, self._release_times)
254+
for t in to_release:
255+
query_num = self._rng.randint(1, len(self._job_graphs))
256+
query_name = f"Q{query_num}"
257+
job_graph = self._job_graphs[query_name]
258+
task_graph = job_graph.get_next_task_graph(
259+
start_time=t,
260+
_flags=self._flags,
261+
)
262+
self._workload.add_task_graph(task_graph)
263+
self._release_times = list(self._release_times)
264+
return self._workload
185265

186266

187267
class SetWithCount(object):
@@ -208,4 +288,3 @@ def remove(self, item):
208288
self.set[item] -= 1
209289
if self.set[item] == 0:
210290
del self.set[item]
211-

main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,17 @@
137137
"./profiles/workload/tpch/queries.yaml",
138138
"Path to a YAML file specifying the TPC-H query DAGs",
139139
)
140+
flags.DEFINE_integer(
141+
"tpch_num_queries",
142+
50,
143+
"Number of TPC-H queries to run",
144+
)
145+
flags.DEFINE_enum(
146+
"tpch_dataset_size",
147+
"50",
148+
["2", "50", "100", "250", "500"],
149+
"Size of the TPC-H dataset to use",
150+
)
140151

141152
# AlibabaLoader related flags.
142153
flags.DEFINE_integer(
@@ -644,6 +655,7 @@ def main(args):
644655
elif FLAGS.replay_trace == "tpch":
645656
workload_loader = TpchLoader(
646657
path=FLAGS.tpch_query_dag_spec,
658+
flags=FLAGS,
647659
)
648660
else:
649661
raise NotImplementedError(

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ cplex
1414
pre-commit
1515
black
1616
isort
17+
more-itertools

0 commit comments

Comments
 (0)