Skip to content

Commit

Permalink
inferable iip passing
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 6, 2018
1 parent 3605ad3 commit bc1c36f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 48 deletions.
100 changes: 56 additions & 44 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,45 @@ type NoiseProcess{T,N,Tt,T2,T3,ZType,F,F2,inplace,S1,S2,RSWM,RNGType} <: Abstrac
rng::RNGType
reset::Bool
reseed::Bool

function NoiseProcess{iip}(t0,W0,Z0,dist,bridge;
rswm = RSWM(),save_everystep=true,timeseries_steps=1,
rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)),
reset = true, reseed = true) where iip
S₁ = DataStructures.Stack(Tuple{typeof(t0),typeof(W0),typeof(Z0)})
S₂ = ResettableStacks.ResettableStack(
Tuple{typeof(t0),typeof(W0),typeof(Z0)})
if Z0==nothing
Z=nothing
curZ = nothing
dZ = nothing
dZtilde= nothing
dZtmp = nothing
else
Z=[copy(Z0)]
curZ = copy(Z0)
dZ = copy(Z0)
dZtilde= copy(Z0)
dZtmp = copy(Z0)
end
W = [copy(W0)]
N = length((size(W0)..., length(W)))
new{eltype(eltype(W0)),N,typeof(t0),typeof(W0),typeof(dZ),typeof(Z),
typeof(dist),typeof(bridge),
iip,typeof(S₁),typeof(S₂),typeof(rswm),typeof(rng)}(
dist,bridge,[t0],W,W,Z,t0,
copy(W0),curZ,t0,copy(W0),dZ,copy(W0),dZtilde,copy(W0),dZtmp,
S₁,S₂,rswm,0,0,save_everystep,timeseries_steps,0,rng,reset,reseed)
end

end
(W::NoiseProcess)(t) = interpolate!(W,t)
(W::NoiseProcess)(out1,out2,t) = interpolate!(out1,out2,W,t)
adaptive_alg(W::NoiseProcess) = adaptive_alg(W.rswm)

function NoiseProcess(t0,W0,Z0,dist,bridge;iip=DiffEqBase.isinplace(dist,4),
rswm = RSWM(),save_everystep=true,timeseries_steps=1,
rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)),
reset = true, reseed = true)
S₁ = DataStructures.Stack{}(Tuple{typeof(t0),typeof(W0),typeof(Z0)})
S₂ = ResettableStacks.ResettableStack{}(
Tuple{typeof(t0),typeof(W0),typeof(Z0)})
if Z0==nothing
Z=nothing
curZ = nothing
dZ = nothing
dZtilde= nothing
dZtmp = nothing
else
Z=[copy(Z0)]
curZ = copy(Z0)
dZ = copy(Z0)
dZtilde= copy(Z0)
dZtmp = copy(Z0)
end
W = [copy(W0)]
N = length((size(W0)..., length(W)))
NoiseProcess{eltype(eltype(W0)),N,typeof(t0),typeof(W0),typeof(dZ),typeof(Z),
typeof(dist),typeof(bridge),
iip,typeof(S₁),typeof(S₂),typeof(rswm),typeof(rng)}(
dist,bridge,[t0],W,W,Z,t0,
copy(W0),curZ,t0,copy(W0),dZ,copy(W0),dZtilde,copy(W0),dZtmp,
S₁,S₂,rswm,0,0,save_everystep,timeseries_steps,0,rng,reset,reseed)
function NoiseProcess(t0,W0,Z0,dist,bridge;kwargs...)
iip=DiffEqBase.isinplace(dist,4)
NoiseProcess{iip}(t0,W0,Z0,dist,bridge;kwargs...)
end

