forked from thuml/Transfer-Learning-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathretinopathy.py
29 lines (22 loc) · 1.28 KB
/
retinopathy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
@author: Junguang Jiang
@contact: [email protected]
"""
import os
from common.vision.datasets import ImageList
class Retinopathy(ImageList):
"""`Retinopathy <https://www.kaggle.com/c/diabetic-retinopathy-detection/data>`_ dataset \
consists of image-label pairs with high-resolution retina images, and labels that indicate \
the presence of Diabetic Retinopahy (DR) in a 0-4 scale (No DR, Mild, Moderate, Severe, \
or Proliferative DR).
.. note:: You need to download the source data manually into `root` directory.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
CLASSES = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']
def __init__(self, root, split, download=False, **kwargs):
super(Retinopathy, self).__init__(os.path.join(root, split), Retinopathy.CLASSES, os.path.join(root, "image_list", "{}.txt".format(split)), **kwargs)