Skip to content

Commit 6d30712

Browse files
authored
perf: do api queries in parallel using errgroup (#347)
* refactor: use `errgroup.Group` to handle parallel fetching * perf: do batch checks in parallel * fix: avoid data race * fix: avoid another data race * test: use content-length check * fix: have at least some request limit
1 parent dcbe42d commit 6d30712

File tree

5 files changed

+51
-55
lines changed

5 files changed

+51
-55
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
github.com/google/osv-scalibr v0.3.5-0.20251002191929-de9496dc5aa2
1111
github.com/tidwall/jsonc v0.3.2
1212
golang.org/x/mod v0.30.0
13+
golang.org/x/sync v0.16.0
1314
gopkg.in/yaml.v3 v3.0.1
1415
)
1516

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
4444
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
4545
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
4646
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
47+
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
48+
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
4749
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
4850
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
4951
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=

pkg/database/api-check.go

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"path"
1313

1414
"github.com/g-rath/osv-detector/internal"
15+
"golang.org/x/sync/errgroup"
1516
)
1617

1718
func (db APIDB) buildAPIPayload(pkg internal.PackageDetails) apiQuery {
@@ -171,15 +172,38 @@ func findOrDefault(vulns Vulnerabilities, def OSV) OSV {
171172
func (db APIDB) Check(pkgs []internal.PackageDetails) ([]Vulnerabilities, error) {
172173
batches := batchPkgs(pkgs, db.BatchSize)
173174

174-
vulnerabilities := make([]Vulnerabilities, 0, len(pkgs))
175+
var eg errgroup.Group
175176

176-
for _, batch := range batches {
177-
results, err := db.checkBatch(batch)
177+
// use a sensible upper limit so it's not possible to have inf. operations going
178+
// even though it's very unlikely there will be more than a couple of batches
179+
eg.SetLimit(100)
178180

179-
if err != nil {
180-
return nil, err
181-
}
181+
batchResults := make([][][]ObjectWithID, len(batches))
182+
183+
for i, batch := range batches {
184+
eg.Go(func() error {
185+
results, err := db.checkBatch(batch)
186+
187+
if err != nil {
188+
return err
189+
}
190+
191+
batchResults[i] = results
192+
193+
return nil
194+
})
195+
}
196+
197+
err := eg.Wait()
198+
199+
if err != nil {
200+
return nil, err
201+
}
202+
203+
vulnerabilities := make([]Vulnerabilities, 0, len(pkgs))
182204

205+
// todo: pretty sure some of these loops and slices can be merged and simplified
206+
for _, results := range batchResults {
183207
for _, withIDs := range results {
184208
vulns := make(Vulnerabilities, 0, len(withIDs))
185209

pkg/database/api-check_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,15 @@ func TestAPIDB_Check_Batches(t *testing.T) {
553553
mux.HandleFunc("/querybatch", func(w http.ResponseWriter, r *http.Request) {
554554
requestCount++
555555

556+
if requestCount > 2 {
557+
t.Errorf("unexpected number of requests (%d)", requestCount)
558+
}
559+
556560
var expectedPayload []apiQuery
557561
var batchResponse []objectsWithIDs
558562

559-
switch requestCount {
560-
case 1:
563+
// strictly speaking not the best of checks, but it should be good enough
564+
if r.ContentLength > 100 {
561565
expectedPayload = []apiQuery{
562566
{
563567
Version: "1.0.0",
@@ -569,16 +573,14 @@ func TestAPIDB_Check_Batches(t *testing.T) {
569573
},
570574
}
571575
batchResponse = []objectsWithIDs{{}, {}}
572-
case 2:
576+
} else if r.ContentLength > 50 {
573577
expectedPayload = []apiQuery{
574578
{
575579
Version: "2.3.1",
576580
Package: apiPackage{Name: "their-package", Ecosystem: lockfile.NpmEcosystem},
577581
},
578582
}
579583
batchResponse = []objectsWithIDs{{}}
580-
default:
581-
t.Errorf("unexpected number of requests (%d)", requestCount)
582584
}
583585

584586
expectRequestPayload(t, r, expectedPayload)

pkg/database/api-fetch-all.go

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,65 +2,32 @@ package database
22

33
import (
44
"sort"
5-
)
65

7-
// a struct to hold the result from each request including an index
8-
// which will be used for sorting the results after they come in
9-
type result struct {
10-
index int
11-
res OSV
12-
err error
13-
}
6+
"golang.org/x/sync/errgroup"
7+
)
148

159
func (db APIDB) FetchAll(ids []string) Vulnerabilities {
16-
conLimit := 200
10+
var eg errgroup.Group
1711

18-
var osvs Vulnerabilities
19-
20-
if len(ids) == 0 {
21-
return osvs
22-
}
12+
eg.SetLimit(200)
2313

24-
// buffered channel which controls the number of concurrent operations
25-
semaphoreChan := make(chan struct{}, conLimit)
26-
resultsChan := make(chan *result)
27-
28-
defer func() {
29-
close(semaphoreChan)
30-
close(resultsChan)
31-
}()
14+
osvs := make(Vulnerabilities, len(ids))
3215

3316
for i, id := range ids {
34-
go func(i int, id string) {
35-
// read from the buffered semaphore channel, which will block if we're
36-
// already got as many goroutines as our concurrency limit allows
37-
//
38-
// when one of those routines finish they'll read from this channel,
39-
// freeing up a slot to unblock this send
40-
semaphoreChan <- struct{}{}
41-
17+
eg.Go(func() error {
4218
// if we error, still report the vulnerability as hopefully the ID should be
4319
// enough to manually look up the details - in future we should ideally warn
4420
// the user too, but for now we just silently eat the error
4521
osv, _ := db.Fetch(id)
46-
result := &result{i, osv, nil}
4722

48-
resultsChan <- result
23+
osvs[i] = osv
4924

50-
// read from the buffered semaphore to free up a slot to allow
51-
// another goroutine to start, since this one is wrapping up
52-
<-semaphoreChan
53-
}(i, id)
25+
return nil
26+
})
5427
}
5528

56-
for {
57-
result := <-resultsChan
58-
osvs = append(osvs, result.res)
59-
60-
if len(osvs) == len(ids) {
61-
break
62-
}
63-
}
29+
// errors are handled within the go routines
30+
_ = eg.Wait()
6431

6532
sort.Slice(osvs, func(i, j int) bool {
6633
return osvs[i].ID < osvs[j].ID

0 commit comments

Comments
 (0)