Skip to content

Commit b87d5ed

Browse files
committed
Add AD versions of these.
1 parent 6805553 commit b87d5ed

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

rsbench/rsbench.fut

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
-- input @ data/small.in.gz output { 880018i64 }
44
-- input @ data/large.in.gz output { 358389i64 }
55

6+
-- ==
7+
-- entry: diff
8+
-- input @ data/small.in.gz
9+
-- input @ data/large.in.gz
10+
611
type input =
712
{ lookups: i64,
813
doppler: i32
@@ -287,3 +292,19 @@ def main lookups doppler
287292
let (input, sd) = unpack lookups doppler
288293
n_windows poles_ls poles_cs windows_f64s windows_i32s pseudo_K0RS num_nucs mats concs
289294
in #[unsafe] verification (run_event_based_simulation input.lookups input.doppler sd)
295+
296+
entry diff lookups doppler
297+
n_windows poles_ls poles_cs windows_f64s windows_i32s pseudo_K0RS num_nucs mats concs =
298+
let (input, sd) = unpack lookups doppler
299+
n_windows poles_ls poles_cs windows_f64s windows_i32s pseudo_K0RS num_nucs mats concs
300+
let diff_res = (vjp (run_event_based_simulation input.lookups input.doppler)
301+
sd
302+
(replicate input.lookups (1,1,1,1))).poles
303+
in (map (map (.l_value)) diff_res,
304+
map (map (.mp_ea.i)) diff_res,
305+
map (map (.mp_ea.r)) diff_res,
306+
map (map (.mp_ra.i)) diff_res,
307+
map (map (.mp_ra.r)) diff_res,
308+
map (map (.mp_rf.i)) diff_res,
309+
map (map (.mp_rt.r)) diff_res,
310+
)

xsbench/xsbench.fut

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
-- no_rtx2080 no_k40 no_gtx780 input @ data/large.in.gz
77
-- output { 952131i64 }
88

9+
-- ==
10+
-- entry: diff
11+
-- input @ data/small.in.gz
12+
-- no_rtx2080 no_k40 no_gtx780 input @ data/large.in.gz
13+
14+
915
type nuclide_grid_point =
1016
{ energy: f64,
1117
total_xs: f64,
@@ -225,9 +231,29 @@ def unpack n_isotopes n_gridpoints grid_type hash_bins lookups
225231
{num_nucs, concs, mats, nuclide_grid, index_grid, unionized_energy_array}
226232
in (inputs, sd)
227233

228-
def main n_isotopes n_gridpoints grid_type hash_bins lookups
229-
num_nucs concs mats nuclide_grid index_grid unionized_energy_array =
234+
entry main n_isotopes n_gridpoints grid_type hash_bins lookups
235+
num_nucs concs mats nuclide_grid index_grid unionized_energy_array =
230236
let (inputs, sd) =
231237
unpack n_isotopes n_gridpoints grid_type hash_bins lookups
232238
num_nucs concs mats nuclide_grid index_grid unionized_energy_array
233239
in #[unsafe] verification (run_event_based_simulation inputs sd)
240+
241+
-- Performs a single vjp pass with an all-unit seed vector. This is
242+
-- unlikely to produce a useful gradient, but does show the overhead
243+
-- of a single jvp invocation.
244+
entry diff n_isotopes n_gridpoints grid_type hash_bins lookups
245+
num_nucs concs mats nuclide_grid index_grid unionized_energy_array =
246+
let (inputs, sd) =
247+
unpack n_isotopes n_gridpoints grid_type hash_bins lookups
248+
num_nucs concs mats nuclide_grid index_grid unionized_energy_array
249+
let diff_res = #[unsafe]
250+
(vjp (run_event_based_simulation inputs)
251+
sd
252+
(replicate inputs.lookups (1,1,1,1,1))).nuclide_grid
253+
in (map (map (.absorbtion_xs)) diff_res,
254+
map (map (.elastic_xs)) diff_res,
255+
map (map (.energy)) diff_res,
256+
map (map (.fission_xs)) diff_res,
257+
map (map (.nu_fission_xs)) diff_res,
258+
map (map (.total_xs)) diff_res,
259+
)

0 commit comments

Comments
 (0)