|
6 | 6 | -- no_rtx2080 no_k40 no_gtx780 input @ data/large.in.gz
|
7 | 7 | -- output { 952131i64 }
|
8 | 8 |
|
| 9 | +-- == |
| 10 | +-- entry: diff |
| 11 | +-- input @ data/small.in.gz |
| 12 | +-- no_rtx2080 no_k40 no_gtx780 input @ data/large.in.gz |
| 13 | + |
| 14 | + |
9 | 15 | type nuclide_grid_point =
|
10 | 16 | { energy: f64,
|
11 | 17 | total_xs: f64,
|
@@ -225,9 +231,29 @@ def unpack n_isotopes n_gridpoints grid_type hash_bins lookups
|
225 | 231 | {num_nucs, concs, mats, nuclide_grid, index_grid, unionized_energy_array}
|
226 | 232 | in (inputs, sd)
|
227 | 233 |
|
228 |
| -def main n_isotopes n_gridpoints grid_type hash_bins lookups |
229 |
| - num_nucs concs mats nuclide_grid index_grid unionized_energy_array = |
| 234 | +entry main n_isotopes n_gridpoints grid_type hash_bins lookups |
| 235 | + num_nucs concs mats nuclide_grid index_grid unionized_energy_array = |
230 | 236 | let (inputs, sd) =
|
231 | 237 | unpack n_isotopes n_gridpoints grid_type hash_bins lookups
|
232 | 238 | num_nucs concs mats nuclide_grid index_grid unionized_energy_array
|
233 | 239 | in #[unsafe] verification (run_event_based_simulation inputs sd)
|
| 240 | + |
| 241 | +-- Performs a single vjp pass with an all-unit seed vector. This is |
| 242 | +-- unlikely to produce a useful gradient, but does show the overhead |
| 243 | +-- of a single jvp invocation. |
| 244 | +entry diff n_isotopes n_gridpoints grid_type hash_bins lookups |
| 245 | + num_nucs concs mats nuclide_grid index_grid unionized_energy_array = |
| 246 | + let (inputs, sd) = |
| 247 | + unpack n_isotopes n_gridpoints grid_type hash_bins lookups |
| 248 | + num_nucs concs mats nuclide_grid index_grid unionized_energy_array |
| 249 | + let diff_res = #[unsafe] |
| 250 | + (vjp (run_event_based_simulation inputs) |
| 251 | + sd |
| 252 | + (replicate inputs.lookups (1,1,1,1,1))).nuclide_grid |
| 253 | + in (map (map (.absorbtion_xs)) diff_res, |
| 254 | + map (map (.elastic_xs)) diff_res, |
| 255 | + map (map (.energy)) diff_res, |
| 256 | + map (map (.fission_xs)) diff_res, |
| 257 | + map (map (.nu_fission_xs)) diff_res, |
| 258 | + map (map (.total_xs)) diff_res, |
| 259 | + ) |
0 commit comments