Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tools: add arxiv tool #1042

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tools/arxiv/arxiv.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package arxiv

import (
"context"
"errors"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/tools"
"github.com/tmc/langchaingo/tools/arxiv/internal"
)

// DefaultUserAgent defines a default value for user-agent header.
const DefaultUserAgent = "github.com/tmc/langchaingo/tools/arxiv"

// Tool defines a tool implementation for the arXiv Search.
type Tool struct {
CallbacksHandler callbacks.Handler
client *internal.Client
}

var _ tools.Tool = Tool{}

// New initializes a new arXiv Search tool with arguments for setting a
// max results per search query and a value for the user agent header.
func New(maxResults int, userAgent string) (*Tool, error) {
return &Tool{
client: internal.NewClient(maxResults, userAgent),
}, nil
}

// Name returns a name for the tool.
func (t Tool) Name() string {
return "arXiv Search"
}

// Description returns a description for the tool.
func (t Tool) Description() string {
return `
"A wrapper around arXiv Search API."
"Search for scientific papers on arXiv."
"Input should be a search query."`
}

// Call performs the search and return the result.
func (t Tool) Call(ctx context.Context, input string) (string, error) {
if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolStart(ctx, input)
}

result, err := t.client.Search(ctx, input)
if err != nil {
if errors.Is(err, internal.ErrNoGoodResult) {
return "No good arXiv Search Results were found", nil
}
if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolError(ctx, err)
}
return "", err
}

if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolEnd(ctx, result)
}

return result, nil
}
19 changes: 19 additions & 0 deletions tools/arxiv/arxiv_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package arxiv

import (
"context"
"testing"
)

func TestNew(t *testing.T) {
t.Parallel()
tool, err := New(10, DefaultUserAgent)
if err != nil {
t.Fatal(err)
}
call, err := tool.Call(context.Background(), "electron")
if err != nil {
t.Fatal(err)
}
t.Log(call)
}
1 change: 1 addition & 0 deletions tools/arxiv/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package arxiv
145 changes: 145 additions & 0 deletions tools/arxiv/internal/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package internal

import (
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)

// Client defines an HTTP client for communicating with arXiv.
type Client struct {
maxResults int
userAgent string
}

// Result defines a search query result type.
type Result struct {
Title string
Authors []string
Summary string
PdfURL string
PublishedAt string
}

var (
ErrNoGoodResult = errors.New("no good search results found")
ErrAPIResponse = errors.New("arXiv api responded with error")
)

// NewClient initializes a Client with arguments for setting a max
// results per search query and a value for the user agent header.
func NewClient(maxResults int, userAgent string) *Client {
if maxResults == 0 {
maxResults = 1
}

return &Client{
maxResults: maxResults,
userAgent: userAgent,
}
}

func (client *Client) newRequest(ctx context.Context, queryURL string) (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, queryURL, nil)
if err != nil {
return nil, fmt.Errorf("creating arXiv request: %w", err)
}

if client.userAgent != "" {
request.Header.Add("User-Agent", client.userAgent)
}

return request, nil
}

// Search performs a search query and returns
// the result as string and an error if any.
func (client *Client) Search(ctx context.Context, query string) (string, error) {
queryURL := fmt.Sprintf("https://export.arxiv.org/api/query?search_query=%s&start=0&max_results=%d",
url.QueryEscape(query), client.maxResults)

request, err := client.newRequest(ctx, queryURL)
if err != nil {
return "", err
}

response, err := http.DefaultClient.Do(request)
if err != nil {
return "", fmt.Errorf("get %s error: %w", queryURL, err)
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return "", ErrAPIResponse
}

body, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("reading response body: %w", err)
}

var feed struct {
Entries []struct {
Title string `xml:"title"`
Summary string `xml:"summary"`
Published string `xml:"published"`
Authors []struct {
Name string `xml:"name"`
} `xml:"author"`
Link []struct {
Href string `xml:"href,attr"`
Type string `xml:"type,attr"`
} `xml:"link"`
} `xml:"entry"`
}

if err := xml.Unmarshal(body, &feed); err != nil {
return "", fmt.Errorf("unmarshaling XML: %w", err)
}

results := []Result{}
for _, entry := range feed.Entries {
authors := []string{}
for _, author := range entry.Authors {
authors = append(authors, author.Name)
}

pdfURL := ""
for _, link := range entry.Link {
if link.Type == "application/pdf" {
pdfURL = link.Href
break
}
}

results = append(results, Result{
Title: entry.Title,
Authors: authors,
Summary: entry.Summary,
PdfURL: pdfURL,
PublishedAt: entry.Published,
})
}

return client.formatResults(results), nil
}

// formatResults will return a structured string with the results.
func (client *Client) formatResults(results []Result) string {
var formattedResults strings.Builder

for _, result := range results {
formattedResults.WriteString(fmt.Sprintf("Title: %s\n", result.Title))
formattedResults.WriteString(fmt.Sprintf("Authors: %s\n", strings.Join(result.Authors, ", ")))
formattedResults.WriteString(fmt.Sprintf("Summary: %s\n", result.Summary))
formattedResults.WriteString(fmt.Sprintf("PDF URL: %s\n", result.PdfURL))
formattedResults.WriteString(fmt.Sprintf("Published: %s\n\n", result.PublishedAt))
}

return formattedResults.String()
}