@@ -18,6 +18,7 @@ struct DiagnosticsHandler{
1818 STORAGE <: Dict ,
1919 ACC <: Dict ,
2020 COUNT <: Dict ,
21+ BROAD <: Dict ,
2122}
2223 """ An iterable with the `ScheduledDiagnostic`s that are scheduled."""
2324 scheduled_diagnostics:: SD
@@ -34,6 +35,11 @@ struct DiagnosticsHandler{
3435 many times the given diagnostics was computed from the last time it was output to
3536 disk."""
3637 counters:: COUNT
38+
39+ """ Dictionary that maps a given `ScheduledDiagnostic` to a Base.Broadcast.Broadcasted
40+ expression. This is used to allow lazy evaluation of expressions, which can lead to
41+ reduce code verbosity and improved performance."""
42+ broadcasted_expressions:: BROAD
3743end
3844
3945"""
@@ -57,9 +63,11 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
5763
5864 # For diagnostics that perform reductions, the storage is used for the values computed
5965 # at each call. Reductions also save the accumulated value in accumulators.
66+ # broadcasted_expressions maps diagnostics with LazyBroadcast objects.
6067 storage = Dict ()
6168 accumulators = Dict ()
6269 counters = Dict ()
70+ broadcasted_expressions = Dict ()
6371
6472 unique_scheduled_diagnostics = Tuple (unique (scheduled_diagnostics))
6573 if length (unique_scheduled_diagnostics) != length (scheduled_diagnostics)
@@ -90,8 +98,18 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
9098 isa_time_reduction = ! isnothing (diag. reduction_time_func)
9199
92100 # The first time we call compute! we use its return value. All the subsequent times
93- # (in the callbacks), we will write the result in place
94- storage[diag] = variable. compute! (nothing , Y, p, t)
101+ # (in the callbacks), we will write the result in place. ClimaDiagnostics supports
102+ # LazyBroadcast.jl. In this case, the return value of `compute!` is a
103+ # `Base.Broadcast.Broadcasted` and we have to manually materialize the result.
104+ out_or_broadcasted_expr = variable. compute! (nothing , Y, p, t)
105+ if out_or_broadcasted_expr isa Base. Broadcast. Broadcasted
106+ broadcasted_expressions[diag] = out_or_broadcasted_expr
107+ storage[diag] =
108+ Base. Broadcast. materialize (broadcasted_expressions[diag])
109+ else
110+ storage[diag] = out_or_broadcasted_expr
111+ end
112+
95113 counters[diag] = 1
96114
97115 # If it is not a reduction, call the output writer as well
@@ -115,6 +133,7 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
115133 storage,
116134 accumulators,
117135 counters,
136+ broadcasted_expressions,
118137 )
119138end
120139
@@ -153,13 +172,40 @@ function orchestrate_diagnostics(
153172 active_compute[diag_index] || continue
154173 diag = scheduled_diagnostics[diag_index]
155174
156- diag. variable. compute! (
175+ diagnostic_handler. counters[diag] += 1
176+
177+ # ClimaDiagnostics supports LazyBroadcast.jl. When used, the return value of
178+ # `compute!` is a `Base.Broadcast.Broadcasted`. We materialize the output to
179+ # diagnostic_handler.storage[diag] in the next for loop.
180+
181+ out_or_broadcasted_expr = diag. variable. compute! (
157182 diagnostic_handler. storage[diag],
158183 integrator. u,
159184 integrator. p,
160185 integrator. t,
161186 )
162- diagnostic_handler. counters[diag] += 1
187+ if out_or_broadcasted_expr isa Base. Broadcast. Broadcasted
188+ diagnostic_handler. broadcasted_expressions[diag] =
189+ out_or_broadcasted_expr
190+ end
191+ end
192+
193+ # Evaluate the lazy compute (aka, materialize everything)
194+ for diag_index in 1 : length (scheduled_diagnostics)
195+ active_compute[diag_index] || continue
196+ diag = scheduled_diagnostics[diag_index]
197+ haskey (diagnostic_handler. broadcasted_expressions, diag) || continue
198+
199+ Base. Broadcast. materialize! (
200+ diagnostic_handler. storage[diag],
201+ diagnostic_handler. broadcasted_expressions[diag],
202+ )
203+ end
204+
205+ # Process possible time reductions (now we have evaluated storage[diag])
206+ for diag_index in 1 : length (scheduled_diagnostics)
207+ active_compute[diag_index] || continue
208+ diag = scheduled_diagnostics[diag_index]
163209
164210 isa_time_reduction = ! isnothing (diag. reduction_time_func)
165211 if isa_time_reduction
0 commit comments