1- import sys
1+ import random
2+ from pathlib import Path
23
34from typing import Any , Dict , List , Optional
45from pathlib import Path
56
6- import absl #noqa: F401
7+ import absl
78import numpy as np
89import yaml
910
11+ from more_itertools import before_and_after
12+
1013from utils import EventTime
1114from workload import (
1215 Workload ,
2225from .base_workload_loader import BaseWorkloadLoader
2326
2427
28+ """
29+ - [ ] Release policy based on workload
30+ - [ ] Fix current time setting
31+ - [ ] Configure deadline variance
32+ - [ ] Configure release policy
33+ """
34+
35+
2536class TpchLoader (BaseWorkloadLoader ):
2637 """Loads the TPCH trace from the provided file
27-
38+
2839 Args:
2940 path (`str`): Path to a YAML file specifying the TPC-H query DAGs
30- _flags (`absl.flags`): The flags used to initialize the app, if any
41+ flags (`absl.flags`): The flags used to initialize the app, if any
3142 """
32- def __init__ (self , path : str , _flags : Optional ["absl.flags" ] = None ) -> None :
33- if _flags :
34- self ._loop_timeout = _flags .loop_timeout
35- self ._workload_profile_path = _flags .workload_profile_path
43+
44+ def __init__ (self , path : str , flags : "absl.flags" ) -> None :
45+ self ._flags = flags
46+ self ._rng_seed = flags .random_seed
47+ self ._rng = random .Random (self ._rng_seed )
48+ self ._loop_timeout = flags .loop_timeout
49+ self ._num_queries = flags .tpch_num_queries
50+ self ._dataset_size = flags .tpch_dataset_size
51+ if flags .workload_profile_path :
52+ self ._workload_profile_path = str (
53+ Path (flags .workload_profile_path ) / f"{ self ._dataset_size } g"
54+ )
3655 else :
37- self ._loop_timeout = EventTime (time = sys .maxsize , unit = EventTime .Unit .US )
3856 self ._workload_profile_path = "./profiles/workload/tpch/decima/2g"
57+ self ._workload_update_interval = EventTime (10 , EventTime .Unit .US )
58+ release_policy = self ._get_release_policy ()
59+ self ._release_times = release_policy .get_release_times (
60+ completion_time = EventTime (self ._flags .loop_timeout , EventTime .Unit .US )
61+ )
3962
4063 with open (path , "r" ) as f :
4164 workload_data = yaml .safe_load (f )
@@ -51,25 +74,72 @@ def __init__(self, path: str, _flags: Optional["absl.flags"] = None) -> None:
5174 )
5275 job_graphs [query_name ] = job_graph
5376
54- workload = Workload .from_job_graphs (job_graphs )
55- workload .populate_task_graphs (completion_time = self ._loop_timeout )
56- self ._workloads = iter ([workload ])
77+ self ._job_graphs = job_graphs
78+ self ._workload = Workload .empty (flags )
5779
80+ def _get_release_policy (self ):
81+ release_policy_args = {}
82+ if self ._flags .override_release_policy == "periodic" :
83+ release_policy_args = {
84+ "period" : EventTime (
85+ self ._flags .override_arrival_period , EventTime .Unit .US
86+ ),
87+ }
88+ elif self ._flags .override_release_policy == "fixed" :
89+ release_policy_args = {
90+ "period" : EventTime (
91+ self ._flags .override_arrival_period , EventTime .Unit .US
92+ ),
93+ "num_invocations" : self ._flags .override_num_invocation ,
94+ }
95+ elif self ._flags .override_release_policy == "poisson" :
96+ release_policy_args = {
97+ "rate" : self ._flags .override_poisson_arrival_rate ,
98+ "num_invocations" : self ._flags .override_num_invocation ,
99+ }
100+ elif self ._flags .override_release_policy == "gamma" :
101+ release_policy_args = {
102+ "rate" : self ._flags .override_poisson_arrival_rate ,
103+ "num_invocations" : self ._flags .override_num_invocation ,
104+ "coefficient" : self ._flags .override_gamma_coefficient ,
105+ }
106+ elif self ._flags .override_release_policy == "fixed_gamma" :
107+ release_policy_args = {
108+ "variable_arrival_rate" : self ._flags .override_poisson_arrival_rate ,
109+ "base_arrival_rate" : self ._flags .override_base_arrival_rate ,
110+ "num_invocations" : self ._flags .override_num_invocation ,
111+ "coefficient" : self ._flags .override_gamma_coefficient ,
112+ }
113+ else :
114+ raise NotImplementedError (
115+ f"Release policy { self ._flags .override_release_policy } not implemented."
116+ )
58117
59- @staticmethod
60- def make_job_graph (query_name : str , graph : List [Dict [str , Any ]], profile_path : str ) -> JobGraph :
61- job_graph = JobGraph (
62- name = query_name ,
118+ # Check that none of the arg values are None
119+ assert all ([val is not None for val in release_policy_args .values ()])
63120
64- # TODO: make configurable
65- release_policy = JobGraph . ReleasePolicy . fixed (
66- period = EventTime ( 30 , EventTime . Unit . US ),
67- num_invocations = 10 ,
68- start = EventTime ( 0 , EventTime . Unit . US ) ,
121+ # Construct the release policy
122+ start_time = EventTime (
123+ time = self . _rng . randint (
124+ self . _flags . randomize_start_time_min ,
125+ self . _flags . randomize_start_time_max ,
69126 ),
127+ unit = EventTime .Unit .US ,
128+ )
129+ release_policy = getattr (
130+ JobGraph .ReleasePolicy , self ._flags .override_release_policy
131+ )(start = start_time , rng_seed = self ._rng_seed , ** release_policy_args )
132+
133+ return release_policy
70134
135+ @staticmethod
136+ def make_job_graph (
137+ query_name : str , graph : List [Dict [str , Any ]], profile_path : str
138+ ) -> JobGraph :
139+ job_graph = JobGraph (
140+ name = query_name ,
71141 # TODO: make configurable
72- deadline_variance = (0 , 0 ),
142+ deadline_variance = (10 , 50 ),
73143 )
74144
75145 query_num = int (query_name [1 :])
@@ -100,9 +170,10 @@ def make_job_graph(query_name: str, graph: List[Dict[str, Any]], profile_path: s
100170
101171 return job_graph
102172
103-
104173 @staticmethod
105- def load_query_profile (profiler_data : Dict [int , Dict [str , Any ]], query_name : str , node_name : str ) -> WorkProfile :
174+ def load_query_profile (
175+ profiler_data : Dict [int , Dict [str , Any ]], query_name : str , node_name : str
176+ ) -> WorkProfile :
106177 profile = profiler_data [int (node_name )]
107178 resources = Resources (
108179 resource_vector = {
@@ -122,9 +193,10 @@ def load_query_profile(profiler_data: Dict[int, Dict[str, Any]], query_name: str
122193 execution_strategies = execution_strategies ,
123194 )
124195
125-
126196 @staticmethod
127- def get_profiler_data_for_query (profile_path : str , query_num : int ) -> Dict [int , Dict [str , Any ]]:
197+ def get_profiler_data_for_query (
198+ profile_path : str , query_num : int
199+ ) -> Dict [int , Dict [str , Any ]]:
128200 def pre_process_task_duration (task_duration ):
129201 # remove fresh durations from first wave
130202 clean_first_wave = {}
@@ -152,7 +224,6 @@ def pre_process_task_duration(task_duration):
152224 for n in range (num_nodes ):
153225 task_duration = task_durations [n ]
154226 e = next (iter (task_duration ["first_wave" ]))
155- # NOTE: somehow only picks the first element {2: [n_tasks_in_ms]}
156227
157228 num_tasks = len (task_duration ["first_wave" ][e ]) + len (
158229 task_duration ["rest_wave" ][e ]
@@ -176,12 +247,21 @@ def pre_process_task_duration(task_duration):
176247
177248 return stage_info
178249
179-
180250 def get_next_workload (self , current_time : EventTime ) -> Optional [Workload ]:
181- try :
182- return next (self ._workloads )
183- except StopIteration :
251+ if len (self ._release_times ) == 0 :
184252 return None
253+ to_release , self ._release_times = before_and_after (lambda t : t <= current_time + self ._workload_update_interval , self ._release_times )
254+ for t in to_release :
255+ query_num = self ._rng .randint (1 , len (self ._job_graphs ))
256+ query_name = f"Q{ query_num } "
257+ job_graph = self ._job_graphs [query_name ]
258+ task_graph = job_graph .get_next_task_graph (
259+ start_time = t ,
260+ _flags = self ._flags ,
261+ )
262+ self ._workload .add_task_graph (task_graph )
263+ self ._release_times = list (self ._release_times )
264+ return self ._workload
185265
186266
187267class SetWithCount (object ):
@@ -208,4 +288,3 @@ def remove(self, item):
208288 self .set [item ] -= 1
209289 if self .set [item ] == 0 :
210290 del self .set [item ]
211-
0 commit comments