1+ import sys
12import random
2- from pathlib import Path
33
44from typing import Any , Dict , List , Optional
55from pathlib import Path
88import numpy as np
99import yaml
1010
11- from more_itertools import before_and_after
12-
1311from utils import EventTime
1412from workload import (
1513 Workload ,
2523from .base_workload_loader import BaseWorkloadLoader
2624
2725
28- """
29- - [ ] Release policy based on workload
30- - [ ] Fix current time setting
31- - [ ] Configure deadline variance
32- - [ ] Configure release policy
33- """
34-
35-
3626class TpchLoader (BaseWorkloadLoader ):
3727 """Loads the TPCH trace from the provided file
3828
@@ -45,36 +35,46 @@ def __init__(self, path: str, flags: "absl.flags") -> None:
4535 self ._flags = flags
4636 self ._rng_seed = flags .random_seed
4737 self ._rng = random .Random (self ._rng_seed )
48- self ._loop_timeout = flags .loop_timeout
49- self ._num_queries = flags .tpch_num_queries
50- self ._dataset_size = flags .tpch_dataset_size
51- if flags .workload_profile_path :
52- self ._workload_profile_path = str (
53- Path (flags .workload_profile_path ) / f"{ self ._dataset_size } g"
54- )
38+ if flags .workload_update_interval > 0 :
39+ self ._workload_update_interval = flags .workload_update_interval
5540 else :
56- self ._workload_profile_path = "./profiles/workload/tpch/decima/2g"
57- self ._workload_update_interval = EventTime (10 , EventTime .Unit .US )
41+ self ._workload_update_interval = EventTime (sys .maxsize , EventTime .Unit .US )
5842 release_policy = self ._get_release_policy ()
5943 self ._release_times = release_policy .get_release_times (
6044 completion_time = EventTime (self ._flags .loop_timeout , EventTime .Unit .US )
6145 )
46+ self ._current_release_pointer = 0
47+
48+ # Set up query name to job graph mapping
6249
6350 with open (path , "r" ) as f :
6451 workload_data = yaml .safe_load (f )
6552
53+ if flags .workload_profile_path :
54+ workload_profile_path = str (
55+ Path (flags .workload_profile_path ) / f"{ flags .s .tpch_dataset_size } g"
56+ )
57+ else :
58+ workload_profile_path = "./profiles/workload/tpch/decima/2g"
59+
6660 job_graphs = {}
6761 for query in workload_data ["graphs" ]:
6862 query_name = query ["name" ]
6963 graph = query ["graph" ]
7064 job_graph = TpchLoader .make_job_graph (
7165 query_name = query_name ,
7266 graph = graph ,
73- profile_path = self ._workload_profile_path ,
67+ profile_path = workload_profile_path ,
68+ deadline_variance = (
69+ int (flags .min_deadline_variance ),
70+ int (flags .max_deadline_variance ),
71+ )
7472 )
7573 job_graphs [query_name ] = job_graph
7674
7775 self ._job_graphs = job_graphs
76+
77+ # Initialize workload
7878 self ._workload = Workload .empty (flags )
7979
8080 def _get_release_policy (self ):
@@ -134,12 +134,11 @@ def _get_release_policy(self):
134134
135135 @staticmethod
136136 def make_job_graph (
137- query_name : str , graph : List [Dict [str , Any ]], profile_path : str
137+ query_name : str , graph : List [Dict [str , Any ]], profile_path : str , deadline_variance = ( 0 , 0 ),
138138 ) -> JobGraph :
139139 job_graph = JobGraph (
140140 name = query_name ,
141- # TODO: make configurable
142- deadline_variance = (10 , 50 ),
141+ deadline_variance = deadline_variance ,
143142 )
144143
145144 query_num = int (query_name [1 :])
@@ -248,9 +247,24 @@ def pre_process_task_duration(task_duration):
248247 return stage_info
249248
250249 def get_next_workload (self , current_time : EventTime ) -> Optional [Workload ]:
251- if len (self ._release_times ) == 0 :
250+ to_release = []
251+ while (
252+ self ._current_release_pointer < len (self ._release_times )
253+ and self ._release_times [self ._current_release_pointer ]
254+ <= current_time + self ._workload_update_interval
255+ ):
256+ to_release .append (
257+ self ._release_times [self ._current_release_pointer ]
258+ )
259+ self ._current_release_pointer += 1
260+
261+ if (
262+ self ._current_release_pointer >= len (self ._release_times )
263+ and len (to_release ) == 0
264+ ):
265+ # Nothing left to release
252266 return None
253- to_release , self . _release_times = before_and_after ( lambda t : t <= current_time + self . _workload_update_interval , self . _release_times )
267+
254268 for t in to_release :
255269 query_num = self ._rng .randint (1 , len (self ._job_graphs ))
256270 query_name = f"Q{ query_num } "
@@ -260,7 +274,7 @@ def get_next_workload(self, current_time: EventTime) -> Optional[Workload]:
260274 _flags = self ._flags ,
261275 )
262276 self ._workload .add_task_graph (task_graph )
263- self . _release_times = list ( self . _release_times )
277+
264278 return self ._workload
265279
266280
0 commit comments