Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llms: Add anthropic LLM implementation #226

Merged
merged 2 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/sql-database-chain-example/sql_database_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func makeSample(dsn string) {
if err != nil {
log.Fatal(err)
}

}

func run() error {
llm, err := openai.New()
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion examples/zapier-llm-example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,4 @@ func main() {
panic(err)
}
fmt.Println(answer)

}
99 changes: 99 additions & 0 deletions llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package anthropic

import (
"context"
"errors"
"os"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/anthropic/internal/anthropicclient"
"github.com/tmc/langchaingo/schema"
)

var (
ErrEmptyResponse = errors.New("no response")
ErrMissingToken = errors.New("missing the Anthropic API key, set it in the ANTHROPIC_API_KEY environment variable")

ErrUnexpectedResponseLength = errors.New("unexpected length of response")
)

type LLM struct {
client *anthropicclient.Client
}

var (
_ llms.LLM = (*LLM)(nil)
_ llms.LanguageModel = (*LLM)(nil)
)

// New returns a new Anthropic LLM.
func New(opts ...Option) (*LLM, error) {
c, err := newClient(opts...)
return &LLM{
client: c,
}, err
}

func newClient(opts ...Option) (*anthropicclient.Client, error) {
options := &options{
token: os.Getenv(tokenEnvVarName),
}

for _, opt := range opts {
opt(options)
}

if len(options.token) == 0 {
return nil, ErrMissingToken
}

return anthropicclient.New(options.token, options.model)
}

// Call requests a completion for the given prompt.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
r, err := o.Generate(ctx, []string{prompt}, options...)
if err != nil {
return "", err
}
if len(r) == 0 {
return "", ErrEmptyResponse
}
return r[0].Text, nil
}

func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) {
opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
}

generations := make([]*llms.Generation, 0, len(prompts))
for _, prompt := range prompts {
result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{
Model: opts.Model,
Prompt: prompt,
MaxTokens: opts.MaxTokens,
StopWords: opts.StopWords,
Temperature: opts.Temperature,
TopP: opts.TopP,
StreamingFunc: opts.StreamingFunc,
})
if err != nil {
return nil, err
}
generations = append(generations, &llms.Generation{
Text: result.Text,
})
}

return generations, nil
}

func (o *LLM) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...llms.CallOption) (llms.LLMResult, error) { //nolint:lll
return llms.GeneratePrompt(ctx, o, promptValues, options...)
}

func (o *LLM) GetNumTokens(text string) int {
return llms.CountTokens(o.client.Model, text)
}
27 changes: 27 additions & 0 deletions llms/anthropic/anthropicllm_option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package anthropic

const (
tokenEnvVarName = "ANTHROPIC_API_KEY" //nolint:gosec
)

type options struct {
token string
model string
}

type Option func(*options)

// WithToken passes the Anthropic API token to the client. If not set, the token
// is read from the ANTHROPIC_API_KEY environment variable.
func WithToken(token string) Option {
return func(opts *options) {
opts.token = token
}
}

// WithModel passes the Anthropic model to the client.
func WithModel(model string) Option {
return func(opts *options) {
opts.model = model
}
}
105 changes: 105 additions & 0 deletions llms/anthropic/internal/anthropicclient/anthropicclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package anthropicclient

import (
"context"
"errors"
"net/http"
)

const (
defaultBaseURL = "https://api.anthropic.com/v1"
)

// ErrEmptyResponse is returned when the Anthropic API returns an empty response.
var ErrEmptyResponse = errors.New("empty response")

// Client is a client for the Anthropic API.
type Client struct {
token string
Model string
baseURL string

httpClient Doer
}

// Option is an option for the Anthropic client.
type Option func(*Client) error

// Doer performs a HTTP request.
type Doer interface {
Do(req *http.Request) (*http.Response, error)
}

// WithHTTPClient allows setting a custom HTTP client.
func WithHTTPClient(client Doer) Option {
return func(c *Client) error {
c.httpClient = client

return nil
}
}

// New returns a new Anthropic client.
func New(token string, model string, opts ...Option) (*Client, error) {
c := &Client{
Model: model,
token: token,
baseURL: defaultBaseURL,
httpClient: http.DefaultClient,
}

for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
}
}

return c, nil
}

// CompletionRequest is a request to create a completion.
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens_to_sample,omitempty"`
StopWords []string `json:"stop_sequences,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`

// StreamingFunc is a function to be called for each chunk of a streaming response.
// Return an error to stop streaming early.
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`
}

// Completion is a completion.
type Completion struct {
Text string `json:"text"`
}

// CreateCompletion creates a completion.
func (c *Client) CreateCompletion(ctx context.Context, r *CompletionRequest) (*Completion, error) {
resp, err := c.createCompletion(ctx, &completionPayload{
Model: r.Model,
Prompt: r.Prompt,
Temperature: r.Temperature,
MaxTokens: r.MaxTokens,
StopWords: r.StopWords,
TopP: r.TopP,
Stream: r.Stream,
StreamingFunc: r.StreamingFunc,
})
if err != nil {
return nil, err
}
return &Completion{
Text: resp.Completion,
}, nil
}

func (c *Client) setHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", c.token)
// TODO: expose version as a option/parameter
req.Header.Set("anthropic-version", "2023-06-01")
}
Loading
Loading