Skip to content

Commit

Permalink
refactor: interpolation with higher dim arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Oct 19, 2024
1 parent caf0b76 commit bae3edc
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 42 deletions.
40 changes: 8 additions & 32 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,29 +173,24 @@ Extrapolation extends the last cubic polynomial on each side.
for a test based on the normalized standard deviation of the difference with respect
to the straight line (see [`looks_linear`](@ref)). Defaults to 1e-2.
"""
struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T, N} <:
struct AkimaInterpolation{uType, tType, IType, pType, T, N} <:
AbstractInterpolation{T, N}
u::uType
t::tType
I::IType
b::bType
c::cType
d::dType
p::pType
extrapolate::Bool
iguesser::Guesser{tType}
cache_parameters::Bool
linear_lookup::Bool
function AkimaInterpolation(
u, t, I, b, c, d, extrapolate, cache_parameters, assume_linear_t)
u, t, I, p, extrapolate, cache_parameters, assume_linear_t)
linear_lookup = seems_linear(assume_linear_t, t)
N = get_output_dim(u)
new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c),
typeof(d), eltype(u), N}(u,
new{typeof(u), typeof(t), typeof(I), typeof(p), eltype(u), N}(u,
t,
I,
b,
c,
d,
p,
extrapolate,
Guesser(t),
cache_parameters,
Expand All @@ -208,30 +203,11 @@ function AkimaInterpolation(
u, t; extrapolate = false, cache_parameters = false, assume_linear_t = 1e-2)
u, t = munge_data(u, t)
linear_lookup = seems_linear(assume_linear_t, t)
n = length(t)
dt = diff(t)
m = Array{eltype(u)}(undef, n + 3)
m[3:(end - 2)] = diff(u) ./ dt
m[2] = 2m[3] - m[4]
m[1] = 2m[2] - m[3]
m[end - 1] = 2m[end - 2] - m[end - 3]
m[end] = 2m[end - 1] - m[end - 2]

b = 0.5 .* (m[4:end] .+ m[1:(end - 3)])
dm = abs.(diff(m))
f1 = dm[3:(n + 2)]
f2 = dm[1:n]
f12 = f1 + f2
ind = findall(f12 .> 1e-9 * maximum(f12))
b[ind] = (f1[ind] .* m[ind .+ 1] .+
f2[ind] .* m[ind .+ 2]) ./ f12[ind]
c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt
d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2

p = AkimaParameterCache(u, t)
A = AkimaInterpolation(
u, t, nothing, b, c, d, extrapolate, cache_parameters, linear_lookup)
u, t, nothing, p, extrapolate, cache_parameters, linear_lookup)
I = cumulative_integral(A, cache_parameters)
AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters, linear_lookup)
AkimaInterpolation(u, t, I, p, extrapolate, cache_parameters, linear_lookup)
end

"""
Expand Down
47 changes: 39 additions & 8 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number, igu
end

