Skip to content

Commit

Permalink
fix for neural SDE
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 28, 2019
1 parent 621d17d commit 38cfa61
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/geometric_bm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ end
function (p::GeometricBrownianMotion)(W,dt,rng) #dist
drift = p.μ-(1/2)*p.σ^2
if typeof(W.dW) <: AbstractArray
rand_val = wiener_randn(rng,size(W.dW))
rand_val = wiener_randn(rng,W.dW)
else
rand_val = wiener_randn(rng,typeof(W.dW))
end
Expand All @@ -27,7 +27,7 @@ https://math.stackexchange.com/questions/412470/conditional-distribution-in-brow
=#
function gbm_bridge(gbm,W,W0,Wh,q,h,rng)
if typeof(W.dW) <: AbstractArray
return gbm.σ*sqrt((1-q)*q*abs(h))*wiener_randn(rng,size(W.dW))+q*Wh
return gbm.σ*sqrt((1-q)*q*abs(h))*wiener_randn(rng,W.dW)+q*Wh
else
return gbm.σ*sqrt((1-q)*q*abs(h))*wiener_randn(rng,typeof(W.dW))+q*Wh
end
Expand Down
14 changes: 8 additions & 6 deletions src/noise_interfaces/noise_process_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ end
end
elseif adaptive_alg(W)==:RSwM2 || adaptive_alg(W)==:RSwM3
if !isinplace(W)
dttmp = 0.0; W.dW = zero(W.dW)
dttmp = zero(eltype(W.dWtmp)); W.dW = zero(W.dW)
if W.Z != nothing
W.dZ = zero(W.dZ)
end
else
dttmp = 0.0; fill!(W.dW,zero(eltype(W.dW)))
dttmp = zero(eltype(W.dWtmp)); fill!(W.dW,zero(eltype(W.dW)))
if W.Z != nothing
fill!(W.dZ,zero(eltype(W.dZ)))
end
Expand Down Expand Up @@ -148,7 +148,8 @@ end
break
end
end #end while empty
dtleft = W.dt - dttmp
# This is a control variable so do not diff through it
dtleft = DiffEqBase.ODE_DEFAULT_NORM(W.dt - dttmp)
if dtleft > W.rswm.discard_length #Stack emptied
if isinplace(W)
W.dist(W.dWtilde,W,dtleft,W.rng)
Expand Down Expand Up @@ -239,12 +240,12 @@ end
W.dt = dtnew
else # RSwM3
if !isinplace(W)
dttmp = 0.0; W.dWtmp = zero(W.dW)
dttmp = zero(eltype(W.dWtmp)); W.dWtmp = zero(W.dW)
if W.Z != nothing
W.dZtmp = zero(W.dZtmp)
end
else
dttmp = 0.0; fill!(W.dWtmp,zero(eltype(W.dWtmp)))
dttmp = zero(eltype(W.dWtmp)); fill!(W.dWtmp,zero(eltype(W.dWtmp)))

This comment has been minimized.

Copy link
@devmotion

devmotion Jan 29, 2019

Member

