Skip to content

Commit

Permalink
Improve DiagnosticsHandler field concreteness
Browse files Browse the repository at this point in the history
Update src/clima_diagnostics.jl

Co-authored-by: Gabriele Bozzola <[email protected]>

Bump patch version
  • Loading branch information
charleskawczynski committed Sep 30, 2024
1 parent 7b08031 commit 2943dd4
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 35 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ClimaDiagnostics"
uuid = "1ecacbb8-0713-4841-9a07-eb5aa8a2d53f"
authors = ["Gabriele Bozzola <[email protected]>"]
version = "0.2.6"
version = "0.2.7"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
93 changes: 59 additions & 34 deletions src/clima_diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,43 @@ include("reduction_identities.jl")
A struct that contains the scheduled diagnostics, ancillary data and areas of memory needed
to store and accumulate results.
"""
struct DiagnosticsHandler{SD, STORAGE <: Dict, ACC <: Dict, COUNT <: Dict}
struct DiagnosticsHandler{SD, V <: Vector{Int}, STORAGE, ACC <: Dict, COUNT}
"""An iterable with the `ScheduledDiagnostic`s that are scheduled."""
scheduled_diagnostics::SD

"""Dictionary that maps a given `ScheduledDiagnostic` to a potentially pre-allocated
"""A Vector containing keys to index into `scheduled_diagnostics`."""
scheduled_diagnostics_keys::V

"""Container holding a potentially pre-allocated
area of memory where to save the newly computed results."""
storage::STORAGE

"""Dictionary that maps a given `ScheduledDiagnostic` to a potentially pre-allocated
"""Container holding a potentially pre-allocated
area of memory where to accumulate results."""
accumulators::ACC

"""Dictionary that maps a given `ScheduledDiagnostic` to the counter that tracks how
many times the given diagnostics was computed from the last time it was output to
disk."""
"""Container holding a counter that tracks how many times the given
diagnostics was computed from the last time it was output to disk."""
counters::COUNT
end

"""
value_types(
data;
value_map = unionall_type,
)
Given `data`, return a type `Union{V...}` where `V` are the `Union` of all found types in
the values of `data`.
"""
function value_types(data)
ret_types = Set([])
for k in eachindex(data)
push!(ret_types, typeof(data[k]))
end
return Union{ret_types...}
end

"""
DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
Expand All @@ -52,16 +71,18 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)

# For diagnostics that perform reductions, the storage is used for the values computed
# at each call. Reductions also save the accumulated value in accumulators.
storage = Dict()
accumulators = Dict()
counters = Dict()
storage = []
# Not all diagnostics need an accumulator, so we put them in a dictionary key-ed over the diagnostic index
accumulators = Dict{Int, Any}()
counters = Int[]
scheduled_diagnostics_keys = Int[]

unique_scheduled_diagnostics = unique(scheduled_diagnostics)
if length(unique_scheduled_diagnostics) != length(scheduled_diagnostics)
@warn "Given list of diagnostics contains duplicates, removing them"
end

for diag in unique_scheduled_diagnostics
for (i, diag) in enumerate(unique_scheduled_diagnostics)
if isnothing(dt)
@warn "dt was not passed to DiagnosticsHandler. No checks will be performed on the frequency of the diagnostics"
else
Expand All @@ -80,33 +101,37 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
)
end
end
push!(scheduled_diagnostics_keys, i)

variable = diag.variable
isa_time_reduction = !isnothing(diag.reduction_time_func)

# The first time we call compute! we use its return value. All the subsequent times
# (in the callbacks), we will write the result in place
storage[diag] = variable.compute!(nothing, Y, p, t)
counters[diag] = 1
push!(storage, variable.compute!(nothing, Y, p, t))
push!(counters, 1)

# If it is not a reduction, call the output writer as well
if !isa_time_reduction
interpolate_field!(diag.output_writer, storage[diag], diag, Y, p, t)
write_field!(diag.output_writer, storage[diag], diag, Y, p, t)
interpolate_field!(diag.output_writer, storage[i], diag, Y, p, t)
write_field!(diag.output_writer, storage[i], diag, Y, p, t)
else
# Add to the accumulator

