@@ -135,21 +135,21 @@ The preprocessing command accepts `.xyz`, `.lmdb`/`.aselmdb`, and `.h5` inputs;
135135``` python
136136from equitrain import get_args_parser_preprocess, preprocess
137137
138- def test_preprocess ():
139- args = get_args_parser_preprocess().parse_args()
140- args.train_file = ' data.xyz'
141- args.valid_file = ' data.xyz'
142- args.output_dir = ' test_preprocess/'
138+
139+ def run_preprocess ():
140+ args = get_args_parser_preprocess().parse_args([])
141+ args.train_file = ' data.xyz'
142+ args.valid_file = ' data.xyz'
143+ args.output_dir = ' test_preprocess'
143144 args.compute_statistics = True
144- # Compute atomic energies
145- args.atomic_energies = " average"
146- # Cutoff radius for computing graphs
145+ args.atomic_energies = ' average'
147146 args.r_max = 4.5
148147
149148 preprocess(args)
150149
151- if __name__ == " __main__" :
152- test_preprocess()
150+
151+ if __name__ == ' __main__' :
152+ run_preprocess()
153153```
154154
155155---
@@ -166,7 +166,7 @@ equitrain -v \
166166 --train-file data/train.h5 \
167167 --valid-file data/valid.h5 \
168168 --output-dir result_mace \
169- --model mace.model \
169+ --model path/to/ mace.model \
170170 --model-wrapper ' mace' \
171171 --epochs 10 \
172172 --tqdm
@@ -176,6 +176,7 @@ equitrain -v \
176176 --train-file data/train.h5 \
177177 --valid-file data/valid.h5 \
178178 --output-dir result_orb \
179+ --model path/to/orb.model \
179180 --model-wrapper ' orb' \
180181 --epochs 10 \
181182 --tqdm
@@ -186,26 +187,45 @@ equitrain -v \
186187
187188``` python
188189from equitrain import get_args_parser_train, train
189- from equitrain.backends.torch_wrappers import MaceWrapper, OrbWrapper
190190
191- # Training with MACE
192- def test_train_mace ():
193- args = get_args_parser_train().parse_args()
194- args.train_file = ' data/train.h5'
195- args.valid_file = ' data/valid.h5'
196- args.output_dir = ' test_train_mace'
197- args.epochs = 10
198- args.batch_size = 64
199- args.lr = 0.01
200- args.verbose = 1
201- args.tqdm = True
202- args.model = MaceWrapper(args, " mace.model" )
191+
192+ def train_mace ():
193+ args = get_args_parser_train().parse_args([])
194+ args.train_file = ' data/train.h5'
195+ args.valid_file = ' data/valid.h5'
196+ args.output_dir = ' runs/mace'
197+ args.epochs = 10
198+ args.batch_size = 64
199+ args.lr = 1e-2
200+ args.verbose = 1
201+ args.tqdm = True
202+
203+ args.model = ' path/to/mace.model'
204+ args.model_wrapper = ' mace'
205+
206+ train(args)
207+
208+
209+ def train_orb ():
210+ args = get_args_parser_train().parse_args([])
211+ args.train_file = ' data/train.h5'
212+ args.valid_file = ' data/valid.h5'
213+ args.output_dir = ' runs/orb'
214+ args.epochs = 10
215+ args.batch_size = 32
216+ args.lr = 5e-4
217+ args.verbose = 1
218+ args.tqdm = True
219+
220+ args.model = ' path/to/orb.model'
221+ args.model_wrapper = ' orb'
203222
204223 train(args)
205224
206- if __name__ == " __main__" :
207- test_train_mace()
208- # test_train_orb()
225+
226+ if __name__ == ' __main__' :
227+ train_mace()
228+ # train_orb()
209229```
210230
211231#### Running the JAX backend
@@ -233,22 +253,23 @@ Use a trained model to make predictions on new data:
233253
234254``` python
235255from equitrain import get_args_parser_predict, predict
236- from equitrain.backends.torch_wrappers import MaceWrapper
237256
238- def test_mace_predict ():
239- args = get_args_parser_predict().parse_args()
257+
258+ def predict_with_mace ():
259+ args = get_args_parser_predict().parse_args([])
240260 args.predict_file = ' data/valid.h5'
241- args.batch_size = 64
242- args.model = MaceWrapper(args, " mace.model" )
261+ args.batch_size = 64
262+ args.model = ' path/to/mace.model'
263+ args.model_wrapper = ' mace'
243264
244265 energy_pred, forces_pred, stress_pred = predict(args)
245-
246266 print (energy_pred)
247267 print (forces_pred)
248268 print (stress_pred)
249269
250- if __name__ == " __main__" :
251- test_mace_predict()
270+
271+ if __name__ == ' __main__' :
272+ predict_with_mace()
252273```
253274
254275---
0 commit comments