Skip to content

Commit c6b7525

Browse files
committed
Add support for unevaluated compute! functions
`LazyBroadcast.jl` provides a way to return an unevaluated function. This is useful in two cases: 1. reduce code verbosity to handle the `isnothing(out)` case 2. allow clustering all the broadcasted expressions in a single place In turn, 2. is useful because it is the first step in fusing different broadcasted calls. This commit adds support for such functions.
1 parent 19005a0 commit c6b7525

File tree

4 files changed

+76
-6
lines changed

4 files changed

+76
-6
lines changed

.buildkite/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
1010
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1111
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1212
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
13+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
1314
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1415
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
1516
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ ClimaCore = "0.13.4, 0.14"
2020
ClimaTimeSteppers = "0.7.10"
2121
Dates = "1"
2222
Documenter = "1"
23+
LazyBroadcast = "0.1.3"
2324
JuliaFormatter = "1"
2425
NCDatasets = "0.13.1, 0.14"
2526
Profile = "1"
@@ -35,10 +36,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3536
ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
3637
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3738
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
39+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
3840
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
3941
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
4042
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4143
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4244

4345
[targets]
44-
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "JuliaFormatter", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]
46+
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "JuliaFormatter", "LazyBroadcast", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]

src/clima_diagnostics.jl

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct DiagnosticsHandler{
1818
STORAGE <: Dict,
1919
ACC <: Dict,
2020
COUNT <: Dict,
21+
BROAD <: Dict,
2122
}
2223
"""An iterable with the `ScheduledDiagnostic`s that are scheduled."""
2324
scheduled_diagnostics::SD
@@ -34,6 +35,11 @@ struct DiagnosticsHandler{
3435
many times the given diagnostics was computed from the last time it was output to
3536
disk."""
3637
counters::COUNT
38+
39+
"""Dictionary that maps a given `ScheduledDiagnostic` to a Base.Broadcast.Broadcasted
40+
expression. This is used to allow lazy evaluation of expressions, which can lead to
41+
reduce code verbosity and improved performance."""
42+
broadcasted_expressions::BROAD
3743
end
3844

3945
"""
@@ -57,9 +63,11 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
5763

5864
# For diagnostics that perform reductions, the storage is used for the values computed
5965
# at each call. Reductions also save the accumulated value in accumulators.
66+
# broadcasted_expressions maps diagnostics with LazyBroadcast objects.
6067
storage = Dict()
6168
accumulators = Dict()
6269
counters = Dict()
70+
broadcasted_expressions = Dict()
6371

6472
unique_scheduled_diagnostics = Tuple(unique(scheduled_diagnostics))
6573
if length(unique_scheduled_diagnostics) != length(scheduled_diagnostics)
@@ -90,8 +98,18 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
9098
isa_time_reduction = !isnothing(diag.reduction_time_func)
9199

92100
# The first time we call compute! we use its return value. All the subsequent times
93-
# (in the callbacks), we will write the result in place
94-
storage[diag] = variable.compute!(nothing, Y, p, t)
101+
# (in the callbacks), we will write the result in place. ClimaDiagnostics supports
102+
# LazyBroadcast.jl. In this case, the return value of `compute!` is a
103+
# `Base.Broadcast.Broadcasted` and we have to manually materialize the result.
104+
out_or_broadcasted_expr = variable.compute!(nothing, Y, p, t)
105+
if out_or_broadcasted_expr isa Base.Broadcast.Broadcasted
106+
broadcasted_expressions[diag] = out_or_broadcasted_expr
107+
storage[diag] =
108+
Base.Broadcast.materialize(broadcasted_expressions[diag])
109+
else
110+
storage[diag] = out_or_broadcasted_expr
111+
end
112+
95113
counters[diag] = 1
96114

97115
# If it is not a reduction, call the output writer as well
@@ -115,6 +133,7 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
115133
storage,
116134
accumulators,
117135
counters,
136+
broadcasted_expressions,
118137
)
119138
end
120139

