1
1
"""Script for preprocess state-tactic pairs into the format required by [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)."""
2
2
3
+ import os
3
4
import json
4
5
import random
5
6
import argparse
@@ -11,30 +12,33 @@ def main() -> None:
11
12
parser .add_argument (
12
13
"--data-path" ,
13
14
type = str ,
14
- default = "./data/leandojo_benchmark_4/random/train.json " ,
15
+ default = "./data/leandojo_benchmark_4/random" ,
15
16
)
16
- parser .add_argument ("--dst-path" , type = str , default = "state_tactic_pairs.json " )
17
+ parser .add_argument ("--dst-path" , type = str , default = "state_tactic_pairs" )
17
18
args = parser .parse_args ()
18
19
logger .info (args )
19
20
20
- pairs = []
21
- for thm in json .load (open (args .data_path )):
22
- for tac in thm ["traced_tactics" ]:
23
- pairs .append ({"state" : tac ["state_before" ], "output" : tac ["tactic" ]})
24
- logger .info (f"Read { len (pairs )} state-tactic paris from { args .data_path } " )
21
+ for split in ("train" , "val" ):
22
+ data_path = os .path .join (args .data_path , f"{ split } .json" )
23
+ pairs = []
24
+ for thm in json .load (open (data_path )):
25
+ for tac in thm ["traced_tactics" ]:
26
+ pairs .append ({"state" : tac ["state_before" ], "output" : tac ["tactic" ]})
27
+ logger .info (f"Read { len (pairs )} state-tactic paris from { data_path } " )
25
28
26
- random .shuffle (pairs )
27
- data = [
28
- {
29
- "instruction" : f"[GOAL]\n { pair ['state' ]} \n [PROOFSTEP]\n " ,
30
- "input" : "" ,
31
- "output" : pair ["output" ],
32
- }
33
- for pair in pairs
34
- ]
35
- logger .info (data [0 ])
36
- json .dump (data , open (args .dst_path , "wt" ))
37
- logger .info (f"Preprocessed data saved to { args .dst_path } " )
29
+ random .shuffle (pairs )
30
+ data = [
31
+ {
32
+ "instruction" : f"[GOAL]\n { pair ['state' ]} \n [PROOFSTEP]\n " ,
33
+ "input" : "" ,
34
+ "output" : pair ["output" ],
35
+ }
36
+ for pair in pairs
37
+ ]
38
+ logger .info (data [0 ])
39
+ dst_path = args .dst_path + f"_{ split } .json"
40
+ json .dump (data , open (dst_path , "wt" ))
41
+ logger .info (f"Preprocessed data saved to { dst_path } " )
38
42
39
43
40
44
if __name__ == "__main__" :
0 commit comments