Skip to content

Commit d6ed963

Browse files
committed
fix existing bugs
1 parent 5feb2ca commit d6ed963

File tree

5 files changed

+34
-27
lines changed

5 files changed

+34
-27
lines changed

flgo/algorithm/afl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def project(self, p):
4747
res.append(max(p[i] + lmbd, 0))
4848
return res
4949

50-
def global_test(self, dataflag='valid'):
50+
def global_test(self, flag='valid'):
5151
"""
5252
Validate accuracies and losses on clients' local datasets
5353
:param
@@ -57,7 +57,7 @@ def global_test(self, dataflag='valid'):
5757
"""
5858
all_metrics = collections.defaultdict(list)
5959
for c in self.clients:
60-
client_metrics = c.test(self.result_model, dataflag)
60+
client_metrics = c.test(self.result_model, flag)
6161
for met_name, met_val in client_metrics.items():
6262
all_metrics[met_name].append(met_val)
6363
return all_metrics

flgo/algorithm/fedbase.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def aggregate(self, models: list, *args, **kwargs):
318318
p = [pk/sump for pk in p]
319319
return fmodule._model_sum([model_k * pk for model_k, pk in zip(models, p)])
320320

321-
def global_test(self, dataflag='valid'):
321+
def global_test(self, flag='valid'):
322322
"""
323323
Validate accuracies and losses on clients' local datasets
324324
:param
@@ -328,7 +328,7 @@ def global_test(self, dataflag='valid'):
328328
"""
329329
all_metrics = collections.defaultdict(list)
330330
for c in self.clients:
331-
client_metrics = c.test(self.model, dataflag)
331+
client_metrics = c.test(self.model, flag)
332332
for met_name, met_val in client_metrics.items():
333333
all_metrics[met_name].append(met_val)
334334
return all_metrics
@@ -345,7 +345,7 @@ def test(self, model=None, flag='test'):
345345
data = self.test_data if flag=='test' else self.valid_data
346346
if data is None: return {}
347347
else:
348-
return self.calculator.test(model, self.test_data, batch_size = self.option['test_batch_size'])
348+
return self.calculator.test(model, data, batch_size = self.option['test_batch_size'], num_workers = self.option['num_workers'], pin_memory = self.option['pin_memory'])
349349

350350
def init_algo_para(self, algo_para: dict):
351351
"""
@@ -442,6 +442,7 @@ def __init__(self, option={}):
442442
# server
443443
self.server = None
444444
# actions of different message type
445+
self.option = option
445446
self.actions = {0: self.reply}
446447

447448
def initialize(self):
@@ -481,7 +482,7 @@ def test(self, model, dataflag='valid'):
481482
"""
482483
dataset = self.train_data if dataflag=='train' else self.valid_data
483484
if dataset is not None:
484-
return self.calculator.test(model, dataset, self.test_batch_size)
485+
return self.calculator.test(model, dataset, self.test_batch_size, self.option['num_workers'])
485486
else:
486487
return {}
487488

@@ -610,7 +611,7 @@ def get_batch_data(self):
610611
try:
611612
batch_data = next(self.data_loader)
612613
except Exception as e:
613-
self.data_loader = iter(self.calculator.get_dataloader(self.train_data, batch_size=self.batch_size, num_workers=self.loader_num_workers))
614+
self.data_loader = iter(self.calculator.get_dataloader(self.train_data, batch_size=self.batch_size, num_workers=self.loader_num_workers, pin_memory=self.option['pin_memory']))
614615
batch_data = next(self.data_loader)
615616
# clear local DataLoader when finishing local training
616617
self.current_steps = (self.current_steps+1) % self.num_steps

flgo/benchmark/toolkits/cv/horizontal/image_classification.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,37 +82,36 @@ def compute_loss(self, model, data):
8282
return {'loss': loss}
8383

8484
@torch.no_grad()
85-
def test(self, model, dataset, batch_size=64, num_workers=0):
85+
def test(self, model, dataset, batch_size=64, num_workers=0, pin_memory=False):
8686
"""
8787
Metric = [mean_accuracy, mean_loss]
8888
:param model:
8989
:param dataset:
9090
:param batch_size:
9191
:return: [mean_accuracy, mean_loss]
9292
"""
93-
with torch.no_grad():
94-
model.eval()
95-
if batch_size==-1:batch_size=len(dataset)
96-
data_loader = self.get_dataloader(dataset, batch_size=batch_size, num_workers=num_workers)
97-
total_loss = 0.0
98-
num_correct = 0
99-
for batch_id, batch_data in enumerate(data_loader):
100-
batch_data = self.to_device(batch_data)
101-
outputs = model(batch_data[0])
102-
batch_mean_loss = self.criterion(outputs, batch_data[-1]).item()
103-
y_pred = outputs.data.max(1, keepdim=True)[1]
104-
correct = y_pred.eq(batch_data[-1].data.view_as(y_pred)).long().cpu().sum()
105-
num_correct += correct.item()
106-
total_loss += batch_mean_loss * len(batch_data[-1])
93+
model.eval()
94+
if batch_size==-1:batch_size=len(dataset)
95+
data_loader = self.get_dataloader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
96+
total_loss = 0.0
97+
num_correct = 0
98+
for batch_id, batch_data in enumerate(data_loader):
99+
batch_data = self.to_device(batch_data)
100+
outputs = model(batch_data[0])
101+
batch_mean_loss = self.criterion(outputs, batch_data[-1]).item()
102+
y_pred = outputs.data.max(1, keepdim=True)[1]
103+
correct = y_pred.eq(batch_data[-1].data.view_as(y_pred)).long().cpu().sum()
104+
num_correct += correct.item()
105+
total_loss += batch_mean_loss * len(batch_data[-1])
107106
return {'accuracy': 1.0*num_correct/len(dataset), 'loss':total_loss/len(dataset)}
108107