@@ -153,13 +172,40 @@ function orchestrate_diagnostics(
153172
active_compute[diag_index] || continue
154173
diag = scheduled_diagnostics[diag_index]
155174

156-
diag.variable.compute!(
175+
diagnostic_handler.counters[diag] += 1
176+
177+
# ClimaDiagnostics supports LazyBroadcast.jl. When used, the return value of
178+
# `compute!` is a `Base.Broadcast.Broadcasted`. We materialize the output to
179+
# diagnostic_handler.storage[diag] in the next for loop.
180+
181+
out_or_broadcasted_expr = diag.variable.compute!(
157182
diagnostic_handler.storage[diag],
158183
integrator.u,
159184
integrator.p,
160185
integrator.t,
161186
)
162-
diagnostic_handler.counters[diag] += 1
187+
if out_or_broadcasted_expr isa Base.Broadcast.Broadcasted
188+
diagnostic_handler.broadcasted_expressions[diag] =
189+
out_or_broadcasted_expr
190+
end
191+
end
192+
193+
# Evaluate the lazy compute (aka, materialize everything)
194+
for diag_index in 1:length(scheduled_diagnostics)
195+
active_compute[diag_index] || continue
196+
diag = scheduled_diagnostics[diag_index]
197+
haskey(diagnostic_handler.broadcasted_expressions, diag) || continue
198+
199+
Base.Broadcast.materialize!(
200+
diagnostic_handler.storage[diag],
201+
diagnostic_handler.broadcasted_expressions[diag],
202+
)
203+
end
204+
205+
# Process possible time reductions (now we have evaluated storage[diag])
206+
for diag_index in 1:length(scheduled_diagnostics)
207+
active_compute[diag_index] || continue
208+
diag = scheduled_diagnostics[diag_index]
163209

164210
isa_time_reduction = !isnothing(diag.reduction_time_func)
165211
if isa_time_reduction

test/integration_test.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import ClimaComms
1212
ClimaComms.@import_required_backends
1313
end
1414

15+
import LazyBroadcast: @lazy
16+
1517
const context = ClimaComms.context()
1618
ClimaComms.init(context)
1719

@@ -41,7 +43,15 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
4143
)
4244

4345
function compute_my_var!(out, u, p, t)
44-
ClimaDiagnostics.@assign out copy(u.my_var)
46+
if isnothing(out)
47+
return copy(u.my_var)
48+
else
49+
out .= u.my_var
50+
end
51+
end
52+
53+
function compute_my_var_lazy!(out, u, p, t)
54+
return @lazy @. out = copy(u.my_var)
4555
end
4656

4757
simple_var = ClimaDiagnostics.DiagnosticVariable(;
@@ -50,6 +60,12 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
5060
long_name = "YO YO",
5161
)
5262

63+
simple_var_lazy = ClimaDiagnostics.DiagnosticVariable(;
64+
compute! = compute_my_var_lazy!,
65+
short_name = "YO LAZY",
66+
long_name = "YO YO LAZY",
67+
)
68+
5369
average_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
5470
variable = simple_var,
5571
output_writer = nc_writer,
@@ -61,6 +77,10 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
6177
variable = simple_var,
6278
output_writer = nc_writer,
6379
)
80+
inst_diagnostic_lazy = ClimaDiagnostics.ScheduledDiagnostic(
81+
variable = simple_var_lazy,
82+
output_writer = nc_writer,
83+
)
6484
inst_every3s_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
6585
variable = simple_var,
6686
output_writer = nc_writer,
@@ -76,6 +96,7 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
7696
scheduled_diagnostics = [
7797
average_diagnostic,
7898
inst_diagnostic,
99+
inst_diagnostic_lazy,
79100
inst_diagnostic_h5,
80101
inst_every3s_diagnostic,
81102
]

0 commit comments

Comments
 (0)