Skip to content

Commit

Permalink
Merge pull request #253 from nidzola/embeddings-azure-openai
Browse files Browse the repository at this point in the history
embeddings/azure-openai: Fixing issues with azure openai implementation
  • Loading branch information
tmc authored Aug 18, 2023
2 parents eb0cbd3 + 97c7bc7 commit fef0821
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 25 deletions.
11 changes: 10 additions & 1 deletion chains/llm_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ import (

func TestLLMChainAzure(t *testing.T) {
t.Parallel()
// Azure OpenAI Key is used as OPENAI_API_KEY
if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}
// Azure OpenAI URL is used as OPENAI_BASE_URL
if openaiBase := os.Getenv("OPENAI_BASE_URL"); openaiBase == "" {
t.Skip("OPENAI_BASE_URL not set")
}
model, err := openai.New(openai.WithModel("gpt-35-turbo"), openai.WithAPIType(openai.APITypeAzure))

model, err := openai.New(
openai.WithAPIType(openai.APITypeAzure),
// Azure deployment that uses desired model, the name depends on what we define in the Azure deployment section
openai.WithModel("model-name"),
// Azure deployment that uses embeddings model, the name depends on what we define in the Azure deployment section
openai.WithEmbeddingModel("embeddings-model-name"),
)
require.NoError(t, err)

prompt := prompts.NewPromptTemplate(
Expand Down
32 changes: 32 additions & 0 deletions embeddings/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,35 @@ func TestOpenaiEmbeddingsWithOptions(t *testing.T) {
require.NoError(t, err)
assert.Len(t, embeddings, 1)
}

func TestOpenaiEmbeddingsWithAzureAPI(t *testing.T) {
t.Parallel()

// Azure OpenAI Key is used as OPENAI_API_KEY
if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}
// Azure OpenAI URL is used as OPENAI_BASE_URL
if openaiBase := os.Getenv("OPENAI_BASE_URL"); openaiBase == "" {
t.Skip("OPENAI_BASE_URL not set")
}

client, err := openai.New(
openai.WithAPIType(openai.APITypeAzure),
// Azure deployment that uses desired model the name depends on what we define in the Azure deployment section
openai.WithModel("model"),
// Azure deployment that uses embeddings model, the name depends on what we define in the Azure deployment section
openai.WithEmbeddingModel("model-embedding"),
)
assert.NoError(t, err)

e, err := NewOpenAI(WithClient(*client), WithBatchSize(1), WithStripNewLines(false))
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world!")
require.NoError(t, err)

embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world"})
require.NoError(t, err)
assert.Len(t, embeddings, 1)
}
2 changes: 1 addition & 1 deletion llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes
if c.baseURL == "" {
c.baseURL = defaultBaseURL
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/chat/completions"), body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/chat/completions", c.Model), body)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion llms/openai/internal/openaiclient/completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (c *Client) createCompletion(ctx context.Context, payload *completionPayloa
}

// Build request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/completions"), bytes.NewReader(payloadBytes))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/completions", c.Model), bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion llms/openai/internal/openaiclient/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (c *Client) createEmbedding(ctx context.Context, payload *embeddingPayload)
if c.baseURL == "" {
c.baseURL = defaultBaseURL
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/embeddings"), bytes.NewReader(payloadBytes))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/embeddings", c.embeddingsModel), bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
Expand Down
34 changes: 18 additions & 16 deletions llms/openai/internal/openaiclient/openaiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ type Client struct {
Model string
baseURL string
organization string
apiType APIType
httpClient Doer

apiType APIType
apiVersion string // required when APIType is APITypeAzure or APITypeAzureAD

httpClient Doer
// required when APIType is APITypeAzure or APITypeAzureAD
apiVersion string
embeddingsModel string
}

