Skip to content

Commit d99be76

Browse files
committed
2025/11/06-11:24:21 (Linux sv1224 x86_64)
1 parent 7c0c44f commit d99be76

File tree

1 file changed

+56
-35
lines changed

1 file changed

+56
-35
lines changed

README.md

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -135,21 +135,21 @@ The preprocessing command accepts `.xyz`, `.lmdb`/`.aselmdb`, and `.h5` inputs;
135135
```python
136136
from equitrain import get_args_parser_preprocess, preprocess
137137

138-
def test_preprocess():
139-
args = get_args_parser_preprocess().parse_args()
140-
args.train_file = 'data.xyz'
141-
args.valid_file = 'data.xyz'
142-
args.output_dir = 'test_preprocess/'
138+
139+
def run_preprocess():
140+
args = get_args_parser_preprocess().parse_args([])
141+
args.train_file = 'data.xyz'
142+
args.valid_file = 'data.xyz'
143+
args.output_dir = 'test_preprocess'
143144
args.compute_statistics = True
144-
# Compute atomic energies
145-
args.atomic_energies = "average"
146-
# Cutoff radius for computing graphs
145+
args.atomic_energies = 'average'
147146
args.r_max = 4.5
148147

149148
preprocess(args)
150149

151-
if __name__ == "__main__":
152-
test_preprocess()
150+
151+
if __name__ == '__main__':
152+
run_preprocess()
153153
```
154154

155155
---
@@ -166,7 +166,7 @@ equitrain -v \
166166
--train-file data/train.h5 \
167167
--valid-file data/valid.h5 \
168168
--output-dir result_mace \
169-
--model mace.model \
169+
--model path/to/mace.model \
170170
--model-wrapper 'mace' \
171171
--epochs 10 \
172172
--tqdm
@@ -176,6 +176,7 @@ equitrain -v \
176176
--train-file data/train.h5 \
177177
--valid-file data/valid.h5 \
178178
--output-dir result_orb \
179+
--model path/to/orb.model \
179180
--model-wrapper 'orb' \
180181
--epochs 10 \
181182
--tqdm
@@ -186,26 +187,45 @@ equitrain -v \
186187

187188
```python
188189
from equitrain import get_args_parser_train, train
189-
from equitrain.backends.torch_wrappers import MaceWrapper, OrbWrapper
190190

191-
# Training with MACE
192-
def test_train_mace():
193-
args = get_args_parser_train().parse_args()
194-
args.train_file = 'data/train.h5'
195-
args.valid_file = 'data/valid.h5'
196-
args.output_dir = 'test_train_mace'
197-
args.epochs = 10
198-
args.batch_size = 64
199-
args.lr = 0.01
200-
args.verbose = 1
201-
args.tqdm = True
202-
args.model = MaceWrapper(args, "mace.model")
191+
192+
def train_mace():
193+
args = get_args_parser_train().parse_args([])
194+
args.train_file = 'data/train.h5'
195+
args.valid_file = 'data/valid.h5'
196+
args.output_dir = 'runs/mace'
197+
args.epochs = 10
198+
args.batch_size = 64
199+
args.lr = 1e-2
200+
args.verbose = 1
201+
args.tqdm = True
202+
203+
args.model = 'path/to/mace.model'
204+
args.model_wrapper = 'mace'
205+
206+
train(args)
207+
208+
209+
def train_orb():
210+
args = get_args_parser_train().parse_args([])
211+
args.train_file = 'data/train.h5'
212+
args.valid_file = 'data/valid.h5'
213+
args.output_dir = 'runs/orb'
214+
args.epochs = 10
215+
args.batch_size = 32
216+
args.lr = 5e-4
217+
args.verbose = 1
218+
args.tqdm = True
219+
220+
args.model = 'path/to/orb.model'
221+
args.model_wrapper = 'orb'
203222

204223
train(args)
205224

206-
if __name__ == "__main__":
207-
test_train_mace()
208-
# test_train_orb()
225+
226+
if __name__ == '__main__':
227+
train_mace()
228+
# train_orb()
209229
```
210230

211231
#### Running the JAX backend
@@ -233,22 +253,23 @@ Use a trained model to make predictions on new data:
233253

234254
```python
235255
from equitrain import get_args_parser_predict, predict
236-
from equitrain.backends.torch_wrappers import MaceWrapper
237256

238-
def test_mace_predict():
239-
args = get_args_parser_predict().parse_args()
257+
258+
def predict_with_mace():
259+
args = get_args_parser_predict().parse_args([])
240260
args.predict_file = 'data/valid.h5'
241-
args.batch_size = 64
242-
args.model = MaceWrapper(args, "mace.model")
261+
args.batch_size = 64
262+
args.model = 'path/to/mace.model'
263+
args.model_wrapper = 'mace'
243264

244265
energy_pred, forces_pred, stress_pred = predict(args)
245-
246266
print(energy_pred)
247267
print(forces_pred)
248268
print(stress_pred)
249269

250-
if __name__ == "__main__":
251-
test_mace_predict()
270+
271+
if __name__ == '__main__':
272+
predict_with_mace()
252273
```
253274

254275
---

0 commit comments

Comments
 (0)