Skip to content

Commit 47ee563

Browse files
committed
minor update to preprocess
1 parent 3cb2f0c commit 47ee563

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

generation/preprocess.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Script for preprocess state-tactic pairs into the format required by [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)."""
22

3+
import os
34
import json
45
import random
56
import argparse
@@ -11,30 +12,33 @@ def main() -> None:
1112
parser.add_argument(
1213
"--data-path",
1314
type=str,
14-
default="./data/leandojo_benchmark_4/random/train.json",
15+
default="./data/leandojo_benchmark_4/random",
1516
)
16-
parser.add_argument("--dst-path", type=str, default="state_tactic_pairs.json")
17+
parser.add_argument("--dst-path", type=str, default="state_tactic_pairs")
1718
args = parser.parse_args()
1819
logger.info(args)
1920

20-
pairs = []
21-
for thm in json.load(open(args.data_path)):
22-
for tac in thm["traced_tactics"]:
23-
pairs.append({"state": tac["state_before"], "output": tac["tactic"]})
24-
logger.info(f"Read {len(pairs)} state-tactic paris from {args.data_path}")
21+
for split in ("train", "val"):
22+
data_path = os.path.join(args.data_path, f"{split}.json")
23+
pairs = []
24+
for thm in json.load(open(data_path)):
25+
for tac in thm["traced_tactics"]:
26+
pairs.append({"state": tac["state_before"], "output": tac["tactic"]})
27+
logger.info(f"Read {len(pairs)} state-tactic paris from {data_path}")
2528

26-
random.shuffle(pairs)
27-
data = [
28-
{
29-
"instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n",
30-
"input": "",
31-
"output": pair["output"],
32-
}
33-
for pair in pairs
34-
]
35-
logger.info(data[0])
36-
json.dump(data, open(args.dst_path, "wt"))
37-
logger.info(f"Preprocessed data saved to {args.dst_path}")
29+
random.shuffle(pairs)
30+
data = [
31+
{
32+
"instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n",
33+
"input": "",
34+
"output": pair["output"],
35+
}
36+
for pair in pairs
37+
]
38+
logger.info(data[0])
39+
dst_path = args.dst_path + f"_{split}.json"
40+
json.dump(data, open(dst_path, "wt"))
41+
logger.info(f"Preprocessed data saved to {dst_path}")
3842

3943

4044
if __name__ == "__main__":

0 commit comments

Comments
 (0)