Skip to content

Commit 717ce77

Browse files
committed
Enhance: support prompting user to chosse pat/oauth
Signed-off-by: Daishan Peng <[email protected]>
1 parent ba8be10 commit 717ce77

File tree

3 files changed

+103
-38
lines changed

3 files changed

+103
-38
lines changed

oauth2/go.mod

+8-6
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@ module gateway-oauth2
22

33
go 1.23.0
44

5+
replace github.com/gptscript-ai/go-gptscript => ../../go-gptscript
6+
57
require (
68
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61
79
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
810
)
911

1012
require (
11-
github.com/getkin/kin-openapi v0.124.0 // indirect
12-
github.com/go-openapi/jsonpointer v0.20.2 // indirect
13-
github.com/go-openapi/swag v0.22.8 // indirect
14-
github.com/invopop/yaml v0.2.0 // indirect
13+
github.com/getkin/kin-openapi v0.129.0 // indirect
14+
github.com/go-openapi/jsonpointer v0.21.0 // indirect
15+
github.com/go-openapi/swag v0.23.0 // indirect
1516
github.com/josharian/intern v1.0.0 // indirect
16-
github.com/mailru/easyjson v0.7.7 // indirect
17+
github.com/mailru/easyjson v0.9.0 // indirect
1718
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
19+
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 // indirect
20+
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 // indirect
1821
github.com/perimeterx/marshmallow v1.1.5 // indirect
19-
github.com/stretchr/testify v1.9.0 // indirect
2022
golang.org/x/sys v0.20.0 // indirect
2123
gopkg.in/yaml.v3 v3.0.1 // indirect
2224
)

oauth2/go.sum

+14-15
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
22
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3-
github.com/getkin/kin-openapi v0.124.0 h1:VSFNMB9C9rTKBnQ/fpyDU8ytMTr4dWI9QovSKj9kz/M=
4-
github.com/getkin/kin-openapi v0.124.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM=
5-
github.com/go-openapi/jsonpointer v0.20.2 h1:mQc3nmndL8ZBzStEo3JYF8wzmeWffDH4VbXz58sAx6Q=
6-
github.com/go-openapi/jsonpointer v0.20.2/go.mod h1:bHen+N0u1KEO3YlmqOjTT9Adn1RfD91Ar825/PuiRVs=
7-
github.com/go-openapi/swag v0.22.8 h1:/9RjDSQ0vbFR+NyjGMkFTsA1IA0fmhKSThmfGZjicbw=
8-
github.com/go-openapi/swag v0.22.8/go.mod h1:6QT22icPLEqAM/z/TChgb4WAveCHF92+2gF0CNjHpPI=
3+
github.com/getkin/kin-openapi v0.129.0 h1:QGYTNcmyP5X0AtFQ2Dkou9DGBJsUETeLH9rFrJXZh30=
4+
github.com/getkin/kin-openapi v0.129.0/go.mod h1:gmWI+b/J45xqpyK5wJmRRZse5wefA5H0RDMK46kLUtI=
5+
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
6+
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
7+
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
8+
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
99
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
1010
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
11-
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI=
12-
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
13-
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
14-
github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
1511
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
1612
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
1713
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
1814
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
1915
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
2016
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
21-
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
22-
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
17+
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
18+
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
2319
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw=
2420
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
21+
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 h1:nZspmSkneBbtxU9TopEAE0CY+SBJLxO8LPUlw2vG4pU=
22+
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80/go.mod h1:7tFDb+Y51LcDpn26GccuUgQXUk6t0CXZsivKjyimYX8=
23+
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 h1:t05Ww3DxZutOqbMN+7OIuqDwXbhl32HiZGpLy26BAPc=
24+
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o=
2525
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
2626
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
2727
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
@@ -30,8 +30,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
3030
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
3131
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
3232
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
33-
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
34-
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
33+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
34+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
3535
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
3636
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
3737
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -40,6 +40,5 @@ golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
4040
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
4141
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
4242
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
43-
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
4443
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
4544
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

oauth2/main.go

+81-17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ import (
1919
"github.com/pkg/browser"
2020
)
2121

