This project implements an Early Exit Convolutional Neural Network (EE-CNN) for efficient image classification on the CIFAR-10 dataset. The model uses reinforcement learning to dynamically decide when to exit the network early, optimizing the trade-off between computational efficiency and accuracy.
The system consists of two main components:
-
Early Exit CNN: A deep neural network with multiple exit points, allowing predictions at different depths:
- 4 exit points with increasing complexity
- Enhanced feature extraction at each stage
- Batch normalization and dropout for regularization
- Progressive increase in channel dimensions (64→128→256→512)
-
DQN Agent: A reinforcement learning agent that learns when to exit:
- Makes exit decisions based on confidence scores
- Balances accuracy and computational efficiency
- Uses experience replay for stable training
- Implements epsilon-greedy exploration
Our model achieves strong performance metrics:
- Overall Accuracy: 88.77%
- Compute Savings: 39.9%
- Effectiveness Score: 35.38
Exit Point Distribution:
- Exit 1: 4.8%
- Exit 2: 37.5%
- Exit 3: 30.3%
- Exit 4: 27.4%
Per-Class Performance:
- Best Classes: Ship (94.3%), Car (94.2%), Truck (93.7%)
- Most Challenging: Dog (78.4%), Cat (80.3%), Bird (80.0%)
Per-Class Exit Performance:
First Exit (≈67% accuracy):
- Strong: Truck (81.4%), Horse (71.7%), Plane (75.4%)
- Weak: Bird (40.7%), Cat (42.1%)
Second Exit (≈78% accuracy):
- Strong: Ship (91.3%), Truck (92.4%), Car (90.8%)
- Weak: Bird (71.5%), Cat (70.9%)
Third Exit (≈88% accuracy):
- Strong: Ship (94.5%), Car (94.5%), Plane (92.5%)
- Weak: Dog (79.7%), Cat (79.3%)
Final Exit (≈89% accuracy):
- Strong: Car (94.6%), Ship (94.4%), Truck (93.7%)
- Weak: Dog (78.4%), Bird (80.1%)
Confidence Statistics:
- Exit 1: Mean=0.673, Std=0.211
- Exit 2: Mean=0.807, Std=0.206
- Exit 3: Mean=0.838, Std=0.184
- Exit 4: Mean=0.706, Std=0.181
Below are some visualizations of the result from the experiments:
Per Class Performance:
Training Progress:
src/
├── models/ # Model architectures
│ ├── early_exit_cnn.py # CNN implementation
│ ├── dqn_agent.py # DQN agent
│ └── environment.py # Training environment
├── training/ # Training implementations
│ ├── train_cnn.py # CNN training
│ └── train_rl.py # RL training
├── evaluation/ # Evaluation code
│ └── evaluate.py # Evaluation metrics
├── inference/ # Inference implementation
│ └── inference.py # Inference code
└── visualization/ # Visualization tools
└── visualize.py # Plotting functions
- Clone the repository:
git clone https://github.com/Shikha-code36/early-exit-cnn.git
cd early-exit-cnn
- Install dependencies:
pip install -r requirements.txt
- Train the CNN model:
from src.training.train_cnn import pretrain_cnn
from src.data.data_loader import load_cifar10_data
# Load data
train_loader, test_loader = load_cifar10_data(batch_size=128)
# Train model
losses, accuracies = pretrain_cnn(model, train_loader, num_epochs=50)
- Train the RL agent:
from src.training.train_rl import train_rl_agent
rewards, exit_counts = train_rl_agent(
model,
agent,
env,
train_loader,
num_episodes=5000
)
Run inference on new images:
from src.inference.inference import EarlyExitInference
# Initialize inference
inferencer = EarlyExitInference(model_path='models/')
# Process image
result = inferencer.process_image("path/to/image.jpg")
print(f"Prediction: {result['class']}")
print(f"Confidence: {result['confidence']:.2f}")
print(f"Exit Point: {result['exit_point']}")
Evaluate model performance:
from src.evaluation.evaluate import evaluate_model
metrics = evaluate_model(model, agent, test_loader)
print(f"Accuracy: {metrics['accuracy']:.2f}%")
print(f"Compute Saved: {metrics['compute_saved']:.2f}%")
The Early Exit CNN employs a progressive architecture:
-
First Exit (64 channels):
- Basic feature extraction
- Early exit for simple cases
- 68.4% accuracy for easy classes
-
Second Exit (128 channels):
- Intermediate processing
- Improved feature representation
- 86.8% accuracy for moderate cases
-
Third Exit (256 channels):
- Advanced feature processing
- Enhanced classification capability
- 92.9% accuracy for complex cases
-
Final Exit (512 channels):
- Deep feature extraction
- Comprehensive classification
- 92.2% accuracy for challenging cases
The project includes several visualization tools:
-
Training Progress:
- Loss curves
- Accuracy per exit point
- Exit distribution
-
Analysis Tools:
- Confidence distributions
- Class-wise exit patterns
- Performance heatmaps
Example visualization code:
from src.visualization.visualize import plot_training_metrics
plot_training_metrics(
train_losses=losses,
accuracies_per_exit=accuracies,
exit_distributions=exit_dist
)
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
If you use this code in your research, please cite:
@misc{early-exit-cnn-2024,
author = {Shikha Pandey},
title = {Early Exit CNN with RL-based Decision Making},
year = {2025},
publisher = {GitHub},
url = {https://github.com/Shikha-code36/early-exit-cnn}
}
- CIFAR-10 dataset
- PyTorch team for the deep learning framework
- Reinforcement learning community for DQN implementations