Skip to content

Commit ba077ff

Browse files
committed
Add support for unevaluated compute! functions
`LazyBroadcast.jl` provides a way to return an unevaluated function. This is useful in two cases: 1. reduce code verbosity to handle the `isnothing(out)` case 2. allow clustering all the broadcasted expressions in a single place In turn, 2. is useful because it is the first step in fusing different broadcasted calls. This commit adds support for such functions.
1 parent 47d6f4b commit ba077ff

File tree

8 files changed

+118
-19
lines changed

8 files changed

+118
-19
lines changed

.buildkite/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
1010
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1111
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1212
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
13+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
1314
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1415
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
1516
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ v0.2.4
44
-------
55

66
- Add `@assign`.
7+
- Add support for lazy compute functions that return `Base.Broadcast.Broadcasted` objects.
78

89
v0.2.3
910
-------

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ ClimaCore = "0.13.4, 0.14"
2020
ClimaTimeSteppers = "0.7.10"
2121
Dates = "1"
2222
Documenter = "1"
23+
LazyBroadcast = "0.1.3"
2324
JuliaFormatter = "1"
2425
NCDatasets = "0.13.1, 0.14"
2526
Profile = "1"
@@ -35,10 +36,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3536
ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
3637
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3738
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
39+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
3840
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
3941
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
4042
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4143
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4244

4345
[targets]
44-
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "JuliaFormatter", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]
46+
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "JuliaFormatter", "LazyBroadcast", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]

docs/src/developer_guide.md

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,27 @@ add_diagnostic_variable!(
201201
)
202202
```
203203

204+
When writing compute functions, consider using [`@assign`](@ref) to simplify
205+
your code and, when possible, making your expression lazy with
206+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl) to further improve
207+
clarity and performance. To do that, add `LazyBroadcast` to your dependencies
208+
and import `@lazy`. The previous example would look like:
209+
210+
```julia
211+
###
212+
# Density (3d)
213+
###
214+
add_diagnostic_variable!(
215+
short_name = "rhoa",
216+
long_name = "Air Density",
217+
standard_name = "air_density",
218+
units = "kg m^-3",
219+
compute! = (out, state, cache, time) -> begin
220+
return @lazy @. state.c.ρ
221+
end,
222+
)
223+
```
224+
204225
It is a good idea to put safeguards in place to ensure that your users will not
205226
be allowed to call diagnostics that do not make sense for the simulation they
206227
are running. If your package has a notion of `Model` that is stored in `p`, you
@@ -224,11 +245,7 @@ function compute_hus!(
224245
time,
225246
moisture_model::T,
226247
) where {T <: Union{EquilMoistModel, NonEquilMoistModel}}
227-
if isnothing(out)
228-
return state.c.ρq_tot ./ state.c.ρ
229-
else
230-
out .= state.c.ρq_tot ./ state.c.ρ
231-
end
248+
@assign out state.c.ρq_tot ./ state.c.ρ
232249
end
233250

234251
add_diagnostic_variable!(

docs/src/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ the Developer guide page.
2323
- Allow users to define arbitrary new diagnostics;
2424
- Trigger diagnostics on arbitrary conditions;
2525
- Save output to HDF5 or NetCDF files, or a dictionary in memory;
26-
26+
- Work with lazy expressions (such as the ones produced by
27+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)).
2728

docs/src/user_guide.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ var = DiagnosticVariable(;
5858
`compute_ta!` is the key function here. It determines how the variable should be
5959
computed from the `state`, `cache`, and `time` of the simulation. Typically,
6060
these are packaged within an `integrator` object (e.g., `state = integrator.u`
61-
or `integrator.Y`).
61+
or `integrator.Y`). `copy` is needed because we want to return a new area of
62+
memory (without `copy`, `ClimaDiagnostics` might end up modifying `state.ta`).
6263

6364
`compute_ta!` takes another argument, `out`. `out` is an area of memory managed
6465
by `ClimaDiagnostics` that is used to reduce the number of allocations needed
@@ -69,9 +70,8 @@ overwritten, leading to much better performance. You should follow this pattern
6970
in all your diagnostics. To help you with that, `ClimaDiagnostics` provides a
7071
macro, `@assign`, that expands to the if switch, so that `compute_ta!` can be
7172
written as
72-
7373
```julia
74-
import ClimaDiagnostics: DiagnosticVariable, @assign
74+
import ClimaDiagnostics: @assign
7575

