@@ -19,6 +19,13 @@ import (
19
19
"github.com/pkg/browser"
20
20
)
21
21
22
+ type authType string
23
+
24
+ const (
25
+ authTypePAT authType = "Personal Access Token (PAT)"
26
+ authTypeOAuth authType = "OAuth"
27
+ )
28
+
22
29
type input struct {
23
30
OAuthInfo oauthInfo `json:"oauthInfo"`
24
31
PromptInfo * promptInfo `json:"promptInfo,omitempty"`
@@ -68,18 +75,28 @@ type cred struct {
68
75
RefreshToken string `json:"refreshToken"`
69
76
}
70
77
78
+ type urls struct {
79
+ authorizeURL string
80
+ refreshURL string
81
+ tokenURL string
82
+ }
83
+
71
84
func normalizeForEnv (appName string ) string {
72
85
return strings .ToUpper (strings .ReplaceAll (appName , "-" , "_" ))
73
86
}
74
87
75
- func getURLs (appName string ) ( string , string , string ) {
88
+ func getURLs (appName string ) urls {
76
89
var (
77
90
normalizedAppName = normalizeForEnv (appName )
78
91
authorizeURL = os .Getenv (fmt .Sprintf ("GPTSCRIPT_OAUTH_%s_AUTH_URL" , normalizedAppName ))
79
92
refreshURL = os .Getenv (fmt .Sprintf ("GPTSCRIPT_OAUTH_%s_REFRESH_URL" , normalizedAppName ))
80
93
tokenURL = os .Getenv (fmt .Sprintf ("GPTSCRIPT_OAUTH_%s_TOKEN_URL" , normalizedAppName ))
81
94
)
82
- return authorizeURL , refreshURL , tokenURL
95
+ return urls {
96
+ authorizeURL : authorizeURL ,
97
+ refreshURL : refreshURL ,
98
+ tokenURL : tokenURL ,
99
+ }
83
100
}
84
101
85
102
func main () {
@@ -118,22 +135,35 @@ func mainErr() (err error) {
118
135
}
119
136
}()
120
137
121
- authorizeURL , refreshURL , tokenURL := getURLs (in .OAuthInfo .Integration )
122
- if authorizeURL == "" || refreshURL == "" || tokenURL == "" {
123
- // The URLs aren't set for this credential. Check to see if we should prompt the user for other tokens
124
- if in .PromptInfo == nil {
125
- fmt .Printf ("All the following environment variables must be set: GPTSCRIPT_OAUTH_%s_AUTH_URL, GPTSCRIPT_OAUTH_%[1]s_REFRESH_URL, GPTSCRIPT_OAUTH_%[1]s_TOKEN_URL" , normalizeForEnv (in .OAuthInfo .Integration ))
126
- fmt .Printf ("Or the promptInfo configuration must be provided for token prompting." )
127
- os .Exit (1 )
138
+ urls := getURLs (in .OAuthInfo .Integration )
139
+
140
+ if in .PromptInfo != nil && os .Getenv ("GPTSCRIPT_EXISTING_CREDENTIAL" ) == "" {
141
+ if urls .authorizeURL == "" && urls .refreshURL == "" && urls .tokenURL == "" {
142
+ credJSON , err = promptForTokens (ctx , gs , in .OAuthInfo .Integration , in .PromptInfo )
143
+ if err != nil {
144
+ return fmt .Errorf ("main: failed to prompt for tokens: %w" , err )
145
+ }
146
+ return nil
128
147
}
129
148
130
- credJSON , err = promptForTokens (ctx , gs , in . OAuthInfo . Integration , in . PromptInfo )
149
+ authType , err := promptForSelect (ctx , gs )
131
150
if err != nil {
132
- return fmt .Errorf ("main: failed to prompt for tokens: %w" , err )
151
+ return fmt .Errorf ("main: failed to prompt for auth type: %w" , err )
152
+ }
153
+
154
+ if authType == authTypePAT {
155
+ credJSON , err = promptForTokens (ctx , gs , in .OAuthInfo .Integration , in .PromptInfo )
156
+ if err != nil {
157
+ return fmt .Errorf ("main: failed to prompt for tokens: %w" , err )
158
+ }
159
+ return nil
133
160
}
134
- return nil
135
161
}
136
162
163
+ return promptForOauth (gs , & urls , & in , & credJSON )
164
+ }
165
+
166
+ func promptForOauth (gs * gptscript.GPTScript , urls * urls , in * input , credJSON * []byte ) error {
137
167
// Refresh existing credential if there is one.
138
168
existing := os .Getenv ("GPTSCRIPT_EXISTING_CREDENTIAL" )
139
169
if existing != "" {
@@ -142,7 +172,7 @@ func mainErr() (err error) {
142
172
return fmt .Errorf ("main: failed to unmarshal existing credential: %w" , err )
143
173
}
144
174
145
- u , err := url .Parse (refreshURL )
175
+ u , err := url .Parse (urls . refreshURL )
146
176
if err != nil {
147
177
return fmt .Errorf ("main: failed to parse refresh URL: %w" , err )
148
178
}
@@ -195,7 +225,7 @@ func mainErr() (err error) {
195
225
out .ExpiresAt = & expiresAt
196
226
}
197
227
198
- credJSON , err = json .Marshal (out )
228
+ * credJSON , err = json .Marshal (out )
199
229
if err != nil {
200
230
return fmt .Errorf ("main: failed to marshal refreshed credential: %w" , err )
201
231
}
@@ -217,7 +247,7 @@ func mainErr() (err error) {
217
247
h .Write ([]byte (verifier ))
218
248
challenge := hex .EncodeToString (h .Sum (nil ))
219
249
220
- u , err := url .Parse (authorizeURL )
250
+ u , err := url .Parse (urls . authorizeURL )
221
251
if err != nil {
222
252
return fmt .Errorf ("main: failed to parse authorize URL: %w" , err )
223
253
}
@@ -274,7 +304,7 @@ func mainErr() (err error) {
274
304
t := time .NewTicker (2 * time .Second )
275
305
for range t .C {
276
306
now := time .Now ()
277
- oauthResp , retry , err := makeTokenRequest (tokenURL , state , verifier )
307
+ oauthResp , retry , err := makeTokenRequest (urls . tokenURL , state , verifier )
278
308
if err != nil {
279
309
if ! retry {
280
310
return err
@@ -301,7 +331,7 @@ func mainErr() (err error) {
301
331
out .ExpiresAt = & expiresAt
302
332
}
303
333
304
- credJSON , err = json .Marshal (out )
334
+ * credJSON , err = json .Marshal (out )
305
335
if err != nil {
306
336
return fmt .Errorf ("main: failed to marshal token credential: %w" , err )
307
337
}
@@ -357,6 +387,40 @@ func generateString() (string, error) {
357
387
return string (b ), nil
358
388
}
359
389
390
+ func promptForSelect (ctx context.Context , g * gptscript.GPTScript ) (authType , error ) {
391
+ fieldName := "Authentication Method"
392
+
393
+ fields := gptscript.Fields {gptscript.Field {Name : fieldName , Description : "The authentication method to use for this tool." , Options : []string {string (authTypePAT ), string (authTypeOAuth )}}}
394
+ sysPromptIn , err := json .Marshal (sysPromptInput {
395
+ Message : "This tool has personal access token (PAT) and OAuth support. Select the authentication method you would like to use for this tool." ,
396
+ Fields : fields ,
397
+ })
398
+ if err != nil {
399
+ return "" , fmt .Errorf ("promptForSelect: error marshalling sys prompt input: %w" , err )
400
+ }
401
+
402
+ run , err := g .Run (ctx , "sys.prompt" , gptscript.Options {
403
+ Input : string (sysPromptIn ),
404
+ })
405
+ if err != nil {
406
+ return "" , fmt .Errorf ("promptForSelect: failed to run sys.prompt: %w" , err )
407
+ }
408
+
409
+ out , err := run .Text ()
410
+ if err != nil {
411
+ return "" , fmt .Errorf ("promptForSelect: failed to get prompt response: %w" , err )
412
+ }
413
+
414
+ m := make (map [string ]string )
415
+ if err = json .Unmarshal ([]byte (out ), & m ); err != nil {
416
+ return "" , fmt .Errorf ("promptForSelect: failed to unmarshal prompt response: %w" , err )
417
+ }
418
+
419
+ authType := authType (m [fieldName ])
420
+
421
+ return authType , nil
422
+ }
423
+
360
424
func promptForTokens (ctx context.Context , g * gptscript.GPTScript , integration string , prompt * promptInfo ) ([]byte , error ) {
361
425
if prompt .Metadata == nil {
362
426
prompt .Metadata = make (map [string ]string )
0 commit comments