[Paper (ECVA)] [Project page] [Supplementary material] [Video]
This is the official PyTorch implementation of "Layout-Corrector: Alleviating Layout Sticking Phenomenon in Discrete Diffusion Model (ECCV2024)".
Shoma Iwai1, Atsuki Osanai2, Shunsuke Kitada2, and Shinichiro Omachi1
1 Tohoku University, 2 LY Corporation
Setup details
The environment is constructed under Python 3.10
.
-
Run the container for Development
make build make run --- in container --- cd /app pip3 install -e .
Please download the starter kit (GDrive, 952 MB) and unzip it:
unzip ./layout_corrector_starter_kit.zip
For more details about the starter kit, please refer to the README.md file included in the zip.
PubLayNet, Rico, and Crello datasets are included in our starter kit. For custom datasets, see docs/custom_dataset.md
Training details
bash bin/train.sh <DATASET> <EXPERIMENT_NAME> <ADDITIONAL_ARGS>
For example,
bash bin/train.sh rico25 layoutdm seed=0,1,2
bash bin/corrector_train.sh <DATASET> <EXPERIMENT_NAME> <DIFFUSION_JOB_DIR> <ADDITIONAL_ARGS>
<EXPERIMENT_NAME>
: Filename of .yaml file insrc/trainer/trainer/config/experiment
.<DIFFUSION_JOB_DIR>
: Job directory of pre-trained diffusion model (e.g., LayoutDM)<ADDITIONAL_ARGS>
: Optional (e.g.,seed=0,1,2
)
For example,
bash bin/corrector_train.sh rico25 layout_corrector download/pretrained_weights/rico25/layoutdm seed=0,1,2 training.epochs=20
Testing details
You can try a quick demo using notebooks/demo.ipynb
.
To generate layouts and calculate metrics, run the following commands:
python bin/test_eval.py <DIFFUSION_JOB_DIR> <DATASET> -t <TIMESTEPS> [-d <DEVICE_ID>]
For example,
python bin/test_eval.py tmp/jobs/rico25/layoutdm_jobdir rico25 -t 100
python ./bin/corrector_test_eval.py <JOB_DIR> <DATASET> -t <NUM_TIMESTEPS> [-d <DEVICE_ID>]
<JOB_DIR>
: Path to the pre-trained corrector job_dir (not that of a generator).
For example,
python ./bin/corrector_test_eval.py ./download/pretrained_weights/layout_corrector rico25 -t 100
- To specify threshold-masking, add
--corrector_mask_mode thresh --corrector_mask_threshold <THRESHOLD>
:
python ./bin/corrector_test_eval.py ./download/pretrained_weights/rico25/layout_corrector rico25 -t 100 --corrector_mask_mode thresh --corrector_mask_threshold 0.7
- To specify Top-K-masking, add
--corrector_mask_mode topk
:
python ./bin/corrector_test_eval.py ./download/pretrained_weights/rico25/layout_corrector rico25 -t 100 --corrector_mask_mode topk
Preliminary Experiment details
Test Corrupted Token-Correction Capability of LayoutDM and its conjunction with Layout-Corrector.
python bin/test_token_correction.py <LAYOUTDM_JOB_DIR> [--start_timesteps <TIMESTEP1> <TIMESTEP2> ...] [--mask] [--num_replace <NUM_REPLACE_TOKENS>] [--save_dir <SAVE_DIR>]
--start_timesteps
: The start timesteps when the generation runs from (default: [10]).--mask
: If given, the randomly selected tokens are replaced with MASK.--num_replace
: The number of tokens to be replaced (default: 1).--save_dir
: A directory where the result is saved (default:token_correction_results
).
The result json
includes two metrics.
token_wise
: The token-wise accuracy of restoring the corrupted tokens to the ground truth.full
: Requiring to correctly restore all corrupted tokens.
To compare the token correction capability for different schedules, as in Fig.2 (b) of our paper, please see layoutdm_token_correction.ipynb.
Test Corrupted Token Detection Accuracy of Layout-Corrector.
python bin/test_error_token_detection.py <LAYOUTDM_JOB_DIR> --corr_job_dir <CORRECTOR_JOB_DIR> --corr_timesteps <CORR_TIMESTEP1> <CORR_TIMESTEP2> ... [--num_replace <NUM_REPLACE_TOKENS>] [--save_dir <SAVE_DIR>]
--corr_job_dir
: Corrector job directory where ckpt is included. If not given, the evaluation runs only for LayoutDM.--corr_timesteps
: The timesteps at which the corrector is applied (default: [10]).--num_replace
: The number of tokens to be replaced (default: 1).--save_dir
: A directory where the result is saved (default:error_token_detection_results
).
The result json
includes two metrics at each timestep.
token_wise_acc
: The token-wise accuracy of detecting the corrupted tokens.full_acc
: Requiring to correctly detect all corrupted tokens.
To plot the accuracy of corrupted token detection, as in Fig.4 of our paper, please see layout_corrector_error_token_detection.ipynb.
Save Layout Generation Process at all timesteps as a pickle file and images.
python tools/visualize_generation_process.py <LAYOUTDM_JOB_DIR> [--corr_job_dir <CORRECTOR_JOB_DIR>] [--num_samples <NUM_SAMPLES>] [--corr_timesteps <CORR_TIMESTEP1> <CORR_TIMESTEP2> ...] [--save_dir <SAVE_DIR>] [--save_images]
--corr_job_dir
: Corrector job directory where ckpt is included. If not given, the results are by just LayoutDM.--num_samples
: The total number of generated samples (default: 100).--corr_timesteps
: The timesteps at which the corrector is applied (default: [10, 20, 30]).--corrector_mask_mode
: The masking strategy for the corrector.topk
orthresh
are allowed (default:thresh
).--corrector_threshold
: The threshold value to determine tokens to reset to MASK whencorrector_mask_mode == "thresh"
(default: 0.7).--save_dir
: A directory where the result is saved (default:generation_process_results
).--save_images
: Whether saving the results as images or not.
Plot the token and element sticking rate of LayoutDM.
Note that you need to run tools/visualize_generation_process.py
before using this tool.
The output is saved in the directory where the pickle file is located.
python tools/analyze_token_sticking.py <PICKLE_PATH>
To compare the sticking rate for different schedules, as in Fig.2 (a) of our paper, please see layoutdm_token_sticking.ipynb.
This codebase is largely based on LayoutDM (Inoue+, CVPR2023). We sincerely appreciate their effort and contribution to the research community.
If you find this code useful for your research, please consider citing our paper:
@inproceedings{iwai2024layout,
title={Layout-Corrector: Alleviating Layout Sticking Phenomenon in Discrete Diffusion Model},
author={Shoma Iwai and Atsuki Osanai and Shunsuke Kitada and Shinichiro Omachi},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
year={2024},
}
- If you have any questions or encounterd any issues, please feel free to open an issue!
- Pull requests are welcome! We hope to open an issue first to discuss what you would like to change.