function _interpolate(
A::LagrangeInterpolation{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
A::LagrangeInterpolation{<:AbstractArray}, t::Number, iguess)
idx = get_idx(A, t, iguess)
findRequiredIdxs!(A, t, idx)
ax = axes(A.u)[1:(end - 1)]
if A.t[A.idxs[1]] == t
return A.u[ax..., A.idxs[1]]
end
N1 = zero(A.u[ax..., 1])
N = zero(A.u[ax..., 1])
D = zero(A.t[1])
tmp = D
for i in 1:length(A.idxs)
Expand All @@ -113,15 +113,22 @@ function _interpolate(
end
tmp = inv((t - A.t[A.idxs[i]]) * mult)
D += tmp
@. N1 += (tmp * A.u[ax..., A.idxs[i]])
@. N += (tmp * A.u[ax..., A.idxs[i]])
end
N1 / D
N / D
end

function _interpolate(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess)
wj = t - A.t[idx]
@evalpoly wj A.u[idx] A.b[idx] A.c[idx] A.d[idx]
@evalpoly wj A.u[idx] A.p.b[idx] A.p.c[idx] A.p.d[idx]
end

function _interpolate(A::AkimaInterpolation{<:AbstractArray}, t::Number, iguess)
idx = get_idx(A, t, iguess)
wj = t - A.t[idx]
ax = axes(A.u)[1:(end - 1)]
@. @evalpoly wj A.u[ax..., idx] A.p.b[ax..., idx] A.p.c[ax..., idx] A.p.d[ax..., idx]
end

# ConstantInterpolation Interpolation
Expand All @@ -137,7 +144,7 @@ function _interpolate(A::ConstantInterpolation{<:AbstractVector}, t::Number, igu
end

function _interpolate(
A::ConstantInterpolation{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
A::ConstantInterpolation{<:AbstractArray}, t::Number, iguess)
if A.dir === :left
# :left means that value to the left is used for interpolation
idx = get_idx(A, t, iguess; lb = 1, ub_shift = 0)
Expand All @@ -158,7 +165,7 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
end

function _interpolate(
A::QuadraticSpline{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
A::QuadraticSpline{<:AbstractArray}, t::Number, iguess)
idx = get_idx(A, t, iguess)
ax = axes(A.u)[1:(end - 1)]
Cᵢ = A.u[ax..., idx]
Expand All @@ -179,7 +186,7 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
I + C + D
end

function _interpolate(A::CubicSpline{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
function _interpolate(A::CubicSpline{<:AbstractArray}, t::Number, iguess)
idx = get_idx(A, t, iguess)
Δt₁ = t - A.t[idx]
Δt₂ = A.t[idx + 1] - t
Expand Down Expand Up @@ -238,6 +245,18 @@ function _interpolate(
out
end

function _interpolate(
A::CubicHermiteSpline{<:AbstractArray}, t::Number, iguess)
idx = get_idx(A, t, iguess)
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
ax = axes(A.u)[1:(end - 1)]
out = A.u[ax..., idx] .+ Δt₀ .* A.du[ax..., idx]
c₁, c₂ = get_parameters(A, idx)
out .+= Δt₀^2 .* (c₁ .+ Δt₁ .* c₂)
out
end

# Quintic Hermite Spline
function _interpolate(
A::QuinticHermiteSpline{<:AbstractVector{<:Number}}, t::Number, iguess)
Expand All @@ -249,3 +268,15 @@ function _interpolate(
out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁))
out
end

function _interpolate(
A::QuinticHermiteSpline{<:AbstractArray}, t::Number, iguess)
idx = get_idx(A, t, iguess)
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
ax = axes(A.u)[1:(end - 1)]
out = A.u[ax..., idx] + Δt₀ * (A.du[ax..., idx] + A.ddu[ax..., idx] * Δt₀ / 2)
c₁, c₂, c₃ = get_parameters(A, idx)
out .+= Δt₀^3 .* (c₁ .+ Δt₁ .* (c₂ .+ c₃ .* Δt₁))
out
end
97 changes: 95 additions & 2 deletions src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,72 @@ function quadratic_interpolation_parameters(u::AbstractArray{T, N}, t, idx) wher
return l₀, l₁, l₂
end

struct AkimaParameterCache{pType}
b::pType
c::pType
d::pType
end

function AkimaParameterCache(u, t)
b, c, d = akima_interpolation_parameters(u, t)
AkimaParameterCache(b, c, d)
end

function akima_interpolation_parameters(u::AbstractVector, t)
n = length(t)
dt = diff(t)
m = Array{eltype(u)}(undef, n + 3)
m[3:(end - 2)] = diff(u) ./ dt
m[2] = 2m[3] - m[4]
m[1] = 2m[2] - m[3]
m[end - 1] = 2m[end - 2] - m[end - 3]
m[end] = 2m[end - 1] - m[end - 2]
b = 0.5 .* (m[4:end] .+ m[1:(end - 3)])
dm = abs.(diff(m))
f1 = dm[3:(n + 2)]
f2 = dm[1:n]
f12 = f1 + f2
ind = findall(f12 .> 1e-9 * maximum(f12))
b[ind] = (f1[ind] .* m[ind .+ 1] .+
f2[ind] .* m[ind .+ 2]) ./ f12[ind]
c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt
d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2
return b, c, d
end

function akima_interpolation_parameters(u::AbstractArray, t)
n = length(t)
dt = diff(t)
ax = axes(u)[1:(end - 1)]
su = size(u)
m = zeros(eltype(u), su[1:(end - 1)]..., n + 3)
m[ax..., 3:(end - 2)] .= mapslices(
x -> x ./ dt, diff(u, dims = length(su)); dims = length(su))
m[ax..., 2] .= 2m[ax..., 3] .- m[ax..., 4]
m[ax..., 1] .= 2m[ax..., 2] .- m[3]
m[ax..., end - 1] .= 2m[ax..., end - 2] - m[ax..., end - 3]
m[ax..., end] .= 2m[ax..., end - 1] .- m[ax..., end - 2]
b = 0.5 .* (m[ax..., 4:end] .+ m[ax..., 1:(end - 3)])
dm = abs.(diff(m, dims = length(su)))
f1 = dm[ax..., 3:(n + 2)]
f2 = dm[ax..., 1:n]
f12 = f1 .+ f2
ind = findall(f12 .> 1e-9 * maximum(f12))
indi = map(i -> i.I, ind)
b[ind] .= (f1[ind] .*
m[CartesianIndex.(map(i -> (i[1:(end - 1)]..., i[end] + 1), indi))] .+
f2[ind] .*
m[CartesianIndex.(map(i -> (i[1:(end - 1)]..., i[end] + 2), indi))]) ./
f12[ind]
c = mapslices(x -> x ./ dt,
(3.0 .* m[ax..., 3:(end - 2)] .- 2.0 .* b[ax..., 1:(end - 1)] .- b[ax..., 2:end]);
dims = length(su))
d = mapslices(x -> x ./ dt .^ 2,
(b[ax..., 1:(end - 1)] .+ b[ax..., 2:end] .- 2.0 .* m[ax..., 3:(end - 2)]);
dims = length(su))
return b, c, d
end

struct QuadraticSplineParameterCache{pType}
σ::pType
end
Expand Down Expand Up @@ -152,7 +218,19 @@ function CubicHermiteParameterCache(du, u, t, cache_parameters)
end
end

function cubic_hermite_spline_parameters(du, u, t, idx)
function cubic_hermite_spline_parameters(du::AbstractArray, u, t, idx)
ax = axes(u)[1:(end - 1)]
Δt = t[idx + 1] - t[idx]
u₀ = u[ax..., idx]
u₁ = u[ax..., idx + 1]
du₀ = du[ax..., idx]
du₁ = du[ax..., idx + 1]
c₁ = (u₁ - u₀ - du₀ * Δt) / Δt^2
c₂ = (du₁ - du₀ - 2c₁ * Δt) / Δt^2
return c₁, c₂
end

function cubic_hermite_spline_parameters(du::AbstractVector, u, t, idx)
Δt = t[idx + 1] - t[idx]
u₀ = u[idx]
u₁ = u[idx + 1]
Expand Down Expand Up @@ -183,7 +261,7 @@ function QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters)
end
end

function quintic_hermite_spline_parameters(ddu, du, u, t, idx)
function quintic_hermite_spline_parameters(ddu::AbstractVector, du, u, t, idx)
Δt = t[idx + 1] - t[idx]
u₀ = u[idx]
u₁ = u[idx + 1]
Expand All @@ -196,3 +274,18 @@ function quintic_hermite_spline_parameters(ddu, du, u, t, idx)
c₃ = (6u₁ - 6u₀ - 3(du₀ + du₁)Δt + (ddu₁ - ddu₀)Δt^2 / 2) / Δt^5
return c₁, c₂, c₃
end

function quintic_hermite_spline_parameters(ddu::AbstractArray, du, u, t, idx)
ax = axes(ddu)[1:(end - 1)]
Δt = t[idx + 1] - t[idx]
u₀ = u[ax..., idx]
u₁ = u[ax..., idx + 1]
du₀ = du[ax..., idx]
du₁ = du[ax..., idx + 1]
ddu₀ = ddu[ax..., idx]
ddu₁ = ddu[ax..., idx + 1]
c₁ = (u₁ - u₀ - du₀ * Δt - ddu₀ * Δt^2 / 2) / Δt^3
c₂ = (3u₀ - 3u₁ + 2(du₀ + du₁ / 2)Δt + ddu₀ * Δt^2 / 2) / Δt^4
c₃ = (6u₁ - 6u₀ - 3(du₀ + du₁)Δt + (ddu₁ - ddu₀)Δt^2 / 2) / Δt^5
return c₁, c₂, c₃
end

0 comments on commit bae3edc

Please sign in to comment.