This change breaks StochasticDiffEq (see https://travis-ci.org/JuliaDiffEq/StochasticDiffEq.jl/jobs/485690976), I guess it should be dttmp = zero(W.dt) here and in similar lines in this commit

This comment has been minimized.

Copy link
@ChrisRackauckas

ChrisRackauckas Jan 29, 2019

Author Member

I'm looking into it right now. The time-based numbers shouldn't end up complex at all which is what the test failure looks to be.

This comment has been minimized.

Copy link
@devmotion

devmotion Jan 29, 2019

Member

Yes, my second hypothesis was a problem similar to SciML/StochasticDiffEq.jl#126

This comment has been minimized.

Copy link
@devmotion

devmotion Jan 29, 2019

Member

I just applied the change in the first comment and complex tests in StochasticDiffEq pass again on my computer it seems

This comment has been minimized.

Copy link
@ChrisRackauckas

ChrisRackauckas Jan 29, 2019

Author Member

ahh yes, because this is pulling the element type from state instead of time which gives it the wrong typing. Can you file a PR for it? Thanks.

This comment has been minimized.

Copy link
@devmotion

devmotion Jan 29, 2019

Member

Yes, I'm on it 👍

if W.Z!= nothing
fill!(W.dZtmp,zero(eltype(W.dZtmp)))
end
Expand Down Expand Up @@ -303,7 +304,8 @@ end
W.dZtilde = W.bridge(W,0,W.dZtmp,qK,dtK,W.rng)# - W.curZ
end
end
cutLength = (1-qK)*dtK
# This is a control variable so do not diff through it
cutLength = DiffEqBase.ODE_DEFAULT_NORM((1-qK)*dtK)
if cutLength > W.rswm.discard_length
if W.Z == nothing
push!(W.S₁,(cutLength,W.dWtmp-W.dWtilde,nothing))
Expand Down
2 changes: 1 addition & 1 deletion src/ornstein_uhlenbeck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ end
# http://www.math.ku.dk/~susanne/StatDiff/Overheads1b.pdf
function (p::OrnsteinUhlenbeck)(W,dt,rng) #dist
if typeof(W.dW) <: AbstractArray
rand_val = wiener_randn(rng,size(W.dW))
rand_val = wiener_randn(rng,W.dW)
else
rand_val = wiener_randn(rng,typeof(W.dW))
end
Expand Down
19 changes: 11 additions & 8 deletions src/wiener.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
const one_over_sqrt2 = 1/sqrt(2)
@inline wiener_randn(rng::AbstractRNG,::Type{T}) where T = randn(rng,T)
@inline wiener_randn(rng::AbstractRNG,y) = randn(rng,y)
@inline function wiener_randn(rng::AbstractRNG,proto::Array{T}) where T
randn(rng,size(proto))
end
@inline function wiener_randn(rng::AbstractRNG,proto)
convert(typeof(proto),randn(rng,size(proto)))
end
@inline wiener_randn!(rng::AbstractRNG,rand_vec::Array) = randn!(rng,rand_vec)

# TODO: This needs an overload for GPUs
@inline wiener_randn!(rng::AbstractRNG,rand_vec) = rand_vec .= Base.Broadcast.Broadcasted(randn,())
@inline wiener_randn(y::AbstractRNG,::Type{Complex{T}}) where T = one_over_sqrt2*(randn(y,T)+im*randn(y,T))
@inline wiener_randn(y::AbstractRNG,::Type{Complex{T}}) where T = convert(T,one_over_sqrt2)*(randn(y,T)+im*randn(y,T))

@inline function wiener_randn!(y::AbstractRNG,x::AbstractArray{<:Complex{T}}) where T<:Number
@inbounds for i in eachindex(x)
x[i] = one_over_sqrt2*(randn(y,T)+im*randn(y,T))
x[i] = convert(T,one_over_sqrt2)*(randn(y,T)+im*randn(y,T))
end
end

@inline function WHITE_NOISE_DIST(W,dt,rng)
if typeof(W.dW) <: AbstractArray && !(typeof(W.dW) <: SArray)
return @fastmath sqrt(abs(dt))*wiener_randn(rng,size(W.dW))
return @fastmath sqrt(abs(dt))*wiener_randn(rng,W.dW)
else
return @fastmath sqrt(abs(dt))*wiener_randn(rng,typeof(W.dW))
end
end

function WHITE_NOISE_BRIDGE(W,W0,Wh,q,h,rng)
if typeof(W.dW) <: AbstractArray
return @fastmath sqrt((1-q)*q*abs(h))*wiener_randn(rng,size(W.dW))+q*Wh
return @fastmath sqrt((1-q)*q*abs(h))*wiener_randn(rng,W.dW)+q*Wh
else
return @fastmath sqrt((1-q)*q*abs(h))*wiener_randn(rng,typeof(W.dW))+q*Wh
end
Expand All @@ -43,9 +48,7 @@ function INPLACE_WHITE_NOISE_BRIDGE(rand_vec,W,W0,Wh,q,h,rng)
wiener_randn!(rng,rand_vec)
#rand_vec .= sqrt((1.-q).*q.*abs(h)).*rand_vec.+q.*Wh
sqrtcoeff = @fastmath sqrt((1-q)*q*abs(h))
@inbounds for i in eachindex(rand_vec)
rand_vec[i] = sqrtcoeff*rand_vec[i]+q*Wh[i]
end
@. rand_vec = sqrtcoeff*rand_vec+q*Wh
end
WienerProcess!(t0,W0,Z0=nothing;kwargs...) = NoiseProcess{true}(t0,W0,Z0,INPLACE_WHITE_NOISE_DIST,INPLACE_WHITE_NOISE_BRIDGE;kwargs...)

Expand Down

0 comments on commit 38cfa61

Please sign in to comment.