From bc1c36f68e896aa78e5eb6dd48b4cc75a428b67b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 5 Jan 2018 19:21:45 -0800 Subject: [PATCH] inferable iip passing --- src/types.jl | 100 ++++++++++++++++++++++++++++---------------------- src/wiener.jl | 8 ++-- 2 files changed, 60 insertions(+), 48 deletions(-) diff --git a/src/types.jl b/src/types.jl index d57e57e..3c4fbe4 100644 --- a/src/types.jl +++ b/src/types.jl @@ -28,39 +28,45 @@ type NoiseProcess{T,N,Tt,T2,T3,ZType,F,F2,inplace,S1,S2,RSWM,RNGType} <: Abstrac rng::RNGType reset::Bool reseed::Bool + + function NoiseProcess{iip}(t0,W0,Z0,dist,bridge; + rswm = RSWM(),save_everystep=true,timeseries_steps=1, + rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)), + reset = true, reseed = true) where iip + S₁ = DataStructures.Stack(Tuple{typeof(t0),typeof(W0),typeof(Z0)}) + S₂ = ResettableStacks.ResettableStack( + Tuple{typeof(t0),typeof(W0),typeof(Z0)}) + if Z0==nothing + Z=nothing + curZ = nothing + dZ = nothing + dZtilde= nothing + dZtmp = nothing + else + Z=[copy(Z0)] + curZ = copy(Z0) + dZ = copy(Z0) + dZtilde= copy(Z0) + dZtmp = copy(Z0) + end + W = [copy(W0)] + N = length((size(W0)..., length(W))) + new{eltype(eltype(W0)),N,typeof(t0),typeof(W0),typeof(dZ),typeof(Z), + typeof(dist),typeof(bridge), + iip,typeof(S₁),typeof(S₂),typeof(rswm),typeof(rng)}( + dist,bridge,[t0],W,W,Z,t0, + copy(W0),curZ,t0,copy(W0),dZ,copy(W0),dZtilde,copy(W0),dZtmp, + S₁,S₂,rswm,0,0,save_everystep,timeseries_steps,0,rng,reset,reseed) + end + end (W::NoiseProcess)(t) = interpolate!(W,t) (W::NoiseProcess)(out1,out2,t) = interpolate!(out1,out2,W,t) adaptive_alg(W::NoiseProcess) = adaptive_alg(W.rswm) -function NoiseProcess(t0,W0,Z0,dist,bridge;iip=DiffEqBase.isinplace(dist,4), - rswm = RSWM(),save_everystep=true,timeseries_steps=1, - rng = Xorshifts.Xoroshiro128Plus(rand(UInt64)), - reset = true, reseed = true) - S₁ = DataStructures.Stack{}(Tuple{typeof(t0),typeof(W0),typeof(Z0)}) - S₂ = ResettableStacks.ResettableStack{}( - Tuple{typeof(t0),typeof(W0),typeof(Z0)}) - if Z0==nothing - Z=nothing - curZ = nothing - dZ = nothing - dZtilde= nothing - dZtmp = nothing - else - Z=[copy(Z0)] - curZ = copy(Z0) - dZ = copy(Z0) - dZtilde= copy(Z0) - dZtmp = copy(Z0) - end - W = [copy(W0)] - N = length((size(W0)..., length(W))) - NoiseProcess{eltype(eltype(W0)),N,typeof(t0),typeof(W0),typeof(dZ),typeof(Z), - typeof(dist),typeof(bridge), - iip,typeof(S₁),typeof(S₂),typeof(rswm),typeof(rng)}( - dist,bridge,[t0],W,W,Z,t0, - copy(W0),curZ,t0,copy(W0),dZ,copy(W0),dZtilde,copy(W0),dZtmp, - S₁,S₂,rswm,0,0,save_everystep,timeseries_steps,0,rng,reset,reseed) +function NoiseProcess(t0,W0,Z0,dist,bridge;kwargs...) + iip=DiffEqBase.isinplace(dist,4) + NoiseProcess{iip}(t0,W0,Z0,dist,bridge;kwargs...) end type NoiseWrapper{T,N,Tt,T2,T3,T4,ZType,inplace} <: AbstractNoiseProcess{T,N,inplace} @@ -108,6 +114,25 @@ type NoiseFunction{T,N,wType,zType,Tt,T2,T3,inplace} <: AbstractNoiseProcess{T,N dW::T2 dZ::T3 reset::Bool + + function NoiseFunction{iip}(t0,W,Z=nothing; + noise_prototype=W(t0),reset=true) where iip + curt = t0 + dt = t0 + curW = copy(noise_prototype) + dW = copy(noise_prototype) + if Z==nothing + curZ = nothing + dZ = nothing + else + curZ = copy(noise_prototype) + dZ = copy(noise_prototype) + end + new{typeof(noise_prototype),ndims(noise_prototype),typeof(W),typeof(Z), + typeof(curt),typeof(curW),typeof(curZ),iip}(W,Z,curt,curW,curZ, + dt,dW,dZ,reset) + end + end function (W::NoiseFunction)(t) @@ -134,22 +159,9 @@ function (W::NoiseFunction)(out1,out2,t) W.Z != nothing && W.Z(out2,t) end -function NoiseFunction(t0,W,Z=nothing;iip=DiffEqBase.isinplace(W,2), - noise_prototype=W(t0),reset=true) - curt = t0 - dt = t0 - curW = copy(noise_prototype) - dW = copy(noise_prototype) - if Z==nothing - curZ = nothing - dZ = nothing - else - curZ = copy(noise_prototype) - dZ = copy(noise_prototype) - end - NoiseFunction{typeof(noise_prototype),ndims(noise_prototype),typeof(W),typeof(Z), - typeof(curt),typeof(curW),typeof(curZ),iip}(W,Z,curt,curW,curZ, - dt,dW,dZ,reset) +function NoiseFunction(t0,W,Z=nothing;kwargs...) + iip=DiffEqBase.isinplace(W,2) + NoiseFunction{iip}(t0,W,Z;kwargs...) end type NoiseGrid{T,N,Tt,T2,T3,ZType,inplace} <: AbstractNoiseProcess{T,N,inplace} diff --git a/src/wiener.jl b/src/wiener.jl index 92af5ff..84ea63d 100644 --- a/src/wiener.jl +++ b/src/wiener.jl @@ -31,7 +31,7 @@ function WHITE_NOISE_BRIDGE(W,W0,Wh,q,h,rng) return sqrt((1-q)*q*abs(h))*wiener_randn(rng,typeof(W.dW))+q*Wh end end -WienerProcess(t0,W0,Z0=nothing;kwargs...) = NoiseProcess(t0,W0,Z0,WHITE_NOISE_DIST,WHITE_NOISE_BRIDGE;kwargs...) +WienerProcess(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{false}(t0,W0,Z0,WHITE_NOISE_DIST,WHITE_NOISE_BRIDGE;kwargs...) function INPLACE_WHITE_NOISE_DIST(rand_vec,W,dt,rng) wiener_randn!(rng,rand_vec) @@ -47,7 +47,7 @@ function INPLACE_WHITE_NOISE_BRIDGE(rand_vec,W,W0,Wh,q,h,rng) rand_vec[i] = sqrt((1.-q)*q*abs(h))*rand_vec[i]+q*Wh[i] end end -WienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess(t0,W0,Z0,INPLACE_WHITE_NOISE_DIST,INPLACE_WHITE_NOISE_BRIDGE;kwargs...) +WienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{true}(t0,W0,Z0,INPLACE_WHITE_NOISE_DIST,INPLACE_WHITE_NOISE_BRIDGE;kwargs...) @@ -66,7 +66,7 @@ function REAL_WHITE_NOISE_BRIDGE(W,W0,Wh,q,h,rng) return sqrt((1-q)*q*abs(h))*randn(rng)+q*Wh end end -RealWienerProcess(t0,W0,Z0=nothing;kwargs...) = NoiseProcess(t0,W0,Z0,REAL_WHITE_NOISE_DIST,REAL_WHITE_NOISE_BRIDGE;kwargs...) +RealWienerProcess(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{false}(t0,W0,Z0,REAL_WHITE_NOISE_DIST,REAL_WHITE_NOISE_BRIDGE;kwargs...) function REAL_INPLACE_WHITE_NOISE_DIST(rand_vec,W,dt,rng) sqabsdt = sqrt(abs(dt)) @@ -84,4 +84,4 @@ function REAL_INPLACE_WHITE_NOISE_BRIDGE(rand_vec,W,W0,Wh,q,h,rng) rand_vec[i] = sqrt((1.-q)*q*abs(h))*rand_vec[i]+q*Wh[i] end end -RealWienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess(t0,W0,Z0,REAL_INPLACE_WHITE_NOISE_DIST,REAL_INPLACE_WHITE_NOISE_BRIDGE;kwargs...) +RealWienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{true}(t0,W0,Z0,REAL_INPLACE_WHITE_NOISE_DIST,REAL_INPLACE_WHITE_NOISE_BRIDGE;kwargs...)