@@ -108,8 +108,12 @@ func (c *CSR) Mul(a, b mat.Matrix) {
108
108
c .mulCSRCSR (lhs , rhs )
109
109
return
110
110
}
111
-
112
111
if dia , ok := a .(* DIA ); ok {
112
+ if diaB , okB := b .(* DIA ); okB {
113
+ // handle DIA * DIA
114
+ c .mulDIADIA (dia , diaB )
115
+ return
116
+ }
113
117
if isRCsr {
114
118
// handle DIA * CSR
115
119
c .mulDIACSR (dia , rhs , false )
@@ -129,7 +133,6 @@ func (c *CSR) Mul(a, b mat.Matrix) {
129
133
c .mulDIAMat (dia , a , true )
130
134
return
131
135
}
132
- // TODO: handle cases where both matrices are DIA
133
136
134
137
srcA , isLSparse := a .(TypeConverter )
135
138
srcB , isRSparse := b .(TypeConverter )
@@ -328,6 +331,47 @@ func (c *CSR) mulDIAMat(dia *DIA, other mat.Matrix, trans bool) {
328
331
}
329
332
}
330
333
334
+ // mulDIADIA multiplies two diagonal matrices
335
+ func (c * CSR ) mulDIADIA (a , b * DIA ) {
336
+ _ , ac := a .Dims ()
337
+ br , _ := b .Dims ()
338
+ aDiagonal := a .Diagonal ()
339
+ bDiagonal := a .Diagonal ()
340
+ if ac != br {
341
+ panic (mat .ErrShape )
342
+ }
343
+ for i := 0 ; i < br ; i ++ {
344
+ var v float64
345
+ v = aDiagonal [i ] * bDiagonal [i ]
346
+ if v != 0 {
347
+ c .matrix .Ind = append (c .matrix .Ind , i )
348
+ c .matrix .Data = append (c .matrix .Data , v )
349
+ }
350
+ c .matrix .Indptr [i + 1 ] = i + 1
351
+ }
352
+ }
353
+
354
+ // addDIADIA add two diagonal matrices
355
+ func (c * CSR ) addDIADIA (a , b * DIA , alpha , beta float64 ) {
356
+ ar , ac := a .Dims ()
357
+ br , bc := b .Dims ()
358
+ aDiagonal := a .Diagonal ()
359
+ bDiagonal := a .Diagonal ()
360
+ if ac != bc {
361
+ panic (mat .ErrShape )
362
+ }
363
+ if ar != br {
364
+ panic (mat .ErrShape )
365
+ }
366
+ for i := 0 ; i < br ; i ++ {
367
+ var v float64
368
+ v = aDiagonal [i ]* alpha + bDiagonal [i ]* beta
369
+ c .matrix .Ind = append (c .matrix .Ind , i )
370
+ c .matrix .Data = append (c .matrix .Data , v )
371
+ c .matrix .Indptr [i + 1 ] = i + 1
372
+ }
373
+ }
374
+
331
375
// Sub subtracts matrix b from a and stores the result in the receiver.
332
376
// If matrices a and b are not the same shape then the method will panic.
333
377
func (c * CSR ) Sub (a , b mat.Matrix ) {
@@ -354,9 +398,17 @@ func (c *CSR) addScaled(a mat.Matrix, b mat.Matrix, alpha float64, beta float64)
354
398
c = m
355
399
}
356
400
401
+ // special case both diagonal
402
+ lDIA , lIsDIA := a .(* DIA )
403
+ rDIA , rIsDIA := b .(* DIA )
404
+ if lIsDIA && rIsDIA {
405
+ c .addDIADIA (lDIA , rDIA , alpha , beta )
406
+ return
407
+ }
408
+
409
+ // and then one or both csr
357
410
lCsr , lIsCsr := a .(* CSR )
358
411
rCsr , rIsCsr := b .(* CSR )
359
- // TODO optimisation for DIA matrices
360
412
if lIsCsr && rIsCsr {
361
413
c .addCSRCSR (lCsr , rCsr , alpha , beta )
362
414
return
0 commit comments