From 583e0b8cea0691213b04474a67ce720128d87f7e Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 06:11:54 +0000 Subject: [PATCH] feat: add tests for data loading and preprocessing --- tests/test_main.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/test_main.py diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..bc34cfb --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,35 @@ +import pytest +from unittest.mock import Mock, call +from PIL import Image +import torch +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +import numpy as np +from src import main + +def test_data_loading(mocker): + """ + Test the data loading step in main.py. + This test checks that the datasets.MNIST and DataLoader functions are called with the correct arguments. + """ + mock_mnist = mocker.patch('torchvision.datasets.MNIST', return_value=[]) + mock_dataloader = mocker.patch('torch.utils.data.DataLoader', return_value=[]) + + main.load_data() + + mock_mnist.assert_called_once_with('.', download=True, train=True, transform=main.transform) + mock_dataloader.assert_called_once_with(mock_mnist.return_value, batch_size=64, shuffle=True) + +def test_data_preprocessing(mocker): + """ + Test the data preprocessing step in main.py. + This test checks that the transforms.Compose function is called with the correct arguments. + """ + mock_compose = mocker.patch('torchvision.transforms.Compose', return_value=[]) + + main.preprocess_data() + + mock_compose.assert_called_once_with([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ])