Skip to content

Commit 156b7bd

Browse files
committed
Enhance: support prompting user to chosse pat/oauth
Signed-off-by: Daishan Peng <[email protected]>
1 parent ba8be10 commit 156b7bd

File tree

3 files changed

+90
-46
lines changed

3 files changed

+90
-46
lines changed

oauth2/go.mod

+8-6
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@ module gateway-oauth2
22

33
go 1.23.0
44

5+
replace github.com/gptscript-ai/go-gptscript => ../../go-gptscript
6+
57
require (
68
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61
79
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
810
)
911

1012
require (
11-
github.com/getkin/kin-openapi v0.124.0 // indirect
12-
github.com/go-openapi/jsonpointer v0.20.2 // indirect
13-
github.com/go-openapi/swag v0.22.8 // indirect
14-
github.com/invopop/yaml v0.2.0 // indirect
13+
github.com/getkin/kin-openapi v0.129.0 // indirect
14+
github.com/go-openapi/jsonpointer v0.21.0 // indirect
15+
github.com/go-openapi/swag v0.23.0 // indirect
1516
github.com/josharian/intern v1.0.0 // indirect
16-
github.com/mailru/easyjson v0.7.7 // indirect
17+
github.com/mailru/easyjson v0.9.0 // indirect
1718
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
19+
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 // indirect
20+
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 // indirect
1821
github.com/perimeterx/marshmallow v1.1.5 // indirect
19-
github.com/stretchr/testify v1.9.0 // indirect
2022
golang.org/x/sys v0.20.0 // indirect
2123
gopkg.in/yaml.v3 v3.0.1 // indirect
2224
)

oauth2/go.sum

+14-15
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
22
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3-
github.com/getkin/kin-openapi v0.124.0 h1:VSFNMB9C9rTKBnQ/fpyDU8ytMTr4dWI9QovSKj9kz/M=
4-
github.com/getkin/kin-openapi v0.124.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM=
5-
github.com/go-openapi/jsonpointer v0.20.2 h1:mQc3nmndL8ZBzStEo3JYF8wzmeWffDH4VbXz58sAx6Q=
6-
github.com/go-openapi/jsonpointer v0.20.2/go.mod h1:bHen+N0u1KEO3YlmqOjTT9Adn1RfD91Ar825/PuiRVs=
7-
github.com/go-openapi/swag v0.22.8 h1:/9RjDSQ0vbFR+NyjGMkFTsA1IA0fmhKSThmfGZjicbw=
8-
github.com/go-openapi/swag v0.22.8/go.mod h1:6QT22icPLEqAM/z/TChgb4WAveCHF92+2gF0CNjHpPI=
3+
github.com/getkin/kin-openapi v0.129.0 h1:QGYTNcmyP5X0AtFQ2Dkou9DGBJsUETeLH9rFrJXZh30=
4+
github.com/getkin/kin-openapi v0.129.0/go.mod h1:gmWI+b/J45xqpyK5wJmRRZse5wefA5H0RDMK46kLUtI=
5+
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
6+
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
7+
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
8+
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
99
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
1010
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
11-
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI=
12-
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
13-
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
14-
github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
1511
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
1612
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
1713
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
1814
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
1915
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
2016
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
21-
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
22-
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
17+
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
18+
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
2319
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw=
2420
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
21+
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 h1:nZspmSkneBbtxU9TopEAE0CY+SBJLxO8LPUlw2vG4pU=
22+
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80/go.mod h1:7tFDb+Y51LcDpn26GccuUgQXUk6t0CXZsivKjyimYX8=
23+
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 h1:t05Ww3DxZutOqbMN+7OIuqDwXbhl32HiZGpLy26BAPc=
24+
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o=
2525
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
2626
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
2727
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
@@ -30,8 +30,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
3030
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
3131
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
3232
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
33-
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
34-
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
33+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
34+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
3535
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
3636
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
3737
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -40,6 +40,5 @@ golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
4040
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
4141
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
4242
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
43-
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
4443
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
4544
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

oauth2/main.go

+68-25
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ import (
1919
"github.com/pkg/browser"
2020
)
2121

