|
145 | 145 | " The descriptors job reads the structure and computes numerical features\n", |
146 | 146 | " (fingerprints) for each structure.\n", |
147 | 147 | "3. Collect the descriptor outputs, as StructureData, for all structures\n", |
148 | | - " and pass them to `process_and_split_data` (a calcfunction).\n", |
149 | | - "4. `process_and_split_data` writes the structures to `train.xyz`, `test.xyz`,\n", |
150 | | - " and `valid.xyz` files, and returns a Dict node with the file paths." |
| 148 | + " and pass them to `create_aiida_files` (a calcfunction task).\n", |
| 149 | + "4. This calls `process_and_split_data` (a Python function) which writes the structures to `train.xyz`, `test.xyz`,\n", |
| 150 | + " and `valid.xyz` files. The task returns `SinglefileData` AiiDA data types, hence why we have to create a `calcfunction` task as oppose to just a `task`, we do this so the files are available on the WorkGraph." |
| 151 | + ] |
| 152 | + }, |
| 153 | + { |
| 154 | + "cell_type": "code", |
| 155 | + "execution_count": null, |
| 156 | + "id": "eec0b5ca", |
| 157 | + "metadata": {}, |
| 158 | + "outputs": [], |
| 159 | + "source": [ |
| 160 | + "from aiida_workgraph import task\n", |
| 161 | + "from sample_split import process_and_split_data\n", |
| 162 | + "from aiida.orm import SinglefileData\n", |
| 163 | + "\n", |
| 164 | + "@task.calcfunction(outputs = [\"test_file\", \"train_file\", \"valid_file\"])\n", |
| 165 | + "def create_qe_files(**inputs):\n", |
| 166 | + " \n", |
| 167 | + " files = process_and_split_data(**inputs)\n", |
| 168 | + "\n", |
| 169 | + " return {\n", |
| 170 | + " \"train_file\": SinglefileData(files[\"train_file\"]),\n", |
| 171 | + " \"test_file\": SinglefileData(files[\"test_file\"]),\n", |
| 172 | + " \"valid_file\": SinglefileData(files[\"valid_file\"])\n", |
| 173 | + " }" |
151 | 174 | ] |
152 | 175 | }, |
153 | 176 | { |
|
158 | 181 | "outputs": [], |
159 | 182 | "source": [ |
160 | 183 | "from aiida.orm import Str, Float, Bool, Int\n", |
161 | | - "from ase.io import read\n", |
| 184 | + "from ase.io import iread\n", |
162 | 185 | "from aiida_workgraph import WorkGraph\n", |
163 | 186 | "from aiida.orm import StructureData\n", |
164 | 187 | "from sample_split import process_and_split_data\n", |
165 | 188 | "\n", |
166 | | - "initail_structure = \"../structures/lj-traj.xyz\"\n", |
167 | | - "num_structs = len(read(initail_structure, index=\":\"))\n", |
| 189 | + "initial_structure = \"../structures/NaCl-traj.xyz\"\n", |
168 | 190 | "\n", |
169 | 191 | "with WorkGraph(\"Calculation Workgraph\") as wg:\n", |
170 | 192 | " final_structures = {}\n", |
171 | 193 | "\n", |
172 | | - " for i in range(num_structs):\n", |
173 | | - " structure = StructureData(ase=read(initail_structure, index=i))\n", |
| 194 | + " for i, struct in enumerate(iread(initial_structure)):\n", |
| 195 | + " structure = StructureData(ase=struct)\n", |
174 | 196 | "\n", |
175 | 197 | " geomopt_calc = wg.add_task(\n", |
176 | 198 | " geomoptCalc,\n", |
|
199 | 221 | " final_structures[f\"structs{i}\"] = descriptors_calc.outputs.xyz_output\n", |
200 | 222 | "\n", |
201 | 223 | " split_task = wg.add_task(\n", |
202 | | - " process_and_split_data,\n", |
| 224 | + " create_qe_files,\n", |
203 | 225 | " config_types= Str(\"\"),\n", |
204 | | - " n_samples=Int(num_structs),\n", |
| 226 | + " n_samples=Int(len(final_structures)),\n", |
205 | 227 | " prefix= Str(\"\"),\n", |
206 | 228 | " scale= Float(1.0e5),\n", |
207 | 229 | " append_mode= Bool(False),\n", |
|
256 | 278 | { |
257 | 279 | "cell_type": "code", |
258 | 280 | "execution_count": null, |
259 | | - "id": "fe7291b7", |
| 281 | + "id": "d2f463f6", |
260 | 282 | "metadata": {}, |
261 | 283 | "outputs": [], |
262 | 284 | "source": [ |
263 | | - "wg.tasks.process_and_split_data.outputs.result.value.get_dict()" |
| 285 | + "test_file = wg.tasks.create_qe_files.outputs.test_file.value\n", |
| 286 | + "train_file = wg.tasks.create_qe_files.outputs.train_file.value\n", |
| 287 | + "valid_file = wg.tasks.create_qe_files.outputs.valid_file.value" |
264 | 288 | ] |
265 | 289 | }, |
266 | 290 | { |
267 | 291 | "cell_type": "markdown", |
268 | | - "id": "514aeb77", |
| 292 | + "id": "8dedf8c5", |
269 | 293 | "metadata": {}, |
270 | 294 | "source": [ |
271 | 295 | "We can use the outputs to visualise the data. For example, below we will plot a histogram of `mace_mp_descriptor`" |
272 | 296 | ] |
273 | 297 | }, |
274 | | - { |
275 | | - "cell_type": "code", |
276 | | - "execution_count": null, |
277 | | - "id": "d2f463f6", |
278 | | - "metadata": {}, |
279 | | - "outputs": [], |
280 | | - "source": [ |
281 | | - "test_file = wg.tasks.process_and_split_data.outputs.result.value.get_dict()[\"test_file\"]\n", |
282 | | - "train_file = wg.tasks.process_and_split_data.outputs.result.value.get_dict()[\"train_file\"]\n", |
283 | | - "valid_file = wg.tasks.process_and_split_data.outputs.result.value.get_dict()[\"valid_file\"]" |
284 | | - ] |
285 | | - }, |
286 | 298 | { |
287 | 299 | "cell_type": "code", |
288 | 300 | "execution_count": null, |
|
294 | 306 | "from ase.io import iread\n", |
295 | 307 | "import matplotlib.pyplot as plt\n", |
296 | 308 | "\n", |
297 | | - "test_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(test_file, index=':')])\n", |
298 | | - "train_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(train_file, index=':')])\n", |
299 | | - "valid_mace_desc = np.array([i.info['mace_mp_descriptor'] for i in iread(valid_file, index=':')])\n", |
| 309 | + "with test_file.as_path() as path:\n", |
| 310 | + " test_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n", |
| 311 | + "with train_file.as_path() as path:\n", |
| 312 | + " train_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n", |
| 313 | + "with valid_file.as_path() as path:\n", |
| 314 | + " valid_mace_desc = np.array([struct.info['mace_mp_descriptor'] for struct in iread(path, index=':')])\n", |
300 | 315 | "\n", |
301 | 316 | "all_values = np.concatenate([train_mace_desc, valid_mace_desc, test_mace_desc])\n", |
302 | 317 | "bins = np.linspace(all_values.min(), all_values.max(), len(all_values))\n", |
|
0 commit comments