From 442bc94f512118a76ff8e0ee4019fd4d329c0675 Mon Sep 17 00:00:00 2001 From: Travis Cline Date: Sat, 29 Jul 2023 15:41:30 -0700 Subject: [PATCH] llms: Add anthropic llm implementation --- .../sql_database_chain.go | 2 +- examples/zapier-llm-example/main.go | 1 - llms/anthropic/anthropicllm.go | 99 ++++++++++ llms/anthropic/anthropicllm_option.go | 27 +++ .../anthropicclient/anthropicclient.go | 105 +++++++++++ .../internal/anthropicclient/completions.go | 172 ++++++++++++++++++ 6 files changed, 404 insertions(+), 2 deletions(-) create mode 100644 llms/anthropic/anthropicllm.go create mode 100644 llms/anthropic/anthropicllm_option.go create mode 100644 llms/anthropic/internal/anthropicclient/anthropicclient.go create mode 100644 llms/anthropic/internal/anthropicclient/completions.go diff --git a/examples/sql-database-chain-example/sql_database_chain.go b/examples/sql-database-chain-example/sql_database_chain.go index 2cb9a0fc3..c274fe57c 100644 --- a/examples/sql-database-chain-example/sql_database_chain.go +++ b/examples/sql-database-chain-example/sql_database_chain.go @@ -70,8 +70,8 @@ func makeSample(dsn string) { if err != nil { log.Fatal(err) } - } + func run() error { llm, err := openai.New() if err != nil { diff --git a/examples/zapier-llm-example/main.go b/examples/zapier-llm-example/main.go index 617439bf2..7ffe3e2b7 100644 --- a/examples/zapier-llm-example/main.go +++ b/examples/zapier-llm-example/main.go @@ -54,5 +54,4 @@ func main() { panic(err) } fmt.Println(answer) - } diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go new file mode 100644 index 000000000..fe9bb5db2 --- /dev/null +++ b/llms/anthropic/anthropicllm.go @@ -0,0 +1,99 @@ +package anthropic + +import ( + "context" + "errors" + "os" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/anthropic/internal/anthropicclient" + "github.com/tmc/langchaingo/schema" +) + +var ( + ErrEmptyResponse = errors.New("no response") + ErrMissingToken = errors.New("missing the Anthropic API key, set it in the ANTHROPIC_API_KEY environment variable") + + ErrUnexpectedResponseLength = errors.New("unexpected length of response") +) + +type LLM struct { + client *anthropicclient.Client +} + +var ( + _ llms.LLM = (*LLM)(nil) + _ llms.LanguageModel = (*LLM)(nil) +) + +// New returns a new Anthropic LLM. +func New(opts ...Option) (*LLM, error) { + c, err := newClient(opts...) + return &LLM{ + client: c, + }, err +} + +func newClient(opts ...Option) (*anthropicclient.Client, error) { + options := &options{ + token: os.Getenv(tokenEnvVarName), + } + + for _, opt := range opts { + opt(options) + } + + if len(options.token) == 0 { + return nil, ErrMissingToken + } + + return anthropicclient.New(options.token, options.model) +} + +// Call requests a completion for the given prompt. +func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { + r, err := o.Generate(ctx, []string{prompt}, options...) + if err != nil { + return "", err + } + if len(r) == 0 { + return "", ErrEmptyResponse + } + return r[0].Text, nil +} + +func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + generations := make([]*llms.Generation, 0, len(prompts)) + for _, prompt := range prompts { + result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{ + Model: opts.Model, + Prompt: prompt, + MaxTokens: opts.MaxTokens, + StopWords: opts.StopWords, + Temperature: opts.Temperature, + TopP: opts.TopP, + StreamingFunc: opts.StreamingFunc, + }) + if err != nil { + return nil, err + } + generations = append(generations, &llms.Generation{ + Text: result.Text, + }) + } + + return generations, nil +} + +func (o *LLM) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...llms.CallOption) (llms.LLMResult, error) { //nolint:lll + return llms.GeneratePrompt(ctx, o, promptValues, options...) +} + +func (o *LLM) GetNumTokens(text string) int { + return llms.CountTokens(o.client.Model, text) +} diff --git a/llms/anthropic/anthropicllm_option.go b/llms/anthropic/anthropicllm_option.go new file mode 100644 index 000000000..de543c3b8 --- /dev/null +++ b/llms/anthropic/anthropicllm_option.go @@ -0,0 +1,27 @@ +package anthropic + +const ( + tokenEnvVarName = "ANTHROPIC_API_KEY" //nolint:gosec +) + +type options struct { + token string + model string +} + +type Option func(*options) + +// WithToken passes the Anthropic API token to the client. If not set, the token +// is read from the ANTHROPIC_API_KEY environment variable. +func WithToken(token string) Option { + return func(opts *options) { + opts.token = token + } +} + +// WithModel passes the Anthropic model to the client. +func WithModel(model string) Option { + return func(opts *options) { + opts.model = model + } +} diff --git a/llms/anthropic/internal/anthropicclient/anthropicclient.go b/llms/anthropic/internal/anthropicclient/anthropicclient.go new file mode 100644 index 000000000..05d8c3c30 --- /dev/null +++ b/llms/anthropic/internal/anthropicclient/anthropicclient.go @@ -0,0 +1,105 @@ +package anthropicclient + +import ( + "context" + "errors" + "net/http" +) + +const ( + defaultBaseURL = "https://api.anthropic.com/v1" +) + +// ErrEmptyResponse is returned when the Anthropic API returns an empty response. +var ErrEmptyResponse = errors.New("empty response") + +// Client is a client for the Anthropic API. +type Client struct { + token string + Model string + baseURL string + + httpClient Doer +} + +// Option is an option for the Anthropic client. +type Option func(*Client) error + +// Doer performs a HTTP request. +type Doer interface { + Do(req *http.Request) (*http.Response, error) +} + +// WithHTTPClient allows setting a custom HTTP client. +func WithHTTPClient(client Doer) Option { + return func(c *Client) error { + c.httpClient = client + + return nil + } +} + +// New returns a new Anthropic client. +func New(token string, model string, opts ...Option) (*Client, error) { + c := &Client{ + Model: model, + token: token, + baseURL: defaultBaseURL, + httpClient: http.DefaultClient, + } + + for _, opt := range opts { + if err := opt(c); err != nil { + return nil, err + } + } + + return c, nil +} + +// CompletionRequest is a request to create a completion. +type CompletionRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens_to_sample,omitempty"` + StopWords []string `json:"stop_sequences,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + + // StreamingFunc is a function to be called for each chunk of a streaming response. + // Return an error to stop streaming early. + StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"` +} + +// Completion is a completion. +type Completion struct { + Text string `json:"text"` +} + +// CreateCompletion creates a completion. +func (c *Client) CreateCompletion(ctx context.Context, r *CompletionRequest) (*Completion, error) { + resp, err := c.createCompletion(ctx, &completionPayload{ + Model: r.Model, + Prompt: r.Prompt, + Temperature: r.Temperature, + MaxTokens: r.MaxTokens, + StopWords: r.StopWords, + TopP: r.TopP, + Stream: r.Stream, + StreamingFunc: r.StreamingFunc, + }) + if err != nil { + return nil, err + } + return &Completion{ + Text: resp.Completion, + }, nil +} + +func (c *Client) setHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", c.token) + // TODO: expose version as a option/parameter + req.Header.Set("anthropic-version", "2023-06-01") +} diff --git a/llms/anthropic/internal/anthropicclient/completions.go b/llms/anthropic/internal/anthropicclient/completions.go new file mode 100644 index 000000000..c96033b2a --- /dev/null +++ b/llms/anthropic/internal/anthropicclient/completions.go @@ -0,0 +1,172 @@ +package anthropicclient + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" +) + +const ( + defaultCompletionModel = "claude-instant-1" +) + +type completionPayload struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens_to_sample,omitempty"` + TopP float64 `json:"top_p,omitempty"` + StopWords []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + + StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"` +} + +type CompletionResponsePayload struct { + Completion string `json:"completion,omitempty"` + LogID string `json:"log_id,omitempty"` + Model string `json:"model,omitempty"` + Stop string `json:"stop,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + +type errorMessage struct { + Error struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error"` +} + +func (c *Client) setCompletionDefaults(payload *completionPayload) { + // Set defaults + if payload.MaxTokens == 0 { + payload.MaxTokens = 256 + } + + if len(payload.StopWords) == 0 { + payload.StopWords = nil + } + + switch { + // Prefer the model specified in the payload. + case payload.Model != "": + + // If no model is set in the payload, take the one specified in the client. + case c.Model != "": + payload.Model = c.Model + // Fallback: use the default model + default: + payload.Model = defaultCompletionModel + } + if payload.StreamingFunc != nil { + payload.Stream = true + } +} + +// nolint:lll +func (c *Client) createCompletion(ctx context.Context, payload *completionPayload) (*CompletionResponsePayload, error) { + c.setCompletionDefaults(payload) + + // Build request payload + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshal payload: %w", err) + } + + if c.baseURL == "" { + c.baseURL = defaultBaseURL + } + + url := fmt.Sprintf("%s/complete", c.baseURL) + // Build request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payloadBytes)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + c.setHeaders(req) + + // Send request + r, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer r.Body.Close() + + if r.StatusCode != http.StatusOK { + msg := fmt.Sprintf("API returned unexpected status code: %d", r.StatusCode) + + // No need to check the error here: if it fails, we'll just return the + // status code. + var errResp errorMessage + if err := json.NewDecoder(r.Body).Decode(&errResp); err != nil { + return nil, errors.New(msg) // nolint:goerr113 + } + + return nil, fmt.Errorf("%s: %s", msg, errResp.Error.Message) // nolint:goerr113 + } + if payload.StreamingFunc != nil { + // Read chunks + return parseStreamingCompletionResponse(ctx, r, payload) + } + + // Parse response + var response CompletionResponsePayload + if err := json.NewDecoder(r.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + return &response, nil +} + +func parseStreamingCompletionResponse(ctx context.Context, r *http.Response, payload *completionPayload) (*CompletionResponsePayload, error) { // nolint:lll + scanner := bufio.NewScanner(r.Body) + responseChan := make(chan *CompletionResponsePayload) + go func() { + defer close(responseChan) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimPrefix(line, "data: ") + streamPayload := &CompletionResponsePayload{} + err := json.NewDecoder(bytes.NewReader([]byte(data))).Decode(&streamPayload) + if err != nil { + log.Fatalf("failed to decode stream payload: %v", err) + } + responseChan <- streamPayload + } + if err := scanner.Err(); err != nil { + log.Println("issue scanning response:", err) + } + }() + // Parse response + response := CompletionResponsePayload{} + + var lastResponse *CompletionResponsePayload + for streamResponse := range responseChan { + response.Completion += streamResponse.Completion + if payload.StreamingFunc != nil { + err := payload.StreamingFunc(ctx, []byte(streamResponse.Completion)) + if err != nil { + return nil, fmt.Errorf("streaming func returned an error: %w", err) + } + } + lastResponse = streamResponse + } + response.Model = lastResponse.Model + response.LogID = lastResponse.LogID + response.Stop = lastResponse.Stop + response.StopReason = lastResponse.StopReason + return &response, nil +}