Skip to content

Commit 705d5ae

Browse files
committed
fix bug in MNIST dataset
1 parent bf18e9f commit 705d5ae

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

veriflow/experiments/datasets.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,23 @@ def __getitem__(self, index: int):
190190

191191
class MnistDequantized(DequantizedDataset):
192192
def __init__(self, dataloc: os.PathLike = None, train: bool = True, digit: T.Optional[int] = None, flatten=True):
193-
rel_path = "MNIST/raw/train-images-idx3-ubyte" if train else "MNIST/raw/t10k-images-idx3-ubyte"
193+
if train:
194+
rel_path = "MNIST/raw/train-images-idx3-ubyte"
195+
else:
196+
rel_path = "MNIST/raw/t10k-images-idx3-ubyte"
194197
path = os.path.join(dataloc, rel_path)
195198
if not os.path.exists(path):
196-
MNIST(path, train=train, download=True)
199+
MNIST(dataloc, train=train, download=True)
200+
197201
# TODO: remove hardcoding of 3x3 downsampling
198202
dataset = idx2numpy.convert_from_file(path)[:, ::3, ::3]
199203
if flatten:
200204
dataset = dataset.reshape(dataset.shape[0], -1)
201205
if digit is not None:
202-
rel_path = "MNIST/raw/train-labels-idx1-ubyte" if train else "MNIST/raw/t10k-labels-idx1-ubyte"
206+
if train:
207+
rel_path = "MNIST/raw/train-labels-idx1-ubyte"
208+
else:
209+
rel_path = "MNIST/raw/t10k-labels-idx1-ubyte"
203210
path = os.path.join(dataloc, rel_path)
204211
labels = idx2numpy.convert_from_file(path)
205212
dataset = dataset[labels == digit]

0 commit comments

Comments
 (0)