@@ -13,7 +13,14 @@ include("reduction_identities.jl")
1313A struct that contains the scheduled diagnostics, ancillary data and areas of memory needed
1414to store and accumulate results.
1515"""
16- struct DiagnosticsHandler{SD, V <: Vector{Int} , STORAGE, ACC <: Dict , COUNT}
16+ struct DiagnosticsHandler{
17+ SD,
18+ V <: Vector{Int} ,
19+ STORAGE,
20+ ACC <: Dict ,
21+ COUNT,
22+ BROAD,
23+ }
1724 """ An iterable with the `ScheduledDiagnostic`s that are scheduled."""
1825 scheduled_diagnostics:: SD
1926
@@ -31,6 +38,11 @@ struct DiagnosticsHandler{SD, V <: Vector{Int}, STORAGE, ACC <: Dict, COUNT}
3138 """ Container holding a counter that tracks how many times the given
3239 diagnostics was computed from the last time it was output to disk."""
3340 counters:: COUNT
41+
42+ """ Dictionary that maps a given `ScheduledDiagnostic` to a Base.Broadcast.Broadcasted
43+ object. This is used to allow lazy evaluation of expressions, which can lead to reduced
44+ code verbosity and improved performance."""
45+ broadcasted_expressions:: BROAD
3446end
3547
3648"""
@@ -78,6 +90,10 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
7890 counters = Int[]
7991 scheduled_diagnostics_keys = Int[]
8092
93+ # broadcasted_expressions contains the un-evaluated expression, which are then moved to
94+ # storage
95+ broadcasted_expressions = Dict ()
96+
8197 # NOTE: unique requires isequal and hash to both be implemented. We don't
8298 # really want to do that. So, we roll our own unique. This is O(N^2) but it
8399 # is run only once, so it should be fine.
@@ -119,11 +135,31 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
119135 isa_time_reduction = ! isnothing (diag. reduction_time_func)
120136
121137 # The first time we call compute! we use its return value. All the subsequent times
122- # (in the callbacks), we will write the result in place. We call copy to acquire ownership
123- # of the data in case compute! returned a reference.
124- push! (storage, copy (variable. compute! (nothing , Y, p, t)))
125- push! (counters, 1 )
138+ # (in the callbacks), we will write the result in place.
139+ #
140+ # ClimaDiagnostics supports LazyBroadcast.jl. In this case, the return value of
141+ # `compute!` is a `Base.Broadcast.Broadcasted` and we have to manually materialize
142+ # the result. When using LazyBroadcast.jl, compute! does not need to have the `out`
143+ # argument, so we have to check how many arguments to pass.
144+
145+ three_args = (Y, p, t)
126146
147+ full_args =
148+ hasmethod (variable. compute!, typeof .(three_args)) ? three_args :
149+ (nothing , three_args... )
150+
151+ out_or_broadcasted_expr = variable. compute! (full_args... )
152+
153+ if out_or_broadcasted_expr isa Base. Broadcast. Broadcasted
154+ broadcasted_expressions[diag] = out_or_broadcasted_expr
155+ push! (storage, Base. Broadcast. materialize (out_or_broadcasted_expr))
156+ else
157+ # We call copy to acquire ownership of the data in case compute! returned a
158+ # reference.
159+ push! (storage, copy (out_or_broadcasted_expr))
160+ end
161+
162+ push! (counters, 1 )
127163 # If it is not a reduction, call the output writer as well
128164 if ! isa_time_reduction
129165 interpolate_field! (diag. output_writer, storage[i], diag, Y, p, t)
@@ -148,6 +184,7 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
148184 storage,
149185 accumulators,
150186 counters,
187+ broadcasted_expressions,
151188 )
152189end
153190
@@ -175,6 +212,10 @@ function orchestrate_diagnostics(
175212 active_output = Bool[]
176213 active_sync = Bool[]
177214
215+ # Used in the compute loop
216+ three_args_type =
217+ (typeof (integrator. u), typeof (integrator. p), typeof (integrator. t))
218+
178219 for diag in scheduled_diagnostics
179220 push! (active_compute, diag. compute_schedule_func (integrator))
180221 push! (active_output, diag. output_schedule_func (integrator))
@@ -186,13 +227,52 @@ function orchestrate_diagnostics(
186227 active_compute[diag_index] || continue
187228 diag = scheduled_diagnostics[diag_index]
188229
189- diag. variable. compute! (
230+ diagnostic_handler. counters[diag_index] += 1
231+
232+ # ClimaDiagnostics supports LazyBroadcast.jl. When used, the return value of
233+ # `compute!` is a `Base.Broadcast.Broadcasted`. We materialize the output to
234+ # diagnostic_handler.storage[diag] in the next for loop. If the output is not a
235+ # `Base.Broadcast.Broadcasted`, we don't have to do anything because compute! will
236+ # already update diagnostic_handler.storage[diag]. Here, too, we have to check if
237+ # compute! is the three or four argument variant.
238+ #
239+ # Here we use a more verbose way of writing this to avoid working with tuples with
240+ # very complex types (that might increase compilation costs).
241+
242+ if hasmethod (diag. variable. compute!, three_args_type)
243+ out_or_broadcasted_expr =
244+ diag. variable. compute! (integrator. u, integrator. p, integrator. t)
245+ else
246+ out_or_broadcasted_expr = diag. variable. compute! (
247+ diagnostic_handler. storage[diag_index],
248+ integrator. u,
249+ integrator. p,
250+ integrator. t,
251+ )
252+ end
253+
254+ if out_or_broadcasted_expr isa Base. Broadcast. Broadcasted
255+ diagnostic_handler. broadcasted_expressions[diag] =
256+ out_or_broadcasted_expr
257+ end
258+ end
259+
260+ # Evaluate the lazy compute (aka, materialize everything)
261+ for diag_index in 1 : length (scheduled_diagnostics)
262+ active_compute[diag_index] || continue
263+ diag = scheduled_diagnostics[diag_index]
264+ haskey (diagnostic_handler. broadcasted_expressions, diag) || continue
265+
266+ Base. Broadcast. materialize! (
190267 diagnostic_handler. storage[diag_index],
191- integrator. u,
192- integrator. p,
193- integrator. t,
268+ diagnostic_handler. broadcasted_expressions[diag],
194269 )
195- diagnostic_handler. counters[diag_index] += 1
270+ end
271+
272+ # Process possible time reductions (now we have evaluated storage[diag])
273+ for diag_index in 1 : length (scheduled_diagnostics)
274+ active_compute[diag_index] || continue
275+ diag = scheduled_diagnostics[diag_index]
196276
197277 isa_time_reduction = ! isnothing (diag. reduction_time_func)
198278 if isa_time_reduction
0 commit comments