109108
def to_device(self, data):
110109
return data[0].to(self.device), data[1].to(self.device)
111110

112-
def get_dataloader(self, dataset, batch_size=64, shuffle=True, num_workers=0):
111+
def get_dataloader(self, dataset, batch_size=64, shuffle=True, num_workers=0, pin_memory=False, drop_last=False):
113112
if self.DataLoader == None:
114113
raise NotImplementedError("DataLoader Not Found.")
115-
return self.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
114+
return self.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
116115

117116
class GeneralGenerator(BasicTaskGenerator):
118117
def __init__(self, benchmark, rawdata_path):

flgo/experiment/analyzer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,14 @@ def set_communication_round(self):
101101

102102
def set_client_id(self):
103103
with open(os.path.join(self.task, 'info')) as inf:
104-
num_clients = json.load(inf)['num_clients']
105-
self.data['client_id'] = [cid for cid in range(int(num_clients))]
104+
task_info = json.load(inf)
105+
if 'num_clients' in task_info.keys():
106+
N = int(task_info['num_clients'])
107+
elif 'num_parties' in task_info.keys():
108+
N = int(task_info['num_parties'])
109+
else:
110+
N = 0
111+
self.data['client_id'] = [cid for cid in range(N)]
106112

107113
def set_legend(self, legend_with = []):
108114
if len(legend_with)==0: self.data['label'] = []

flgo/utils/fflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
sample_list=['uniform', 'md', 'full', 'uniform_available', 'md_available', 'full_available']
3535
agg_list=['uniform', 'weighted_scale', 'weighted_com']
3636
optimizer_list=['SGD', 'Adam', 'RMSprop', 'Adagrad']
37-
default_option_dict = {'pretrain': '', 'sample': 'md', 'aggregate': 'uniform', 'num_rounds': 20, 'proportion': 0.2, 'learning_rate_decay': 0.998, 'lr_scheduler': -1, 'early_stop': -1, 'num_epochs': 5, 'num_steps': -1, 'learning_rate': 0.1, 'batch_size': 64.0, 'optimizer': 'SGD', 'momentum': 0, 'weight_decay': 0, 'algo_para': [], 'train_holdout': 0.1, 'test_holdout': 0.0, 'seed': 0, 'gpu': [], 'server_with_cpu': False, 'num_parallels': 1, 'num_workers': 0, 'test_batch_size': 512, 'simulator': 'default_simulator', 'availability': 'IDL', 'connectivity': 'IDL', 'completeness': 'IDL', 'responsiveness': 'IDL', 'logger': 'basic_logger', 'log_level': 'INFO', 'log_file': False, 'no_log_console': False, 'no_overwrite': False, 'eval_interval': 1}
37+
default_option_dict = {'pretrain': '', 'sample': 'md', 'aggregate': 'uniform', 'num_rounds': 20, 'proportion': 0.2, 'learning_rate_decay': 0.998, 'lr_scheduler': -1, 'early_stop': -1, 'num_epochs': 5, 'num_steps': -1, 'learning_rate': 0.1, 'batch_size': 64.0, 'optimizer': 'SGD', 'momentum': 0, 'weight_decay': 0, 'algo_para': [], 'train_holdout': 0.1, 'test_holdout': 0.0, 'seed': 0, 'gpu': [], 'server_with_cpu': False, 'num_parallels': 1, 'num_workers': 0, 'pin_memory':False,'test_batch_size': 512, 'simulator': 'default_simulator', 'availability': 'IDL', 'connectivity': 'IDL', 'completeness': 'IDL', 'responsiveness': 'IDL', 'logger': 'basic_logger', 'log_level': 'INFO', 'log_file': False, 'no_log_console': False, 'no_overwrite': False, 'eval_interval': 1}
3838

3939
class GlobalVariable:
4040
"""this class is to create a buffer space for sharing variables across different parties for each runner respectively in a single machine"""
@@ -99,6 +99,7 @@ def read_option_from_command():
9999
parser.add_argument('--server_with_cpu', help='seed for random initialization;', action="store_true", default=False)
100100
parser.add_argument('--num_parallels', help="the number of parallels in the clients computing session", type=int, default=1)
101101
parser.add_argument('--num_workers', help='the number of workers of DataLoader', type=int, default=0)
102+
parser.add_argument('--pin_memory', help='pin_memory of DataLoader', action="store_true", default=False)
102103
parser.add_argument('--test_batch_size', help='the batch_size used in testing phase;', type=int, default=512)
103104

104105
"""Simulator Options"""

0 commit comments

Comments
 (0)