@@ -19,6 +19,13 @@ import (
19
19
"github.com/pkg/browser"
20
20
)
21
21
22
+ type authType string
23
+
24
+ const (
25
+ authTypePAT authType = "Personal Access Token (PAT)"
26
+ authTypeOAuth authType = "OAuth"
27
+ )
28
+
22
29
type input struct {
23
30
OAuthInfo oauthInfo `json:"oauthInfo"`
24
31
PromptInfo * promptInfo `json:"promptInfo,omitempty"`
@@ -68,18 +75,10 @@ type cred struct {
68
75
RefreshToken string `json:"refreshToken"`
69
76
}
70
77
71
- func normalizeForEnv (appName string ) string {
72
- return strings .ToUpper (strings .ReplaceAll (appName , "-" , "_" ))
73
- }
74
-
75
- func getURLs (appName string ) (string , string , string ) {
76
- var (
77
- normalizedAppName = normalizeForEnv (appName )
78
- authorizeURL = os .Getenv (fmt .Sprintf ("GPTSCRIPT_OAUTH_%s_AUTH_URL" , normalizedAppName ))
79
- refreshURL = os .Getenv (fmt .Sprintf ("GPTSCRIPT_OAUTH_%s_REFRESH_URL" , normalizedAppName ))
80
- tokenURL = os .Getenv (fmt .Sprintf ("GPTSCRIPT_OAUTH_%s_TOKEN_URL" , normalizedAppName ))
81
- )
82
- return authorizeURL , refreshURL , tokenURL
78
+ func getURLs (serverURL , appName string ) (string , string , string ) {
79
+ return fmt .Sprintf ("%s/api/app-oauth/authorize/%s" , serverURL , appName ),
80
+ fmt .Sprintf ("%s/api/app-oauth/refresh/%s" , serverURL , appName ),
81
+ fmt .Sprintf ("%s/api/app-oauth/get-token" , serverURL )
83
82
}
84
83
85
84
func main () {
@@ -96,6 +95,11 @@ func mainErr() (err error) {
96
95
return fmt .Errorf ("main: TOOL_CALL_BODY environment variable not set" )
97
96
}
98
97
98
+ serverURL := os .Getenv ("OAUTH_SERVER_URL" )
99
+ if serverURL == "" {
100
+ return fmt .Errorf ("main: OAUTH_SERVER_URL environment variable not set" )
101
+ }
102
+
99
103
var in input
100
104
if err = json .Unmarshal ([]byte (inputStr ), & in ); err != nil {
101
105
return fmt .Errorf ("main: error parsing input JSON: %w" , err )
@@ -118,22 +122,27 @@ func mainErr() (err error) {
118
122
}
119
123
}()
120
124
121
- authorizeURL , refreshURL , tokenURL := getURLs (in .OAuthInfo .Integration )
122
- if authorizeURL == "" || refreshURL == "" || tokenURL == "" {
123
- // The URLs aren't set for this credential. Check to see if we should prompt the user for other tokens
124
- if in .PromptInfo == nil {
125
- fmt .Printf ("All the following environment variables must be set: GPTSCRIPT_OAUTH_%s_AUTH_URL, GPTSCRIPT_OAUTH_%[1]s_REFRESH_URL, GPTSCRIPT_OAUTH_%[1]s_TOKEN_URL" , normalizeForEnv (in .OAuthInfo .Integration ))
126
- fmt .Printf ("Or the promptInfo configuration must be provided for token prompting." )
127
- os .Exit (1 )
125
+ if in .PromptInfo != nil && os .Getenv ("GPTSCRIPT_EXISTING_CREDENTIAL" ) == "" {
126
+ authType , err := promptForSelect (ctx , gs )
127
+ if err != nil {
128
+ return fmt .Errorf ("main: failed to prompt for auth type: %w" , err )
128
129
}
129
130
130
- credJSON , err = promptForTokens (ctx , gs , in .OAuthInfo .Integration , in .PromptInfo )
131
- if err != nil {
132
- return fmt .Errorf ("main: failed to prompt for tokens: %w" , err )
131
+ if authType == authTypePAT {
132
+ credJSON , err = promptForTokens (ctx , gs , in .OAuthInfo .Integration , in .PromptInfo )
133
+ if err != nil {
134
+ return fmt .Errorf ("main: failed to prompt for tokens: %w" , err )
135
+ }
136
+ return nil
133
137
}
134
- return nil
135
138
}
136
139
140
+ return promptForOauth (ctx , gs , serverURL , & in , & credJSON )
141
+ }
142
+
143
+ func promptForOauth (ctx context.Context , gs * gptscript.GPTScript , serverURL string , in * input , credJSON * []byte ) error {
144
+ authorizeURL , refreshURL , tokenURL := getURLs (serverURL , in .OAuthInfo .Integration )
145
+
137
146
// Refresh existing credential if there is one.
138
147
existing := os .Getenv ("GPTSCRIPT_EXISTING_CREDENTIAL" )
139
148
if existing != "" {
@@ -195,7 +204,7 @@ func mainErr() (err error) {
195
204
out .ExpiresAt = & expiresAt
196
205
}
197
206
198
- credJSON , err = json .Marshal (out )
207
+ * credJSON , err = json .Marshal (out )
199
208
if err != nil {
200
209
return fmt .Errorf ("main: failed to marshal refreshed credential: %w" , err )
201
210
}
@@ -301,7 +310,7 @@ func mainErr() (err error) {
301
310
out .ExpiresAt = & expiresAt
302
311
}
303
312
304
- credJSON , err = json .Marshal (out )
313
+ * credJSON , err = json .Marshal (out )
305
314
if err != nil {
306
315
return fmt .Errorf ("main: failed to marshal token credential: %w" , err )
307
316
}
@@ -357,6 +366,40 @@ func generateString() (string, error) {
357
366
return string (b ), nil
358
367
}
359
368
369
+ func promptForSelect (ctx context.Context , g * gptscript.GPTScript ) (authType , error ) {
370
+ fieldName := "Authentication Method"
371
+
372
+ fields := gptscript.Fields {gptscript.Field {Name : fieldName , Description : "The authentication method to use for this tool." , Options : []string {string (authTypePAT ), string (authTypeOAuth )}}}
373
+ sysPromptIn , err := json .Marshal (sysPromptInput {
374
+ Message : "This tool has personal access token (PAT) and OAuth support. Select the authentication method you would like to use for this tool." ,
375
+ Fields : fields ,
376
+ })
377
+ if err != nil {
378
+ return "" , fmt .Errorf ("promptForSelect: error marshalling sys prompt input: %w" , err )
379
+ }
380
+
381
+ run , err := g .Run (ctx , "sys.prompt" , gptscript.Options {
382
+ Input : string (sysPromptIn ),
383
+ })
384
+ if err != nil {
385
+ return "" , fmt .Errorf ("promptForSelect: failed to run sys.prompt: %w" , err )
386
+ }
387
+
388
+ out , err := run .Text ()
389
+ if err != nil {
390
+ return "" , fmt .Errorf ("promptForSelect: failed to get prompt response: %w" , err )
391
+ }
392
+
393
+ m := make (map [string ]string )
394
+ if err = json .Unmarshal ([]byte (out ), & m ); err != nil {
395
+ return "" , fmt .Errorf ("promptForSelect: failed to unmarshal prompt response: %w" , err )
396
+ }
397
+
398
+ authType := authType (m [fieldName ])
399
+
400
+ return authType , nil
401
+ }
402
+
360
403
func promptForTokens (ctx context.Context , g * gptscript.GPTScript , integration string , prompt * promptInfo ) ([]byte , error ) {
361
404
if prompt .Metadata == nil {
362
405
prompt .Metadata = make (map [string ]string )
0 commit comments