generated from milosgajdos/go-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add ollama embeddings API (#23)
Signed-off-by: Milos Gajdos <[email protected]>
- Loading branch information
1 parent
bf24cee
commit 4ffb3f8
Showing
7 changed files
with
227 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"flag" | ||
"fmt" | ||
"log" | ||
|
||
"github.com/milosgajdos/go-embeddings/ollama" | ||
) | ||
|
||
var ( | ||
prompt string | ||
model string | ||
) | ||
|
||
func init() { | ||
flag.StringVar(&prompt, "prompt", "what is life", "input prompt") | ||
flag.StringVar(&model, "model", "", "model name") | ||
} | ||
|
||
func main() { | ||
flag.Parse() | ||
|
||
if model == "" { | ||
log.Fatal("missing ollama model") | ||
} | ||
|
||
c := ollama.NewClient() | ||
|
||
embReq := &ollama.EmbeddingRequest{ | ||
Prompt: prompt, | ||
Model: model, | ||
} | ||
|
||
embs, err := c.Embed(context.Background(), embReq) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
fmt.Printf("got %d embeddings", len(embs)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package ollama | ||
|
||
import ( | ||
"github.com/milosgajdos/go-embeddings" | ||
"github.com/milosgajdos/go-embeddings/client" | ||
) | ||
|
||
const ( | ||
// BaseURL is Ollama HTTP API embeddings base URL. | ||
BaseURL = "http://localhost:11434/api" | ||
) | ||
|
||
// Client is an OpenAI HTTP API client. | ||
type Client struct { | ||
opts Options | ||
} | ||
|
||
type Options struct { | ||
BaseURL string | ||
HTTPClient *client.HTTP | ||
} | ||
|
||
// Option is functional graph option. | ||
type Option func(*Options) | ||
|
||
// NewClient creates a new Ollama HTTP API client and returns it. | ||
// You can override the default options via the client methods. | ||
func NewClient(opts ...Option) *Client { | ||
options := Options{ | ||
BaseURL: BaseURL, | ||
HTTPClient: client.NewHTTP(), | ||
} | ||
|
||
for _, apply := range opts { | ||
apply(&options) | ||
} | ||
|
||
return &Client{ | ||
opts: options, | ||
} | ||
} | ||
|
||
// NewEmbedder creates a client that implements embeddings.Embedder | ||
func NewEmbedder(opts ...Option) embeddings.Embedder[*EmbeddingRequest] { | ||
return NewClient(opts...) | ||
} | ||
|
||
// WithBaseURL sets the API base URL. | ||
func WithBaseURL(baseURL string) Option { | ||
return func(o *Options) { | ||
o.BaseURL = baseURL | ||
} | ||
} | ||
|
||
// WithVersion sets the API version. | ||
// WithHTTPClient sets the HTTP client. | ||
func WithHTTPClient(httpClient *client.HTTP) Option { | ||
return func(o *Options) { | ||
o.HTTPClient = httpClient | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package ollama | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/milosgajdos/go-embeddings/client" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestClient(t *testing.T) { | ||
t.Run("BaseURL", func(t *testing.T) { | ||
c := NewClient() | ||
assert.Equal(t, c.opts.BaseURL, BaseURL) | ||
|
||
testVal := "http://foo" | ||
c = NewClient(WithBaseURL(testVal)) | ||
assert.Equal(t, c.opts.BaseURL, testVal) | ||
}) | ||
|
||
t.Run("http client", func(t *testing.T) { | ||
c := NewClient() | ||
assert.NotNil(t, c.opts.HTTPClient) | ||
|
||
testVal := client.NewHTTP() | ||
c = NewClient(WithHTTPClient(testVal)) | ||
assert.NotNil(t, c.opts.HTTPClient) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package ollama | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"net/http" | ||
"net/url" | ||
|
||
"github.com/milosgajdos/go-embeddings" | ||
"github.com/milosgajdos/go-embeddings/request" | ||
) | ||
|
||
// EmbeddingRequest is serialized and sent to the API server. | ||
type EmbeddingRequest struct { | ||
Prompt any `json:"prompt"` | ||
Model string `json:"model"` | ||
} | ||
|
||
// EmbedddingResponse received from API. | ||
type EmbedddingResponse struct { | ||
Embedding []float64 `json:"embedding"` | ||
} | ||
|
||
// ToEmbeddings converts the API response, | ||
// into a slice of embeddings and returns it. | ||
func (e *EmbedddingResponse) ToEmbeddings() ([]*embeddings.Embedding, error) { | ||
floats := make([]float64, len(e.Embedding)) | ||
copy(floats, e.Embedding) | ||
return []*embeddings.Embedding{ | ||
{Vector: floats}, | ||
}, nil | ||
} | ||
|
||
// Embed returns embeddings for every object in EmbeddingRequest. | ||
func (c *Client) Embed(ctx context.Context, embReq *EmbeddingRequest) ([]*embeddings.Embedding, error) { | ||
u, err := url.Parse(c.opts.BaseURL + "/embeddings") | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var body = &bytes.Buffer{} | ||
enc := json.NewEncoder(body) | ||
enc.SetEscapeHTML(false) | ||
if err := enc.Encode(embReq); err != nil { | ||
return nil, err | ||
} | ||
|
||
options := []request.Option{} | ||
req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
resp, err := request.Do[APIError](c.opts.HTTPClient, req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer resp.Body.Close() | ||
|
||
e := new(EmbedddingResponse) | ||
if err := json.NewDecoder(resp.Body).Decode(e); err != nil { | ||
return nil, err | ||
} | ||
|
||
return e.ToEmbeddings() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package ollama | ||
|
||
import "encoding/json" | ||
|
||
// APIError is Ollama API error. | ||
type APIError struct { | ||
ErrorMessage string `json:"error"` | ||
} | ||
|
||
// Error implements errors interface. | ||
func (e APIError) Error() string { | ||
b, err := json.Marshal(e) | ||
if err != nil { | ||
return "unknown error" | ||
} | ||
return string(b) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters