Skip to content

Commit fef0821

Browse files
authored
Merge pull request #253 from nidzola/embeddings-azure-openai
embeddings/azure-openai: Fixing issues with azure openai implementation
2 parents eb0cbd3 + 97c7bc7 commit fef0821

File tree

8 files changed

+81
-25
lines changed

8 files changed

+81
-25
lines changed

chains/llm_azure_test.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,22 @@ import (
1313

1414
func TestLLMChainAzure(t *testing.T) {
1515
t.Parallel()
16+
// Azure OpenAI Key is used as OPENAI_API_KEY
1617
if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
1718
t.Skip("OPENAI_API_KEY not set")
1819
}
20+
// Azure OpenAI URL is used as OPENAI_BASE_URL
1921
if openaiBase := os.Getenv("OPENAI_BASE_URL"); openaiBase == "" {
2022
t.Skip("OPENAI_BASE_URL not set")
2123
}
22-
model, err := openai.New(openai.WithModel("gpt-35-turbo"), openai.WithAPIType(openai.APITypeAzure))
24+
25+
model, err := openai.New(
26+
openai.WithAPIType(openai.APITypeAzure),
27+
// Azure deployment that uses desired model, the name depends on what we define in the Azure deployment section
28+
openai.WithModel("model-name"),
29+
// Azure deployment that uses embeddings model, the name depends on what we define in the Azure deployment section
30+
openai.WithEmbeddingModel("embeddings-model-name"),
31+
)
2332
require.NoError(t, err)
2433

2534
prompt := prompts.NewPromptTemplate(

embeddings/openai/openai_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,35 @@ func TestOpenaiEmbeddingsWithOptions(t *testing.T) {
4747
require.NoError(t, err)
4848
assert.Len(t, embeddings, 1)
4949
}
50+
51+
func TestOpenaiEmbeddingsWithAzureAPI(t *testing.T) {
52+
t.Parallel()
53+
54+
// Azure OpenAI Key is used as OPENAI_API_KEY
55+
if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
56+
t.Skip("OPENAI_API_KEY not set")
57+
}
58+
// Azure OpenAI URL is used as OPENAI_BASE_URL
59+
if openaiBase := os.Getenv("OPENAI_BASE_URL"); openaiBase == "" {
60+
t.Skip("OPENAI_BASE_URL not set")
61+
}
62+
63+
client, err := openai.New(
64+
openai.WithAPIType(openai.APITypeAzure),
65+
// Azure deployment that uses desired model the name depends on what we define in the Azure deployment section
66+
openai.WithModel("model"),
67+
// Azure deployment that uses embeddings model, the name depends on what we define in the Azure deployment section
68+
openai.WithEmbeddingModel("model-embedding"),
69+
)
70+
assert.NoError(t, err)
71+
72+
e, err := NewOpenAI(WithClient(*client), WithBatchSize(1), WithStripNewLines(false))
73+
require.NoError(t, err)
74+
75+
_, err = e.EmbedQuery(context.Background(), "Hello world!")
76+
require.NoError(t, err)
77+
78+
embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world"})
79+
require.NoError(t, err)
80+
assert.Len(t, embeddings, 1)
81+
}

llms/openai/internal/openaiclient/chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes
145145
if c.baseURL == "" {
146146
c.baseURL = defaultBaseURL
147147
}
148-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/chat/completions"), body)
148+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/chat/completions", c.Model), body)
149149
if err != nil {
150150
return nil, err
151151
}

llms/openai/internal/openaiclient/completions.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (c *Client) createCompletion(ctx context.Context, payload *completionPayloa
8888
}
8989

9090
// Build request
91-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/completions"), bytes.NewReader(payloadBytes))
91+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/completions", c.Model), bytes.NewReader(payloadBytes))
9292
if err != nil {
9393
return nil, fmt.Errorf("create request: %w", err)
9494
}

llms/openai/internal/openaiclient/embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (c *Client) createEmbedding(ctx context.Context, payload *embeddingPayload)
4141
if c.baseURL == "" {
4242
c.baseURL = defaultBaseURL
4343
}
44-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/embeddings"), bytes.NewReader(payloadBytes))
44+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/embeddings", c.embeddingsModel), bytes.NewReader(payloadBytes))
4545
if err != nil {
4646
return nil, fmt.Errorf("create request: %w", err)
4747
}

llms/openai/internal/openaiclient/openaiclient.go

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ type Client struct {
3030
Model string
3131
baseURL string
3232
organization string
33+
apiType APIType
34+
httpClient Doer
3335

34-
apiType APIType
35-
apiVersion string // required when APIType is APITypeAzure or APITypeAzureAD
36-
37-
httpClient Doer
36+
// required when APIType is APITypeAzure or APITypeAzureAD
37+
apiVersion string
38+
embeddingsModel string
3839
}
3940

4041
// Option is an option for the OpenAI client.
@@ -47,17 +48,18 @@ type Doer interface {
4748

4849
// New returns a new OpenAI client.
4950
func New(token string, model string, baseURL string, organization string,
50-
apiType APIType, apiVersion string, httpClient Doer,
51+
apiType APIType, apiVersion string, httpClient Doer, embeddingsModel string,
5152
opts ...Option,
5253
) (*Client, error) {
5354
c := &Client{
54-
token: token,
55-
Model: model,
56-
baseURL: baseURL,
57-
organization: organization,
58-
apiType: apiType,
59-
apiVersion: apiVersion,
60-
httpClient: httpClient,
55+
token: token,
56+
Model: model,
57+
embeddingsModel: embeddingsModel,
58+
baseURL: baseURL,
59+
organization: organization,
60+
apiType: apiType,
61+
apiVersion: apiVersion,
62+
httpClient: httpClient,
6163
}
6264

6365
for _, opt := range opts {
@@ -181,22 +183,22 @@ func (c *Client) setHeaders(req *http.Request) {
181183
}
182184
}
183185

184-
func (c *Client) buildURL(suffix string) string {
186+
func (c *Client) buildURL(suffix string, model string) string {
185187
if IsAzure(c.apiType) {
186-
return c.buildAzureURL(suffix)
188+
return c.buildAzureURL(suffix, model)
187189
}
188190

189191
// open ai implement:
190192
return fmt.Sprintf("%s%s", c.baseURL, suffix)
191193
}
192194

193-
func (c *Client) buildAzureURL(suffix string) string {
195+
func (c *Client) buildAzureURL(suffix string, model string) string {
194196
baseURL := c.baseURL
195197
baseURL = strings.TrimRight(baseURL, "/")
196198

197199
// azure example url:
198200
// /openai/deployments/{model}/chat/completions?api-version={api_version}
199201
return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s",
200-
baseURL, c.Model, suffix, c.apiVersion,
202+
baseURL, model, suffix, c.apiVersion,
201203
)
202204
}

llms/openai/llm.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ import (
99
)
1010

1111
var (
12-
ErrEmptyResponse = errors.New("no response")
13-
ErrMissingToken = errors.New("missing the OpenAI API key, set it in the OPENAI_API_KEY environment variable")
12+
ErrEmptyResponse = errors.New("no response")
13+
ErrMissingToken = errors.New("missing the OpenAI API key, set it in the OPENAI_API_KEY environment variable") //nolint:lll
14+
ErrMissingAzureEmbeddingModel = errors.New("embeddings model needs to be provided when using Azure API")
1415

1516
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
1617
)
@@ -33,12 +34,15 @@ func newClient(opts ...Option) (*openaiclient.Client, error) {
3334
// set of options needed for Azure client
3435
if openaiclient.IsAzure(openaiclient.APIType(options.apiType)) && options.apiVersion == "" {
3536
options.apiVersion = DefaultAPIVersion
37+
if options.embeddingModel == "" {
38+
return nil, ErrMissingAzureEmbeddingModel
39+
}
3640
}
3741

3842
if len(options.token) == 0 {
3943
return nil, ErrMissingToken
4044
}
4145

4246
return openaiclient.New(options.token, options.model, options.baseURL, options.organization,
43-
openaiclient.APIType(options.apiType), options.apiVersion, options.httpClient)
47+
openaiclient.APIType(options.apiType), options.apiVersion, options.httpClient, options.embeddingModel)
4448
}

llms/openai/openaillm_option.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ type options struct {
2626
model string
2727
baseURL string
2828
organization string
29+
apiType APIType
2930

30-
apiType APIType
31-
apiVersion string // required when APIType is APITypeAzure or APITypeAzureAD
31+
// required when APIType is APITypeAzure or APITypeAzureAD
32+
apiVersion string
33+
embeddingModel string
3234

3335
httpClient openaiclient.Doer
3436
}
@@ -51,6 +53,13 @@ func WithModel(model string) Option {
5153
}
5254
}
5355

56+
// WithEmbeddingModel passes the OpenAI model to the client. Required when ApiType is Azure.
57+
func WithEmbeddingModel(embeddingModel string) Option {
58+
return func(opts *options) {
59+
opts.embeddingModel = embeddingModel
60+
}
61+
}
62+
5463
// WithBaseURL passes the OpenAI base url to the client. If not set, the base url
5564
// is read from the OPENAI_BASE_URL environment variable. If still not set in ENV
5665
// VAR OPENAI_BASE_URL, then the default value is https://api.openai.com/v1 is used.

0 commit comments

Comments
 (0)