Skip to content

The works focus on pruning a neural network, structurally, based on the sensitivity of weights that cover the entire output range of the loss function..

Notifications You must be signed in to change notification settings

mansooralodhi/nnpruning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Table of Contents

Introduction

The repository is part of the master's thesis, 'Interval-Adjoint Sensitivity Guided Neural Network Pruning'. The project aims to structurally prune a fully-connected multilayer perceptron using interval-adjoint sensitivity information. Each iteration compute the significance of each neuron, removes a fraction of least significant neurons in each layer. and later retrain the model. Pruning continues unless the stopping criteria is met. Later, a pruned model, with suitable trade-off between complexity and efficiency, is selected for deployment.

PyTorch Torchvision Matplotlib thop

Prerequisite

  • Generate (any) problem directiory, src/mnist/, for mnist data classification.
  • Define datasets and dataloaders in the file src/mnist/dataloaders.py
  • Download data in the directory src/mnist/data/
  • Define neural network model in the file src/mnist/model.py
  • To do: Add the hyperparameters of neuron-decay function in .yaml file
  • [Temporary] Generate neurons retention data from the file src/mnist/retain.py
    • Define initial number of neurons in each hidden layer as
      self.h1, self.h2 = 256, 64
    • Define maximum number of iteration
      self.iterations = list(range(0,31))
    • Add respective hidden neurons headers in self.headers
    • Define the number of neurons to retain in each layer every iteration
      t1 = round(self.h1 * (self.q) ** t)
      N1 = t1 if t1 > 1 else 1
    • Add the number of neurons retained in each hidden layer to
      per = (N1 + N2) / (self.h1+self.h2) * 100
    • Append the number of neurons retained in each hiiden layer to
      self.data.append([t, N1, N2, per])
    • To plot the neurons decay function, adjust functions plot_data_iterations() and plot_data_distribution()
  • Ensure the generated direcotry src/mnist/retain/ has the files
    • retain.txt
    • pruneNodesIteration.png
    • pruneNodesDistribution

pruneNodesDistribution.png

pruneNodesIteration.png

Perform Iterative Pruning

  • Locate `name=='main' in the file src/iterativeprune.py
  • Define the neural network model
    model = SimpleNN()
  • Generate pruning data that results in file src/mnist/retain/retain.txt
    data = GeneratePruneData().generate_data()
  • Initialze the iterative pruning with relevant arugments
    experiment = IterativePruning(train_loader, test_loader, train, evaluate)
  • Define the directory to store experiment results
    my_dir = Path('mnist/sessions/2025-04-21/iterative_pruning/')
  • Define the begin_itr variable
    • 0: start from scratch
    • ¬ 0 : start from given pruning iteration
  • Iterative pruning with retraining results in the directory structure.
    mnist/sessions/2025-04-21/iterative_pruning/
    ├── 0/
    |   ├── prune_w_retrain/
    |   |    └── best_model.pth
    |   |    └── train_results.txt
    ├── 1/
    |   ├── prune_w_retrain/
    |   |    └── best_model.pth
    |   |    └── train_results.txt
    |   |    └── validation_results.txt
    |   ├── ...
    |   ...
    ...
    ├── i/
    |   ├── prune_w_retrain/
    |   |     └── best_model.pth
    |   |     └── train_results.txt
    |   |     └── validation_results.txt
    
  • Once iterations halt, execute the file src/concatenate.py to generate
    • mnist/sessions/2025-04-21/iterative_pruning/prune_w_retrain_results/
      • validations.txt
      • accuracies.png
      • losses.png

accuracies.png

losses.png
  • Given the loss and accuracy graphs, copy the model path
    src/mnist/sessions/2025-04-21/iterative_pruning/8/prune_w_retrain/best_model.pth
  • Locate if __name__=='__main__' in the file src/validator.py
  • Define the model
    model = SimpleNN()
  • Load pruned model
    load_pruned_model(model, pruned_model_file)
  • Define the model and paste the model file here to evaluate the pruned model on test data using
    evaluate(model, train_loader, test_loader)
  • Validate if the above results are same as those shown in graph.

Model Deployment

  • Copy the pruned model file
    src/mnist/sessions/2025-04-21/iterative_pruning/8/prune_w_retrain/best_model.pth
  • Copy the neural network model file:
    src/mnist/model.py
  • Copy the model loader file
    src/io.py
  • Execute the model in 3 steps below
    • model = SimpleNN()
    • load_pruned_model(model, pruned_model_file)
    • model(input)

About

The works focus on pruning a neural network, structurally, based on the sensitivity of weights that cover the entire output range of the loss function..

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published