-
Notifications
You must be signed in to change notification settings - Fork 3
/
embedding.go
84 lines (67 loc) · 2.16 KB
/
embedding.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package sdk
import (
"context"
"github.com/gogf/gf/v2/os/gtime"
"github.com/iimeta/fastapi-sdk/logger"
"github.com/iimeta/fastapi-sdk/model"
"github.com/iimeta/go-openai"
"net/http"
"net/url"
)
type EmbeddingClient struct {
client *openai.Client
}
func NewEmbeddingClient(ctx context.Context, model, key, baseURL, path string, proxyURL ...string) *EmbeddingClient {
logger.Infof(ctx, "NewClient OpenAI model: %s, key: %s", model, key)
config := openai.DefaultConfig(key)
if baseURL != "" {
logger.Infof(ctx, "NewClient OpenAI model: %s, baseURL: %s", model, baseURL)
config.BaseURL = baseURL
}
if len(proxyURL) > 0 && proxyURL[0] != "" {
logger.Infof(ctx, "NewClient OpenAI model: %s, proxyURL: %s", model, proxyURL[0])
proxyUrl, err := url.Parse(proxyURL[0])
if err != nil {
panic(err)
}
config.HTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
},
}
}
return &EmbeddingClient{
client: openai.NewClientWithConfig(config),
}
}
func (c *EmbeddingClient) Embeddings(ctx context.Context, request model.EmbeddingRequest) (res model.EmbeddingResponse, err error) {
logger.Infof(ctx, "Embeddings OpenAI model: %s start", request.Model)
now := gtime.Now().UnixMilli()
defer func() {
res.TotalTime = gtime.Now().UnixMilli() - now
logger.Infof(ctx, "Embeddings OpenAI model: %s totalTime: %d ms", request.Model, res.TotalTime)
}()
response, err := c.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
Input: request.Input,
Model: request.Model,
User: request.User,
EncodingFormat: request.EncodingFormat,
Dimensions: request.Dimensions,
})
if err != nil {
logger.Errorf(ctx, "Embeddings OpenAI model: %s, error: %v", request.Model, err)
return res, err
}
logger.Infof(ctx, "Embeddings OpenAI model: %s finished", request.Model)
res = model.EmbeddingResponse{
Object: response.Object,
Data: response.Data,
Model: response.Model,
Usage: &model.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
return res, nil
}