Language-Aware Visual Explanations (LAVE) ; a combination of text and visual explanation framework for image classification.
Purushothaman Natarajan, Athira Nambiar
This repository presents LAVE, a unique explainer tool for the ImageNet dataset, combining both SHAP (SHapley Additive exPlanations) and SAM (Segment Anything Model) for visual masks, and VLM Llava for generating textual explanations. It allows users to generate both visual and textual explanations without requiring additional training.
LAVE offers an explainer for models based on the ImageNet dataset, allowing you to interpret model predictions through:
- SHAP-based explanations for visual and textual insights.
- Visual masks highlighting objects of interest in images, generated by SAM using SHAP coordinates.
- Textual explanations generated by VLM Llava for each visual mask.
- SHAP-Based Explanations: Provides visual and textual explanations based on SHAP values, requiring no additional training on ImageNet.
- SAM Integration: Uses SAM to generate visual masks from SHAP coordinates highlighting important regions.
- Textual Explanations: VLM Llava generates language-aware, coherent textual descriptions for the visual explanations.
- Modular Use: Capable of using both custom-trained models and pre-trained models directly from TorchHub.
Clone the repository and install the dependencies listed in requirements.txt
:
git clone https://github.com/Purushothaman-natarajan/LAVE-Explainer.git
cd LAVE-Explainer
conda env create -f environment.yaml
conda activate LAVE-Explainer
You can use either of the following scripts based on your model type:
-
Pre-trained Models: Use
pre-trained_model_explainer.py
for playing with pre-trained models available on TorchHub.python pre-trained_model_explainer.py --model_name <pretrained_model_name> --img_path <path_to_image>
-
Custom-Trained Models: Use
custom_model_explainer.py
for explaining models trained using the codes in this repo.python custom_model_explainer.py --model_path <path_to_custom_model> --img_path <path_to_image>
This section covers how to use the provided scripts to train a image classification model using transfer learning for your datasets.
This script is used to load, process, split dataset - train, val, tests. and augment data.
Command Line Usage:
python data_loader.py --path <path_to_data> --target_folder <path_to_target_folder> --dim <dimension> --batch_size <batch_size> --num_workers <num_workers> [--augment_data]
Arguments:
--path
: Path to the data.--target_folder
: Path to the target folder where processed data will be saved.--dim
: Dimension for resizing the images.--batch_size
: Batch size for data loading.--num_workers
: Number of workers for data loading.--augment_data
(optional): Flag to enable data augmentation.
Example:
python data_loader.py --path "./data" --target_folder "./processed_data" --dim 224 --batch_size 32 --num_workers 4 --augment_data
Dataset Structure:
├── Dataset (Raw)
├── class_name_1
│ └── *.jpg
├── class_name_2
│ └── *.jpg
├── class_name_3
│ └── *.jpg
└── class_name_4
└── *.jpg
This script is used for training and storing the models leveraging transfer learning.
Command Line Usage:
python train.py --base_model_names <model_names> --shape <shape> --data_path <data_path> --log_dir <log_dir> --model_dir <model_dir> --epochs <epochs> --optimizer <optimizer> --learning_rate <learning_rate> --batch_size <batch_size>
Arguments:
--base_model_names
: Comma-separated list of base model names (e.g., 'vgg16,alexnet').--shape
: Image shape (size).--data_path
: Path to the data.--log_dir
: Path to the log directory.--model_dir
: Path to the model directory.--epochs
: Number of training epochs.--optimizer
: Optimizer type ('adam' or 'sgd').--learning_rate
: Learning rate for the optimizer.--batch_size
: Batch size for training.
Example:
python train.py --base_model_names "vgg16,alexnet" --shape 224 --data_path "./data" --log_dir "./logs" --model_dir "./models" --epochs 100 --optimizer "adam" --learning_rate 0.001 --batch_size 32
This script is used for testing and storing the test logs of the trained models.
Command Line Usage:
python test.py --data_path <data_path> --base_model_name <base_model_name> --model_path <model_path> --models_folder_path <models_folder_path> --log_dir <log_dir>
Arguments:
--data_path
: Path to the test data.--base_model_name
: Name of the base model.--model_path
(optional): Path to the specific model file.--models_folder_path
(optional): Path to the folder containing models.--log_dir
: Path to the log directory.
Example:
python test.py --data_path "./test_data" --base_model_name "vgg16" --model_path "./models/vgg16_model.pth" --log_dir "./logs"
This script is used for making predictions on new images.
Command Line Usage:
python predict.py --model_path <model_path> --img_path <img_path> --train_dir <train_dir> --base_model_name <base_model_name>
Arguments:
--model_path
: Path to the model file.--img_path
: Path to the image file.--train_dir
: Path to the training dataset.--base_model_name
: Name of the base model.
Example:
python predict.py --model_path "./models/vgg16_model.pth" --img_path "./images/test_image.jpg" --train_dir "./data/train" --base_model_name "vgg16"
You can also run these scripts programmatically using Python's subprocess
module specially for Jupyter Notebook users. Here is an example of how to do this for each script:
import subprocess
# Run data_loader.py
subprocess.run([
"python", "data_loader.py",
"--path", "./data",
"--target_folder", "./processed_data",
"--dim", "224",
"--batch_size", "32",
"--num_workers", "4",
"--augment_data"
])
# Run train.py
subprocess.run([
"python", "train.py",
"--base_model_names", "vgg16,alexnet",
"--shape", "224",
"--data_path", "./data",
"--log_dir", "./logs",
"--model_dir", "./models",
"--epochs", "100",
"--optimizer", "adam",
"--learning_rate", "0.001",
"--batch_size", "32"
])
# Run test.py
subprocess.run([
"python", "test.py",
"--data_path", "./test_data",
"--base_model_name", "vgg16",
"--model_path", "./models/vgg16_model.pth",
"--log_dir", "./logs"
])
# Run predict.py
subprocess.run([
"python", "predict.py",
"--model_path", "./models/vgg16_model.pth",
"--img_path", "./images/test_image.jpg",
"--train_dir", "./data/train",
"--base_model_name", "vgg16"
])
- Python
- SHAP
- SAM (Segment Anything Model)
- VLM Llava
- PyTorch (for pre-trained models and custom model training)
- Other dependencies listed in
requirements.txt
For any inquiries or feedback, please contact [[email protected] / [email protected]].
We would like to acknowledge the developers of SHAP, SAM, and VLM Llava for their invaluable open-source models, as well as our funder, DRDO, India, in the field of explainable AI.
If you use this repository, please cite the following paper:
@article{natarajan2024vale,
title={VALE: A Multimodal Visual and Language Explanation Framework for Image Classifiers using eXplainable AI and Language Models},
author={Natarajan, Purushothaman and Nambiar, Athira},
journal={arXiv preprint arXiv:2408.12808},
year={2024}
}