@@ -18,6 +18,7 @@ struct DiagnosticsHandler{
1818 STORAGE <: Dict ,
1919 ACC <: Dict ,
2020 COUNT <: Dict ,
21+ BROAD <: Dict ,
2122}
2223 """ An iterable with the `ScheduledDiagnostic`s that are scheduled."""
2324 scheduled_diagnostics:: SD
@@ -34,6 +35,11 @@ struct DiagnosticsHandler{
3435 many times the given diagnostics was computed from the last time it was output to
3536 disk."""
3637 counters:: COUNT
38+
39+ """ Dictionary that maps a given `ScheduledDiagnostic` to a Base.Broadcast.Broadcasted
40+ object. This is used to allow lazy evaluation of expressions, which can lead to reduced
41+ code verbosity and improved performance."""
42+ broadcasted_expressions:: BROAD
3743end
3844
3945"""
@@ -57,9 +63,11 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
5763
5864 # For diagnostics that perform reductions, the storage is used for the values computed
5965 # at each call. Reductions also save the accumulated value in accumulators.
66+ # broadcasted_expressions maps diagnostics with LazyBroadcast objects.
6067 storage = Dict ()
6168 accumulators = Dict ()
6269 counters = Dict ()
70+ broadcasted_expressions = Dict ()
6371
6472 unique_scheduled_diagnostics = Tuple (unique (scheduled_diagnostics))
6573 if length (unique_scheduled_diagnostics) != length (scheduled_diagnostics)
@@ -90,8 +98,18 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
9098 isa_time_reduction = ! isnothing (diag. reduction_time_func)
9199
92100 # The first time we call compute! we use its return value. All the subsequent times
93- # (in the callbacks), we will write the result in place
94- storage[diag] = variable. compute! (nothing , Y, p, t)
101+ # (in the callbacks), we will write the result in place. ClimaDiagnostics supports
102+ # LazyBroadcast.jl. In this case, the return value of `compute!` is a
103+ # `Base.Broadcast.Broadcasted` and we have to manually materialize the result.
104+ out_or_broadcasted_expr = variable. compute! (nothing , Y, p, t)
105+ if out_or_broadcasted_expr isa Base. Broadcast. Broadcasted
106+ broadcasted_expressions[diag] = out_or_broadcasted_expr
107+ storage[diag] =
108+ Base. Broadcast. materialize (broadcasted_expressions[diag])
109+ else
110+ storage[diag] = out_or_broadcasted_expr
111+ end
112+
95113 counters[diag] = 1
96114
97115 # If it is not a reduction, call the output writer as well
@@ -115,6 +133,7 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
115133 storage,
116134 accumulators,
117135 counters,
136+ broadcasted_expressions,
118137 )
119138end
120139
@@ -153,13 +172,42 @@ function orchestrate_diagnostics(
153172 active_compute[diag_index] || continue
154173 diag = scheduled_diagnostics[diag_index]
155174
156- diag. variable. compute! (
175+ diagnostic_handler. counters[diag] += 1
176+
177+ # ClimaDiagnostics supports LazyBroadcast.jl. When used, the return value of
178+ # `compute!` is a `Base.Broadcast.Broadcasted`. We materialize the output to
179+ # diagnostic_handler.storage[diag] in the next for loop. If the output is not a
180+ # `Base.Broadcast.Broadcasted`, we don't have to do anything because compute! will
181+ # already update diagnostic_handler.storage[diag].
182+
183+ out_or_broadcasted_expr = diag. variable. compute! (
157184 diagnostic_handler. storage[diag],
158185 integrator. u,
159186 integrator. p,
160187 integrator. t,
161188 )
162- diagnostic_handler. counters[diag] += 1
189+ if out_or_broadcasted_expr isa Base. Broadcast. Broadcasted
190+ diagnostic_handler. broadcasted_expressions[diag] =
191+ out_or_broadcasted_expr
192+ end
193+ end
194+
195+ # Evaluate the lazy compute (aka, materialize everything)
196+ for diag_index in 1 : length (scheduled_diagnostics)
197+ active_compute[diag_index] || continue
198+ diag = scheduled_diagnostics[diag_index]
199+ haskey (diagnostic_handler. broadcasted_expressions, diag) || continue
200+
201+ Base. Broadcast. materialize! (
202+ diagnostic_handler. storage[diag],
203+ diagnostic_handler. broadcasted_expressions[diag],
204+ )
205+ end
206+
207+ # Process possible time reductions (now we have evaluated storage[diag])
208+ for diag_index in 1 : length (scheduled_diagnostics)
209+ active_compute[diag_index] || continue
210+ diag = scheduled_diagnostics[diag_index]
163211
164212 isa_time_reduction = ! isnothing (diag. reduction_time_func)
165213 if isa_time_reduction
0 commit comments