Skip to content

Commit 3323d33

Browse files
feat: tests for generator package and small refactor
1 parent df3c373 commit 3323d33

File tree

3 files changed

+153
-28
lines changed

3 files changed

+153
-28
lines changed

cmd/generate.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cmd
22

33
import (
44
"fmt"
5+
"os"
56
"path/filepath"
67
"strings"
78
"time"
@@ -13,6 +14,7 @@ import (
1314
"github.com/idlab-discover/AIBoMGen-cli/internal/enricher"
1415
"github.com/idlab-discover/AIBoMGen-cli/internal/fetcher"
1516
"github.com/idlab-discover/AIBoMGen-cli/internal/generator"
17+
bomio "github.com/idlab-discover/AIBoMGen-cli/internal/io"
1618
"github.com/idlab-discover/AIBoMGen-cli/internal/metadata"
1719
"github.com/idlab-discover/AIBoMGen-cli/internal/scanner"
1820
)
@@ -174,6 +176,9 @@ var generateCmd = &cobra.Command{
174176
outputDir = "."
175177
}
176178
outputDir = filepath.Clean(outputDir)
179+
if err := os.MkdirAll(outputDir, 0o755); err != nil {
180+
return err
181+
}
177182

178183
fileExt := ".json"
179184
if fmtChosen == "xml" {
@@ -198,7 +203,7 @@ var generateCmd = &cobra.Command{
198203
fileName := fmt.Sprintf("%s_aibom%s", sanitized, fileExt)
199204
dest := filepath.Join(outputDir, fileName)
200205

201-
if err := generator.WriteWithFormatAndSpec(dest, d.BOM, fmtChosen, generateSpecVersion); err != nil {
206+
if err := bomio.WriteBOM(d.BOM, dest, fmtChosen, generateSpecVersion); err != nil {
202207
return err
203208
}
204209
written = append(written, dest)

internal/generator/generator.go

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@ package generator
33
import (
44
"context"
55
"net/http"
6-
"os"
7-
"path/filepath"
86
"strings"
97
"time"
108

119
"github.com/idlab-discover/AIBoMGen-cli/internal/builder"
1210
"github.com/idlab-discover/AIBoMGen-cli/internal/fetcher"
13-
bomio "github.com/idlab-discover/AIBoMGen-cli/internal/io"
1411
"github.com/idlab-discover/AIBoMGen-cli/internal/scanner"
1512

1613
cdx "github.com/CycloneDX/cyclonedx-go"
@@ -21,6 +18,14 @@ type DiscoveredBOM struct {
2118
BOM *cdx.BOM
2219
}
2320

21+
type bomBuilder interface {
22+
Build(builder.BuildContext) (*cdx.BOM, error)
23+
}
24+
25+
var newBOMBuilder = func() bomBuilder {
26+
return builder.NewBOMBuilder(builder.DefaultOptions())
27+
}
28+
2429
// BuildPerDiscovery orchestrates: fetch HF API (optional) → build BOM per model via registry-driven builder.
2530
func BuildPerDiscovery(discoveries []scanner.Discovery, hfToken string, timeout time.Duration) ([]DiscoveredBOM, error) {
2631
results := make([]DiscoveredBOM, 0, len(discoveries))
@@ -32,7 +37,7 @@ func BuildPerDiscovery(discoveries []scanner.Discovery, hfToken string, timeout
3237
httpClient := &http.Client{Timeout: timeout}
3338
apiFetcher := &fetcher.ModelAPIFetcher{Client: httpClient, Token: hfToken}
3439

35-
bomBuilder := builder.NewBOMBuilder(builder.DefaultOptions())
40+
bomBuilder := newBOMBuilder()
3641

3742
for _, d := range discoveries {
3843
modelID := strings.TrimSpace(d.ID)
@@ -75,26 +80,3 @@ func BuildPerDiscovery(discoveries []scanner.Discovery, hfToken string, timeout
7580

7681
return results, nil
7782
}
78-
79-
// Write writes the BOM to the given output path, creating directories as needed.
80-
func Write(outputPath string, bom *cdx.BOM) error { return WriteWithFormat(outputPath, bom, "json") }
81-
82-
// WriteWithFormat writes the BOM in the specified format (json|xml). If format is auto, infer from extension.
83-
func WriteWithFormat(outputPath string, bom *cdx.BOM, format string) error {
84-
return WriteWithFormatAndSpec(outputPath, bom, format, "")
85-
}
86-
87-
// WriteWithFormatAndSpec writes the BOM with the specified file format and optional spec version.
88-
// If spec is non-empty (e.g. "1.3"), EncodeVersion is used.
89-
func WriteWithFormatAndSpec(outputPath string, bom *cdx.BOM, format string, spec string) error {
90-
if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil {
91-
return err
92-
}
93-
return bomio.WriteBOM(bom, outputPath, format, spec)
94-
}
95-
96-
// ParseSpecVersion parses a spec version string.
97-
// Deprecated: Use bomio.ParseSpecVersion instead.
98-
func ParseSpecVersion(s string) (cdx.SpecVersion, bool) {
99-
return bomio.ParseSpecVersion(s)
100-
}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package generator
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"io"
7+
"net/http"
8+
"strings"
9+
"testing"
10+
"time"
11+
12+
cdx "github.com/CycloneDX/cyclonedx-go"
13+
"github.com/idlab-discover/AIBoMGen-cli/internal/builder"
14+
"github.com/idlab-discover/AIBoMGen-cli/internal/scanner"
15+
"github.com/idlab-discover/AIBoMGen-cli/internal/ui"
16+
)
17+
18+
type failingBuilder struct {
19+
err error
20+
}
21+
22+
func (f *failingBuilder) Build(builder.BuildContext) (*cdx.BOM, error) {
23+
return nil, f.err
24+
}
25+
26+
type roundTripFunc func(*http.Request) (*http.Response, error)
27+
28+
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
29+
return f(req)
30+
}
31+
32+
func TestBuildPerDiscovery_FetchesMetadataAndBuilds(t *testing.T) {
33+
responses := []string{
34+
`{"id":"hf-alpha","modelId":"hf-alpha","author":"org","pipeline_tag":"tag","library_name":"lib","tags":["t1"],"license":"mit","sha":"abc","downloads":1,"likes":1,"lastModified":"2024-01-01","createdAt":"2023-01-01","private":false,"usedStorage":1,"cardData":{"license":"mit"}}`,
35+
`{"id":"hf-beta","modelId":"hf-beta","author":"org","pipeline_tag":"tag","library_name":"lib","tags":["t2"],"license":"apache","sha":"def","downloads":2,"likes":2,"lastModified":"2024-02-01","createdAt":"2023-02-01","private":false,"usedStorage":2,"cardData":{"license":"apache"}}`,
36+
}
37+
38+
origTransport := http.DefaultTransport
39+
var paths []string
40+
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
41+
idx := len(paths)
42+
if idx >= len(responses) {
43+
t.Fatalf("unexpected request #%d to %s", idx+1, req.URL)
44+
}
45+
if got, want := req.Header.Get("Authorization"), "Bearer test-token"; got != want {
46+
t.Fatalf("authorization header = %q, want %q", got, want)
47+
}
48+
paths = append(paths, req.URL.Path)
49+
body := io.NopCloser(strings.NewReader(responses[idx]))
50+
return &http.Response{StatusCode: http.StatusOK, Body: body, Header: make(http.Header)}, nil
51+
})
52+
t.Cleanup(func() { http.DefaultTransport = origTransport })
53+
54+
discoveries := []scanner.Discovery{
55+
{ID: "org-model", Path: "model.py"},
56+
{Name: "beta"},
57+
{},
58+
}
59+
60+
got, err := BuildPerDiscovery(discoveries, "test-token", 0)
61+
if err != nil {
62+
t.Fatalf("BuildPerDiscovery() error = %v", err)
63+
}
64+
if len(got) != len(discoveries) {
65+
t.Fatalf("results len = %d, want %d", len(got), len(discoveries))
66+
}
67+
if len(paths) != 2 {
68+
t.Fatalf("expected 2 fetches, got %d", len(paths))
69+
}
70+
if got[0].BOM == nil || got[0].BOM.Metadata == nil || got[0].BOM.Metadata.Component == nil {
71+
t.Fatalf("first BOM missing metadata/component")
72+
}
73+
if got[0].BOM.Metadata.Component.Name != "hf-alpha" {
74+
t.Fatalf("first component name = %q, want hf-alpha", got[0].BOM.Metadata.Component.Name)
75+
}
76+
if got[1].Discovery.Name != "beta" {
77+
t.Fatalf("second discovery preserved name, got %q", got[1].Discovery.Name)
78+
}
79+
if got[2].BOM.Metadata.Component.Name != "model" {
80+
t.Fatalf("third component default name = %q, want model", got[2].BOM.Metadata.Component.Name)
81+
}
82+
if !strings.Contains(paths[1], "beta") {
83+
t.Fatalf("second request path %q missing beta", paths[1])
84+
}
85+
}
86+
87+
func TestBuildPerDiscovery_FetchErrorStillBuilds(t *testing.T) {
88+
origTransport := http.DefaultTransport
89+
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
90+
return nil, errors.New("boom")
91+
})
92+
t.Cleanup(func() { http.DefaultTransport = origTransport })
93+
94+
discoveries := []scanner.Discovery{{ID: "err-model"}}
95+
got, err := BuildPerDiscovery(discoveries, "", 5*time.Second)
96+
if err != nil {
97+
t.Fatalf("BuildPerDiscovery() error = %v", err)
98+
}
99+
if len(got) != 1 {
100+
t.Fatalf("len = %d, want 1", len(got))
101+
}
102+
if got[0].BOM == nil || got[0].BOM.Metadata == nil || got[0].BOM.Metadata.Component == nil {
103+
t.Fatalf("result bom missing metadata/component")
104+
}
105+
if got[0].BOM.Metadata.Component.Name != "err-model" {
106+
t.Fatalf("component name = %q, want err-model", got[0].BOM.Metadata.Component.Name)
107+
}
108+
}
109+
110+
func TestLogfWritesWithConfiguredLogger(t *testing.T) {
111+
ui.Init(true)
112+
113+
var buf bytes.Buffer
114+
SetLogger(&buf)
115+
t.Cleanup(func() { SetLogger(nil) })
116+
117+
logf("model-x", "hello %s", "world")
118+
119+
got := buf.String()
120+
for _, want := range []string{"Generator:", "model=model-x", "hello world"} {
121+
if !strings.Contains(got, want) {
122+
t.Fatalf("log output %q missing %q", got, want)
123+
}
124+
}
125+
}
126+
127+
func TestBuildPerDiscovery_PropagatesBuilderError(t *testing.T) {
128+
origFactory := newBOMBuilder
129+
newBOMBuilder = func() bomBuilder {
130+
return &failingBuilder{err: errors.New("builder boom")}
131+
}
132+
t.Cleanup(func() { newBOMBuilder = origFactory })
133+
134+
_, err := BuildPerDiscovery([]scanner.Discovery{{}}, "", time.Second)
135+
if err == nil || !strings.Contains(err.Error(), "builder boom") {
136+
t.Fatalf("expected builder error, got %v", err)
137+
}
138+
}

0 commit comments

Comments
 (0)