Skip to content

Commit a485b59

Browse files
authored
Fix CUB200 Download (#406)
* Replace GDrive link by Zenodo. * Update CHANGELOG. * Fix pypi dependency in CI.
1 parent 632317f commit a485b59

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

.github/workflows/python_unittest.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ jobs:
5656
python3 --version
5757
python3 -m pip install -U pip
5858
pip3 install cython pybind11
59+
pip3 install scipy==1.10.1
5960
pip3 install torch==${{ matrix.pytorch }}
6061
pip3 install torchvision==${{ matrix.torchvision }}
6162
pip3 install chardet==3.0.4 # can be remove when fix in: https://github.com/aio-libs/aiohttp/issues/5366

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2828
* MAML Toy example. (@[Theo Morales](https://github.com/DubiousCactus))
2929
* Example for `detach_module`. ([Nimish Sanghi](https://github.com/nsanghi))
3030
* Loading duplicate FGVC Aircraft images.
31-
* Move vision datasets to Zenodo. (mini-ImageNet, tiered-ImageNet, FC100, CIFAR-FS)
31+
* Move vision datasets to Zenodo. (mini-ImageNet, tiered-ImageNet, FC100, CIFAR-FS, CUB200)
3232
* mini-ImageNet targets are now ints (not np.float64).
3333
* Swap family for variants in FGVCAircraft, as in MetaDataset.
3434

learn2learn/vision/datasets/cu_birds200.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
import torch
66

77
from PIL import Image
8-
from learn2learn.data.utils import download_file_from_google_drive
8+
from learn2learn.data.utils import (
9+
download_file_from_google_drive,
10+
download_file,
11+
)
912

1013
DATA_DIR = 'cubirds200'
1114
DATA_FILENAME = 'CUB_200_2011.tgz'
1215
ARCHIVE_ID = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
16+
ZENODO_URL = 'https://zenodo.org/record/8000562/files/CUB_200_2011.tgz'
1317

1418
SPLITS = {
1519
'train': [
@@ -353,11 +357,18 @@ def download(self):
353357
os.makedirs(data_path, exist_ok=True)
354358
tar_path = os.path.join(data_path, DATA_FILENAME)
355359
print('Downloading CUBirds200 dataset. (1.1Gb)')
356-
download_file_from_google_drive(ARCHIVE_ID, tar_path)
357-
tar_file = tarfile.open(tar_path)
358-
tar_file.extractall(data_path)
359-
tar_file.close()
360-
os.remove(tar_path)
360+
try:
361+
download_file(ZENODO_URL, tar_path)
362+
tar_file = tarfile.open(tar_path)
363+
tar_file.extractall(data_path)
364+
tar_file.close()
365+
os.remove(tar_path)
366+
except Exception:
367+
download_file_from_google_drive(ARCHIVE_ID, tar_path)
368+
tar_file = tarfile.open(tar_path)
369+
tar_file.extractall(data_path)
370+
tar_file.close()
371+
os.remove(tar_path)
361372

362373
def load_data(self, mode='train'):
363374
classes = sum(SPLITS.values(), []) if mode == 'all' else SPLITS[mode]

0 commit comments

Comments
 (0)