Skip to content

Commit 10d56a1

Browse files
committed
add special-cases for dia-dia in csr arithmetic
SolveTo renamed per gonum/gonum#830
1 parent 7d83420 commit 10d56a1

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

compressed_arith.go

+55-3
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,12 @@ func (c *CSR) Mul(a, b mat.Matrix) {
108108
c.mulCSRCSR(lhs, rhs)
109109
return
110110
}
111-
112111
if dia, ok := a.(*DIA); ok {
112+
if diaB, okB := b.(*DIA); okB {
113+
// handle DIA * DIA
114+
c.mulDIADIA(dia, diaB)
115+
return
116+
}
113117
if isRCsr {
114118
// handle DIA * CSR
115119
c.mulDIACSR(dia, rhs, false)
@@ -129,7 +133,6 @@ func (c *CSR) Mul(a, b mat.Matrix) {
129133
c.mulDIAMat(dia, a, true)
130134
return
131135
}
132-
// TODO: handle cases where both matrices are DIA
133136

134137
srcA, isLSparse := a.(TypeConverter)
135138
srcB, isRSparse := b.(TypeConverter)
@@ -328,6 +331,47 @@ func (c *CSR) mulDIAMat(dia *DIA, other mat.Matrix, trans bool) {
328331
}
329332
}
330333

334+
// mulDIADIA multiplies two diagonal matrices
335+
func (c *CSR) mulDIADIA(a, b *DIA) {
336+
_, ac := a.Dims()
337+
br, _ := b.Dims()
338+
aDiagonal := a.Diagonal()
339+
bDiagonal := a.Diagonal()
340+
if ac != br {
341+
panic(mat.ErrShape)
342+
}
343+
for i := 0; i < br; i++ {
344+
var v float64
345+
v = aDiagonal[i] * bDiagonal[i]
346+
if v != 0 {
347+
c.matrix.Ind = append(c.matrix.Ind, i)
348+
c.matrix.Data = append(c.matrix.Data, v)
349+
}
350+
c.matrix.Indptr[i+1] = i + 1
351+
}
352+
}
353+
354+
// addDIADIA add two diagonal matrices
355+
func (c *CSR) addDIADIA(a, b *DIA, alpha, beta float64) {
356+
ar, ac := a.Dims()
357+
br, bc := b.Dims()
358+
aDiagonal := a.Diagonal()
359+
bDiagonal := a.Diagonal()
360+
if ac != bc {
361+
panic(mat.ErrShape)
362+
}
363+
if ar != br {
364+
panic(mat.ErrShape)
365+
}
366+
for i := 0; i < br; i++ {
367+
var v float64
368+
v = aDiagonal[i]*alpha + bDiagonal[i]*beta
369+
c.matrix.Ind = append(c.matrix.Ind, i)
370+
c.matrix.Data = append(c.matrix.Data, v)
371+
c.matrix.Indptr[i+1] = i + 1
372+
}
373+
}
374+
331375
// Sub subtracts matrix b from a and stores the result in the receiver.
332376
// If matrices a and b are not the same shape then the method will panic.
333377
func (c *CSR) Sub(a, b mat.Matrix) {
@@ -354,9 +398,17 @@ func (c *CSR) addScaled(a mat.Matrix, b mat.Matrix, alpha float64, beta float64)
354398
c = m
355399
}
356400

401+
// special case both diagonal
402+
lDIA, lIsDIA := a.(*DIA)
403+
rDIA, rIsDIA := b.(*DIA)
404+
if lIsDIA && rIsDIA {
405+
c.addDIADIA(lDIA, rDIA, alpha, beta)
406+
return
407+
}
408+
409+
// and then one or both csr
357410
lCsr, lIsCsr := a.(*CSR)
358411
rCsr, rIsCsr := b.(*CSR)
359-
// TODO optimisation for DIA matrices
360412
if lIsCsr && rIsCsr {
361413
c.addCSRCSR(lCsr, rCsr, alpha, beta)
362414
return

example_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func ExampleLU() {
8585
var lu mat.LU
8686
lu.Factorize(A)
8787

88-
err := lu.Solve(d, false, f)
88+
err := lu.SolveTo(d, false, f)
8989
if err != nil {
9090
fmt.Printf("err = %v", err)
9191
return

0 commit comments

Comments
 (0)