Skip to content

Commit 93ec15f

Browse files
committed
Add support for unevaluated compute! functions
`LazyBroadcast.jl` provides a way to return an unevaluated function. This is useful in two cases: 1. reduce code verbosity to handle the `isnothing(out)` case 2. allow clustering all the broadcasted expressions in a single place In turn, 2. is useful because it is the first step in fusing different broadcasted calls. This commit adds support for such functions.
1 parent 20195df commit 93ec15f

File tree

8 files changed

+199
-22
lines changed

8 files changed

+199
-22
lines changed

.buildkite/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1212
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1313
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
1414
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
15+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
1516
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1617
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
1718
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

NEWS.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,37 @@
11
# NEWS
2+
3+
v0.2.13
4+
-------
5+
6+
## Features
7+
8+
### Support for `@lazy`
9+
10+
Starting version `0.2.13`, `ClimaDiagnostics` supports diagnostic variables
11+
specified with un-evaluated expressions (as provided by
12+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)).
13+
14+
Instead of
15+
```julia
16+
function compute_ta!(out, state, cache, time)
17+
if isnothing(out)
18+
return state.ta
19+
else
20+
out .= state.ta
21+
end
22+
end
23+
```
24+
25+
You can now write
26+
```julia
27+
import LazyBroadcast: @lazy
28+
29+
function compute_ta!(state, cache, time)
30+
return @lazy @. state.ta
31+
end
32+
```
33+
34+
235
v0.2.12
336
-------
437

@@ -110,6 +143,7 @@ v0.2.4
110143

111144
- Add `EveryCalendarDtSchedule` for schedules with calendar periods.
112145

146+
113147
v0.2.3
114148
-------
115149

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaDiagnostics"
22
uuid = "1ecacbb8-0713-4841-9a07-eb5aa8a2d53f"
33
authors = ["Gabriele Bozzola <[email protected]>"]
4-
version = "0.2.12"
4+
version = "0.2.13"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -24,6 +24,7 @@ ClimaUtilities = "0.1.22"
2424
Dates = "1"
2525
Documenter = "1"
2626
ExplicitImports = "1.6"
27+
LazyBroadcast = "0.1.4"
2728
JuliaFormatter = "1"
2829
NCDatasets = "0.14"
2930
OrderedCollections = "1.4"
@@ -41,10 +42,11 @@ ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
4142
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4243
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
4344
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
45+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
4446
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
4547
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
4648
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4749
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4850

4951
[targets]
50-
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "ExplicitImports", "JuliaFormatter", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]
52+
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "ExplicitImports", "JuliaFormatter", "LazyBroadcast", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]