7676
function compute_ta!(out, state, cache, time)
7777
@assign out copy(state.ta) state.ta
@@ -80,10 +80,18 @@ end
8080
Note that when the second and third arguments of `@assign` are the same, one of
8181
the two can be omitted.
8282

83-
> Note, in the future, we hope to improve this rather clumsy way to write
84-
> diagnostics. Hopefully, at some point you will just have to write something like
85-
> `state.ta` and not worry about the `out` at all. For the time being, we recommend
86-
> using the `@assign` macro to help with possible future transitions.
83+
`ClimaDiagnostics` supports working with unevaluated expressions represented by
84+
`Base.Broadcast.Broadcasted` objects, such as the ones produced with
85+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl). Using
86+
`LazyBroadcast.jl`, the previous snippet can be rewritten as
87+
```julia
88+
import LazyBroadcast: @lazy
89+
90+
function compute_ta!(out, state, cache, time)
91+
@lazy @. out = state.ta
92+
end
93+
```
94+
Using lazy expressions can lead to improved performance and clearer code.
8795

8896
A `DiagnosticVariable` defines what a variable is and how to compute it, but
8997
does not specify when to compute/output it. For that, we need

src/clima_diagnostics.jl

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct DiagnosticsHandler{
1818
STORAGE <: Dict,
1919
ACC <: Dict,
2020
COUNT <: Dict,
21+
BROAD <: Dict,
2122
}
2223
"""An iterable with the `ScheduledDiagnostic`s that are scheduled."""
2324
scheduled_diagnostics::SD
@@ -34,6 +35,11 @@ struct DiagnosticsHandler{
3435
many times the given diagnostics was computed from the last time it was output to
3536
disk."""
3637
counters::COUNT
38+
39+
"""Dictionary that maps a given `ScheduledDiagnostic` to a Base.Broadcast.Broadcasted
40+
object. This is used to allow lazy evaluation of expressions, which can lead to reduced
41+
code verbosity and improved performance."""
42+
broadcasted_expressions::BROAD
3743
end
3844

3945
"""
@@ -57,9 +63,11 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
5763

5864
# For diagnostics that perform reductions, the storage is used for the values computed
5965
# at each call. Reductions also save the accumulated value in accumulators.
66+
# broadcasted_expressions maps diagnostics with LazyBroadcast objects.
6067
storage = Dict()
6168
accumulators = Dict()
6269
counters = Dict()
70+
broadcasted_expressions = Dict()
6371

6472
unique_scheduled_diagnostics = Tuple(unique(scheduled_diagnostics))
6573
if length(unique_scheduled_diagnostics) != length(scheduled_diagnostics)
@@ -90,8 +98,18 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
9098
isa_time_reduction = !isnothing(diag.reduction_time_func)
9199

92100
# The first time we call compute! we use its return value. All the subsequent times
93-
# (in the callbacks), we will write the result in place
94-
storage[diag] = variable.compute!(nothing, Y, p, t)
101+
# (in the callbacks), we will write the result in place. ClimaDiagnostics supports
102+
# LazyBroadcast.jl. In this case, the return value of `compute!` is a
103+
# `Base.Broadcast.Broadcasted` and we have to manually materialize the result.
104+
out_or_broadcasted_expr = variable.compute!(nothing, Y, p, t)
105+
if out_or_broadcasted_expr isa Base.Broadcast.Broadcasted
106+
broadcasted_expressions[diag] = out_or_broadcasted_expr
107+
storage[diag] =
108+
Base.Broadcast.materialize(broadcasted_expressions[diag])
109+
else
110+
storage[diag] = out_or_broadcasted_expr
111+
end
112+
95113
counters[diag] = 1
96114

97115
# If it is not a reduction, call the output writer as well
@@ -115,6 +133,7 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
115133
storage,
116134
accumulators,
117135
counters,
136+
broadcasted_expressions,
118137
)
119138
end
120139

