Skip to content

Commit 2b03e07

Browse files
author
Kaiyu Yang
committed
first commit
0 parents  commit 2b03e07

27 files changed

+3727
-0
lines changed

.gitignore

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
data/
2+
lightning_logs/
3+
sbatch_*.sh
4+
5+
# Byte-compiled / optimized / DLL files
6+
__pycache__/
7+
*.py[cod]
8+
*$py.class
9+
10+
# C extensions
11+
*.so
12+
13+
# Distribution / packaging
14+
.Python
15+
build/
16+
develop-eggs/
17+
dist/
18+
downloads/
19+
eggs/
20+
.eggs/
21+
lib/
22+
lib64/
23+
parts/
24+
sdist/
25+
var/
26+
wheels/
27+
pip-wheel-metadata/
28+
share/python-wheels/
29+
*.egg-info/
30+
.installed.cfg
31+
*.egg
32+
MANIFEST
33+
34+
# PyInstaller
35+
# Usually these files are written by a python script from a template
36+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
37+
*.manifest
38+
*.spec
39+
40+
# Installer logs
41+
pip-log.txt
42+
pip-delete-this-directory.txt
43+
44+
# Unit test / coverage reports
45+
htmlcov/
46+
.tox/
47+
.nox/
48+
.coverage
49+
.coverage.*
50+
.cache
51+
nosetests.xml
52+
coverage.xml
53+
*.cover
54+
*.py,cover
55+
.hypothesis/
56+
.pytest_cache/
57+
58+
# Translations
59+
*.mo
60+
*.pot
61+
62+
# Django stuff:
63+
*.log
64+
local_settings.py
65+
db.sqlite3
66+
db.sqlite3-journal
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# IPython
85+
profile_default/
86+
ipython_config.py
87+
88+
# pyenv
89+
.python-version
90+
91+
# pipenv
92+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
94+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
95+
# install all needed dependencies.
96+
#Pipfile.lock
97+
98+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
99+
__pypackages__/
100+
101+
# Celery stuff
102+
celerybeat-schedule
103+
celerybeat.pid
104+
105+
# SageMath parsed files
106+
*.sage.py
107+
108+
# Environments
109+
.env
110+
.venv
111+
env/
112+
venv/
113+
ENV/
114+
env.bak/
115+
venv.bak/
116+
117+
# Spyder project settings
118+
.spyderproject
119+
.spyproject
120+
121+
# Rope project settings
122+
.ropeproject
123+
124+
# mkdocs documentation
125+
/site
126+
127+
# mypy
128+
.mypy_cache/
129+
.dmypy.json
130+
dmypy.json
131+
132+
# Pyre type checker
133+
.pyre/

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 Princeton Natural Language Processing
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Generating Natural Language Proofs with Verifier-Guided Search
2+
3+
![Task](images/nlproofs.jpg)
4+
5+
Code for the paper:
6+
7+
[Generating Natural Language Proofs with Verifier-Guided Search](https://arxiv.org/abs/2205.12443)
8+
[Kaiyu Yang](https://www.cs.princeton.edu/~kaiyuy/), [Jia Deng](https://www.cs.princeton.edu/~jiadeng/), and [Danqi Chen](https://www.cs.princeton.edu/~danqic/)
9+
10+
11+
## Quick Links
12+
13+
- [Requirements](#requirements)
14+
- [Data Preprocessing](#data-preprocessing)
15+
- [EntailmentBank Experiments](#entailmentbank-experiments)
16+
- [RuleTaker Experiments](#ruletaker-experiments)
17+
- [Citation](#citation)
18+
- [Credits](#credits)
19+
20+
21+
## Requirements
22+
23+
1. Download and install [Miniconda Python 3](https://docs.conda.io/en/latest/miniconda.html) (Anaconda should also work).
24+
1. Clone this repo and `cd` its root.
25+
1. Install Python dependencies: `conda env create -f nlproofs.yaml`. You may need to edit [nlproofs.yaml](./nlproofs.yaml) according to your system, e.g., use a different CUDA version. If you have trouble running the installation command, you may also manually install the packages in [nlproofs.yaml](./nlproofs.yaml) in whatever way that works for you.
26+
1. Activate the conda environment: `conda activate nlproofs`, and prepend the root of this repo to the `PYTHONPATH` environment variable.
27+
28+
## Data Preprocessing
29+
30+
1. Download the v3_May6_2022 version of [EntailmentBank](https://allenai.org/data/entailmentbank) (MD5: 9cb91896325157cee1f35616be0be179) and unzip it as `./data/entailment_trees_emnlp2021_data_v3/`.
31+
1. Download the OWA version of [RuleTaker](https://allenai.org/data/proofwriter) (MD5: bf490364bca241bb5ff9f0ab0c78b71a) and unzip it as `./data/proofwriter-dataset-V2020.12.3/`.
32+
1. Run `python check_data.py` to check.
33+
1. Run `python preprocess_ruletaker.py` to preprocess the RuleTaker dataset.
34+
35+
36+
## EntailmentBank Experiments
37+
38+
We use [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html) to create scripts for training, validation, and testing: [prover/main.py](prover/main.py) and [verifier/main.py](verifier/main.py) for the prover and the verifier, respectively. They take arguments from the command line as well as YAML configuration files. Please run `python main.py --help` or refer to the documentation of [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html) for details.
39+
40+
We provide YAML files for our hyperparameters and experimental settings in [./prover/](./prover/) and [./verifier/](./verifier/). We run all experiments on a single NVIDIA A6000 GPU with 48GB memory. For running them on GPUs with smaller memory, you may have to change `batch_size` and `accumulate_grad_batches`.
41+
42+
43+
### Training
44+
45+
#### Prover
46+
47+
First, `cd` into [./prover/](./prover). Then run `python main.py fit --help` to see how to use the training script. Below are example commands used in our experiments:
48+
```bash
49+
python main.py fit --config cli_task1_single_shot_t5-large.yaml # Train a single-shot prover on Task 1 of EntailmentBank.
50+
python main.py fit --config cli_task1_stepwise_t5-large.yaml # Train a stepwise prover on Task 1 of EntailmentBank.
51+
python main.py fit --config cli_task2_single_shot_t5-large.yaml # Train a single-shot prover on Task 2 of EntailmentBank.
52+
python main.py fit --config cli_task2_stepwise_t5-large.yaml # Train a stepwise prover on Task 2 of EntailmentBank.
53+
```
54+
55+
The training script saves hyperparameters, model checkpoints, and other information to `./prover/lightning_logs/EXP_ID/`, where `EXP_ID` is an arbitrary experiment ID that will be printed by the training script.
56+
57+
#### Verifier
58+
59+
First, `cd` into [./verifier/](./verifier). Then run `python main.py fit --help` to see how to use the training script. Below are example commands used in our experiments:
60+
```bash
61+
python main.py fit --config cli_entailmentbank_task1.yaml # Train a verifier on Task 1 of EntailmentBank.
62+
python main.py fit --config cli_entailmentbank_task2.yaml # Train a verifier on Task 2 of EntailmentBank.
63+
```
64+
65+
The training script saves hyperparameters, model checkpoints, and other information to `./verifier/lightning_logs/EXP_ID/`.
66+
67+
### Validation and Testing
68+
69+
Once training completes, we use the model checkpoint to predict on the validation and testing data. `cd` into [./prover/](./prover) and run `python main.py validate --help` and `python main.py test --help` to see how to use the script for validation and testing. Assume we have a prover checkpoint `PATH_TO_PROVER_CKPT` and a verifier checkpoint `PATH_TO_VERIFIER_CKPT`, below are example commands:
70+
```bash
71+
python main validate --config cli_task2_stepwise_t5-large.yaml --ckpt_path PATH_TO_PROVER_CKPT # Validate the stepwise prover without verifier-guided search on Task 2 of EntailmentBank.
72+
python main validate --config cli_task2_stepwise_t5-large.yaml --ckpt_path PATH_TO_PROVER_CKPT --model.verifier_weight 0.5 --model.verifier_ckpt PATH_TO_VERIFIER_CKPT --model.proof_search true # Validate NLProofS (stepwise prover + verifier-guided search).
73+
python main validate --config cli_task2_stepwise_t5-large.yaml --ckpt_path PATH_TO_PROVER_CKPT --model.verifier_weight 1.0 --model.verifier_ckpt PATH_TO_VERIFIER_CKPT --model.proof_search true # Validate NLProofS w/o prover score.
74+
python main test --config cli_task2_stepwise_t5-large.yaml --ckpt_path PATH_TO_PROVER_CKPT --model.verifier_weight 0.5 --model.verifier_ckpt PATH_TO_VERIFIER_CKPT --model.proof_search true # Test NLProofS (stepwise prover + verifier-guided search).
75+
python main.py test --config cli_task1_single_shot_t5-large.yaml --ckpt_path PATH_TO_PROVER_CKPT # Test the single-shot prover on Task 1 of EntailmentBank.
76+
python main.py test --confing cli_task2_single_shot_t5-large.yaml --ckpt_path PATH_TO_PROVER_CKPT --data.path_test ../data/entailment_trees_emnlp2021_data_v3/dataset/task_3/test.jsonl # Test the single-shot prover (trained on Task 2) on Task 3 of EntailmentBank.
77+
```
78+
79+
Validation and testing results are saved as `./prover/lightning_logs/EXP_ID/results_val.tsv` and `./prover/lightning_logs/EXP_ID/results_test.tsv`. They are the input to the [EntailmentBank's official evaluation code](https://github.com/allenai/entailment_bank/tree/71385b6d7cc42ac394006bc2fe84d5bd1117f9ac) for calculating the evaluation metrics.
80+
81+
### Test Results and Model Checkpoints
82+
83+
#### Task 1
84+
85+
| Model | Leaves-F1 | Leaves-AllCorrect | Steps-F1 | Steps-AllCorrect | Intermediates-F1 | Intermediates-AllCorrect | Overall-AllCorrect | Model checkpoints | Validation predictions | Test predictions |
86+
| ------------- | -------- | ------- | --------------- | ------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
87+
| NLProofS | 97.6 | 90.0 | 54.8 | 41.8 | 72.0 | 39.7 | 38.2 | [prover](https://drive.google.com/file/d/16Mgor1gT_bJx3tCfT0DMyais79RM3Ka-/view?usp=sharing), [verifier](https://drive.google.com/file/d/1qR8JLwMUQWPHn_m9QImLP7RK1-PbjjuV/view?usp=sharing) | [results_val.tsv](https://drive.google.com/file/d/1pHhCt3JEbfRv7krixcyBcXa-1S-A3QYj/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/11DZRqzefU55AOobsvm8dpTLVHnMc-hHN/view?usp=sharing) |
88+
| Stepwise prover | 98.8 | 98.5 | 54.8 | 41.5 | 71.9 | 38.5 | 36.8 | The `prover` above | [results_val.tsv](https://drive.google.com/file/d/1PTbs_pO5Fds-RtpHtnK6ddXFgyu8trMl/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/1v6qN9xjX9RRDquXtgX0rwONU7XJKubE5/view?usp=sharing) |
89+
| Single-shot prover | 98.2 | 82.7 | 51.8 | 40.9 | 66.7 | 36.5 | 34.7 | [prover](https://drive.google.com/file/d/1l4ULsNqdNMco-tyOKiLxvxhaSo6TSyzy/view?usp=sharing) | [results_val.tsv](https://drive.google.com/file/d/1xncmUFBFQTO1ksflhZhiHMC65yaRN5zC/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/1MVGheKhG7XpP9Ejy0WCuZHUnlmhISl_T/view?usp=sharing) |
90+
91+
#### Task 2
92+
93+
| Model | Leaves-F1 | Leaves-AllCorrect | Steps-F1 | Steps-AllCorrect | Intermediates-F1 | Intermediates-AllCorrect | Overall-AllCorrect | Model checkpoints | Validation predictions | Test predictions |
94+
| ------------- | -------- | ------- | --------------- | ------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
95+
| NLProofS | 90.3 | 60.6 | 48.6 | 35.6 | 70.3 | 39.4 | 34.4 | [prover](https://drive.google.com/file/d/1T10tJ7S1RBWYc-_uQ11ALAhHiQc2V6aN/view?usp=sharing), [verifier](https://drive.google.com/file/d/1l9xLcKbJoFzmnvgQgnvuvncdP8u2FQbV/view?usp=sharing) | [results_val.tsv](https://drive.google.com/file/d/15emy7mokuqnoFjf5t5uLM98W3Oq94U_5/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/1_L4_25KLgxhoEMcOAKPMYAzyvrdYwXmI/view?usp=sharing) |
96+
| Stepwise prover | 90.3 | 57.1 | 48.6 | 35.6 | 70.1 | 38.5 | 33.8 | The `prover` above | [results_val.tsv](https://drive.google.com/file/d/1Mkj_uTg4COo16TR5F7lIaoUpqoYMKWQ8/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/1XyO_cDE8RYEEwiaBUH5efpWWjjHe0D2C/view?usp=sharing) |
97+
| Single-shot prover | 85.9 | 44.7 | 41.3 | 29.1 | 62.5 | 31.5 | 27.7 | [prover](https://drive.google.com/file/d/1hk5ekp4FQb1-lqEDbbZnYQhxZiafL8Vm/view?usp=sharing) | [results_val.tsv](https://drive.google.com/file/d/16zZopp0DHNMFXHVX_zmLAEhFNd9i-D3z/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/1ubmt1zCfg__I4f0odiKs4fcT78Sr1Rku/view?usp=sharing) |
98+
99+
#### Task 3
100+
101+
Results on Task 3 are produced by evaluating Task 2 models zero-shot on Task 3 data (by changing `--data.path_val` and `--data.path_test`).
102+
103+
| Model | Leaves-F1 | Leaves-AllCorrect | Steps-F1 | Steps-AllCorrect | Intermediates-F1 | Intermediates-AllCorrect | Overall-AllCorrect | Model checkpoints | Validation predictions | Test predictions |
104+
| ------------- | -------- | ------- | --------------- | ------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
105+
| NLProofS | 43.9 | 9.1 | 10.6 | 6.8 | 42.4 | 15.9 | 6.8 | Same as Task 2 | [results_val.tsv](https://drive.google.com/file/d/19Ohmhgb8HhUL6UMkcjx2nCvEC9n6hODJ/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/1kAfBQLe-cIlQknwWXLf4-v5KkflEUkdq/view?usp=sharing) |
106+
| Stepwise prover | 42.8 | 7.4 | 9.3 | 5.9 | 42.1 | 15.0 | 5.9 | Same as Task 2 | [results_val.tsv](https://drive.google.com/file/d/1yOOkixHfo3l6KuG3eJlCYEkGdxDfqLki/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/1TOzYy_8qSli1--ej7ku7deTntEwUq3IP/view?usp=sharing) |
107+
| Single-shot prover | 40.5 | 4.4 | 9.1 | 3.8 | 35.3 | 7.9 | 3.8 | Same as Task 2 | [results_val.tsv](https://drive.google.com/file/d/1PICdgzf48nsE1vgJlaAiNoA4S3HjBA77/view?usp=sharing) | [results_test.tsv](https://drive.google.com/file/d/1nRPLQ-DfbDj3b9EL7pLhs4IJl2vKkJKd/view?usp=sharing) |
108+
109+
110+
## RuleTaker Experiments
111+
112+
### Training
113+
114+
#### Prover
115+
116+
Training on RuleTaker is similar to training on EntailmentBank but with different configuration files. Run the following commands in [./prover/](./prover):
117+
```bash
118+
python main.py fit --config cli_ruletaker_single_shot_t5-large.yaml # Train a single-shot prover on D0–D3 of RuleTaker (OWA).
119+
python main.py fit --config cli_ruletaker_stepwise_t5-large.yaml # Train a stepwise prover on D0–D3 of RuleTaker (OWA).
120+
```
121+
122+
#### Verifier
123+
124+
Training the verifier is also similar. Run the following commands in [./verifier/](./verifier):
125+
```bash
126+
python main.py fit --config cli_ruletaker.yaml # Train a verifier on D0–D3 of RuleTaker (OWA).
127+
```
128+
129+
### Validation and Testing
130+
131+
`cd` into [./prover/](./prover). Assume we have a prover checkpoint `PATH_TO_PROVER_CKPT` and a verifier checkpoint `PATH_TO_VERIFIER_CKPT`.
132+
```bash
133+
python main.py validate --config cli_ruletaker_stepwise_t5-large.yaml --ckpt_path PATH_TO_PROVER_CKPT --model.verifier_weight 0.5 --model.verifier_ckpt PATH_TO_VERIFIER_CKPT --model.proof_search true --trainer.limit_val_batches 1.0 # Validate NLProofS on D0–D3 of RuleTaker (OWA).
134+
python main.py test --config cli_ruletaker_stepwise_t5-large.yaml --ckpt_path PATH_TO_PROVER_CKPT --model.verifier_weight 0.5 --model.verifier_ckpt PATH_TO_VERIFIER_CKPT --model.proof_search true # Test NLProofS on D0–D3 of RuleTaker (OWA).
135+
```
136+
137+
Note the `--trainer.limit_val_batches 1.0` above. By default, we use only 200 batches for RuleTaker validation (see [./prover/cli_ruletaker_stepwise_t5-large.yaml](prover/cli_ruletaker_stepwise_t5-large.yaml) and [./prover/cli_ruletaker_single_shot_t5-large.yaml](prover/cli_ruletaker_single_shot_t5-large.yaml)), but here we want to use all batches.
138+
139+
Validation and testing results are saved as `./prover/lightning_logs/EXP_ID/results_val.json` and `./prover/lightning_logs/EXP_ID/results_test.json`. Run the following command for final evaluation:
140+
```bash
141+
python evaluate.py ruletaker --path-val PATH_TO_VAL_RESULTS --path-test PATH_TO_TEST_RESULTS
142+
```
143+
144+
### Test Results and Model Checkpoints
145+
146+
| Model | Answer accuracy | Proof accuracy | Model checkpoints | Validation predictions | Test predictions |
147+
| ------------- | -------- | -------- | --------------- | ------------- | ----------------- |
148+
| NLProofS | 99.3 | 99.2 | [prover](https://drive.google.com/file/d/1Js-jCyt8yGvwMyowwOkq-5bn0lvY2FQz/view?usp=sharing), [verifier](https://drive.google.com/file/d/1l2-vCU6TQ4_OtTygXXLGiJ7UUzuDeyN2/view?usp=sharing) | [results_val.json](https://drive.google.com/file/d/1JhCIhPkPdyoNpNhB0uNZr2WXB6ffuSBG/view?usp=sharing) | [results_test.json](https://drive.google.com/file/d/134HjjaYztCb-nLUqtdQ7hWxO4SO1ZW6d/view?usp=sharing) |
149+
| Stepwise prover | 68.7 | 91.3 | The `prover` above | [results_val.json](https://drive.google.com/file/d/1aB3ciVCX2_h9qYVyKJ6qhlJS8T_w26wf/view?usp=sharing) | [results_test.json](https://drive.google.com/file/d/1WO2c1C_4WIjHRVU4emiMzWt4s1lndviV/view?usp=sharing) |
150+
| Single-shot prover | 56.3 | 72.6 | [prover](https://drive.google.com/file/d/1yg5c2MXnGFVr6b7g9dFCeIAhMyb0gS_m/view?usp=sharing) | [results_val.json](https://drive.google.com/file/d/1NGtCfAp4F3eUEEGT7J0IGK7-UCP8YYPA/view?usp=sharing) | [results_test.json](https://drive.google.com/file/d/1HeXiKU0IcRTCZ-_0u0unHmZge54E7Iyu/view?usp=sharing) |
151+
152+
153+
## Citation
154+
155+
```bibtex
156+
@article{yang2022nlproofs,
157+
title={Generating Natural Language Proofs with Verifier-Guided Search},
158+
author={Yang, Kaiyu and Deng, Jia and Chen, Danqi},
159+
journal={arXiv preprint arXiv:2205.12443},
160+
year={2022}
161+
}
162+
```
163+
164+
## Credits
165+
166+
* The code is formatted using [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black).
167+

0 commit comments

Comments
 (0)