Skip to content

Commit 67733ea

Browse files
committed
fix bugs in modeldict_cossim
1 parent 62c8795 commit 67733ea

File tree

5 files changed

+79
-16
lines changed

5 files changed

+79
-16
lines changed

flgo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .utils.fflow import init, gen_task, gen_task_from_para, gen_benchmark_from_file, gen_decentralized_benchmark, gen_hierarchical_benchmark, convert_model,tune, run_in_parallel, multi_init_and_run
1+
from .utils.fflow import init, gen_task, gen_task_from_para, gen_benchmark_from_file, gen_decentralized_benchmark, gen_hierarchical_benchmark, convert_model,tune, run_in_parallel, module2fmodule, multi_init_and_run
22

33
communicator = None
44

flgo/algorithm/fednova.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from flgo.utils import fmodule
77

88
class Server(BasicServer):
9+
def initialize(self, *args, **kwargs):
10+
self.sample_option = 'uniform'
11+
912
def iterate(self):
1013
self.selected_clients = self.sample()
1114
# training

flgo/utils/fflow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,3 +985,11 @@ def init_global_module(self, object):
985985
else:
986986
raise NotImplementedError('The current version only support converting model for horizontalFL and DecentralizedFL.')
987987
return AnonymousModel()
988+
989+
def module2fmodule(Model):
990+
class TempFModule(Model, flgo.utils.fmodule.FModule):
991+
def __init__(self, *args, **kwargs):
992+
super(TempFModule, self).__init__(*args, **kwargs)
993+
return TempFModule
994+
995+

flgo/utils/fmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def _modeldict_cossim(md1, md2):
852852
l1 = torch.tensor(0.).to(md1[list(md1)[0]].device)
853853
l2 = torch.tensor(0.).to(md1[list(md1)[0]].device)
854854
for layer in md1.keys():
855-
if md1[layer] is None or md1[layer].requires_grad==False:
855+
if md1[layer] is None:
856856
continue
857857
res += (md1[layer].view(-1).dot(md2[layer].view(-1)))
858858
l1 += torch.sum(torch.pow(md1[layer], 2))