22+
type authType string
23+
24+
const (
25+
authTypePAT authType = "Personal Access Token (PAT)"
26+
authTypeOAuth authType = "OAuth"
27+
)
28+
2229
type input struct {
2330
OAuthInfo oauthInfo `json:"oauthInfo"`
2431
PromptInfo *promptInfo `json:"promptInfo,omitempty"`
@@ -68,18 +75,10 @@ type cred struct {
6875
RefreshToken string `json:"refreshToken"`
6976
}
7077

71-
func normalizeForEnv(appName string) string {
72-
return strings.ToUpper(strings.ReplaceAll(appName, "-", "_"))
73-
}
74-
75-
func getURLs(appName string) (string, string, string) {
76-
var (
77-
normalizedAppName = normalizeForEnv(appName)
78-
authorizeURL = os.Getenv(fmt.Sprintf("GPTSCRIPT_OAUTH_%s_AUTH_URL", normalizedAppName))
79-
refreshURL = os.Getenv(fmt.Sprintf("GPTSCRIPT_OAUTH_%s_REFRESH_URL", normalizedAppName))
80-
tokenURL = os.Getenv(fmt.Sprintf("GPTSCRIPT_OAUTH_%s_TOKEN_URL", normalizedAppName))
81-
)
82-
return authorizeURL, refreshURL, tokenURL
78+
func getURLs(serverURL, appName string) (string, string, string) {
79+
return fmt.Sprintf("%s/api/app-oauth/authorize/%s", serverURL, appName),
80+
fmt.Sprintf("%s/api/app-oauth/refresh/%s", serverURL, appName),
81+
fmt.Sprintf("%s/api/app-oauth/get-token", serverURL)
8382
}
8483

8584
func main() {
@@ -96,6 +95,11 @@ func mainErr() (err error) {
9695
return fmt.Errorf("main: TOOL_CALL_BODY environment variable not set")
9796
}
9897

98+
serverURL := os.Getenv("OAUTH_SERVER_URL")
99+
if serverURL == "" {
100+
return fmt.Errorf("main: OAUTH_SERVER_URL environment variable not set")
101+
}
102+
99103
var in input
100104
if err = json.Unmarshal([]byte(inputStr), &in); err != nil {
101105
return fmt.Errorf("main: error parsing input JSON: %w", err)
@@ -118,22 +122,27 @@ func mainErr() (err error) {
118122
}
119123
}()
120124

121-
authorizeURL, refreshURL, tokenURL := getURLs(in.OAuthInfo.Integration)
122-
if authorizeURL == "" || refreshURL == "" || tokenURL == "" {
123-
// The URLs aren't set for this credential. Check to see if we should prompt the user for other tokens
124-
if in.PromptInfo == nil {
125-
fmt.Printf("All the following environment variables must be set: GPTSCRIPT_OAUTH_%s_AUTH_URL, GPTSCRIPT_OAUTH_%[1]s_REFRESH_URL, GPTSCRIPT_OAUTH_%[1]s_TOKEN_URL", normalizeForEnv(in.OAuthInfo.Integration))
126-
fmt.Printf("Or the promptInfo configuration must be provided for token prompting.")
127-
os.Exit(1)
125+
if in.PromptInfo != nil && os.Getenv("GPTSCRIPT_EXISTING_CREDENTIAL") == "" {
126+
authType, err := promptForSelect(ctx, gs)
127+
if err != nil {
128+
return fmt.Errorf("main: failed to prompt for auth type: %w", err)
128129
}
129130

130-
credJSON, err = promptForTokens(ctx, gs, in.OAuthInfo.Integration, in.PromptInfo)
131-
if err != nil {
132-
return fmt.Errorf("main: failed to prompt for tokens: %w", err)
131+
if authType == authTypePAT {
132+
credJSON, err = promptForTokens(ctx, gs, in.OAuthInfo.Integration, in.PromptInfo)
133+
if err != nil {
134+
return fmt.Errorf("main: failed to prompt for tokens: %w", err)
135+
}
136+
return nil
133137
}
134-
return nil
135138
}
136139

140+
return promptForOauth(ctx, gs, serverURL, &in, &credJSON)
141+
}
142+
143+
func promptForOauth(ctx context.Context, gs *gptscript.GPTScript, serverURL string, in *input, credJSON *[]byte) error {
144+
authorizeURL, refreshURL, tokenURL := getURLs(serverURL, in.OAuthInfo.Integration)
145+
137146
// Refresh existing credential if there is one.
138147
existing := os.Getenv("GPTSCRIPT_EXISTING_CREDENTIAL")
139148
if existing != "" {
@@ -195,7 +204,7 @@ func mainErr() (err error) {
195204
out.ExpiresAt = &expiresAt
196205
}
197206

198-
credJSON, err = json.Marshal(out)
207+
*credJSON, err = json.Marshal(out)
199208
if err != nil {
200209
return fmt.Errorf("main: failed to marshal refreshed credential: %w", err)
201210
}
@@ -301,7 +310,7 @@ func mainErr() (err error) {
301310
out.ExpiresAt = &expiresAt
302311
}
303312

304-
credJSON, err = json.Marshal(out)
313+
*credJSON, err = json.Marshal(out)
305314
if err != nil {
306315
return fmt.Errorf("main: failed to marshal token credential: %w", err)
307316
}
@@ -357,6 +366,40 @@ func generateString() (string, error) {
357366
return string(b), nil
358367
}
359368

369+
func promptForSelect(ctx context.Context, g *gptscript.GPTScript) (authType, error) {
370+
fieldName := "Authentication Method"
371+
372+
fields := gptscript.Fields{gptscript.Field{Name: fieldName, Description: "The authentication method to use for this tool.", Options: []string{string(authTypePAT), string(authTypeOAuth)}}}
373+
sysPromptIn, err := json.Marshal(sysPromptInput{
374+
Message: "This tool has personal access token (PAT) and OAuth support. Select the authentication method you would like to use for this tool.",
375+
Fields: fields,
376+
})
377+
if err != nil {
378+
return "", fmt.Errorf("promptForSelect: error marshalling sys prompt input: %w", err)
379+
}
380+
381+
run, err := g.Run(ctx, "sys.prompt", gptscript.Options{
382+
Input: string(sysPromptIn),
383+
})
384+
if err != nil {
385+
return "", fmt.Errorf("promptForSelect: failed to run sys.prompt: %w", err)
386+
}
387+
388+
out, err := run.Text()
389+
if err != nil {
390+
return "", fmt.Errorf("promptForSelect: failed to get prompt response: %w", err)
391+
}
392+
393+
m := make(map[string]string)
394+
if err = json.Unmarshal([]byte(out), &m); err != nil {
395+
return "", fmt.Errorf("promptForSelect: failed to unmarshal prompt response: %w", err)
396+
}
397+
398+
authType := authType(m[fieldName])
399+
400+
return authType, nil
401+
}
402+
360403
func promptForTokens(ctx context.Context, g *gptscript.GPTScript, integration string, prompt *promptInfo) ([]byte, error) {
361404
if prompt.Metadata == nil {
362405
prompt.Metadata = make(map[string]string)

0 commit comments

Comments
 (0)