docs/src/developer_guide.md

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,27 @@ add_diagnostic_variable!(
246246
)
247247
```
248248

249+
When writing compute functions, consider using [`@assign`](@ref) to simplify
250+
your code and, when possible, making your expression lazy with
251+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl) to further improve
252+
clarity and performance. To do that, add `LazyBroadcast` to your dependencies
253+
and import `@lazy`. The previous example would look like:
254+
255+
```julia
256+
###
257+
# Density (3d)
258+
###
259+
add_diagnostic_variable!(
260+
short_name = "rhoa",
261+
long_name = "Air Density",
262+
standard_name = "air_density",
263+
units = "kg m^-3",
264+
compute! = (out, state, cache, time) -> begin
265+
return @lazy @. state.c.ρ
266+
end,
267+
)
268+
```
269+
249270
It is a good idea to put safeguards in place to ensure that your users will not
250271
be allowed to call diagnostics that do not make sense for the simulation they
251272
are running. If your package has a notion of `Model` that is stored in `p`, you
@@ -269,11 +290,7 @@ function compute_hus!(
269290
time,
270291
moisture_model::T,
271292
) where {T <: Union{EquilMoistModel, NonEquilMoistModel}}
272-
if isnothing(out)
273-
return state.c.ρq_tot ./ state.c.ρ
274-
else
275-
out .= state.c.ρq_tot ./ state.c.ρ
276-
end
293+
@assign out state.c.ρq_tot ./ state.c.ρ
277294
end
278295

279296
add_diagnostic_variable!(

docs/src/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ the Developer guide page.
2424
- Allow users to define arbitrary new diagnostics;
2525
- Trigger diagnostics on arbitrary conditions;
2626
- Save output to HDF5 or NetCDF files, or a dictionary in memory;
27-
27+
- Work with lazy expressions (such as the ones produced by
28+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)).
2829

docs/src/user_guide.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ var = DiagnosticVariable(;
5858
`compute_ta!` is the key function here. It determines how the variable should be
5959
computed from the `state`, `cache`, and `time` of the simulation. Typically,
6060
these are packaged within an `integrator` object (e.g., `state = integrator.u`
61-
or `integrator.Y`).
61+
or `integrator.Y`). `copy` is needed because we want to return a new area of
62+
memory (without `copy`, `ClimaDiagnostics` might end up modifying `state.ta`).
6263

6364
`compute_ta!` takes another argument, `out`. `out` is an area of memory managed
6465
by `ClimaDiagnostics` that is used to reduce the number of allocations needed
@@ -67,9 +68,18 @@ of memory is allocated and filled with the value (this is when `out` is
6768
`nothing`). All the subsequent times, the same space is overwritten, leading to
6869
much better performance. You should follow this pattern in all your diagnostics.
6970

70-
> Note, in the future, we hope to improve this rather clumsy way to write
71-
> diagnostics. Hopefully, at some point you will just have to write something like
72-
> `state.ta` and not worry about the `out` at all.
71+
`ClimaDiagnostics` supports working with unevaluated expressions represented by
72+
`Base.Broadcast.Broadcasted` objects, such as the ones produced with
73+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl). Using
74+
`LazyBroadcast.jl`, the snippet above can be rewritten as
75+
```julia
76+
import LazyBroadcast: @lazy
77+
78+
function compute_ta!(state, cache, time)
79+
return @lazy @. state.ta
80+
end
81+
```
82+
Using lazy expressions can lead to improved performance and clearer code.
7383

7484
A `DiagnosticVariable` defines what a variable is and how to compute it, but
7585
does not specify when to compute/output it. For that, we need

src/clima_diagnostics.jl

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@ include("reduction_identities.jl")
1313
A struct that contains the scheduled diagnostics, ancillary data and areas of memory needed
1414
to store and accumulate results.
1515
"""
16-
struct DiagnosticsHandler{SD, V <: Vector{Int}, STORAGE, ACC <: Dict, COUNT}
16+
struct DiagnosticsHandler{
17+
SD,
18+
V <: Vector{Int},
19+
STORAGE,
20+
ACC <: Dict,
21+
COUNT,
22+
BROAD,
23+
}
1724
"""An iterable with the `ScheduledDiagnostic`s that are scheduled."""
1825
scheduled_diagnostics::SD
1926

@@ -31,6 +38,11 @@ struct DiagnosticsHandler{SD, V <: Vector{Int}, STORAGE, ACC <: Dict, COUNT}
3138
"""Container holding a counter that tracks how many times the given
3239
diagnostics was computed from the last time it was output to disk."""
3340
counters::COUNT
41+
42+
"""Dictionary that maps a given `ScheduledDiagnostic` to a Base.Broadcast.Broadcasted
43+
object. This is used to allow lazy evaluation of expressions, which can lead to reduced
44+
code verbosity and improved performance."""
45+
broadcasted_expressions::BROAD
3446
end
3547

3648
"""
@@ -78,6 +90,10 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
7890
counters = Int[]
7991
scheduled_diagnostics_keys = Int[]
8092

93+
# broadcasted_expressions contains the un-evaluated expression, which are then moved to
94+
# storage
95+
broadcasted_expressions = Dict()
96+
8197
# NOTE: unique requires isequal and hash to both be implemented. We don't
8298
# really want to do that. So, we roll our own unique. This is O(N^2) but it
8399
# is run only once, so it should be fine.
@@ -119,11 +135,31 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
119135
isa_time_reduction = !isnothing(diag.reduction_time_func)
120136

121137
# The first time we call compute! we use its return value. All the subsequent times
122-
# (in the callbacks), we will write the result in place. We call copy to acquire ownership
123-
# of the data in case compute! returned a reference.
124-
push!(storage, copy(variable.compute!(nothing, Y, p, t)))
125-
push!(counters, 1)
138+
# (in the callbacks), we will write the result in place.
139+
#
140+
# ClimaDiagnostics supports LazyBroadcast.jl. In this case, the return value of
141+
# `compute!` is a `Base.Broadcast.Broadcasted` and we have to manually materialize
142+
# the result. When using LazyBroadcast.jl, compute! does not need to have the `out`
143+
# argument, so we have to check how many arguments to pass.
144+
145+
three_args = (Y, p, t)
126146

147+
full_args =
148+
hasmethod(variable.compute!, typeof.(three_args)) ? three_args :
149+
(nothing, three_args...)
150+
151+
out_or_broadcasted_expr = variable.compute!(full_args...)
152+
153+
if out_or_broadcasted_expr isa Base.Broadcast.Broadcasted
154+
broadcasted_expressions[diag] = out_or_broadcasted_expr
155+
push!(storage, Base.Broadcast.materialize(out_or_broadcasted_expr))
156+
else
157+
# We call copy to acquire ownership of the data in case compute! returned a
158+
# reference.
159+
push!(storage, copy(out_or_broadcasted_expr))
160+
end
161+
162+
push!(counters, 1)
127163
# If it is not a reduction, call the output writer as well
128164
if !isa_time_reduction
129165
interpolate_field!(diag.output_writer, storage[i], diag, Y, p, t)
@@ -148,6 +184,7 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
148184
storage,
149185
accumulators,
150186
counters,
187+
broadcasted_expressions,
151188
)
152189
end
153190

@@ -175,6 +212,10 @@ function orchestrate_diagnostics(
175212
active_output = Bool[]
176213
active_sync = Bool[]
177214

215+
# Used in the compute loop
216+
three_args_type =
217+
(typeof(integrator.u), typeof(integrator.p), typeof(integrator.t))
218+
178219
for diag in scheduled_diagnostics
179220
push!(active_compute, diag.compute_schedule_func(integrator))
180221
push!(active_output, diag.output_schedule_func(integrator))
@@ -186,13 +227,52 @@ function orchestrate_diagnostics(
186227
active_compute[diag_index] || continue
187228
diag = scheduled_diagnostics[diag_index]
188229

189-
diag.variable.compute!(
230+
diagnostic_handler.counters[diag_index] += 1
231+
232+
# ClimaDiagnostics supports LazyBroadcast.jl. When used, the return value of
233+
# `compute!` is a `Base.Broadcast.Broadcasted`. We materialize the output to
234+
# diagnostic_handler.storage[diag] in the next for loop. If the output is not a
235+
# `Base.Broadcast.Broadcasted`, we don't have to do anything because compute! will
236+
# already update diagnostic_handler.storage[diag]. Here, too, we have to check if
237+
# compute! is the three or four argument variant.
238+
#
239+
# Here we use a more verbose way of writing this to avoid working with tuples with
240+
# very complex types (that might increase compilation costs).
241+
242+
if hasmethod(diag.variable.compute!, three_args_type)
243+
out_or_broadcasted_expr =
244+
diag.variable.compute!(integrator.u, integrator.p, integrator.t)
245+
else
246+
out_or_broadcasted_expr = diag.variable.compute!(
247+
diagnostic_handler.storage[diag_index],
248+
integrator.u,
249+
integrator.p,
250+
integrator.t,
251+
)
252+
end
253+
254+
if out_or_broadcasted_expr isa Base.Broadcast.Broadcasted
255+
diagnostic_handler.broadcasted_expressions[diag] =
256+
out_or_broadcasted_expr
257+
end
258+
end
259+
260+
# Evaluate the lazy compute (aka, materialize everything)
261+
for diag_index in 1:length(scheduled_diagnostics)
262+
active_compute[diag_index] || continue
263+
diag = scheduled_diagnostics[diag_index]
264+
haskey(diagnostic_handler.broadcasted_expressions, diag) || continue
265+
266+
Base.Broadcast.materialize!(
190267
diagnostic_handler.storage[diag_index],
191-
integrator.u,
192-
integrator.p,
193-
integrator.t,
268+
diagnostic_handler.broadcasted_expressions[diag],
194269
)
195-
diagnostic_handler.counters[diag_index] += 1
270+
end
271+
272+
# Process possible time reductions (now we have evaluated storage[diag])
273+
for diag_index in 1:length(scheduled_diagnostics)
274+
active_compute[diag_index] || continue
275+
diag = scheduled_diagnostics[diag_index]
196276

197277
isa_time_reduction = !isnothing(diag.reduction_time_func)
198278
if isa_time_reduction

test/integration_test.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import ClimaComms
1414
ClimaComms.@import_required_backends
1515
end
1616

17+
import LazyBroadcast: @lazy
18+
1719
const context = ClimaComms.context()
1820
ClimaComms.init(context)
1921

@@ -52,12 +54,32 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
5254
end
5355
end
5456

57+
function compute_my_var_lazy!(out, u, p, t)
58+
return @lazy @. out = u.my_var
59+
end
60+
61+
function compute_my_var_lazy_three_args!(u, p, t)
62+
return @lazy @. u.my_var
63+
end
64+
5565
simple_var = ClimaDiagnostics.DiagnosticVariable(;
5666
compute! = compute_my_var!,
5767
short_name = "YO",
5868
long_name = "YO YO",
5969
)
6070

71+
simple_var_lazy = ClimaDiagnostics.DiagnosticVariable(;
72+
compute! = compute_my_var_lazy!,
73+
short_name = "YO LAZY",
74+
long_name = "YO YO LAZY",
75+
)
76+
77+
simple_var_lazy_three = ClimaDiagnostics.DiagnosticVariable(;
78+
compute! = compute_my_var_lazy_three_args!,
79+
short_name = "YO LAZY THREE",
80+
long_name = "YO YO LAZY THREE",
81+
)
82+
6183
average_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
6284
variable = simple_var,
6385
output_writer = nc_writer,
@@ -69,6 +91,14 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
6991
variable = simple_var,
7092
output_writer = nc_writer,
7193
)
94+
inst_diagnostic_lazy = ClimaDiagnostics.ScheduledDiagnostic(
95+
variable = simple_var_lazy,
96+
output_writer = nc_writer,
97+
)
98+
inst_diagnostic_lazy_three = ClimaDiagnostics.ScheduledDiagnostic(
99+
variable = simple_var_lazy_three,
100+
output_writer = nc_writer,
101+
)
72102
inst_every3s_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
73103
variable = simple_var,
74104
output_writer = nc_writer,
@@ -92,6 +122,8 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
92122
scheduled_diagnostics = [
93123
average_diagnostic,
94124
inst_diagnostic,
125+
inst_diagnostic_lazy,
126+
inst_diagnostic_lazy_three,
95127
inst_diagnostic_h5,
96128
inst_every3s_diagnostic,
97129
inst_every3s_diagnostic_another,

0 commit comments

Comments
 (0)