Skip to content

Commit

Permalink
feat: add ollama embeddings API (#23)
Browse files Browse the repository at this point in the history
Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos authored Apr 9, 2024
1 parent bf24cee commit 4ffb3f8
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 4 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ This project provides an implementation of API clients for fetching embeddings f

Currently supported APIs:
* [x] [OpenAI](https://platform.openai.com/docs/api-reference/embeddings)
* [x] [Cohere AI](https://docs.cohere.com/reference/embed)
* [x] [Google Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings)
* [x] [Cohere](https://docs.cohere.com/reference/embed)
* [x] [Google Vertex](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings)
* [x] [VoyageAI](https://docs.voyageai.com/reference/embeddings-api)
* [x] [Ollama](https://ollama.com/)

You can find sample programs that demonstrate how to use the client packages to fetch the embeddings in `cmd` directory of this project.

Expand All @@ -19,7 +20,10 @@ It's essentially a Go rewrite of character and recursive character text splitter

## Environment variables

Each client package lets you initialize a default API client for a specific embeddings provider by reading the API keys from environment variables.
> [!NOTE]
> Each client package lets you initialize a default API client for a specific embeddings provider by reading the API keys from environment variables
Here's a list of the env vars for each supported client

### OpenAI

Expand All @@ -36,6 +40,10 @@ Each client package lets you initialize a default API client for a specific embe
* `GOOGLE_PROJECT_ID`: Google Project ID
* `VOYAGE_API_KEY`: VoyageAI API key

### Voyage

* `VOYAGE_API_KEY`: Voyage AI API key

## nix

The project provides a simple `nix` flake tha leverages [gomod2nix](https://github.com/nix-community/gomod2nix) for consistent Go environments and builds.
Expand Down
42 changes: 42 additions & 0 deletions cmd/ollama/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package main

import (
"context"
"flag"
"fmt"
"log"

"github.com/milosgajdos/go-embeddings/ollama"
)

var (
prompt string
model string
)

func init() {
flag.StringVar(&prompt, "prompt", "what is life", "input prompt")
flag.StringVar(&model, "model", "", "model name")
}

func main() {
flag.Parse()

if model == "" {
log.Fatal("missing ollama model")
}

c := ollama.NewClient()

embReq := &ollama.EmbeddingRequest{
Prompt: prompt,
Model: model,
}

embs, err := c.Embed(context.Background(), embReq)
if err != nil {
log.Fatal(err)
}

fmt.Printf("got %d embeddings", len(embs))
}
61 changes: 61 additions & 0 deletions ollama/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package ollama

import (
"github.com/milosgajdos/go-embeddings"
"github.com/milosgajdos/go-embeddings/client"
)

const (
// BaseURL is Ollama HTTP API embeddings base URL.
BaseURL = "http://localhost:11434/api"
)

// Client is an OpenAI HTTP API client.
type Client struct {
opts Options
}

type Options struct {
BaseURL string
HTTPClient *client.HTTP
}

// Option is functional graph option.
type Option func(*Options)

// NewClient creates a new Ollama HTTP API client and returns it.
// You can override the default options via the client methods.
func NewClient(opts ...Option) *Client {
options := Options{
BaseURL: BaseURL,
HTTPClient: client.NewHTTP(),
}

for _, apply := range opts {
apply(&options)
}

return &Client{
opts: options,
}
}

// NewEmbedder creates a client that implements embeddings.Embedder
func NewEmbedder(opts ...Option) embeddings.Embedder[*EmbeddingRequest] {
return NewClient(opts...)
}

// WithBaseURL sets the API base URL.
func WithBaseURL(baseURL string) Option {
return func(o *Options) {
o.BaseURL = baseURL
}
}

// WithVersion sets the API version.
// WithHTTPClient sets the HTTP client.
func WithHTTPClient(httpClient *client.HTTP) Option {
return func(o *Options) {
o.HTTPClient = httpClient
}
}
28 changes: 28 additions & 0 deletions ollama/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ollama

import (
"testing"

"github.com/milosgajdos/go-embeddings/client"
"github.com/stretchr/testify/assert"
)

func TestClient(t *testing.T) {
t.Run("BaseURL", func(t *testing.T) {
c := NewClient()
assert.Equal(t, c.opts.BaseURL, BaseURL)

testVal := "http://foo"
c = NewClient(WithBaseURL(testVal))
assert.Equal(t, c.opts.BaseURL, testVal)
})

t.Run("http client", func(t *testing.T) {
c := NewClient()
assert.NotNil(t, c.opts.HTTPClient)

testVal := client.NewHTTP()
c = NewClient(WithHTTPClient(testVal))
assert.NotNil(t, c.opts.HTTPClient)
})
}
67 changes: 67 additions & 0 deletions ollama/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package ollama

import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/url"

"github.com/milosgajdos/go-embeddings"
"github.com/milosgajdos/go-embeddings/request"
)

// EmbeddingRequest is serialized and sent to the API server.
type EmbeddingRequest struct {
Prompt any `json:"prompt"`
Model string `json:"model"`
}

// EmbedddingResponse received from API.
type EmbedddingResponse struct {
Embedding []float64 `json:"embedding"`
}

// ToEmbeddings converts the API response,
// into a slice of embeddings and returns it.
func (e *EmbedddingResponse) ToEmbeddings() ([]*embeddings.Embedding, error) {
floats := make([]float64, len(e.Embedding))
copy(floats, e.Embedding)
return []*embeddings.Embedding{
{Vector: floats},
}, nil
}

// Embed returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embed(ctx context.Context, embReq *EmbeddingRequest) ([]*embeddings.Embedding, error) {
u, err := url.Parse(c.opts.BaseURL + "/embeddings")
if err != nil {
return nil, err
}

var body = &bytes.Buffer{}
enc := json.NewEncoder(body)
enc.SetEscapeHTML(false)
if err := enc.Encode(embReq); err != nil {
return nil, err
}

options := []request.Option{}
req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...)
if err != nil {
return nil, err
}

resp, err := request.Do[APIError](c.opts.HTTPClient, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

e := new(EmbedddingResponse)
if err := json.NewDecoder(resp.Body).Decode(e); err != nil {
return nil, err
}

return e.ToEmbeddings()
}
17 changes: 17 additions & 0 deletions ollama/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ollama

import "encoding/json"

// APIError is Ollama API error.
type APIError struct {
ErrorMessage string `json:"error"`
}

// Error implements errors interface.
func (e APIError) Error() string {
b, err := json.Marshal(e)
if err != nil {
return "unknown error"
}
return string(b)
}
2 changes: 1 addition & 1 deletion openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Options struct {
// Option is functional graph option.
type Option func(*Options)

// NewClient creates a new HTTP API client and returns it.
// NewClient creates a new OpenAI HTTP API client and returns it.
// By default it reads the OpenAI API key from OPENAI_API_KEY
// env var and uses the default Go http.Client for making API requests.
// You can override the default options via the client methods.
Expand Down

0 comments on commit 4ffb3f8

Please sign in to comment.