Skip to content

Commit d237e8f

Browse files
committed
Add support for unevaluated compute! functions
`LazyBroadcast.jl` provides a way to return an unevaluated function. This is useful in two cases: 1. reduce code verbosity to handle the `isnothing(out)` case 2. allow clustering all the broadcasted expressions in a single place In turn, 2. is useful because it is the first step in fusing different broadcasted calls. This commit adds support for such functions.
1 parent 20195df commit d237e8f

File tree

10 files changed

+319
-82
lines changed

10 files changed

+319
-82
lines changed

.buildkite/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1212
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1313
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
1414
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
15+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
1516
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1617
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
1718
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

NEWS.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,41 @@
11
# NEWS
2+
3+
v0.2.13
4+
-------
5+
6+
## Features
7+
8+
### Support for `lazy`
9+
10+
Starting version `0.2.13`, `ClimaDiagnostics` supports diagnostic variables
11+
specified with un-evaluated expressions (as provided by
12+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)).
13+
14+
Instead of
15+
```julia
16+
function compute_ta!(out, state, cache, time)
17+
if isnothing(out)
18+
return state.ta .- 273.15
19+
else
20+
out .= state.ta .- 273.15
21+
end
22+
end
23+
```
24+
You can now write
25+
```julia
26+
import LazyBroadcast: lazy
27+
28+
function compute_ta(state, cache, time)
29+
return lazy.(state.ta .- 273.15)
30+
end
31+
```
32+
Or, for `Field`s
33+
```julia
34+
function compute_ta(state, cache, time)
35+
return state.ta
36+
end
37+
```
38+
239
v0.2.12
340
-------
441

@@ -110,6 +147,7 @@ v0.2.4
110147

111148
- Add `EveryCalendarDtSchedule` for schedules with calendar periods.
112149

150+
113151
v0.2.3
114152
-------
115153

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaDiagnostics"
22
uuid = "1ecacbb8-0713-4841-9a07-eb5aa8a2d53f"
33
authors = ["Gabriele Bozzola <[email protected]>"]
4-
version = "0.2.12"
4+
version = "0.2.13"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -24,6 +24,7 @@ ClimaUtilities = "0.1.22"
2424
Dates = "1"
2525
Documenter = "1"
2626
ExplicitImports = "1.6"
27+
LazyBroadcast = "1"
2728
JuliaFormatter = "1"
2829
NCDatasets = "0.14"
2930
OrderedCollections = "1.4"
@@ -41,10 +42,11 @@ ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
4142
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4243
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
4344
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
45+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
4446
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
4547
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
4648
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4749
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4850

4951
[targets]
50-
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "ExplicitImports", "JuliaFormatter", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]
52+
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "ExplicitImports", "JuliaFormatter", "LazyBroadcast", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]

docs/src/developer_guide.md

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ Let us see the simplest example to accomplish this
105105
import ClimaDiagnostics: DiagnosticVariable, ScheduledDiagnostic
106106
import ClimaDiagnostics.Writers: DictWriter
107107

108-
myvar = DiagnosticVariable(; compute! = (out, u, p, t) -> u.var1)
108+
myvar = DiagnosticVariable(; compute = (u, p, t) -> u.var1)
109109

110110
myschedule = (integrator) -> maximum(integrator.u.var2) > 10.0
111111

