Skip to content

Commit eb075be

Browse files
committed
fix syntax
1 parent 9136ac3 commit eb075be

30 files changed

+192
-266
lines changed

avalanche/benchmarks/classic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
from .clear import *
1414
from .stream51 import *
1515
from .ex_model import *
16-
from .concon import *
16+
from .concon import *

avalanche/benchmarks/classic/concon.py

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from torchvision import transforms
77

88
from avalanche.benchmarks.utils.data import AvalancheDataset
9-
from avalanche.benchmarks.utils.classification_dataset import _as_taskaware_supervised_classification_dataset
9+
from avalanche.benchmarks.utils.classification_dataset import (
10+
_as_taskaware_supervised_classification_dataset,
11+
)
1012
from avalanche.benchmarks import benchmark_from_datasets, CLScenario
1113

1214
from avalanche.benchmarks.datasets.concon import ConConDataset
@@ -18,20 +20,14 @@
1820
_default_train_transform = transforms.Compose(
1921
[
2022
transforms.ToTensor(),
21-
transforms.Normalize(
22-
mean=[0.5, 0.5, 0.5],
23-
std=[0.5, 0.5, 0.5]
24-
)
23+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
2524
]
2625
)
2726

2827
_default_eval_transform = transforms.Compose(
2928
[
3029
transforms.ToTensor(),
31-
transforms.Normalize(
32-
mean=[0.5, 0.5, 0.5],
33-
std=[0.5, 0.5, 0.5]
34-
)
30+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
3531
]
3632
)
3733

@@ -55,7 +51,8 @@ def build_concon_scenario(
5551

5652
for i in range(1, len(list_train_dataset)):
5753
new_list_train_dataset[0] = new_list_train_dataset[0].concat(
58-
list_train_dataset[i])
54+
list_train_dataset[i]
55+
)
5956

6057
list_train_dataset = new_list_train_dataset
6158

@@ -64,14 +61,12 @@ def build_concon_scenario(
6461

6562
for i in range(1, len(list_test_dataset)):
6663
new_list_test_dataset[0] = new_list_test_dataset[0].concat(
67-
list_test_dataset[i])
64+
list_test_dataset[i]
65+
)
6866

6967
list_test_dataset = new_list_test_dataset
7068

71-
return benchmark_from_datasets(
72-
train=list_train_dataset,
73-
test=list_test_dataset
74-
)
69+
return benchmark_from_datasets(train=list_train_dataset, test=list_test_dataset)
7570

7671