@@ -153,13 +172,42 @@ function orchestrate_diagnostics(
153172
active_compute[diag_index] || continue
154173
diag = scheduled_diagnostics[diag_index]
155174

156-
diag.variable.compute!(
175+
diagnostic_handler.counters[diag] += 1
176+
177+
# ClimaDiagnostics supports LazyBroadcast.jl. When used, the return value of
178+
# `compute!` is a `Base.Broadcast.Broadcasted`. We materialize the output to
179+
# diagnostic_handler.storage[diag] in the next for loop. If the output is not a
180+
# `Base.Broadcast.Broadcasted`, we don't have to do anything because compute! will
181+
# already update diagnostic_handler.storage[diag].
182+
183+
out_or_broadcasted_expr = diag.variable.compute!(
157184
diagnostic_handler.storage[diag],
158185
integrator.u,
159186
integrator.p,
160187
integrator.t,
161188
)
162-
diagnostic_handler.counters[diag] += 1
189+
if out_or_broadcasted_expr isa Base.Broadcast.Broadcasted
190+
diagnostic_handler.broadcasted_expressions[diag] =
191+
out_or_broadcasted_expr
192+
end
193+
end
194+
195+
# Evaluate the lazy compute (aka, materialize everything)
196+
for diag_index in 1:length(scheduled_diagnostics)
197+
active_compute[diag_index] || continue
198+
diag = scheduled_diagnostics[diag_index]
199+
haskey(diagnostic_handler.broadcasted_expressions, diag) || continue
200+
201+
Base.Broadcast.materialize!(
202+
diagnostic_handler.storage[diag],
203+
diagnostic_handler.broadcasted_expressions[diag],
204+
)
205+
end
206+
207+
# Process possible time reductions (now we have evaluated storage[diag])
208+
for diag_index in 1:length(scheduled_diagnostics)
209+
active_compute[diag_index] || continue
210+
diag = scheduled_diagnostics[diag_index]
163211

164212
isa_time_reduction = !isnothing(diag.reduction_time_func)
165213
if isa_time_reduction

test/integration_test.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import ClimaComms
1212
ClimaComms.@import_required_backends
1313
end
1414

15+
import LazyBroadcast: @lazy
16+
1517
const context = ClimaComms.context()
1618
ClimaComms.init(context)
1719

@@ -41,7 +43,15 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
4143
)
4244

4345
function compute_my_var!(out, u, p, t)
44-
ClimaDiagnostics.@assign out copy(u.my_var)
46+
if isnothing(out)
47+
return copy(u.my_var)
48+
else
49+
out .= u.my_var
50+
end
51+
end
52+
53+
function compute_my_var_lazy!(out, u, p, t)
54+
return @lazy @. out = u.my_var
4555
end
4656

4757
simple_var = ClimaDiagnostics.DiagnosticVariable(;
@@ -50,6 +60,12 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
5060
long_name = "YO YO",
5161
)
5262

63+
simple_var_lazy = ClimaDiagnostics.DiagnosticVariable(;
64+
compute! = compute_my_var_lazy!,
65+
short_name = "YO LAZY",
66+
long_name = "YO YO LAZY",
67+
)
68+
5369
average_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
5470
variable = simple_var,
5571
output_writer = nc_writer,
@@ -61,6 +77,10 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
6177
variable = simple_var,
6278
output_writer = nc_writer,
6379
)
80+
inst_diagnostic_lazy = ClimaDiagnostics.ScheduledDiagnostic(
81+
variable = simple_var_lazy,
82+
output_writer = nc_writer,
83+
)
6484
inst_every3s_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
6585
variable = simple_var,
6686
output_writer = nc_writer,
@@ -76,6 +96,7 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
7696
scheduled_diagnostics = [
7797
average_diagnostic,
7898
inst_diagnostic,
99+
inst_diagnostic_lazy,
79100
inst_diagnostic_h5,
80101
inst_every3s_diagnostic,
81102
]

0 commit comments

Comments
 (0)