22+
type authType string
23+
24+
const (
25+
authTypePAT authType = "Personal Access Token (PAT)"
26+
authTypeOAuth authType = "OAuth"
27+
)
28+
2229
type input struct {
2330
OAuthInfo oauthInfo `json:"oauthInfo"`
2431
PromptInfo *promptInfo `json:"promptInfo,omitempty"`
@@ -68,18 +75,28 @@ type cred struct {
6875
RefreshToken string `json:"refreshToken"`
6976
}
7077

78+
type urls struct {
79+
authorizeURL string
80+
refreshURL string
81+
tokenURL string
82+
}
83+
7184
func normalizeForEnv(appName string) string {
7285
return strings.ToUpper(strings.ReplaceAll(appName, "-", "_"))
7386
}
7487

75-
func getURLs(appName string) (string, string, string) {
88+
func getURLs(appName string) urls {
7689
var (
7790
normalizedAppName = normalizeForEnv(appName)
7891
authorizeURL = os.Getenv(fmt.Sprintf("GPTSCRIPT_OAUTH_%s_AUTH_URL", normalizedAppName))
7992
refreshURL = os.Getenv(fmt.Sprintf("GPTSCRIPT_OAUTH_%s_REFRESH_URL", normalizedAppName))
8093
tokenURL = os.Getenv(fmt.Sprintf("GPTSCRIPT_OAUTH_%s_TOKEN_URL", normalizedAppName))
8194
)
82-
return authorizeURL, refreshURL, tokenURL
95+
return urls{
96+
authorizeURL: authorizeURL,
97+
refreshURL: refreshURL,
98+
tokenURL: tokenURL,
99+
}
83100
}
84101

85102
func main() {
@@ -118,22 +135,35 @@ func mainErr() (err error) {
118135
}
119136
}()
120137

121-
authorizeURL, refreshURL, tokenURL := getURLs(in.OAuthInfo.Integration)
122-
if authorizeURL == "" || refreshURL == "" || tokenURL == "" {
123-
// The URLs aren't set for this credential. Check to see if we should prompt the user for other tokens
124-
if in.PromptInfo == nil {
125-
fmt.Printf("All the following environment variables must be set: GPTSCRIPT_OAUTH_%s_AUTH_URL, GPTSCRIPT_OAUTH_%[1]s_REFRESH_URL, GPTSCRIPT_OAUTH_%[1]s_TOKEN_URL", normalizeForEnv(in.OAuthInfo.Integration))
126-
fmt.Printf("Or the promptInfo configuration must be provided for token prompting.")
127-
os.Exit(1)
138+
urls := getURLs(in.OAuthInfo.Integration)
139+
140+
if in.PromptInfo != nil && os.Getenv("GPTSCRIPT_EXISTING_CREDENTIAL") == "" {
141+
if urls.authorizeURL == "" && urls.refreshURL == "" && urls.tokenURL == "" {
142+
credJSON, err = promptForTokens(ctx, gs, in.OAuthInfo.Integration, in.PromptInfo)
143+
if err != nil {
144+
return fmt.Errorf("main: failed to prompt for tokens: %w", err)
145+
}
146+
return nil
128147
}
129148

130-
credJSON, err = promptForTokens(ctx, gs, in.OAuthInfo.Integration, in.PromptInfo)
149+
authType, err := promptForSelect(ctx, gs)
131150
if err != nil {
132-
return fmt.Errorf("main: failed to prompt for tokens: %w", err)
151+
return fmt.Errorf("main: failed to prompt for auth type: %w", err)
152+
}
153+
154+
if authType == authTypePAT {
155+
credJSON, err = promptForTokens(ctx, gs, in.OAuthInfo.Integration, in.PromptInfo)
156+
if err != nil {
157+
return fmt.Errorf("main: failed to prompt for tokens: %w", err)
158+
}
159+
return nil
133160
}
134-
return nil
135161
}
136162

163+
return promptForOauth(gs, &urls, &in, &credJSON)
164+
}
165+
166+
func promptForOauth(gs *gptscript.GPTScript, urls *urls, in *input, credJSON *[]byte) error {
137167
// Refresh existing credential if there is one.
138168
existing := os.Getenv("GPTSCRIPT_EXISTING_CREDENTIAL")
139169
if existing != "" {
@@ -142,7 +172,7 @@ func mainErr() (err error) {
142172
return fmt.Errorf("main: failed to unmarshal existing credential: %w", err)
143173
}
144174

145-
u, err := url.Parse(refreshURL)
175+
u, err := url.Parse(urls.refreshURL)
146176
if err != nil {
147177
return fmt.Errorf("main: failed to parse refresh URL: %w", err)
148178
}
@@ -195,7 +225,7 @@ func mainErr() (err error) {
195225
out.ExpiresAt = &expiresAt
196226
}
197227

198-
credJSON, err = json.Marshal(out)
228+
*credJSON, err = json.Marshal(out)
199229
if err != nil {
200230
return fmt.Errorf("main: failed to marshal refreshed credential: %w", err)
201231
}
@@ -217,7 +247,7 @@ func mainErr() (err error) {
217247
h.Write([]byte(verifier))
218248
challenge := hex.EncodeToString(h.Sum(nil))
219249

220-
u, err := url.Parse(authorizeURL)
250+
u, err := url.Parse(urls.authorizeURL)
221251
if err != nil {
222252
return fmt.Errorf("main: failed to parse authorize URL: %w", err)
223253
}
@@ -274,7 +304,7 @@ func mainErr() (err error) {
274304
t := time.NewTicker(2 * time.Second)
275305
for range t.C {
276306
now := time.Now()
277-
oauthResp, retry, err := makeTokenRequest(tokenURL, state, verifier)
307+
oauthResp, retry, err := makeTokenRequest(urls.tokenURL, state, verifier)
278308
if err != nil {
279309
if !retry {
280310
return err
@@ -301,7 +331,7 @@ func mainErr() (err error) {
301331
out.ExpiresAt = &expiresAt
302332
}
303333

304-
credJSON, err = json.Marshal(out)
334+
*credJSON, err = json.Marshal(out)
305335
if err != nil {
306336
return fmt.Errorf("main: failed to marshal token credential: %w", err)
307337
}
@@ -357,6 +387,40 @@ func generateString() (string, error) {
357387
return string(b), nil
358388
}
359389

390+
func promptForSelect(ctx context.Context, g *gptscript.GPTScript) (authType, error) {
391+
fieldName := "Authentication Method"
392+
393+
fields := gptscript.Fields{gptscript.Field{Name: fieldName, Description: "The authentication method to use for this tool.", Options: []string{string(authTypePAT), string(authTypeOAuth)}}}
394+
sysPromptIn, err := json.Marshal(sysPromptInput{
395+
Message: "This tool has personal access token (PAT) and OAuth support. Select the authentication method you would like to use for this tool.",
396+
Fields: fields,
397+
})
398+
if err != nil {
399+
return "", fmt.Errorf("promptForSelect: error marshalling sys prompt input: %w", err)
400+
}
401+
402+
run, err := g.Run(ctx, "sys.prompt", gptscript.Options{
403+
Input: string(sysPromptIn),
404+
})
405+
if err != nil {
406+
return "", fmt.Errorf("promptForSelect: failed to run sys.prompt: %w", err)
407+
}
408+
409+
out, err := run.Text()
410+
if err != nil {
411+
return "", fmt.Errorf("promptForSelect: failed to get prompt response: %w", err)
412+
}
413+
414+
m := make(map[string]string)
415+
if err = json.Unmarshal([]byte(out), &m); err != nil {
416+
return "", fmt.Errorf("promptForSelect: failed to unmarshal prompt response: %w", err)
417+
}
418+
419+
authType := authType(m[fieldName])
420+
421+
return authType, nil
422+
}
423+
360424
func promptForTokens(ctx context.Context, g *gptscript.GPTScript, integration string, prompt *promptInfo) ([]byte, error) {
361425
if prompt.Metadata == nil {
362426
prompt.Metadata = make(map[string]string)

0 commit comments

Comments
 (0)