From 38cfa616245a1411e75303a66c9484b732a36ce9 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 28 Jan 2019 08:27:19 -0500 Subject: [PATCH] fix for neural SDE --- src/geometric_bm.jl | 4 ++-- .../noise_process_interface.jl | 14 ++++++++------ src/ornstein_uhlenbeck.jl | 2 +- src/wiener.jl | 19 +++++++++++-------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/geometric_bm.jl b/src/geometric_bm.jl index 221438f..c7b0b51 100644 --- a/src/geometric_bm.jl +++ b/src/geometric_bm.jl @@ -6,7 +6,7 @@ end function (p::GeometricBrownianMotion)(W,dt,rng) #dist drift = p.μ-(1/2)*p.σ^2 if typeof(W.dW) <: AbstractArray - rand_val = wiener_randn(rng,size(W.dW)) + rand_val = wiener_randn(rng,W.dW) else rand_val = wiener_randn(rng,typeof(W.dW)) end @@ -27,7 +27,7 @@ https://math.stackexchange.com/questions/412470/conditional-distribution-in-brow =# function gbm_bridge(gbm,W,W0,Wh,q,h,rng) if typeof(W.dW) <: AbstractArray - return gbm.σ*sqrt((1-q)*q*abs(h))*wiener_randn(rng,size(W.dW))+q*Wh + return gbm.σ*sqrt((1-q)*q*abs(h))*wiener_randn(rng,W.dW)+q*Wh else return gbm.σ*sqrt((1-q)*q*abs(h))*wiener_randn(rng,typeof(W.dW))+q*Wh end diff --git a/src/noise_interfaces/noise_process_interface.jl b/src/noise_interfaces/noise_process_interface.jl index 4582fca..70eae16 100644 --- a/src/noise_interfaces/noise_process_interface.jl +++ b/src/noise_interfaces/noise_process_interface.jl @@ -56,12 +56,12 @@ end end elseif adaptive_alg(W)==:RSwM2 || adaptive_alg(W)==:RSwM3 if !isinplace(W) - dttmp = 0.0; W.dW = zero(W.dW) + dttmp = zero(eltype(W.dWtmp)); W.dW = zero(W.dW) if W.Z != nothing W.dZ = zero(W.dZ) end else - dttmp = 0.0; fill!(W.dW,zero(eltype(W.dW))) + dttmp = zero(eltype(W.dWtmp)); fill!(W.dW,zero(eltype(W.dW))) if W.Z != nothing fill!(W.dZ,zero(eltype(W.dZ))) end @@ -148,7 +148,8 @@ end break end end #end while empty - dtleft = W.dt - dttmp + # This is a control variable so do not diff through it + dtleft = DiffEqBase.ODE_DEFAULT_NORM(W.dt - dttmp) if dtleft > W.rswm.discard_length #Stack emptied if isinplace(W) W.dist(W.dWtilde,W,dtleft,W.rng) @@ -239,12 +240,12 @@ end W.dt = dtnew else # RSwM3 if !isinplace(W) - dttmp = 0.0; W.dWtmp = zero(W.dW) + dttmp = zero(eltype(W.dWtmp)); W.dWtmp = zero(W.dW) if W.Z != nothing W.dZtmp = zero(W.dZtmp) end else - dttmp = 0.0; fill!(W.dWtmp,zero(eltype(W.dWtmp))) + dttmp = zero(eltype(W.dWtmp)); fill!(W.dWtmp,zero(eltype(W.dWtmp))) if W.Z!= nothing fill!(W.dZtmp,zero(eltype(W.dZtmp))) end @@ -303,7 +304,8 @@ end W.dZtilde = W.bridge(W,0,W.dZtmp,qK,dtK,W.rng)# - W.curZ end end - cutLength = (1-qK)*dtK + # This is a control variable so do not diff through it + cutLength = DiffEqBase.ODE_DEFAULT_NORM((1-qK)*dtK) if cutLength > W.rswm.discard_length if W.Z == nothing push!(W.S₁,(cutLength,W.dWtmp-W.dWtilde,nothing)) diff --git a/src/ornstein_uhlenbeck.jl b/src/ornstein_uhlenbeck.jl index ffd9900..b676511 100644 --- a/src/ornstein_uhlenbeck.jl +++ b/src/ornstein_uhlenbeck.jl @@ -7,7 +7,7 @@ end # http://www.math.ku.dk/~susanne/StatDiff/Overheads1b.pdf function (p::OrnsteinUhlenbeck)(W,dt,rng) #dist if typeof(W.dW) <: AbstractArray - rand_val = wiener_randn(rng,size(W.dW)) + rand_val = wiener_randn(rng,W.dW) else rand_val = wiener_randn(rng,typeof(W.dW)) end diff --git a/src/wiener.jl b/src/wiener.jl index 7cb8977..27bc1be 100644 --- a/src/wiener.jl +++ b/src/wiener.jl @@ -1,21 +1,26 @@ const one_over_sqrt2 = 1/sqrt(2) @inline wiener_randn(rng::AbstractRNG,::Type{T}) where T = randn(rng,T) -@inline wiener_randn(rng::AbstractRNG,y) = randn(rng,y) +@inline function wiener_randn(rng::AbstractRNG,proto::Array{T}) where T + randn(rng,size(proto)) +end +@inline function wiener_randn(rng::AbstractRNG,proto) + convert(typeof(proto),randn(rng,size(proto))) +end @inline wiener_randn!(rng::AbstractRNG,rand_vec::Array) = randn!(rng,rand_vec) # TODO: This needs an overload for GPUs @inline wiener_randn!(rng::AbstractRNG,rand_vec) = rand_vec .= Base.Broadcast.Broadcasted(randn,()) -@inline wiener_randn(y::AbstractRNG,::Type{Complex{T}}) where T = one_over_sqrt2*(randn(y,T)+im*randn(y,T)) +@inline wiener_randn(y::AbstractRNG,::Type{Complex{T}}) where T = convert(T,one_over_sqrt2)*(randn(y,T)+im*randn(y,T)) @inline function wiener_randn!(y::AbstractRNG,x::AbstractArray{<:Complex{T}}) where T<:Number @inbounds for i in eachindex(x) - x[i] = one_over_sqrt2*(randn(y,T)+im*randn(y,T)) + x[i] = convert(T,one_over_sqrt2)*(randn(y,T)+im*randn(y,T)) end end @inline function WHITE_NOISE_DIST(W,dt,rng) if typeof(W.dW) <: AbstractArray && !(typeof(W.dW) <: SArray) - return @fastmath sqrt(abs(dt))*wiener_randn(rng,size(W.dW)) + return @fastmath sqrt(abs(dt))*wiener_randn(rng,W.dW) else return @fastmath sqrt(abs(dt))*wiener_randn(rng,typeof(W.dW)) end @@ -23,7 +28,7 @@ end function WHITE_NOISE_BRIDGE(W,W0,Wh,q,h,rng) if typeof(W.dW) <: AbstractArray - return @fastmath sqrt((1-q)*q*abs(h))*wiener_randn(rng,size(W.dW))+q*Wh + return @fastmath sqrt((1-q)*q*abs(h))*wiener_randn(rng,W.dW)+q*Wh else return @fastmath sqrt((1-q)*q*abs(h))*wiener_randn(rng,typeof(W.dW))+q*Wh end @@ -43,9 +48,7 @@ function INPLACE_WHITE_NOISE_BRIDGE(rand_vec,W,W0,Wh,q,h,rng) wiener_randn!(rng,rand_vec) #rand_vec .= sqrt((1.-q).*q.*abs(h)).*rand_vec.+q.*Wh sqrtcoeff = @fastmath sqrt((1-q)*q*abs(h)) - @inbounds for i in eachindex(rand_vec) - rand_vec[i] = sqrtcoeff*rand_vec[i]+q*Wh[i] - end + @. rand_vec = sqrtcoeff*rand_vec+q*Wh end WienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{true}(t0,W0,Z0,INPLACE_WHITE_NOISE_DIST,INPLACE_WHITE_NOISE_BRIDGE;kwargs...)