Skip to content

Commit e56f85f

Browse files
authored
Merge pull request #23 from jonreiter/mulelemvec
add MulElemVec for element-wise multiplication of sparse vectors
2 parents 30083ba + c130c4d commit e56f85f

File tree

3 files changed

+70
-22
lines changed

3 files changed

+70
-22
lines changed

cholesky.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (ch *Cholesky) At(i, j int) float64 {
4141
ri := ch.chol.RowView(i).(*Vector)
4242
rj := ch.chol.RowView(j).(*Vector)
4343
// FIXME: check types
44-
val = dotSparseSparseNoSortBefore(ri, rj, min(i, j)+1)
44+
val = dotSparseSparseNoSortBefore(ri, rj, nil, min(i, j)+1)
4545
return val
4646
}
4747

@@ -212,10 +212,6 @@ func cholCSR(matrix *CSR, lower *CSR) {
212212
if matrix.RowNNZ(i) == 0 {
213213
continue
214214
}
215-
// rowDotSum := 0.0
216-
// aPos := 0
217-
// bPos := 0
218-
// thisSum := 0.0
219215
for j := 0; j <= i; j++ {
220216
iRow := lower.RowView(i)
221217
iRowS, iRowIsSparse := iRow.(*Vector)
@@ -231,9 +227,7 @@ func cholCSR(matrix *CSR, lower *CSR) {
231227
}
232228
lower.Set(j, j, math.Sqrt(matrix.At(i, i)-sum))
233229
} else {
234-
// thisSum, _, _ = dotSparseSparseNoSortBeforeWithStart(iRowS, jRowS, j, aPos, bPos)
235-
// rowDotSum = thisSum
236-
rowDotSum := dotSparseSparseNoSort(iRowS, jRowS)
230+
rowDotSum := dotSparseSparseNoSort(iRowS, jRowS, nil)
237231
if rowDotSum == 0.0 && matrix.At(i, j) == 0.0 {
238232
continue
239233
}

vector.go

+31-9
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ func Dot(a, b mat.Vector) float64 {
349349

350350
if aIsSparse {
351351
if bIsSparse {
352-
return dotSparseSparse(as, bs)
352+
return dotSparseSparse(as, bs, nil)
353353
}
354354
if bdense, bIsDense := b.(mat.RawVectorer); bIsDense {
355355
raw := bdense.RawVector()
@@ -557,23 +557,23 @@ func (v *Vector) IsSorted() bool {
557557

558558
// dotSparseSparse computes the dot product of two sparse vectors.
559559
// This will be called by Dot if both entered vectors are Sparse.
560-
func dotSparseSparse(a, b *Vector) float64 {
560+
func dotSparseSparse(a, b, c *Vector) float64 {
561561
a.Sort()
562562
b.Sort()
563-
return dotSparseSparseNoSort(a, b)
563+
return dotSparseSparseNoSort(a, b, c)
564564
}
565565

566-
func dotSparseSparseNoSort(a, b *Vector) float64 {
566+
func dotSparseSparseNoSort(a, b, c *Vector) float64 {
567567
n := a.Len()
568-
return dotSparseSparseNoSortBefore(a, b, n)
568+
return dotSparseSparseNoSortBefore(a, b, c, n)
569569
}
570570

571-
func dotSparseSparseNoSortBefore(a, b *Vector, n int) float64 {
572-
v, _, _ := dotSparseSparseNoSortBeforeWithStart(a, b, n, 0, 0)
571+
func dotSparseSparseNoSortBefore(a, b, c *Vector, n int) float64 {
572+
v, _, _ := dotSparseSparseNoSortBeforeWithStart(a, b, c, n, 0, 0)
573573
return v
574574
}
575575

576-
func dotSparseSparseNoSortBeforeWithStart(a, b *Vector, n, aStart, bStart int) (float64, int, int) {
576+
func dotSparseSparseNoSortBeforeWithStart(a, b, c *Vector, n, aStart, bStart int) (float64, int, int) {
577577
tot := 0.0
578578
aPos := aStart
579579
bPos := bStart
@@ -583,7 +583,11 @@ func dotSparseSparseNoSortBeforeWithStart(a, b *Vector, n, aStart, bStart int) (
583583
aIndex = a.ind[aPos]
584584
bIndex = b.ind[bPos]
585585
if aIndex == bIndex {
586-
tot += a.data[aPos] * b.data[bPos]
586+
val := a.data[aPos] * b.data[bPos]
587+
tot += val
588+
if c != nil {
589+
c.SetVec(aIndex, val)
590+
}
587591
aPos++
588592
bPos++
589593
} else {
@@ -596,3 +600,21 @@ func dotSparseSparseNoSortBeforeWithStart(a, b *Vector, n, aStart, bStart int) (
596600
}
597601
return tot, aPos, bPos
598602
}
603+
604+
// MulElemVec does element-by-element multiplication of a and b
605+
// and puts the result in the receiver.
606+
func (v *Vector) MulElemVec(a, b *Vector) {
607+
ar := a.Len()
608+
br := b.Len()
609+
if ar != br {
610+
panic(mat.ErrShape)
611+
}
612+
aNNZ := a.NNZ()
613+
bNNZ := b.NNZ()
614+
minNNZ := aNNZ
615+
if bNNZ < minNNZ {
616+
minNNZ = bNNZ
617+
}
618+
v.reuseAs(ar, minNNZ, true)
619+
dotSparseSparse(a, b, v)
620+
}

vector_test.go

+37-5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ func TestDot(t *testing.T) {
156156
b: NewVector(6, []int{0, 1, 3}, []float64{1, 1, 2}),
157157
r: 5,
158158
},
159+
{
160+
a: NewVector(6, []int{4, 3, 1}, []float64{1, 2, 1}),
161+
b: NewVector(6, []int{1, 0, 3}, []float64{1, 1, 2}),
162+
r: 5,
163+
},
159164
{
160165
a: mat.NewVecDense(6, []float64{0, 1, 0, 2, 1, 0}),
161166
b: NewVector(6, []int{0, 1, 3}, []float64{1, 1, 2}),
@@ -171,13 +176,40 @@ func TestDot(t *testing.T) {
171176
for ti, test := range tests {
172177

173178
result := Dot(test.a, test.b)
174-
175179
if result != test.r {
176180
t.Errorf("Test %d: Incorrect result for Dot - expected:\n%v\nbut received:\n%v\n", ti+1, test.r, result)
177181
}
178182
}
179183
}
180184

185+
func TestMulElemVec(t *testing.T) {
186+
v := NewVector(6, nil, nil)
187+
tests := []struct {
188+
a *Vector
189+
b *Vector
190+
r *Vector
191+
}{
192+
{
193+
a: NewVector(6, []int{1, 3, 4}, []float64{1, 2, 1}),
194+
b: NewVector(6, []int{0, 1, 3}, []float64{1, 1, 2}),
195+
r: NewVector(6, []int{1, 3}, []float64{1, 4}),
196+
},
197+
{
198+
a: NewVector(6, []int{4, 3, 1}, []float64{1, 2, 1}),
199+
b: NewVector(6, []int{1, 0, 3}, []float64{1, 1, 2}),
200+
r: NewVector(6, []int{1, 3}, []float64{1, 4}),
201+
},
202+
}
203+
204+
for ti, test := range tests {
205+
206+
v.MulElemVec(test.a, test.b)
207+
if !mat.Equal(v, test.r) {
208+
t.Errorf("Test %d: Incorrect result for MulElemVec - expected:\n%v\nbut received:\n%v\n", ti+1, test.r, v)
209+
}
210+
}
211+
}
212+
181213
func TestVectorNorm(t *testing.T) {
182214
tests := []struct {
183215
a mat.Vector
@@ -538,7 +570,7 @@ func TestVectorSet(t *testing.T) {
538570
},
539571
}
540572

541-
for ti, test := range tests {
573+
for ti, test := range tests {
542574
act := new(Vector)
543575
act.CloneVec(test.source)
544576
act.Set(test.idx, 0, test.val)
@@ -635,9 +667,9 @@ func TestVecSetPanic(t *testing.T) {
635667
act.Set(2, col, 1.1)
636668

637669
}(ti, test.col, test.source)
638-
}
670+
}
639671
}
640-
672+
641673
func TestMulMatSparseVec(t *testing.T) {
642674
permsB := []struct {
643675
name string
@@ -687,7 +719,7 @@ func TestMulMatSparseVec(t *testing.T) {
687719
},
688720
}
689721

690-
for ti, test := range tests {
722+
for ti, test := range tests {
691723
for _, b := range permsB {
692724
for _, rawa := range matrixPermutationsForA {
693725
var matPair = []struct {

0 commit comments

Comments
 (0)