Skip to content

Commit 843a6fe

Browse files
committed
Address PR comments
1 parent adaea1f commit 843a6fe

File tree

3 files changed

+49
-35
lines changed

3 files changed

+49
-35
lines changed

examples/tutorials/workgraphs/descriptors_filter.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,19 +181,19 @@
181181
"outputs": [],
182182
"source": [
183183
"from aiida.orm import Str, Float, Bool, Int\n",
184-
"from ase.io import read\n",
184+
"from ase.io import iread\n",
185185
"from aiida_workgraph import WorkGraph\n",
186186
"from aiida.orm import StructureData\n",
187187
"from sample_split import process_and_split_data\n",
188188
"\n",
189189
"initial_structure = \"../structures/NaCl-traj.xyz\"\n",
190-
"num_structs = len(read(initial_structure, index=\":\"))\n",
190+
"# num_structs = len(read(initial_structure, index=\":\"))\n",
191191
"\n",
192192
"with WorkGraph(\"Calculation Workgraph\") as wg:\n",
193193
" final_structures = {}\n",
194194
"\n",
195-
" for i in range(num_structs):\n",
196-
" structure = StructureData(ase=read(initial_structure, index=i))\n",
195+
" for i, struct in enumerate(iread(initial_structure)):\n",
196+
" structure = StructureData(ase=struct)\n",
197197
"\n",
198198
" geomopt_calc = wg.add_task(\n",
199199
" geomoptCalc,\n",
@@ -224,7 +224,7 @@
224224
" split_task = wg.add_task(\n",
225225
" create_aiida_files,\n",
226226
" config_types= Str(\"\"),\n",
227-
" n_samples=Int(num_structs),\n",
227+
" n_samples=Int(len(final_structures)),\n",
228228
" prefix= Str(\"\"),\n",
229229
" scale= Float(1.0e5),\n",
230230
" append_mode= Bool(False),\n",
@@ -308,11 +308,11 @@
308308
"import matplotlib.pyplot as plt\n",
309309
"\n",
310310
"with test_file.as_path() as path:\n",
311-
" test_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(path, index=':')])\n",
311+
" test_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n",
312312
"with train_file.as_path() as path:\n",
313-
" train_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(path, index=':')])\n",
313+
" train_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n",
314314
"with valid_file.as_path() as path:\n",
315-
" valid_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(path, index=':')])\n",
315+
" valid_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n",
316316
"\n",
317317
"all_values = np.concatenate([train_mace_desc, valid_mace_desc, test_mace_desc])\n",
318318
"bins = np.linspace(all_values.min(), all_values.max(), len(all_values))\n",

examples/tutorials/workgraphs/descriptors_filter_qe.ipynb

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
"source": [
140140
"from aiida_workgraph import task\n",
141141
"from aiida_workgraph.manager import get_current_graph\n",
142-
"from aiida.orm import StructureData, load_group, KpointsData, SinglefileData\n",
142+
"from aiida.orm import StructureData, load_group, KpointsData, SinglefileData, InstalledCode, List, Dict\n",
143143
"from ase.io import iread\n",
144144
"from pathlib import Path\n",
145145
"import yaml\n",
@@ -148,18 +148,23 @@
148148
"\n",
149149
"\n",
150150
"@task.graph(outputs = [\"test_file\", \"train_file\", \"valid_file\"])\n",
151-
"def qe(**inputs):\n",
151+
"def qe(\n",
152+
" code: InstalledCode,\n",
153+
" kpoints_mesh: List,\n",
154+
" task_metadata: Dict,\n",
155+
" test_file: SinglefileData,\n",
156+
" train_file: SinglefileData,\n",
157+
" valid_file: SinglefileData\n",
158+
" ):\n",
152159
"\n",
153160
" wg = get_current_graph()\n",
154161
"\n",
155-
" task_inputs = inputs[\"task_params\"]['task_inputs']\n",
156-
" code =inputs[\"task_params\"][\"code\"]\n",
157-
"\n",
158162
" kpoints = KpointsData()\n",
159-
" kpoints.set_kpoints_mesh(task_inputs['kpoint_mesh'])\n",
163+
" kpoints.set_kpoints_mesh(kpoints_mesh)\n",
160164
"\n",
161165
" pseudo_family = load_group('SSSP/1.3/PBE/efficiency')\n",
162-
" files = {\"test_file\": inputs['test_file'],\"train_file\":inputs['train_file'],\"valid_file\":inputs['valid_file']}\n",
166+
" \n",
167+
" files = {\"test_file\": test_file, \"train_file\": train_file, \"valid_file\": valid_file}\n",
163168
"\n",
164169
" for file_name, file in files.items():\n",
165170
" with file.as_path() as path:\n",
@@ -187,12 +192,12 @@
187192
" \n",
188193
" qe_task = wg.add_task(\n",
189194
" PwCalculation,\n",
190-
" code = code,\n",
191-
" parameters= pw_params,\n",
192-
" kpoints= kpoints,\n",
193-
" pseudos= pseudos,\n",
194-
" metadata= task_inputs[\"metadata\"],\n",
195-
" structure= structure,\n",
195+
" code=code,\n",
196+
" parameters=pw_params,\n",
197+
" kpoints=kpoints,\n",
198+
" pseudos=pseudos,\n",
199+
" metadata=task_metadata.value,\n",
200+
" structure=structure,\n",
196201
" )\n",
197202
" \n",
198203
" structfile = f\"{file_name}.struct{i}\"\n",
@@ -239,9 +244,9 @@
239244
" for file_name, structs in inputs.items():\n",
240245
" path = Path(f\"mlip_{file_name}.extxyz\")\n",
241246
"\n",
242-
" for stuct_out_params in structs.values():\n",
247+
" for struct_out_params in structs.values():\n",
243248
" \n",
244-
" trajectory = stuct_out_params[\"trajectory\"]\n",
249+
" trajectory = struct_out_params[\"trajectory\"]\n",
245250
"\n",
246251
" fileStructure = trajectory.get_structure(index=0)\n",
247252
" fileAtoms = fileStructure.get_ase()\n",
@@ -253,7 +258,7 @@
253258
" fileAtoms.info[\"units\"] = {\"energy\": \"eV\",\"forces\": \"ev/Ang\",\"stress\": \"ev/Ang^3\"}\n",
254259
" fileAtoms.set_array(\"qe_forces\", trajectory.arrays[\"forces\"][0])\n",
255260
"\n",
256-
" parameters = stuct_out_params[\"parameters\"]\n",
261+
" parameters = struct_out_params[\"parameters\"]\n",
257262
" fileParams = parameters.get_dict()\n",
258263
" fileAtoms.info[\"qe_energy\"] = fileParams[\"energy\"]\n",
259264
" write(path, fileAtoms, append=True)\n",
@@ -318,7 +323,7 @@
318323
" \"metadata\": {\"options\": {\"resources\": {\"num_machines\": 1}}},\n",
319324
"}\n",
320325
"\n",
321-
"goemopt_inputs = {\n",
326+
"geomopt_inputs = {\n",
322327
" \"fmax\": Float(0.1),\n",
323328
" \"opt_cell_lengths\": Bool(False),\n",
324329
" \"opt_cell_fully\": Bool(True),\n",
@@ -332,8 +337,7 @@
332337
"}\n",
333338
"\n",
334339
"qe_inputs = {\n",
335-
" \"task_inputs\": Dict({\n",
336-
" \"metadata\": {\n",
340+
" \"task_metadata\": Dict({\n",
337341
" \"options\": {\n",
338342
" \"resources\": {\n",
339343
" \"num_machines\": 1,\n",
@@ -352,9 +356,8 @@
352356
" \"\"\",\n",
353357
" \"append_text\": \"\",\n",
354358
" },\n",
355-
" },\n",
356-
" \"kpoint_mesh\": List([1, 1, 1]),\n",
357359
" }),\n",
360+
" \"kpoints_mesh\": List([1, 1, 1]),\n",
358361
" \"code\": qe_code,\n",
359362
"}"
360363
]
@@ -364,7 +367,7 @@
364367
"id": "06ee80fd",
365368
"metadata": {},
366369
"source": [
367-
"Now we can build the `Workgraph`. First we iterate through each structure in the initail structure file, and run `Geomopt` and `Descriptors` on them these give a `SinglefileData` instance of the structure outputs. These structures can then be passed to the `split_task`, which splits these structures up into training files. Then we run `QE` task, getting the outputs and passing them into the `training_files` task which, as the name suggests, it creates the training file from the `QE` task outputs. Finally we can run the training script. Ideally, if any of the inputs need to changed, they should be done in the cell above."
370+
"Now we can build the `Workgraph`. First we iterate through each structure in the initail structure file, and run `Geomopt` and `Descriptors` on them these give a `SinglefileData` instance of the structure outputs. These structures can then be passed to the `split_task`, which splits these structures up into train, test and validation files. Then we run `QE` task, getting the outputs and passing them into the `training_files` task which, as the name suggests, it creates the training file from the `QE` task outputs. Finally we can run the training script. Ideally, if any of the inputs need to changed, they should be done in the cell above."
368371
]
369372
},
370373
{
@@ -390,7 +393,7 @@
390393
" geomopt_calc = wg.add_task(\n",
391394
" geomoptCalc,\n",
392395
" **calc_inputs,\n",
393-
" **goemopt_inputs,\n",
396+
" **geomopt_inputs,\n",
394397
" struct=structure,\n",
395398
" )\n",
396399
" \n",
@@ -416,7 +419,7 @@
416419
" test_file= split_task.outputs.test_file,\n",
417420
" train_file= split_task.outputs.train_file,\n",
418421
" valid_file= split_task.outputs.valid_file,\n",
419-
" task_params = qe_inputs\n",
422+
" **qe_inputs\n",
420423
" )\n",
421424
"\n",
422425
" training_files = wg.add_task(\n",
@@ -441,7 +444,7 @@
441444
"id": "7f3c72ca",
442445
"metadata": {},
443446
"source": [
444-
"Run and visualise the workgraph"
447+
"Visualise and run the workgraph"
445448
]
446449
},
447450
{
@@ -464,6 +467,16 @@
464467
"wg.run()"
465468
]
466469
},
470+
{
471+
"cell_type": "code",
472+
"execution_count": null,
473+
"id": "13e509e3",
474+
"metadata": {},
475+
"outputs": [],
476+
"source": [
477+
"wg.tasks.create_aiida_files.outputs.test_file"
478+
]
479+
},
467480
{
468481
"cell_type": "markdown",
469482
"id": "2e36396f",
@@ -484,7 +497,8 @@
484497
"import matplotlib.image as mpimg\n",
485498
"\n",
486499
"folder = wg.tasks.Train.outputs.remote_folder.value\n",
487-
"picturePath = f\"{os.getcwd()}/traingraph.png\"\n",
500+
"picturePath = Path.cwd() / \"traingraph.png\"\n",
501+
"\n",
488502
"folder.getfile(relpath='results/test_run-123_train_Default_stage_one.png',destpath=picturePath)\n",
489503
"\n",
490504
"img = mpimg.imread(picturePath)\n",

examples/tutorials/workgraphs/sample_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def process_and_split_data(
4747
scale: float,
4848
append_mode: bool,
4949
**trajectory_data,
50-
) -> dict:
50+
) -> dict[str, Path]:
5151
"""
5252
Split a trajectory into training, validation, and test sets.
5353

0 commit comments

Comments
 (0)