tutorial/2.1_Try_Your_Own_Algorithms - Local_Training.ipynb

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,47 @@
33
{
44
"cell_type": "markdown",
55
"id": "c48d75ca",
6-
"metadata": {},
6+
"metadata": {
7+
"pycharm": {
8+
"name": "#%% md\n"
9+
}
10+
},
711
"source": [
812
"### 介绍 \\ Introduction"
913
]
1014
},
1115
{
1216
"cell_type": "markdown",
1317
"id": "e5ee136c",
14-
"metadata": {},
18+
"metadata": {
19+
"pycharm": {
20+
"name": "#%% md\n"
21+
}
22+
},
1523
"source": [
1624
"这一节主要介绍如何使用FLGo来实现自己的idea。这里首先关注的是本地训练阶段有所变化的算法(如FedProx,MOON等),这类算法在联邦学习中是极为常见的,因为联邦学习中用户数据分布通常是non-IID的,导致本地训练容易越走越偏。下面从FedProx入手,讲解如何使用FLGo复现该算法。"
1725
]
1826
},
1927
{
2028
"cell_type": "markdown",
2129
"id": "088c1728",
22-
"metadata": {},
30+
"metadata": {
31+
"pycharm": {
32+
"name": "#%% md\n"
33+
}
34+
},
2335
"source": [
2436
"### FedProx简介 "
2537
]
2638
},
2739
{
2840
"cell_type": "markdown",
2941
"id": "4e264286",
30-
"metadata": {},
42+
"metadata": {
43+
"pycharm": {
44+
"name": "#%% md\n"
45+
}
46+
},
3147
"source": [
3248
"FedProx是Li Tian等人于2018年([论文链接](https://arxiv.org/abs/1812.06127))所提出的一种针对系统异构性鲁棒的联邦优化算法,发表于MLSys 2020上。它相较于FedAvg主要做出了两点改进:\n",
3349
"\n",
@@ -43,7 +59,11 @@
4359
{
4460
"cell_type": "markdown",
4561
"id": "b2df44b4",
46-
"metadata": {},
62+
"metadata": {
63+
"pycharm": {
64+
"name": "#%% md\n"
65+
}
66+
},
4767
"source": [
4868
"### 加入算法超参数\n",
4969
"所有的横向联邦中的超参数都在Server的initialize方法中加入,加入的方法是调用Server.init_algo_para方法,并传入超参数字典。例如,对于fedprox来说,超参数是mu,因此只要通过该方法传入{'mu':0.01},Server和Client就会被添加额外的属性Server.mu和Client.mu来访问该超参数,其中0.01是默认值。"
@@ -53,7 +73,11 @@
5373
"cell_type": "code",
5474
"execution_count": 1,
5575
"id": "08a33678",
56-
"metadata": {},
76+
"metadata": {
77+
"pycharm": {
78+
"name": "#%%\n"
79+
}
80+
},
5781
"outputs": [],
5882
"source": [
5983
"import flgo.algorithm.fedbase as fedbase\n",
@@ -72,7 +96,11 @@
7296
{
7397
"cell_type": "markdown",
7498
"id": "a04dd293",
75-
"metadata": {},
99+
"metadata": {
100+
"pycharm": {
101+
"name": "#%% md\n"
102+
}
103+
},
76104
"source": [
77105
"### 修改本地训练阶段"
78106
]
@@ -81,7 +109,11 @@
81109
"cell_type": "code",
82110
"execution_count": 2,
83111
"id": "6bf59686",
84-
"metadata": {},
112+
"metadata": {
113+
"pycharm": {
114+
"name": "#%%\n"
115+
}
116+
},
85117
"outputs": [],
86118
"source": [
87119
"class Client(fedbase.BasicClient):\n",
@@ -112,7 +144,11 @@
112144
{
113145
"cell_type": "markdown",
114146
"id": "e10e0698",
115-
"metadata": {},
147+
"metadata": {
148+
"pycharm": {
149+
"name": "#%% md\n"
150+
}
151+
},
116152
"source": [
117153
"### 构造算法fedprox\n",
118154
"在FLGo中,每个算法既可以用一个类表示,也可以用一个文件表示,唯一的要求是算法必须具备 algorithm_name.Server和algorithm_name.Client这两个属性(对于横向联邦是这俩)。"
@@ -122,7 +158,11 @@
122158
"cell_type": "code",
123159
"execution_count": 3,
124160
"id": "8cd43e21",
125-
"metadata": {},
161+
"metadata": {
162+
"pycharm": {
163+
"name": "#%%\n"
164+
}
165+
},
126166
"outputs": [],
127167
"source": [
128168
"class fedprox:\n",
@@ -133,7 +173,11 @@
133173
{
134174
"cell_type": "markdown",
135175
"id": "a3ea442a",
136-
"metadata": {},
176+
"metadata": {
177+
"pycharm": {
178+
"name": "#%% md\n"
179+
}
180+
},
137181
"source": [
138182
"### 测试fedprox"
139183
]
@@ -142,7 +186,11 @@
142186
"cell_type": "code",
143187
"execution_count": 4,
144188
"id": "0936a8f9",
145-
"metadata": {},
189+
"metadata": {
190+
"pycharm": {
191+
"name": "#%%\n"
192+
}
193+
},
146194
"outputs": [
147195
{
148196
"name": "stderr",
@@ -1004,7 +1052,11 @@
10041052
"cell_type": "code",
10051053
"execution_count": null,
10061054
"id": "00b4061b",
1007-
"metadata": {},
1055+
"metadata": {
1056+
"pycharm": {
1057+
"name": "#%%\n"
1058+
}
1059+
},
10081060
"outputs": [],
10091061
"source": []
10101062
}
@@ -1030,4 +1082,4 @@
10301082
},
10311083
"nbformat": 4,
10321084
"nbformat_minor": 5
1033-
}
1085+
}

0 commit comments

Comments
 (0)