Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new CNN class for MNIST handling and update model usage in API #71

Closed
wants to merge 11 commits into from

Conversation

sweep-nightly[bot]
Copy link

@sweep-nightly sweep-nightly bot commented Oct 19, 2023

Description

This PR adds a new CNN class in src/cnn.py that handles the MNIST dataset. The CNN class is responsible for loading and preprocessing the data, defining the CNN architecture, and training the model. The existing code in src/main.py for loading and preprocessing the MNIST dataset and defining the PyTorch model has been removed and replaced with the new CNN class. Additionally, the usage of the model in src/api.py has been updated to use the new CNN model instead of the previous Net model.

Summary of Changes

  • Created a new file src/cnn.py to contain the new CNN class.
  • Imported necessary libraries in src/cnn.py for building the CNN model.
  • Defined a new class CNN in src/cnn.py that inherits from torch.nn.Module.
  • Implemented the __init__ method in CNN class to define the layers of the CNN.
  • Implemented the forward method in CNN class to perform the forward pass of the CNN.
  • Defined a load_data function in src/cnn.py to load and preprocess the MNIST dataset.
  • Defined a train function in src/cnn.py to train the CNN model on the MNIST dataset.
  • Added a main function in src/cnn.py to create an instance of the CNN class, load the data, and train the model.
  • Updated src/main.py to import the CNN class from src/cnn.py.
  • Removed the code in src/main.py for loading and preprocessing the MNIST dataset and defining the PyTorch model.
  • Added code in src/main.py to create an instance of the CNN class and call the train method.
  • Updated src/api.py to import the CNN class from src/cnn.py.
  • Replaced the usage of the previous Net model with the new CNN model in src/api.py.
  • Updated the path to the state dict file in src/api.py to match the location of the saved CNN model.

Please review and merge this PR to incorporate the changes.

Fixes #9.


🎉 Latest improvements to Sweep:

  • Sweep can now passively improve your repository! Check out Rules to learn more.

💡 To get Sweep to edit this pull request, you can:

  • Comment below, and Sweep can edit the entire PR
  • Comment on a file, Sweep will only modify the commented file
  • Edit the original issue to get Sweep to recreate the PR from scratch

@sweep-nightly
Copy link
Author

sweep-nightly bot commented Oct 19, 2023

Rollback Files For Sweep

  • Rollback changes to src/main.py
  • Rollback changes to src/cnn.py
  • Rollback changes to src/main.py
  • Rollback changes to src/main.py
  • Rollback changes to src/main.py
  • Rollback changes to src/api.py
  • Rollback changes to src/train.py

@sweep-nightly
Copy link
Author

sweep-nightly bot commented Oct 19, 2023

Apply Sweep Rules to your PR?

  • Apply: All docstrings and comments should be up to date.
  • Apply: Code should be properly formatted and indented.
  • Apply: Variable and function names should be descriptive and follow a consistent naming convention.
  • Apply: Imports should be organized and grouped together.
  • Apply: There should be no unused imports or variables.
  • Apply: Code should be properly commented and include docstrings for functions and classes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Sweep: add a new cnn class that defines AND trains the cnn to handle mnist in cnn.py and import it to main.py
0 participants