# We use similar + .= instead of copy because CUDA 5.2 does not supported nested
# wrappers with view(reshape(view)) objects. See discussion in
# https://github.com/CliMA/ClimaAtmos.jl/pull/2579 and
# https://github.com/JuliaGPU/Adapt.jl/issues/21
accumulators[diag] = similar(storage[diag])
accumulators[diag] .= storage[diag]
accumulators[i] = similar(storage[i])
accumulators[i] .= storage[i]
end
end
storage = value_types(storage)[storage...]
accumulators = Dict{Int, value_types(accumulators)}(accumulators...)

return DiagnosticsHandler(
unique_scheduled_diagnostics,
scheduled_diagnostics_keys,
storage,
accumulators,
counters,
Expand All @@ -132,7 +157,7 @@ function orchestrate_diagnostics(
integrator,
diagnostic_handler::DiagnosticsHandler,
)
scheduled_diagnostics = diagnostic_handler.scheduled_diagnostics
(; scheduled_diagnostics, scheduled_diagnostics_keys) = diagnostic_handler
active_compute = Bool[]
active_output = Bool[]
active_sync = Bool[]
Expand All @@ -144,30 +169,30 @@ function orchestrate_diagnostics(
end

# Compute
for diag_index in 1:length(scheduled_diagnostics)
for diag_index in scheduled_diagnostics_keys
active_compute[diag_index] || continue
diag = scheduled_diagnostics[diag_index]

diag.variable.compute!(
diagnostic_handler.storage[diag],
diagnostic_handler.storage[diag_index],
integrator.u,
integrator.p,
integrator.t,
)
diagnostic_handler.counters[diag] += 1
diagnostic_handler.counters[diag_index] += 1

isa_time_reduction = !isnothing(diag.reduction_time_func)
if isa_time_reduction
diagnostic_handler.accumulators[diag] .=
diagnostic_handler.accumulators[diag_index] .=
diag.reduction_time_func.(
diagnostic_handler.accumulators[diag],
diagnostic_handler.storage[diag],
diagnostic_handler.accumulators[diag_index],
diagnostic_handler.storage[diag_index],
)
end
end

# Pre-output (averages/interpolation)
for diag_index in 1:length(scheduled_diagnostics)
for diag_index in scheduled_diagnostics_keys
active_output[diag_index] || continue
diag = scheduled_diagnostics[diag_index]

Expand All @@ -176,20 +201,20 @@ function orchestrate_diagnostics(
# additional copy. If this copy turns out to be too expensive, we can move the if
# statement below.
isnothing(diag.reduction_time_func) || (
diagnostic_handler.storage[diag] .=
diagnostic_handler.accumulators[diag]
diagnostic_handler.storage[diag_index] .=
diagnostic_handler.accumulators[diag_index]
)

# Any operations we have to perform before writing to output? Here is where we would
# divide by N to obtain an arithmetic average
diag.pre_output_hook!(
diagnostic_handler.storage[diag],
diagnostic_handler.counters[diag],
diagnostic_handler.storage[diag_index],
diagnostic_handler.counters[diag_index],
)

interpolate_field!(
diag.output_writer,
diagnostic_handler.storage[diag],
diagnostic_handler.storage[diag_index],
diag,
integrator.u,
integrator.p,
Expand All @@ -198,13 +223,13 @@ function orchestrate_diagnostics(
end

# Save to disk
for diag_index in 1:length(scheduled_diagnostics)
for diag_index in scheduled_diagnostics_keys
active_output[diag_index] || continue
diag = scheduled_diagnostics[diag_index]

write_field!(
diag.output_writer,
diagnostic_handler.storage[diag],
diagnostic_handler.storage[diag_index],
diag,
integrator.u,
integrator.p,
Expand All @@ -213,7 +238,7 @@ function orchestrate_diagnostics(
end

# Post-output clean-up
for diag_index in 1:length(scheduled_diagnostics)
for diag_index in scheduled_diagnostics_keys
diag = scheduled_diagnostics[diag_index]

# First, maybe call sync for the writer. This might happen regardless of
Expand All @@ -229,10 +254,10 @@ function orchestrate_diagnostics(
# identity_of_reduction works by dispatching over operation.
# The function is defined in reduction_identities.jl
identity = identity_of_reduction(diag.reduction_time_func)
fill!(parent(diagnostic_handler.accumulators[diag]), identity)
fill!(parent(diagnostic_handler.accumulators[diag_index]), identity)
end
# Reset counter
diagnostic_handler.counters[diag] = 0
diagnostic_handler.counters[diag_index] = 0
end

return nothing
Expand Down

0 comments on commit 2943dd4

Please sign in to comment.