@@ -148,7 +148,7 @@ const ALL_DIAGNOSTICS = Dict{String, DiagnosticVariable}()
148148
standard_name,
149149
units,
150150
description,
151-
compute!)
151+
compute)
152152
153153
154154
Add a new variable to the `ALL_DIAGNOSTICS` dictionary (this function mutates the state of
@@ -173,27 +173,25 @@ Keyword arguments
173173
- `comments`: More verbose explanation of what the variable is, or comments related to how
174174
it is defined or computed.
175175
176-
- `compute!`: Function that compute the diagnostic variable from the state. It has to take
177-
two arguments: the `integrator`, and a pre-allocated area of memory where to
178-
write the result of the computation. It the no pre-allocated area is
179-
available, a new one will be allocated. To avoid extra allocations, this
180-
function should perform the calculation in-place (i.e., using `.=`).
181-
176+
- `compute`: Function that computes the diagnostic variable from the state, cache, and time. The function
177+
should return a `Field` or a `Base.Broadcast.Broadcasted` expression. It should not allocate
178+
new `Field`: if you find yourself using a dot, that is a good indication you should be using
179+
`lazy`.
182180
"""
183181
function add_diagnostic_variable!(;
184182
short_name,
185183
long_name,
186184
standard_name = "",
187185
units,
188186
comments = "",
189-
compute!,
187+
compute,
190188
)
191189
haskey(ALL_DIAGNOSTICS, short_name) && @warn(
192190
"overwriting diagnostic `$short_name` entry containing fields\n" *
193191
"$(map(
194192
field -> "$(getfield(ALL_DIAGNOSTICS[short_name], field))",
195193
# We cannot really compare functions...
196-
filter(field -> field != :compute!, fieldnames(DiagnosticVariable)),
194+
filter(field -> !(field in (:compute!, :compute)), fieldnames(DiagnosticVariable)),
197195
))"
198196
)
199197

@@ -203,7 +201,7 @@ function add_diagnostic_variable!(;
203201
standard_name,
204202
units,
205203
comments,
206-
compute!,
204+
compute,
207205
)
208206

209207
"""
@@ -236,15 +234,30 @@ add_diagnostic_variable!(
236234
long_name = "Air Density",
237235
standard_name = "air_density",
238236
units = "kg m^-3",
239-
compute! = (out, state, cache, time) -> begin
240-
if isnothing(out)
241-
return state.c.ρ
242-
else
243-
out .= state.c.ρ
244-
end
245-
end,
237+
compute = (state, cache, time) -> state.c.ρ,
238+
)
239+
```
240+
241+
When writing compute functions, make them lazy with
242+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl) to improve
243+
performance and avoid intermediate allocations. To do that, add `LazyBroadcast`
244+
to your dependencies and import `lazy`. A slight variation of the previous
245+
example would look like
246+
247+
```julia
248+
###
249+
# Density (3d)
250+
###
251+
add_diagnostic_variable!(
252+
short_name = "rhoa",
253+
long_name = "Air Density",
254+
standard_name = "air_density",
255+
units = "kg m^-3",
256+
compute = (state, cache, time) -> lazy.(1000 .* state.c.ρ),
246257
)
247258
```
259+
Where we added the `1000` to simulate a more complex expression. If you didn't have
260+
`lazy`, the diagnostic would allocate an intermediate `Field`, severly hurting performance.
248261

249262
It is a good idea to put safeguards in place to ensure that your users will not
250263
be allowed to call diagnostics that do not make sense for the simulation they
@@ -254,26 +267,21 @@ can dispatch over that and return an error. A simple example might be
254267
###
255268
# Specific Humidity
256269
###
257-
compute_hus!(out, state, cache, time) =
258-
compute_hus!(out, state, cache, time, cache.atmos.moisture_model)
270+
compute_hus(state, cache, time) =
271+
compute_hus(state, cache, time, cache.atmos.moisture_model)
259272

260-
compute_hus!(out, state, cache, time) =
261-
compute_hus!(out, state, cache, time, cache.model.moisture_model)
262-
compute_hus!(_, _, _, _, model::T) where {T} =
273+
compute_hus(state, cache, time) =
274+
compute_hus!(state, cache, time, cache.model.moisture_model)
275+
compute_hus(_, _, _, model::T) where {T} =
263276
error("Cannot compute hus with $model")
264277

265-
function compute_hus!(
266-
out,
278+
function compute_hus(
267279
state,
268280
cache,
269281
time,
270282
moisture_model::T,
271283
) where {T <: Union{EquilMoistModel, NonEquilMoistModel}}
272-
if isnothing(out)
273-
return state.c.ρq_tot ./ state.c.ρ
274-
else
275-
out .= state.c.ρq_tot ./ state.c.ρ
276-
end
284+
return lazy.(state.c.ρq_tot ./ state.c.ρ)
277285
end
278286

279287
add_diagnostic_variable!(
@@ -282,7 +290,7 @@ add_diagnostic_variable!(
282290
standard_name = "specific_humidity",
283291
units = "kg kg^-1",
284292
comments = "Mass of all water phases per mass of air",
285-
compute! = compute_hus!,
293+
compute = compute_hus,
286294
)
287295
```
288296
This relies on dispatching over `moisture_model`. If `model` is not in

docs/src/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ the Developer guide page.
2424
- Allow users to define arbitrary new diagnostics;
2525
- Trigger diagnostics on arbitrary conditions;
2626
- Save output to HDF5 or NetCDF files, or a dictionary in memory;
27-
27+
- Work with lazy expressions (such as the ones produced by
28+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)).
2829