7772
def ConConDisjoint(
@@ -85,7 +80,7 @@ def ConConDisjoint(
8580
) -> CLScenario:
8681
"""
8782
Creates a ConCon Disjoint benchmark.
88-
83+
8984
If the dataset is not present in the computer, this method will
9085
automatically download and store it.
9186
@@ -107,20 +102,20 @@ def ConConDisjoint(
107102
108103
:returns: The ConCon Disjoint benchmark.
109104
"""
110-
assert n_experiences == 3 or n_experiences == 1, "n_experiences must be 1 or 3 for ConCon Disjoint"
105+
assert (
106+
n_experiences == 3 or n_experiences == 1
107+
), "n_experiences must be 1 or 3 for ConCon Disjoint"
111108
list_train_dataset = []
112109
list_test_dataset = []
113110

114111
for i in range(3):
115112
train_dataset = ConConDataset("disjoint", i, root=dataset_root, train=True)
116113
test_dataset = ConConDataset("disjoint", i, root=dataset_root, train=False)
117114
train_dataset = _as_taskaware_supervised_classification_dataset(
118-
train_dataset,
119-
transform=train_transform
115+
train_dataset, transform=train_transform
120116
)
121117
test_dataset = _as_taskaware_supervised_classification_dataset(
122-
test_dataset,
123-
transform=eval_transform
118+
test_dataset, transform=eval_transform
124119
)
125120
list_train_dataset.append(train_dataset)
126121
list_test_dataset.append(test_dataset)
@@ -130,7 +125,7 @@ def ConConDisjoint(
130125
list_test_dataset,
131126
seed=seed,
132127
n_experiences=n_experiences,
133-
shuffle_order=shuffle_order
128+
shuffle_order=shuffle_order,
134129
)
135130

136131

@@ -145,13 +140,13 @@ def ConConStrict(
145140
) -> CLScenario:
146141
"""
147142
Creates a ConCon Strict benchmark.
148-
143+
149144
If the dataset is not present in the computer, this method will
150145
automatically download and store it.
151146
152147
The returned benchmark will be a domain-incremental one, where each task
153148
is a different domain with different confounders. In this setting,
154-
task-specific confounders may appear in other tasks as random features
149+
task-specific confounders may appear in other tasks as random features
155150
in both positive and negative samples.
156151
157152
The benchmark instance returned by this method will have two fields,
@@ -168,20 +163,20 @@ def ConConStrict(
168163
169164
:returns: The ConCon Strict benchmark.
170165
"""
171-
assert n_experiences == 3 or n_experiences == 1, "n_experiences must be 1 or 3 for ConCon Disjoint"
166+
assert (
167+
n_experiences == 3 or n_experiences == 1
168+
), "n_experiences must be 1 or 3 for ConCon Disjoint"
172169
list_train_dataset = []
173170
list_test_dataset = []
174171

175172
for i in range(3):
176173
train_dataset = ConConDataset("strict", i, root=dataset_root, train=True)
177174
test_dataset = ConConDataset("strict", i, root=dataset_root, train=False)
178175
train_dataset = _as_taskaware_supervised_classification_dataset(
179-
train_dataset,
180-
transform=train_transform
176+
train_dataset, transform=train_transform
181177
)
182178
test_dataset = _as_taskaware_supervised_classification_dataset(
183-
test_dataset,
184-
transform=eval_transform
179+
test_dataset, transform=eval_transform
185180
)
186181
list_train_dataset.append(train_dataset)
187182
list_test_dataset.append(test_dataset)
@@ -191,7 +186,7 @@ def ConConStrict(
191186
list_test_dataset,
192187
seed=seed,
193188
n_experiences=n_experiences,
194-
shuffle_order=shuffle_order
189+
shuffle_order=shuffle_order,
195190
)
196191

197192

@@ -203,7 +198,7 @@ def ConConUnconfounded(
203198
) -> CLScenario:
204199
"""
205200
Creates a ConCon Unconfounded benchmark.
206-
201+
207202
If the dataset is not present in the computer, this method will
208203
automatically download and store it.
209204
@@ -214,35 +209,32 @@ def ConConUnconfounded(
214209
`train_stream` and `test_stream`, which can be iterated to obtain
215210
training and test :class:`Experience`. Each Experience contains the
216211
`dataset` and the associated task label.
217-
212+
218213
:param dataset_root: The root directory of the dataset.
219214
:param train_transform: The training transform to use.
220215
:param eval_transform: The evaluation transform to use.
221-
216+
222217
:returns: The ConCon Unconfounded benchmark.
223218
"""
224219
train_dataset = []
225220
test_dataset = []
226221

227-
train_dataset.append(ConConDataset(
228-
"unconfounded", 0, root=dataset_root, train=True))
229-
test_dataset.append(ConConDataset(
230-
"unconfounded", 0, root=dataset_root, train=False))
222+
train_dataset.append(
223+
ConConDataset("unconfounded", 0, root=dataset_root, train=True)
224+
)
225+
test_dataset.append(
226+
ConConDataset("unconfounded", 0, root=dataset_root, train=False)
227+
)
231228

232229
train_dataset[0] = _as_taskaware_supervised_classification_dataset(
233-
train_dataset[0],
234-
transform=train_transform
230+
train_dataset[0], transform=train_transform
235231
)
236232

237233
test_dataset[0] = _as_taskaware_supervised_classification_dataset(
238-
test_dataset[0],
239-
transform=eval_transform
234+
test_dataset[0], transform=eval_transform
240235
)
241236

242-
return benchmark_from_datasets(
243-
train=train_dataset,
244-
test=test_dataset
245-
)
237+
return benchmark_from_datasets(train=train_dataset, test=test_dataset)
246238

247239

248240
__all__ = [

avalanche/benchmarks/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
from .inaturalist import *
1313
from .penn_fudan import *
1414
from .clear import *
15-
from .concon import *
15+
from .concon import *
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .concon import *
1+
from .concon import *

avalanche/benchmarks/datasets/concon/concon.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class ConConDataset(SimpleDownloadableDataset):
2424
and negative samples.
2525
- Unconfounded: No task-specific confounders.
2626
27-
Reference:
28-
Busch, Florian Peter, et al. "Where is the Truth? The Risk of Getting Confounded in a Continual World."
27+
Reference:
28+
Busch, Florian Peter, et al. "Where is the Truth? The Risk of Getting Confounded in a Continual World."
2929
arXiv preprint arXiv:2402.06434 (2024).
3030
3131
Args:
@@ -38,52 +38,58 @@ class ConConDataset(SimpleDownloadableDataset):
3838
transform: A function/transform that takes in an PIL image and returns a transformed version.
3939
E.g, ``transforms.RandomCrop`` for data augmentation.
4040
"""
41-
41+
4242
urls = {
4343
"strict": "https://zenodo.org/records/10630482/files/case_strict_main.zip",
4444
"disjoint": "https://zenodo.org/records/10630482/files/case_disjoint_main.zip",
45-
"unconfounded": "https://zenodo.org/records/10630482/files/unconfounded.zip"
45+
"unconfounded": "https://zenodo.org/records/10630482/files/unconfounded.zip",
4646
}
4747

48-
def __init__(self,
49-
variant: str,
50-
scenario: int,
51-
root: Optional[Union[str, Path]] = None,
52-
train: bool = True,
53-
download: bool = True,
54-
transform = None,
55-
):
56-
assert variant in ["strict", "disjoint", "unconfounded"], "Invalid variant, must be one of 'strict', 'disjoint', 'unconf'"
57-
assert scenario in range(
58-
0, 3), "Invalid scenario, must be between 0 and 2"
59-
assert variant != "unconfounded" or scenario == 0, "Unconfounded scenario only has one variant"
48+
def __init__(
49+
self,
50+
variant: str,
51+
scenario: int,
52+
root: Optional[Union[str, Path]] = None,
53+
train: bool = True,
54+
download: bool = True,
55+
transform=None,
56+
):
57+
assert variant in [
58+
"strict",
59+
"disjoint",
60+
"unconfounded",
61+
], "Invalid variant, must be one of 'strict', 'disjoint', 'unconf'"
62+
assert scenario in range(0, 3), "Invalid scenario, must be between 0 and 2"
63+
assert (
64+
variant != "unconfounded" or scenario == 0
65+
), "Unconfounded scenario only has one variant"
6066

6167
if root is None:
6268
root = default_dataset_location("concon")
63-
69+
6470
self.root = Path(root)
65-
71+
6672
url = self.urls[variant]
67-
73+
6874
super(ConConDataset, self).__init__(
6975
self.root, url, None, download=download, verbose=True
7076
)
71-
77+
7278
if variant == "strict":
7379
self.variant = "case_strict_main"
7480
elif variant == "disjoint":
7581
self.variant = "case_disjoint_main"
7682
else:
7783
self.variant = variant
78-
84+
7985
self.scenario = scenario
8086
self.train = train
8187
self.transform = transform
8288
self._load_dataset()
83-
89+
8490
def _load_metadata(self) -> bool:
8591
root = self.root / self.variant
86-
92+
8793
if self.train:
8894
images_dir = root / "train"
8995
else:
@@ -98,7 +104,7 @@ def _load_metadata(self) -> bool:
98104
for image_path in class_dir.iterdir():
99105
self.image_paths.append(image_path)
100106
self.targets.append(class_id)
101-
107+
102108
return True
103109

104110
def __len__(self):
@@ -107,14 +113,14 @@ def __len__(self):
107113
def __getitem__(self, idx):
108114
image_path = self.image_paths[idx]
109115
image = Image.open(image_path).convert("RGB")
110-
116+
111117
if self.transform is not None:
112118
image = self.transform(image)
113-
119+
114120
target = self.targets[idx]
115121
return image, target
116-
117-
122+
123+
118124
if __name__ == "__main__":
119125
# this little example script can be used to visualize the first image
120126
# loaded from the dataset.

avalanche/benchmarks/scenarios/deprecated/classification_scenario.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,10 @@ def __len__(self) -> int:
257257
return len(self._benchmark.streams[self._stream])
258258

259259
@overload
260-
def __getitem__(self, exp_id: int) -> Optional[Set[int]]:
261-
...
260+
def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ...
262261

263262
@overload
264-
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]:
265-
...
263+
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ...
266264

267265
def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet:
268266
indexing_collate = _LazyClassesInClassificationExps._slice_collate

avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,17 +184,17 @@ def __init__(
184184
invoking the super constructor) to specialize the experience class.
185185
"""
186186

187-
self.experience_factory: Callable[
188-
[TCLStream, int], TDatasetExperience
189-
] = experience_factory
187+
self.experience_factory: Callable[[TCLStream, int], TDatasetExperience] = (
188+
experience_factory
189+
)
190190

191-
self.stream_factory: Callable[
192-
[str, TDatasetScenario], TCLStream
193-
] = stream_factory
191+
self.stream_factory: Callable[[str, TDatasetScenario], TCLStream] = (
192+
stream_factory
193+
)
194194

195-
self.stream_definitions: Dict[
196-
str, StreamDef[TCLDataset]
197-
] = DatasetScenario._check_stream_definitions(stream_definitions)
195+
self.stream_definitions: Dict[str, StreamDef[TCLDataset]] = (
196+
DatasetScenario._check_stream_definitions(stream_definitions)
197+
)
198198
"""
199199
A structure containing the definition of the streams.
200200
"""

0 commit comments

Comments
 (0)