Skip to content

Commit

Permalink
fix bug in MNIST dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
fariedabuzaid committed Sep 19, 2023
1 parent bf18e9f commit 705d5ae
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions veriflow/experiments/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,23 @@ def __getitem__(self, index: int):

class MnistDequantized(DequantizedDataset):
def __init__(self, dataloc: os.PathLike = None, train: bool = True, digit: T.Optional[int] = None, flatten=True):
rel_path = "MNIST/raw/train-images-idx3-ubyte" if train else "MNIST/raw/t10k-images-idx3-ubyte"
if train:
rel_path = "MNIST/raw/train-images-idx3-ubyte"
else:
rel_path = "MNIST/raw/t10k-images-idx3-ubyte"
path = os.path.join(dataloc, rel_path)
if not os.path.exists(path):
MNIST(path, train=train, download=True)
MNIST(dataloc, train=train, download=True)

# TODO: remove hardcoding of 3x3 downsampling
dataset = idx2numpy.convert_from_file(path)[:, ::3, ::3]
if flatten:
dataset = dataset.reshape(dataset.shape[0], -1)
if digit is not None:
rel_path = "MNIST/raw/train-labels-idx1-ubyte" if train else "MNIST/raw/t10k-labels-idx1-ubyte"
if train:
rel_path = "MNIST/raw/train-labels-idx1-ubyte"
else:
rel_path = "MNIST/raw/t10k-labels-idx1-ubyte"
path = os.path.join(dataloc, rel_path)
labels = idx2numpy.convert_from_file(path)
dataset = dataset[labels == digit]
Expand Down

0 comments on commit 705d5ae

Please sign in to comment.