// Option is an option for the OpenAI client.
Expand All @@ -47,17 +48,18 @@ type Doer interface {

// New returns a new OpenAI client.
func New(token string, model string, baseURL string, organization string,
apiType APIType, apiVersion string, httpClient Doer,
apiType APIType, apiVersion string, httpClient Doer, embeddingsModel string,
opts ...Option,
) (*Client, error) {
c := &Client{
token: token,
Model: model,
baseURL: baseURL,
organization: organization,
apiType: apiType,
apiVersion: apiVersion,
httpClient: httpClient,
token: token,
Model: model,
embeddingsModel: embeddingsModel,
baseURL: baseURL,
organization: organization,
apiType: apiType,
apiVersion: apiVersion,
httpClient: httpClient,
}

for _, opt := range opts {
Expand Down Expand Up @@ -181,22 +183,22 @@ func (c *Client) setHeaders(req *http.Request) {
}
}

func (c *Client) buildURL(suffix string) string {
func (c *Client) buildURL(suffix string, model string) string {
if IsAzure(c.apiType) {
return c.buildAzureURL(suffix)
return c.buildAzureURL(suffix, model)
}

// open ai implement:
return fmt.Sprintf("%s%s", c.baseURL, suffix)
}

func (c *Client) buildAzureURL(suffix string) string {
func (c *Client) buildAzureURL(suffix string, model string) string {
baseURL := c.baseURL
baseURL = strings.TrimRight(baseURL, "/")

// azure example url:
// /openai/deployments/{model}/chat/completions?api-version={api_version}
return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s",
baseURL, c.Model, suffix, c.apiVersion,
baseURL, model, suffix, c.apiVersion,
)
}
10 changes: 7 additions & 3 deletions llms/openai/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
)

var (
ErrEmptyResponse = errors.New("no response")
ErrMissingToken = errors.New("missing the OpenAI API key, set it in the OPENAI_API_KEY environment variable")
ErrEmptyResponse = errors.New("no response")
ErrMissingToken = errors.New("missing the OpenAI API key, set it in the OPENAI_API_KEY environment variable") //nolint:lll
ErrMissingAzureEmbeddingModel = errors.New("embeddings model needs to be provided when using Azure API")

ErrUnexpectedResponseLength = errors.New("unexpected length of response")
)
Expand All @@ -33,12 +34,15 @@ func newClient(opts ...Option) (*openaiclient.Client, error) {
// set of options needed for Azure client
if openaiclient.IsAzure(openaiclient.APIType(options.apiType)) && options.apiVersion == "" {
options.apiVersion = DefaultAPIVersion
if options.embeddingModel == "" {
return nil, ErrMissingAzureEmbeddingModel
}
}

if len(options.token) == 0 {
return nil, ErrMissingToken
}

return openaiclient.New(options.token, options.model, options.baseURL, options.organization,
openaiclient.APIType(options.apiType), options.apiVersion, options.httpClient)
openaiclient.APIType(options.apiType), options.apiVersion, options.httpClient, options.embeddingModel)
}
13 changes: 11 additions & 2 deletions llms/openai/openaillm_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ type options struct {
model string
baseURL string
organization string
apiType APIType

apiType APIType
apiVersion string // required when APIType is APITypeAzure or APITypeAzureAD
// required when APIType is APITypeAzure or APITypeAzureAD
apiVersion string
embeddingModel string

httpClient openaiclient.Doer
}
Expand All @@ -51,6 +53,13 @@ func WithModel(model string) Option {
}
}

// WithEmbeddingModel passes the OpenAI model to the client. Required when ApiType is Azure.
func WithEmbeddingModel(embeddingModel string) Option {
return func(opts *options) {
opts.embeddingModel = embeddingModel
}
}

// WithBaseURL passes the OpenAI base url to the client. If not set, the base url
// is read from the OPENAI_BASE_URL environment variable. If still not set in ENV
// VAR OPENAI_BASE_URL, then the default value is https://api.openai.com/v1 is used.
Expand Down

0 comments on commit fef0821

Please sign in to comment.