@@ -121,27 +121,55 @@ def __init__(
121121 self .epoch = epoch
122122 self .iteration = iteration
123123 self .strategy = strategy
124+ self .batch_sampler = None
124125
125126 def __iter__ (self ):
126127 return self ._get_iter ()
127128
129+ def _make_dataloader (self ) -> DataLoader :
130+ if self .batch_sampler is None :
131+ return self .dataloader
132+ else :
133+ return DataLoader (
134+ dataset = self .dataloader .dataset ,
135+ batch_sampler = self .batch_sampler ,
136+ num_workers = self .dataloader .num_workers ,
137+ collate_fn = self .dataloader .collate_fn ,
138+ pin_memory = self .dataloader .pin_memory ,
139+ timeout = self .dataloader .timeout ,
140+ worker_init_fn = self .dataloader .worker_init_fn ,
141+ multiprocessing_context = self .dataloader .multiprocessing_context ,
142+ generator = self .dataloader .generator ,
143+ prefetch_factor = self .dataloader .prefetch_factor ,
144+ persistent_workers = self .dataloader .persistent_workers ,
145+ pin_memory_device = self .dataloader .pin_memory_device ,
146+ )
147+
128148 def _get_iter (self ):
129149 finished = False
130150
131151 while not finished :
132- for batch in self .dataloader :
152+ count = 0
153+ dataloader = self ._make_dataloader ()
154+ for batch in dataloader :
133155 if finished :
134156 break
135157
136158 _logger .info (f"Training epoch { self .epoch } iteration { self .iteration } " )
137159
138160 yield batch
139-
161+
162+ count += 1
140163 if self .strategy .should_merge (self .epoch , self .iteration , False ):
141164 _logger .info (f"iteration { self .iteration } , start to merge" )
165+ assert dataloader .batch_sampler is not None
166+ if self .batch_sampler is None :
167+ self .batch_sampler = list (dataloader .batch_sampler )
168+ self .batch_sampler = self .batch_sampler [count :]
142169 finished = True
143170 self .iteration += 1
144-
171+
172+ self .batch_sampler = None
145173 if self .strategy .should_merge (self .epoch , self .iteration , True ):
146174 _logger .info (f"epoch { self .epoch } , start to merge" )
147175 finished = True
@@ -372,9 +400,14 @@ def map(
372400 epoch : int ,
373401 iteration : int ,
374402 ) -> Tuple [Dict [str , np .ndarray ], int , int ]:
375- self .learning .strategy .weight_to_params (
376- weight , self .learning .state_dict ()
377- )
403+ if len (weight ) > 0 :
404+ self .learning .strategy .weight_to_params (
405+ weight , self .learning .state_dict ()
406+ )
407+ else :
408+ weight = self .learning .strategy .params_to_weight (
409+ self .learning .state_dict ()
410+ )
378411 _logger .info (f"Round { self .round } training" )
379412 train_iter = TrainIterator (
380413 dataloader , epoch , iteration , self .learning .strategy
@@ -517,7 +550,10 @@ def reduce(
517550 self .learning .strategy .weight_to_params (
518551 weight , self .learning .state_dict ()
519552 )
520- return self .learning .state_dict ()
553+ res : Dict [str , Any ] = {"weight" : self .learning .state_dict ()}
554+ if metrics is not None :
555+ res ["metrics" ] = metrics
556+ return res
521557
522558 input_nodes : List [DataNode ] = [weight_node ]
523559 if metrics_node is not None :
@@ -541,9 +577,8 @@ def _build_graph(self) -> Tuple[List[delta.dataset.Dataset], List[GraphNode]]:
541577 iteration_node = InputGraphNode (
542578 name = "iteration" , location = DataLocation .CLIENT , default = 1
543579 )
544- weight_arr = self .strategy .params_to_weight (self .state_dict ())
545580 weight_node = InputGraphNode (
546- name = "weight_0" , location = DataLocation .SERVER , default = weight_arr
581+ name = "weight_0" , location = DataLocation .SERVER , default = np . empty ( 0 )
547582 )
548583 metrics_node = None
549584 inputs = [dataset_node , epoch_node , iteration_node , weight_node ]
0 commit comments