Skip to content

Commit f812984

Browse files
committed
Ensure bandwidth is preserved in jac_gbmm! matrices
1 parent 6b0e72b commit f812984

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

src/Spaces/PolynomialSpace.jl

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -151,43 +151,56 @@ See https://github.com/JuliaLinearAlgebra/BandedMatrices.jl/blob/master/LICENSE
151151
end
152152

153153
_view(::Any, A, b) = view(A, b)
154-
_view(::Val{true}, A::BandedMatrix, b) = dataview(view(A, b))
154+
function _view(::Val{true}, A::BandedMatrix, b::Band)
155+
l, u = bandwidths(A)
156+
-l <= b.i <= u || throw(ArgumentError("invalid band $b for bandwidths $((-l,u))"))
157+
dataview(view(A, b))
158+
end
155159

156-
function _get_bands(B, C, bmk, f, ValBC)
160+
function _get_bands(B, C, bmk, f, valB)
157161
Cbmk = _view(Val(true), C, band(bmk*f))
158162
Bm = _view(Val(true), B, band(flipsign(bmk-1, f)))
159163
B0 = _view(Val(true), B, band(flipsign(bmk, f)))
160-
Bp = _view(ValBC, B, band(flipsign(bmk+1, f)))
164+
Bp = _view(valB, B, band(flipsign(bmk+1, f)))
161165
Cbmk, Bm, B0, Bp
162166
end
163167

164-
function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC)
165-
Jp = _view(ValJ, J, band(1))
166-
J0 = _view(ValJ, J, band(0))
167-
Jm = _view(ValJ, J, band(-1))
168+
# Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is
169+
# specified by b, not by the parameters in B
170+
function jac_gbmm!(α, J, B, β, C, b, valB)
171+
if β 1
172+
lmul!(β,C)
173+
end
174+
175+
n = size(J,1)
176+
Cn, Cm = size(C)
177+
178+
Jp = _view(Val(true), J, band(1))
179+
J0 = _view(Val(true), J, band(0))
180+
Jm = _view(Val(true), J, band(-1))
168181

169182
kr = intersect(-1:b-1, b-Cm+1:b-1+Cn)
170183

171184
# unwrap the loops to forward indexing to the data wherever applicable
172185
# this might also help with cache localization
173186
k = -1
174187
if k in kr
175-
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, ValBC)
188+
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB)
176189
for i in 1:n-b+k
177190
Cbmk[i] += α * Bm[i+1] * Jp[i]
178191
end
179192
end
180193

181194
k = 0
182195
if k in kr
183-
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, Val(true))
196+
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB)
184197
for i in 1:n-b+k
185198
Cbmk[i] += α * (Bm[i+1] * Jp[i] + B0[i] * J0[i])
186199
end
187200
end
188201

189202
for k in max(1, first(kr)):last(kr)
190-
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, Val(true))
203+
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB)
191204
Cbmk[1] += α * (Bm[2] * Jp[1] + B0[1] * J0[1])
192205
for i in 2:n-b+k
193206
Cbmk[i] += α * (Bm[i+1] * Jp[i] + B0[i] * J0[i] + Bp[i-1] * Jm[i-1])
@@ -198,15 +211,15 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC)
198211

199212
k = -1
200213
if k in kr
201-
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, ValBC)
214+
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, valB)
202215
for (i, Ji) in enumerate(b-k:n-1)
203216
Ckmb[i] += α * Bp[i] * Jm[Ji]
204217
end
205218
end
206219

207220
k = 0
208221
if k in kr
209-
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, Val(true))
222+
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, valB)
210223
Ckmb[1] += α * Bp[1] * Jm[b-k]
211224
for (i, Ji) in enumerate(b-k+1:n-1)
212225
Ckmb[i] += α * B0[i] * J0[Ji]
@@ -238,21 +251,6 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC)
238251
return C
239252
end
240253

241-
# Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is
242-
# specified by b, not by the parameters in B
243-
function jac_gbmm!(α, J, B, β, C, b, valJ, valBC)
244-
if β 1
245-
lmul!(β,C)
246-
end
247-
248-
n = size(J,1)
249-
Cn, Cm = size(C)
250-
251-
_jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, valJ, valBC)
252-
253-
C
254-
end
255-
256254
function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T},
257255
NTuple{2,UnitRange{Int}}}) where {PS<:PolynomialSpace,T,C<:PolynomialSpace}
258256
M=parent(S)
@@ -285,31 +283,31 @@ function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T},
285283

286284
#Multiplication is transpose
287285
J=Operator{T}(Recurrence(M.space))[jkr,jkr]
288-
valJ = all(>=(1), bandwidths(J)) ? Val(true) : Val(false)
289286

290287
B=n-1 # final bandwidth
291288

292289
# Clenshaw for operators
293290
Bk2 = BandedMatrix(Zeros{T}(size(J)), (B,B))
294291
dataview(view(Bk2, band(0))) .= a[n]/recβ(T,sp,n-1)
295292
α,β = recα(T,sp,n-1),recβ(T,sp,n-2)
296-
Bk1 = (-α/β)*Bk2
293+
Bk1 = lmul!(-α/β, copy(Bk2))
297294
dataview(view(Bk1, band(0))) .+= a[n-1]/β
298-
jac_gbmm!(one(T)/β,J,Bk2,one(T),Bk1,0,valJ, Val(true))
295+
jac_gbmm!(one(T)/β,J,Bk2,one(T),Bk1,0, Val(true))
299296
b=1 # we keep track of bandwidths manually to reuse memory
300297
for k=n-2:-1:2
298+
# b goes from 1:
301299
α,β,γ=recα(T,sp,k),recβ(T,sp,k-1),recγ(T,sp,k+1)
302300
lmul!(-γ/β,Bk2)
303301
dataview(view(Bk2, band(0))) .+= a[k]/β
304-
jac_gbmm!(1/β,J,Bk1,one(T),Bk2,b,valJ,Val(true))
302+
jac_gbmm!(1/β,J,Bk1,one(T),Bk2,b,Val(true))
305303
LinearAlgebra.axpy!(-α/β,Bk1,Bk2)
306304
Bk2,Bk1=Bk1,Bk2
307305
b+=1
308306
end
309307
α,γ=recα(T,sp,1),recγ(T,sp,2)
310308
lmul!(-γ,Bk2)
311309
dataview(view(Bk2, band(0))) .+= a[1]
312-
jac_gbmm!(one(T),J,Bk1,one(T),Bk2,b,valJ,Val(false))
310+
jac_gbmm!(one(T),J,Bk1,one(T),Bk2,b,Val(false))
313311
LinearAlgebra.axpy!(-α,Bk1,Bk2)
314312

315313
# relationship between jkr and kr, jr

0 commit comments

Comments
 (0)