type NoiseWrapper{T,N,Tt,T2,T3,T4,ZType,inplace} <: AbstractNoiseProcess{T,N,inplace}
Expand Down Expand Up @@ -108,6 +114,25 @@ type NoiseFunction{T,N,wType,zType,Tt,T2,T3,inplace} <: AbstractNoiseProcess{T,N
dW::T2
dZ::T3
reset::Bool

function NoiseFunction{iip}(t0,W,Z=nothing;
noise_prototype=W(t0),reset=true) where iip
curt = t0
dt = t0
curW = copy(noise_prototype)
dW = copy(noise_prototype)
if Z==nothing
curZ = nothing
dZ = nothing
else
curZ = copy(noise_prototype)
dZ = copy(noise_prototype)
end
new{typeof(noise_prototype),ndims(noise_prototype),typeof(W),typeof(Z),
typeof(curt),typeof(curW),typeof(curZ),iip}(W,Z,curt,curW,curZ,
dt,dW,dZ,reset)
end

end

function (W::NoiseFunction)(t)
Expand All @@ -134,22 +159,9 @@ function (W::NoiseFunction)(out1,out2,t)
W.Z != nothing && W.Z(out2,t)
end

function NoiseFunction(t0,W,Z=nothing;iip=DiffEqBase.isinplace(W,2),
noise_prototype=W(t0),reset=true)
curt = t0
dt = t0
curW = copy(noise_prototype)
dW = copy(noise_prototype)
if Z==nothing
curZ = nothing
dZ = nothing
else
curZ = copy(noise_prototype)
dZ = copy(noise_prototype)
end
NoiseFunction{typeof(noise_prototype),ndims(noise_prototype),typeof(W),typeof(Z),
typeof(curt),typeof(curW),typeof(curZ),iip}(W,Z,curt,curW,curZ,
dt,dW,dZ,reset)
function NoiseFunction(t0,W,Z=nothing;kwargs...)
iip=DiffEqBase.isinplace(W,2)
NoiseFunction{iip}(t0,W,Z;kwargs...)
end

type NoiseGrid{T,N,Tt,T2,T3,ZType,inplace} <: AbstractNoiseProcess{T,N,inplace}
Expand Down
8 changes: 4 additions & 4 deletions src/wiener.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function WHITE_NOISE_BRIDGE(W,W0,Wh,q,h,rng)
return sqrt((1-q)*q*abs(h))*wiener_randn(rng,typeof(W.dW))+q*Wh
end
end
WienerProcess(t0,W0,Z0=nothing;kwargs...) = NoiseProcess(t0,W0,Z0,WHITE_NOISE_DIST,WHITE_NOISE_BRIDGE;kwargs...)
WienerProcess(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{false}(t0,W0,Z0,WHITE_NOISE_DIST,WHITE_NOISE_BRIDGE;kwargs...)

function INPLACE_WHITE_NOISE_DIST(rand_vec,W,dt,rng)
wiener_randn!(rng,rand_vec)
Expand All @@ -47,7 +47,7 @@ function INPLACE_WHITE_NOISE_BRIDGE(rand_vec,W,W0,Wh,q,h,rng)
rand_vec[i] = sqrt((1.-q)*q*abs(h))*rand_vec[i]+q*Wh[i]
end
end
WienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess(t0,W0,Z0,INPLACE_WHITE_NOISE_DIST,INPLACE_WHITE_NOISE_BRIDGE;kwargs...)
WienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{true}(t0,W0,Z0,INPLACE_WHITE_NOISE_DIST,INPLACE_WHITE_NOISE_BRIDGE;kwargs...)



Expand All @@ -66,7 +66,7 @@ function REAL_WHITE_NOISE_BRIDGE(W,W0,Wh,q,h,rng)
return sqrt((1-q)*q*abs(h))*randn(rng)+q*Wh
end
end
RealWienerProcess(t0,W0,Z0=nothing;kwargs...) = NoiseProcess(t0,W0,Z0,REAL_WHITE_NOISE_DIST,REAL_WHITE_NOISE_BRIDGE;kwargs...)
RealWienerProcess(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{false}(t0,W0,Z0,REAL_WHITE_NOISE_DIST,REAL_WHITE_NOISE_BRIDGE;kwargs...)

function REAL_INPLACE_WHITE_NOISE_DIST(rand_vec,W,dt,rng)
sqabsdt = sqrt(abs(dt))
Expand All @@ -84,4 +84,4 @@ function REAL_INPLACE_WHITE_NOISE_BRIDGE(rand_vec,W,W0,Wh,q,h,rng)
rand_vec[i] = sqrt((1.-q)*q*abs(h))*rand_vec[i]+q*Wh[i]
end
end
RealWienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess(t0,W0,Z0,REAL_INPLACE_WHITE_NOISE_DIST,REAL_INPLACE_WHITE_NOISE_BRIDGE;kwargs...)
RealWienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{true}(t0,W0,Z0,REAL_INPLACE_WHITE_NOISE_DIST,REAL_INPLACE_WHITE_NOISE_BRIDGE;kwargs...)

0 comments on commit bc1c36f

Please sign in to comment.