diff --git a/veriflow/experiments/datasets.py b/veriflow/experiments/datasets.py index b0bf119..2036230 100644 --- a/veriflow/experiments/datasets.py +++ b/veriflow/experiments/datasets.py @@ -190,16 +190,23 @@ def __getitem__(self, index: int): class MnistDequantized(DequantizedDataset): def __init__(self, dataloc: os.PathLike = None, train: bool = True, digit: T.Optional[int] = None, flatten=True): - rel_path = "MNIST/raw/train-images-idx3-ubyte" if train else "MNIST/raw/t10k-images-idx3-ubyte" + if train: + rel_path = "MNIST/raw/train-images-idx3-ubyte" + else: + rel_path = "MNIST/raw/t10k-images-idx3-ubyte" path = os.path.join(dataloc, rel_path) if not os.path.exists(path): - MNIST(path, train=train, download=True) + MNIST(dataloc, train=train, download=True) + # TODO: remove hardcoding of 3x3 downsampling dataset = idx2numpy.convert_from_file(path)[:, ::3, ::3] if flatten: dataset = dataset.reshape(dataset.shape[0], -1) if digit is not None: - rel_path = "MNIST/raw/train-labels-idx1-ubyte" if train else "MNIST/raw/t10k-labels-idx1-ubyte" + if train: + rel_path = "MNIST/raw/train-labels-idx1-ubyte" + else: + rel_path = "MNIST/raw/t10k-labels-idx1-ubyte" path = os.path.join(dataloc, rel_path) labels = idx2numpy.convert_from_file(path) dataset = dataset[labels == digit]