Neural network-based image classifier for accurate and fast flower identification.
This repository contains the code and assets for a deep learning project that builds an image classifier using PyTorch and torchvision. The project includes scripts for training a model on an image dataset and making predictions on new images.
The project is organized into the following components:
train.py: Script for training a deep learning model on an image dataset.predict.py: Script for making predictions using a trained model checkpoint.cat_to_name.json: JSON file mapping category indices to flower names.checkpoint.pth: Example trained model checkpoint./flowers: Sample image dataset directory containing training and validation sets.
To train the model on your own dataset, use the following command:
python train.py /path/to/dataset --arch vgg16 --hidden_units 512 --learning_rate 0.001 --epochs 10 --save_dir checkpoint.pth --gpuReplace /path/to/dataset with the path to your image dataset.
To make predictions on a new image, use the following command:
python predict.py /path/to/image checkpoint.pth --top_k 5 --category_names cat_to_name.json --gpuReplace /path/to/image with the path to the image you want to classify.
Make sure you have the required dependencies installed. You can install them using:
pip install -r requirements.txtThis project is licensed under the MIT License.
- This project was built using PyTorch and torchvision.
- The image dataset used for training is based on the Flower Recognition Dataset.