Skip to content

Commit b4d1390

Browse files
Additional tests for 100% coverage
1 parent 65fe78d commit b4d1390

File tree

3 files changed

+342
-4
lines changed

3 files changed

+342
-4
lines changed

internal/fetcher/model_api_fetcher.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ func (v *BoolOrString) UnmarshalJSON(b []byte) error {
4848

4949
// ModelAPIFetcher fetches model metadata from the Hugging Face Hub API.
5050
type ModelAPIFetcher struct {
51-
Client *http.Client
52-
Token string
51+
Client *http.Client
52+
Token string
53+
BaseURL string // optional; defaults to "https://huggingface.co"
5354
}
5455

5556
// ModelAPIResponse is the decoded response from GET https://huggingface.co/api/models/:id
@@ -83,9 +84,15 @@ func (f *ModelAPIFetcher) Fetch(ctx context.Context, modelID string) (*ModelAPIR
8384
client = http.DefaultClient
8485
}
8586

86-
logf(modelID, "GET /api/models/%s", strings.TrimPrefix(strings.TrimSpace(modelID), "/"))
87+
trimmedModelID := strings.TrimPrefix(strings.TrimSpace(modelID), "/")
88+
logf(modelID, "GET /api/models/%s", trimmedModelID)
8789

88-
url := fmt.Sprintf("https://huggingface.co/api/models/%s", modelID)
90+
baseURL := strings.TrimRight(strings.TrimSpace(f.BaseURL), "/")
91+
if baseURL == "" {
92+
baseURL = "https://huggingface.co"
93+
}
94+
95+
url := fmt.Sprintf("%s/api/models/%s", baseURL, trimmedModelID)
8996
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
9097
if err != nil {
9198
return nil, err
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
package fetcher
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"errors"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"net/url"
12+
"strings"
13+
"testing"
14+
)
15+
16+
type roundTripperFunc func(*http.Request) (*http.Response, error)
17+
18+
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
19+
20+
func rewriteToServer(t *testing.T, srvURL string) http.RoundTripper {
21+
t.Helper()
22+
u, err := url.Parse(srvURL)
23+
if err != nil {
24+
t.Fatalf("parse server url: %v", err)
25+
}
26+
return roundTripperFunc(func(r *http.Request) (*http.Response, error) {
27+
rr := r.Clone(r.Context())
28+
rr.URL.Scheme = u.Scheme
29+
rr.URL.Host = u.Host
30+
rr.Host = u.Host
31+
rr.RequestURI = ""
32+
return http.DefaultTransport.RoundTrip(rr)
33+
})
34+
}
35+
36+
func TestBoolOrString_UnmarshalJSON(t *testing.T) {
37+
t.Run("empty bytes", func(t *testing.T) {
38+
var v BoolOrString
39+
if err := v.UnmarshalJSON([]byte("")); err != nil {
40+
t.Fatalf("expected nil error, got %v", err)
41+
}
42+
if v.Bool != nil || v.String != nil {
43+
t.Fatalf("expected nil fields, got Bool=%v String=%v", v.Bool, v.String)
44+
}
45+
})
46+
47+
t.Run("null", func(t *testing.T) {
48+
var v BoolOrString
49+
if err := v.UnmarshalJSON([]byte("null")); err != nil {
50+
t.Fatalf("expected nil error, got %v", err)
51+
}
52+
if v.Bool != nil || v.String != nil {
53+
t.Fatalf("expected nil fields, got Bool=%v String=%v", v.Bool, v.String)
54+
}
55+
})
56+
57+
t.Run("string", func(t *testing.T) {
58+
var v BoolOrString
59+
if err := v.UnmarshalJSON([]byte(`" auto "`)); err != nil {
60+
t.Fatalf("expected nil error, got %v", err)
61+
}
62+
if v.String == nil || *v.String != "auto" {
63+
t.Fatalf("expected String=auto, got %v", v.String)
64+
}
65+
if v.Bool != nil {
66+
t.Fatalf("expected Bool=nil, got %v", v.Bool)
67+
}
68+
})
69+
70+
t.Run("bool", func(t *testing.T) {
71+
var v BoolOrString
72+
if err := v.UnmarshalJSON([]byte(`true`)); err != nil {
73+
t.Fatalf("expected nil error, got %v", err)
74+
}
75+
if v.Bool == nil || *v.Bool != true {
76+
t.Fatalf("expected Bool=true, got %v", v.Bool)
77+
}
78+
if v.String != nil {
79+
t.Fatalf("expected String=nil, got %v", v.String)
80+
}
81+
})
82+
83+
t.Run("invalid string json", func(t *testing.T) {
84+
var v BoolOrString
85+
if err := v.UnmarshalJSON([]byte(`"unterminated`)); err == nil {
86+
t.Fatalf("expected error, got nil")
87+
}
88+
})
89+
90+
t.Run("invalid bool json", func(t *testing.T) {
91+
var v BoolOrString
92+
if err := v.UnmarshalJSON([]byte(`notabool`)); err == nil {
93+
t.Fatalf("expected error, got nil")
94+
}
95+
})
96+
}
97+
98+
func TestFetch_Success_DefaultClientNil_NoToken(t *testing.T) {
99+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
100+
if r.Method != http.MethodGet {
101+
t.Fatalf("method = %s", r.Method)
102+
}
103+
if r.URL.Path != "/api/models/my/model" {
104+
t.Fatalf("path = %q", r.URL.Path)
105+
}
106+
if got := r.Header.Get("Accept"); got != "application/json" {
107+
t.Fatalf("Accept = %q", got)
108+
}
109+
if got := r.Header.Get("Authorization"); got != "" {
110+
t.Fatalf("Authorization should be empty, got %q", got)
111+
}
112+
113+
w.Header().Set("Content-Type", "application/json")
114+
_ = json.NewEncoder(w).Encode(map[string]any{
115+
"id": "my/model",
116+
"modelId": "my/model",
117+
"library_name": "transformers",
118+
"pipeline_tag": "text-generation",
119+
"gated": "auto",
120+
})
121+
}))
122+
defer srv.Close()
123+
124+
f := &ModelAPIFetcher{
125+
Client: nil, // cover default-client branch
126+
BaseURL: srv.URL,
127+
}
128+
resp, err := f.Fetch(context.Background(), " /my/model ")
129+
if err != nil {
130+
t.Fatalf("Fetch error: %v", err)
131+
}
132+
if resp == nil {
133+
t.Fatalf("expected response")
134+
}
135+
if resp.Gated.String == nil || *resp.Gated.String != "auto" {
136+
t.Fatalf("expected gated string auto, got %#v", resp.Gated)
137+
}
138+
}
139+
140+
func TestFetch_SetsAuthorizationHeader_And_TrimsBaseURL(t *testing.T) {
141+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
142+
if got := r.Header.Get("Authorization"); got != "Bearer t0k" {
143+
t.Fatalf("Authorization = %q", got)
144+
}
145+
w.Header().Set("Content-Type", "application/json")
146+
_, _ = io.WriteString(w, `{"id":"x","modelId":"x","gated":true}`)
147+
}))
148+
defer srv.Close()
149+
150+
f := &ModelAPIFetcher{
151+
BaseURL: srv.URL + "/", // cover TrimRight branch
152+
Token: " t0k ",
153+
}
154+
resp, err := f.Fetch(context.Background(), "x")
155+
if err != nil {
156+
t.Fatalf("Fetch error: %v", err)
157+
}
158+
if resp.Gated.Bool == nil || *resp.Gated.Bool != true {
159+
t.Fatalf("expected gated bool true, got %#v", resp.Gated)
160+
}
161+
}
162+
163+
func TestFetch_Non200(t *testing.T) {
164+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
165+
w.WriteHeader(http.StatusForbidden)
166+
}))
167+
defer srv.Close()
168+
169+
f := &ModelAPIFetcher{BaseURL: srv.URL}
170+
_, err := f.Fetch(context.Background(), "x")
171+
if err == nil || !strings.Contains(err.Error(), "status 403") {
172+
t.Fatalf("expected status error, got %v", err)
173+
}
174+
}
175+
176+
func TestFetch_DecodeError(t *testing.T) {
177+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
178+
w.Header().Set("Content-Type", "application/json")
179+
_, _ = io.WriteString(w, "{") // invalid json
180+
}))
181+
defer srv.Close()
182+
183+
f := &ModelAPIFetcher{BaseURL: srv.URL}
184+
_, err := f.Fetch(context.Background(), "x")
185+
if err == nil {
186+
t.Fatalf("expected decode error, got nil")
187+
}
188+
}
189+
190+
func TestFetch_RequestError(t *testing.T) {
191+
want := errors.New("boom")
192+
f := &ModelAPIFetcher{
193+
BaseURL: "http://invalid.local",
194+
Client: &http.Client{
195+
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
196+
return nil, want
197+
}),
198+
},
199+
}
200+
_, err := f.Fetch(context.Background(), "x")
201+
if err == nil || !errors.Is(err, want) {
202+
t.Fatalf("expected %v, got %v", want, err)
203+
}
204+
}
205+
206+
func TestFetch_DefaultBaseURLBranch_WithoutNetwork(t *testing.T) {
207+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
208+
if r.URL.Path != "/api/models/p/q" {
209+
t.Fatalf("path = %q", r.URL.Path)
210+
}
211+
w.Header().Set("Content-Type", "application/json")
212+
_, _ = io.WriteString(w, `{"id":"p/q","modelId":"p/q","gated":false}`)
213+
}))
214+
defer srv.Close()
215+
216+
// BaseURL left empty to cover default-BaseURL branch, but transport rewrites to httptest server.
217+
f := &ModelAPIFetcher{
218+
BaseURL: " ",
219+
Client: &http.Client{Transport: rewriteToServer(t, srv.URL)},
220+
}
221+
resp, err := f.Fetch(context.Background(), "/p/q")
222+
if err != nil {
223+
t.Fatalf("Fetch error: %v", err)
224+
}
225+
if resp.Gated.Bool == nil || *resp.Gated.Bool != false {
226+
t.Fatalf("expected gated bool false, got %#v", resp.Gated)
227+
}
228+
}
229+
230+
func TestSetLoggerAndLogf_Writes(t *testing.T) {
231+
var buf bytes.Buffer
232+
SetLogger(&buf)
233+
logf("m", "hello %s", "world")
234+
if buf.Len() == 0 {
235+
t.Fatalf("expected log output")
236+
}
237+
if !strings.Contains(buf.String(), "hello") {
238+
t.Fatalf("expected message to contain %q, got %q", "hello", buf.String())
239+
}
240+
}
241+
242+
func TestFetch_NewRequestError_InvalidBaseURL(t *testing.T) {
243+
f := &ModelAPIFetcher{
244+
// Invalid host (missing closing bracket) => NewRequestWithContext should error.
245+
BaseURL: "http://[::1",
246+
}
247+
got, err := f.Fetch(context.Background(), "x")
248+
if err == nil {
249+
t.Fatalf("expected error, got nil")
250+
}
251+
if got != nil {
252+
t.Fatalf("expected nil response, got %#v", got)
253+
}
254+
}

