Skip to content

Commit 2fbaae1

Browse files
MM0hsinElliottKasoaroerc0122alinelena
authored
Add fine tuning example workgraph (#202)
Co-authored-by: Elliott Kasoar <[email protected]> Co-authored-by: Jacob Wilkins <[email protected]> Co-authored-by: Alin Marin Elena <[email protected]>
1 parent 7aebec2 commit 2fbaae1

File tree

7 files changed

+708
-1214
lines changed

7 files changed

+708
-1214
lines changed

examples/tutorials/structures/NaCl-traj.xyz

Lines changed: 48 additions & 0 deletions
Large diffs are not rendered by default.

examples/tutorials/structures/lj-traj.xyz

Lines changed: 0 additions & 1122 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: "test"
2+
E0s: 'average'
3+
max_num_epochs: 1
4+
model: 'MACE'
5+
energy_key: 'qe_energy'
6+
forces_key: 'qe_forces'
7+
stress_key: 'qe_stress'
8+
loss: 'universal'
9+
energy_weight: 1
10+
forces_weight: 10
11+
stress_weight: 100
12+
compute_stress: True
13+
eval_interval: 2
14+
error_table: 'PerAtomRMSE'
15+
interaction_first: 'RealAgnosticResidualInteractionBlock'
16+
interaction: 'RealAgnosticResidualInteractionBlock'
17+
num_interactions: 2
18+
correlation: 3
19+
max_ell: 3
20+
r_max: 4.0
21+
max_L: 0
22+
num_channels: 16
23+
num_radial_basis: 6
24+
MLP_irreps: '16x0e'
25+
scaling: 'rms_forces_scaling'
26+
lr: 0.005
27+
weight_decay: 1e-8
28+
ema: True
29+
ema_decay: 0.995
30+
scheduler_patience: 5
31+
batch_size: 2
32+
valid_batch_size: 2
33+
patience: 50
34+
amsgrad: True
35+
device: 'cpu'
36+
distributed: False
37+
clip_grad: 100
38+
keep_checkpoints: False
39+
keep_isolated_atoms: True
40+
save_cpu: True

examples/tutorials/workgraphs/descriptors_filter.ipynb

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,32 @@
145145
" The descriptors job reads the structure and computes numerical features\n",
146146
" (fingerprints) for each structure.\n",
147147
"3. Collect the descriptor outputs, as StructureData, for all structures\n",
148-
" and pass them to `process_and_split_data` (a calcfunction).\n",
149-
"4. `process_and_split_data` writes the structures to `train.xyz`, `test.xyz`,\n",
150-
" and `valid.xyz` files, and returns a Dict node with the file paths."
148+
" and pass them to `create_aiida_files` (a calcfunction task).\n",
149+
"4. This calls `process_and_split_data` (a Python function) which writes the structures to `train.xyz`, `test.xyz`,\n",
150+
" and `valid.xyz` files. The task returns `SinglefileData` AiiDA data types, hence why we have to create a `calcfunction` task as oppose to just a `task`, we do this so the files are available on the WorkGraph."
151+
]
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": null,
156+
"id": "eec0b5ca",
157+
"metadata": {},
158+
"outputs": [],
159+
"source": [
160+
"from aiida_workgraph import task\n",
161+
"from sample_split import process_and_split_data\n",
162+
"from aiida.orm import SinglefileData\n",
163+
"\n",
164+
"@task.calcfunction(outputs = [\"test_file\", \"train_file\", \"valid_file\"])\n",
165+
"def create_qe_files(**inputs):\n",
166+
" \n",
167+
" files = process_and_split_data(**inputs)\n",
168+
"\n",
169+
" return {\n",
170+
" \"train_file\": SinglefileData(files[\"train_file\"]),\n",
171+
" \"test_file\": SinglefileData(files[\"test_file\"]),\n",
172+
" \"valid_file\": SinglefileData(files[\"valid_file\"])\n",
173+
" }"
151174
]
152175
},
153176
{
@@ -158,19 +181,18 @@
158181
"outputs": [],
159182
"source": [
160183
"from aiida.orm import Str, Float, Bool, Int\n",
161-
"from ase.io import read\n",
184+
"from ase.io import iread\n",
162185
"from aiida_workgraph import WorkGraph\n",
163186
"from aiida.orm import StructureData\n",
164187
"from sample_split import process_and_split_data\n",
165188
"\n",
166-
"initail_structure = \"../structures/lj-traj.xyz\"\n",
167-
"num_structs = len(read(initail_structure, index=\":\"))\n",
189+
"initial_structure = \"../structures/NaCl-traj.xyz\"\n",
168190
"\n",
169191
"with WorkGraph(\"Calculation Workgraph\") as wg:\n",
170192
" final_structures = {}\n",
171193
"\n",
172-
" for i in range(num_structs):\n",
173-
" structure = StructureData(ase=read(initail_structure, index=i))\n",
194+
" for i, struct in enumerate(iread(initial_structure)):\n",
195+
" structure = StructureData(ase=struct)\n",
174196
"\n",
175197
" geomopt_calc = wg.add_task(\n",
176198
" geomoptCalc,\n",
@@ -199,9 +221,9 @@
199221
" final_structures[f\"structs{i}\"] = descriptors_calc.outputs.xyz_output\n",
200222
"\n",
201223
" split_task = wg.add_task(\n",
202-
" process_and_split_data,\n",
224+
" create_qe_files,\n",
203225
" config_types= Str(\"\"),\n",
204-
" n_samples=Int(num_structs),\n",
226+
" n_samples=Int(len(final_structures)),\n",
205227
" prefix= Str(\"\"),\n",
206228
" scale= Float(1.0e5),\n",
207229
" append_mode= Bool(False),\n",
@@ -256,33 +278,23 @@
256278
{
257279
"cell_type": "code",
258280
"execution_count": null,
259-
"id": "fe7291b7",
281+
"id": "d2f463f6",
260282
"metadata": {},
261283
"outputs": [],
262284
"source": [
263-
"wg.tasks.process_and_split_data.outputs.result.value.get_dict()"
285+
"test_file = wg.tasks.create_qe_files.outputs.test_file.value\n",
286+
"train_file = wg.tasks.create_qe_files.outputs.train_file.value\n",
287+
"valid_file = wg.tasks.create_qe_files.outputs.valid_file.value"
264288
]
265289
},
266290
{
267291
"cell_type": "markdown",
268-
"id": "514aeb77",
292+
"id": "8dedf8c5",
269293
"metadata": {},
270294
"source": [
271295
"We can use the outputs to visualise the data. For example, below we will plot a histogram of `mace_mp_descriptor`"
272296
]
273297
},
274-
{
275-
"cell_type": "code",
276-
"execution_count": null,
277-
"id": "d2f463f6",
278-
"metadata": {},
279-
"outputs": [],
280-
"source": [
281-
"test_file = wg.tasks.process_and_split_data.outputs.result.value.get_dict()[\"test_file\"]\n",
282-
"train_file = wg.tasks.process_and_split_data.outputs.result.value.get_dict()[\"train_file\"]\n",
283-
"valid_file = wg.tasks.process_and_split_data.outputs.result.value.get_dict()[\"valid_file\"]"
284-
]
285-
},
286298
{
287299
"cell_type": "code",
288300
"execution_count": null,
@@ -294,9 +306,12 @@
294306
"from ase.io import iread\n",
295307
"import matplotlib.pyplot as plt\n",
296308
"\n",
297-
"test_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(test_file, index=':')])\n",
298-
"train_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(train_file, index=':')])\n",
299-
"valid_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(valid_file, index=':')])\n",
309+
"with test_file.as_path() as path:\n",
310+
" test_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n",
311+
"with train_file.as_path() as path:\n",
312+
" train_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n",
313+
"with valid_file.as_path() as path:\n",
314+
" valid_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n",
300315
"\n",
301316
"all_values = np.concatenate([train_mace_desc, valid_mace_desc, test_mace_desc])\n",
302317
"bins = np.linspace(all_values.min(), all_values.max(), len(all_values))\n",

0 commit comments

Comments
 (0)