In this project, I experimented with getting a Generative Adversarial Network (GAN) to work.
Setup: I used PyTorch and PyTorch Lightning as the framework. The datasets used were:
- MNIST (handwritten digit dataset)
- IRIS Flower dataset
- A custom GHIBLI dataset (hand-drawn landscape images from Studio Ghibli)
Goal: The objective was to generate images that resemble the dataset inputs. For instance, when training on the MNIST dataset, the aim was to generate images that look like handwritten digits.
Model Architectures: I explored different architectures and fine-tuned them. The models used included:
-
Generative Adversarial Network (GAN): A basic architecture using a Multi-Layer Perceptron (MLP), with a Discriminator that attempts to detect fake images and a Generator that attempts to produce realistic-looking images to fool the Discriminator.
-
Deep Convolutional Generative Adversarial Network (DCGAN): Uses convolutional layers (commonly used in image classification) to improve image realism.
-
Conditional Deep Convolutional GAN (CDC-GAN): Uses both the image and metadata (e.g., labels) to condition the generation process.
Training Tricks & Techniques:
-
Soft labels
-
Experience replay
-
More tricks:
More Resources:
- Download the Jupyter Notebook files (
.ipynb
) you want to use from thecode
folder. - Run them with Jupyter Notebook/Lab or Google Colab.
Notes:
- On Google Colab, all necessary packages should auto-install.
- On a local Jupyter Notebook, you’ll need to install required packages manually.
- Please read the Disclaimer below.
This project was created a while ago and the documentation is not very organized. As a result, there are multiple Jupyter Notebooks with vague naming, and differences are not always clearly explained.
When selecting a notebook, inspect both the filename and the code.
Some notebooks may require tinkering before they work. Here’s why:
-
Some notebooks do not train from scratch (Epoch 0) but attempt to load from a checkpoint (e.g., Epoch 50 or 100). → To fix: comment out
resume_from_checkpoint=""
in thepl.Trainer()
call. -
Some models are trained using a custom GHIBLI dataset. → To use another dataset, modify the
train_dataloader
andtest_dataloader
functions. Be mindful of image sizes when doing so. -
The logger used is
comet_ml
, but I removed my API key. → Replace with your owncomet_ml
API key or use a different logging method. -
There may be other issues I don’t currently remember.
I was able to generate convincing digit images in various resolutions, including 16×16, 32×32, 64×64, and 128×128 (via upscaled MNIST). However, results for the flower and GHIBLI datasets were less successful.
👉 Click here to see more results and loss graphs
Distributed under the MIT License.
See LICENSE
for more information.