66from torchvision import transforms
77
88from avalanche .benchmarks .utils .data import AvalancheDataset
9- from avalanche .benchmarks .utils .classification_dataset import _as_taskaware_supervised_classification_dataset
9+ from avalanche .benchmarks .utils .classification_dataset import (
10+ _as_taskaware_supervised_classification_dataset ,
11+ )
1012from avalanche .benchmarks import benchmark_from_datasets , CLScenario
1113
1214from avalanche .benchmarks .datasets .concon import ConConDataset
1820_default_train_transform = transforms .Compose (
1921 [
2022 transforms .ToTensor (),
21- transforms .Normalize (
22- mean = [0.5 , 0.5 , 0.5 ],
23- std = [0.5 , 0.5 , 0.5 ]
24- )
23+ transforms .Normalize (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ]),
2524 ]
2625)
2726
2827_default_eval_transform = transforms .Compose (
2928 [
3029 transforms .ToTensor (),
31- transforms .Normalize (
32- mean = [0.5 , 0.5 , 0.5 ],
33- std = [0.5 , 0.5 , 0.5 ]
34- )
30+ transforms .Normalize (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ]),
3531 ]
3632)
3733
@@ -55,7 +51,8 @@ def build_concon_scenario(
5551
5652 for i in range (1 , len (list_train_dataset )):
5753 new_list_train_dataset [0 ] = new_list_train_dataset [0 ].concat (
58- list_train_dataset [i ])
54+ list_train_dataset [i ]
55+ )
5956
6057 list_train_dataset = new_list_train_dataset
6158
@@ -64,14 +61,12 @@ def build_concon_scenario(
6461
6562 for i in range (1 , len (list_test_dataset )):
6663 new_list_test_dataset [0 ] = new_list_test_dataset [0 ].concat (
67- list_test_dataset [i ])
64+ list_test_dataset [i ]
65+ )
6866
6967 list_test_dataset = new_list_test_dataset
7068
71- return benchmark_from_datasets (
72- train = list_train_dataset ,
73- test = list_test_dataset
74- )
69+ return benchmark_from_datasets (train = list_train_dataset , test = list_test_dataset )
7570
7671
7772def ConConDisjoint (
@@ -85,7 +80,7 @@ def ConConDisjoint(
8580) -> CLScenario :
8681 """
8782 Creates a ConCon Disjoint benchmark.
88-
83+
8984 If the dataset is not present in the computer, this method will
9085 automatically download and store it.
9186
@@ -107,20 +102,20 @@ def ConConDisjoint(
107102
108103 :returns: The ConCon Disjoint benchmark.
109104 """
110- assert n_experiences == 3 or n_experiences == 1 , "n_experiences must be 1 or 3 for ConCon Disjoint"
105+ assert (
106+ n_experiences == 3 or n_experiences == 1
107+ ), "n_experiences must be 1 or 3 for ConCon Disjoint"
111108 list_train_dataset = []
112109 list_test_dataset = []
113110
114111 for i in range (3 ):
115112 train_dataset = ConConDataset ("disjoint" , i , root = dataset_root , train = True )
116113 test_dataset = ConConDataset ("disjoint" , i , root = dataset_root , train = False )
117114 train_dataset = _as_taskaware_supervised_classification_dataset (
118- train_dataset ,
119- transform = train_transform
115+ train_dataset , transform = train_transform
120116 )
121117 test_dataset = _as_taskaware_supervised_classification_dataset (
122- test_dataset ,
123- transform = eval_transform
118+ test_dataset , transform = eval_transform
124119 )
125120 list_train_dataset .append (train_dataset )
126121 list_test_dataset .append (test_dataset )
@@ -130,7 +125,7 @@ def ConConDisjoint(
130125 list_test_dataset ,
131126 seed = seed ,
132127 n_experiences = n_experiences ,
133- shuffle_order = shuffle_order
128+ shuffle_order = shuffle_order ,
134129 )
135130
136131
@@ -145,13 +140,13 @@ def ConConStrict(
145140) -> CLScenario :
146141 """
147142 Creates a ConCon Strict benchmark.
148-
143+
149144 If the dataset is not present in the computer, this method will
150145 automatically download and store it.
151146
152147 The returned benchmark will be a domain-incremental one, where each task
153148 is a different domain with different confounders. In this setting,
154- task-specific confounders may appear in other tasks as random features
149+ task-specific confounders may appear in other tasks as random features
155150 in both positive and negative samples.
156151
157152 The benchmark instance returned by this method will have two fields,
@@ -168,20 +163,20 @@ def ConConStrict(
168163
169164 :returns: The ConCon Strict benchmark.
170165 """
171- assert n_experiences == 3 or n_experiences == 1 , "n_experiences must be 1 or 3 for ConCon Disjoint"
166+ assert (
167+ n_experiences == 3 or n_experiences == 1
168+ ), "n_experiences must be 1 or 3 for ConCon Disjoint"
172169 list_train_dataset = []
173170 list_test_dataset = []
174171
175172 for i in range (3 ):
176173 train_dataset = ConConDataset ("strict" , i , root = dataset_root , train = True )
177174 test_dataset = ConConDataset ("strict" , i , root = dataset_root , train = False )
178175 train_dataset = _as_taskaware_supervised_classification_dataset (
179- train_dataset ,
180- transform = train_transform
176+ train_dataset , transform = train_transform
181177 )
182178 test_dataset = _as_taskaware_supervised_classification_dataset (
183- test_dataset ,
184- transform = eval_transform
179+ test_dataset , transform = eval_transform
185180 )
186181 list_train_dataset .append (train_dataset )
187182 list_test_dataset .append (test_dataset )
@@ -191,7 +186,7 @@ def ConConStrict(
191186 list_test_dataset ,
192187 seed = seed ,
193188 n_experiences = n_experiences ,
194- shuffle_order = shuffle_order
189+ shuffle_order = shuffle_order ,
195190 )
196191
197192
@@ -203,7 +198,7 @@ def ConConUnconfounded(
203198) -> CLScenario :
204199 """
205200 Creates a ConCon Unconfounded benchmark.
206-
201+
207202 If the dataset is not present in the computer, this method will
208203 automatically download and store it.
209204
@@ -214,35 +209,32 @@ def ConConUnconfounded(
214209 `train_stream` and `test_stream`, which can be iterated to obtain
215210 training and test :class:`Experience`. Each Experience contains the
216211 `dataset` and the associated task label.
217-
212+
218213 :param dataset_root: The root directory of the dataset.
219214 :param train_transform: The training transform to use.
220215 :param eval_transform: The evaluation transform to use.
221-
216+
222217 :returns: The ConCon Unconfounded benchmark.
223218 """
224219 train_dataset = []
225220 test_dataset = []
226221
227- train_dataset .append (ConConDataset (
228- "unconfounded" , 0 , root = dataset_root , train = True ))
229- test_dataset .append (ConConDataset (
230- "unconfounded" , 0 , root = dataset_root , train = False ))
222+ train_dataset .append (
223+ ConConDataset ("unconfounded" , 0 , root = dataset_root , train = True )
224+ )
225+ test_dataset .append (
226+ ConConDataset ("unconfounded" , 0 , root = dataset_root , train = False )
227+ )
231228
232229 train_dataset [0 ] = _as_taskaware_supervised_classification_dataset (
233- train_dataset [0 ],
234- transform = train_transform
230+ train_dataset [0 ], transform = train_transform
235231 )
236232
237233 test_dataset [0 ] = _as_taskaware_supervised_classification_dataset (
238- test_dataset [0 ],
239- transform = eval_transform
234+ test_dataset [0 ], transform = eval_transform
240235 )
241236
242- return benchmark_from_datasets (
243- train = train_dataset ,
244- test = test_dataset
245- )
237+ return benchmark_from_datasets (train = train_dataset , test = test_dataset )
246238
247239
248240__all__ = [
0 commit comments