GViG: Generative Visual Grounding using Prompt-based Language Modeling for Visual Question Answering
This repository contains the code and methodology used for our study presented at the WSDM 2023 Toloka VQA challenge. We reformulated the GVQA task as a language modeling problem and employed prompt tuning with predictions from state-of-the-art VQA models. The proposed approach achieved third place in the challenge benchmark.
- pip 21.2.4
- python 3.10.12
- pytorch 1.12.1+cu113
- torchvision 0.13.1+cu113
- opencv-python 4.6.0.66
- pytorch-lightning 1.8.4
- timm 0.6.11
- pycocotools 2.0.7
- datasets 2.7.1
git clone https://github.com/IKMLab/GViG.git
make setup_env
make download_wsdm_vqa_dataset
This command will download the WSDM 2023 Toloka VQA dataset and place it in the datasets
directory. There are two folders will be created: datasets/images
and datasets/offical
.
The datasets/images
folder contains the images for the WSDM VQA dataset and the datasets/offical
folder contains the official splits for the WSDM VQA dataset.
The datasets/offical
folder contains the following files:
datasets/offical
├── train.csv
├── test_public.csv
├── test_private.csv
└── train_sample.csv
More details please refer to datasets/README.md.
Please go to OFA/checkpoints.md to download the pretrained models. The pretrained models should be placed in the checkpoints
directory. More details please refer to pretrained_weights/README.md.
Below are the commands to train and inference the model on the WSDM 2023 Toloka VQA dataset. All the commands are stored in the scripts
directory and all the output file will be stored in the results
directory.
Before training the prompt tuning model, we need to generate the prompt tuning data. The prompt tuning data is generated by the following steps:
-
VQA model zero-shot inference on the WSDM VQA dataset: We use the pretrained VQA model to generate the zero-shot inference results on the WSDM VQA dataset. The zero-shot inference results are stored in the
datasets/results/vqa/<model_arch>/<exp_tag>
directory. The zero-shot inference results are generated by the following command:make eval_vqa
This command will execute the file
scripts/evaluate_multiple_vqa_exp.sh
, so you can modify the file to change the VQA modelexperiment tag
,arch
,checkpoint path
,train prompt
,val prompt
,test prompt
,test files
,seed
,beam
,batch_size
and more. -
Generate the prompt tuning data: We use the zero-shot inference results to generate the prompt tuning data. The prompt tuning data is stored in the
datasets/<raw_file_dir>
directory. The prompt tuning data is generated by the following command:make gen_prompt_data
This command will execute the file
scripts/generate_prompt_data.sh
, so you can modify the file to change theraw_file
,answer files mapping
,prompt name
, andoutput file path
.
After generate the prompt tuning data, we can train the VG model. The VG model is trained by the following command:
make train
Ths command will execute the file scripts/train_single_vg_exp.sh
, so you can modify the file to change all the variables. After training, the model will be saved in the checkpoints/<model_arch>/<exp_tag>/train-P<prompt_name>/val-P<prompt_name>
directory. More details please refer to checkpoints/README.md.
After training the VG model, we can inference the model on the test set. The inference results will be stored in the results/<model_arch>/<exp_tag>/train-P<prompt_name>/val-P<prompt name>
directory. The inference is executed by the following command:
make eval_vg
This command will execute the file scripts/evaluate_single_vg_exp.sh
, so you can modify the file to change the model arch
, exp tag
, checkpoint path
, train prompt
, val prompt
, test prompt
, test files
, seed
, beam
, batch_size
and more. More details please refer to results/README.md.
That's all! Have fun! 🥳.