internal/io/io_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,22 @@ func TestReadBOM_OpenError(t *testing.T) {
5555
}
5656
}
5757

58+
func TestReadBOM_Auto_SelectsJSONByExtension(t *testing.T) {
59+
dir := t.TempDir()
60+
p := filepath.Join(dir, "bom.json")
61+
if err := os.WriteFile(p, []byte(`{}`), 0o600); err != nil {
62+
t.Fatalf("WriteFile: %v", err)
63+
}
64+
65+
got, err := ReadBOM(p, "auto")
66+
if err != nil {
67+
t.Fatalf("ReadBOM(auto): %v", err)
68+
}
69+
if got == nil {
70+
t.Fatalf("expected BOM")
71+
}
72+
}
73+
5874
func TestReadBOM_DecodeError_WhenFormatDoesNotMatchContent(t *testing.T) {
5975
dir := t.TempDir()
6076
p := filepath.Join(dir, "bom.json")
@@ -180,3 +196,64 @@ func TestWriteBOM_UnsupportedFormat(t *testing.T) {
180196
t.Fatalf("expected error for unsupported format")
181197
}
182198
}
199+
200+
func TestWriteBOM_OpenError_WhenOutputIsDirectory(t *testing.T) {
201+
dir := t.TempDir()
202+
203+
// Make a *directory* that still has a valid ".json" extension so we get past
204+
// extension validation and hit the os.Create(...) error path.
205+
outDir := filepath.Join(dir, "bom.json")
206+
if err := os.Mkdir(outDir, 0o700); err != nil {
207+
t.Fatalf("Mkdir: %v", err)
208+
}
209+
210+
if err := WriteBOM(minimalBOM(), outDir, "json", ""); err == nil {
211+
t.Fatalf("expected error when output path is a directory")
212+
}
213+
}
214+
215+
func TestWriteBOM_UnsupportedFormat_Errors(t *testing.T) {
216+
dir := t.TempDir()
217+
out := filepath.Join(dir, "bom.json")
218+
219+
if err := WriteBOM(minimalBOM(), out, "yaml", ""); err == nil {
220+
t.Fatalf("expected error for unsupported write format")
221+
}
222+
}
223+
224+
func TestWriteBOM_ExtensionMismatch_XMLFormatButJSONPath(t *testing.T) {
225+
dir := t.TempDir()
226+
out := filepath.Join(dir, "bom.json")
227+
228+
if err := WriteBOM(minimalBOM(), out, "xml", ""); err == nil {
229+
t.Fatalf("expected error for extension/format mismatch")
230+
}
231+
}
232+
233+
func TestWriteBOM_ExtensionMismatch_JSONFormatButXMLPath(t *testing.T) {
234+
dir := t.TempDir()
235+
out := filepath.Join(dir, "bom.xml")
236+
237+
if err := WriteBOM(minimalBOM(), out, "json", ""); err == nil {
238+
t.Fatalf("expected error for extension/format mismatch")
239+
}
240+
}
241+
242+
func TestWriteBOM_SpecProvidedButInvalid_ReturnsError(t *testing.T) {
243+
dir := t.TempDir()
244+
out := filepath.Join(dir, "bom.json")
245+
246+
if err := WriteBOM(minimalBOM(), out, "json", "9.9"); err == nil {
247+
t.Fatalf("expected error for unsupported CycloneDX spec version")
248+
}
249+
}
250+
251+
func TestWriteBOM_Auto_UppercaseXMLExtension_HitsEqualFoldThenValidationMismatch(t *testing.T) {
252+
dir := t.TempDir()
253+
out := filepath.Join(dir, "bom.XML") // ext is ".XML"
254+
255+
// auto picks "xml" due to EqualFold(ext, ".xml"), then validation compares ext != ".xml" and errors.
256+
if err := WriteBOM(minimalBOM(), out, "auto", ""); err == nil {
257+
t.Fatalf("expected error for uppercase .XML extension validation mismatch")
258+
}
259+
}

0 commit comments

Comments
 (0)