docs/src/user_guide.md

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,8 @@ Let us see how we would define a `DiagnosticVariable`
3737
```julia
3838
import ClimaDiagnostics: DiagnosticVariable
3939

40-
function compute_ta!(out, state, cache, time)
41-
if isnothing(out)
42-
return state.ta
43-
else
44-
out .= state.ta
45-
end
40+
function compute_ta(state, cache, time)
41+
return state.ta
4642
end
4743

4844
var = DiagnosticVariable(;
@@ -51,25 +47,67 @@ var = DiagnosticVariable(;
5147
standard_name = "air_temperature",
5248
comments = "Measured assuming that the air is in quantum equilibrium with the metaverse",
5349
units = "K",
54-
compute! = compute_ta!
50+
compute = compute_ta
5551
)
5652
```
5753

58-
`compute_ta!` is the key function here. It determines how the variable should be
54+
`compute_ta` is the key function here. It determines how the variable should be
5955
computed from the `state`, `cache`, and `time` of the simulation. Typically,
6056
these are packaged within an `integrator` object (e.g., `state = integrator.u`
61-
or `integrator.Y`).
62-
63-
`compute_ta!` takes another argument, `out`. `out` is an area of memory managed
64-
by `ClimaDiagnostics` that is used to reduce the number of allocations needed
65-
when working with diagnostics. The first time the diagnostic is called, an area
66-
of memory is allocated and filled with the value (this is when `out` is
67-
`nothing`). All the subsequent times, the same space is overwritten, leading to
68-
much better performance. You should follow this pattern in all your diagnostics.
69-
70-
> Note, in the future, we hope to improve this rather clumsy way to write
71-
> diagnostics. Hopefully, at some point you will just have to write something like
72-
> `state.ta` and not worry about the `out` at all.
57+
or `integrator.Y`). `copy` is needed because we want to return a new area of
58+
memory (without `copy`, `ClimaDiagnostics` might end up modifying `state.ta`).
59+
60+
!!! compat "ClimaDiagnostics 0.2.13"
61+
62+
Support for `compute` was introduced in version `0.2.13`. Prior to this
63+
version, the in-place `compute!` had to be provided. In this case, `compute`
64+
has to take an extra argument, `out`. `out` is an area of memory managed by
65+
`ClimaDiagnostics` that is used to reduce the number of allocations needed
66+
when working with diagnostics. The first time the diagnostic is called, an
67+
area of memory is allocated and filled with the value (this is when `out` is
68+
`nothing`). All the subsequent times, the same space is overwritten, leading
69+
to much better performance. You should follow this pattern in all your
70+
diagnostics. This is left to developer to implement, so `compute_ta` would
71+
look like
72+
73+
```julia
74+
function compute_ta!(out, state, cache, time)
75+
if isnothing(out)
76+
return state.ta
77+
else
78+
out .= state.ta
79+
end
80+
end
81+
```
82+
83+
In general, we do not recommend implementing `compute!`, unless required for
84+
backward compatibility.
85+
86+
When the expression is anything more complicated than just returning a `Field`,
87+
it is best to return an unevaluated expression represented by a
88+
`Base.Broadcast.Broadcasted` object (such as the ones produced with
89+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)). Consider the
90+
following example where we want to shift the temperature to Celsius:
91+
```julia
92+
function compute_ta(state, cache, time)
93+
return state.ta .- 273.15
94+
end
95+
```
96+
97+
This `compute` function is inefficient because it allocates an entire `Field`
98+
before returning it. Instead, we can return just a recipe on how the diagnostic
99+
should be compute: Using `LazyBroadcast.jl`, the snippet above can be rewritten
100+
as
101+
```julia
102+
import LazyBroadcast: lazy
103+
104+
function compute_ta(state, cache, time)
105+
return lazy.(state.ta .- 273.15)
106+
end
107+
```
108+
The return value of `compute_ta` is a `Base.Broadcast.Broadcasted` object and
109+
`ClimaDiagnostics` knows how to handle it efficiently avoiding the intermediate
110+
allocations.
73111

74112
A `DiagnosticVariable` defines what a variable is and how to compute it, but
75113
does not specify when to compute/output it. For that, we need

0 commit comments

Comments
 (0)