Skip to content

Commit

Permalink
refactor: eliminate JWTHelper
Browse files Browse the repository at this point in the history
  • Loading branch information
vivshankar committed Aug 6, 2023
1 parent 7d4dfe9 commit 6e3faec
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 115 deletions.
79 changes: 7 additions & 72 deletions client_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ package fosite

import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"encoding/json"
"net/http"
"net/url"
"time"

"github.com/ory/x/errorsx"

"github.com/go-jose/go-jose/v3"

"github.com/ory/fosite/token/jwt"
)

Expand All @@ -26,29 +22,11 @@ type ClientAuthenticationStrategy func(context.Context, *http.Request, url.Value
const clientAssertionJWTBearerType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"

func (f *Fosite) findClientPublicJWK(ctx context.Context, oidcClient OpenIDConnectClient, t *jwt.Token, expectsRSAKey bool) (interface{}, error) {
if set := oidcClient.GetJSONWebKeys(); set != nil {
return findPublicKey(t, set, expectsRSAKey)
}

if location := oidcClient.GetJSONWebKeysURI(); len(location) > 0 {
keys, err := f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, false)
if err != nil {
return nil, err
}

if key, err := findPublicKey(t, keys, expectsRSAKey); err == nil {
return key, nil
}

keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, true)
if err != nil {
return nil, err
}

return findPublicKey(t, keys, expectsRSAKey)
if oidcClient.GetJSONWebKeys() == nil && oidcClient.GetJSONWebKeysURI() == "" {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request."))
}

return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request."))
return findPublicJWK(ctx, f.Config, t, oidcClient.GetJSONWebKeysURI(), oidcClient.GetJSONWebKeys(), expectsRSAKey, ErrInvalidClient)
}

// AuthenticateClient authenticates client requests using the configured strategy
Expand All @@ -69,16 +47,8 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The client_assertion request parameter must be set when using client_assertion_type of '%s'.", clientAssertionJWTBearerType))
}

// for backward compatibility
if f.JWTHelper == nil {
f.JWTHelper = &JWTHelper{
JWTStrategy: nil,
Config: f.Config,
}
}

// Parse the assertion
token, parsedToken, isJWE, err := f.newToken(assertion, "client_assertion", ErrInvalidClient)
token, parsedToken, isJWE, err := newToken(assertion, "client_assertion", ErrInvalidClient)
if err != nil {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to parse the client_assertion").WithWrap(err).WithDebug(err.Error()))
}
Expand Down Expand Up @@ -133,7 +103,9 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The client_assertion uses signing algorithm '%s', but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", parsedToken.Headers[0].Algorithm, oidcClient.GetTokenEndpointAuthSigningAlgorithm()))
}

if token, parsedToken, err = f.ValidateParsedAssertionWithClient(ctx, "client_assertion", assertion, token, parsedToken, oidcClient, false, ErrInvalidClient); err != nil {
ctx = context.WithValue(ctx, AssertionTypeContextKey, "client_assertion")
ctx = context.WithValue(ctx, BaseErrorContextKey, ErrInvalidClient)
if token, parsedToken, err = ValidateParsedAssertionWithClient(ctx, f.Config, assertion, token, parsedToken, oidcClient, false); err != nil {
return nil, err
}

Expand Down Expand Up @@ -248,43 +220,6 @@ func (f *Fosite) checkClientSecret(ctx context.Context, client Client, clientSec
return err
}

func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool) (interface{}, error) {
keys := set.Keys
if len(keys) == 0 {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The retrieved JSON Web Key Set does not contain any key."))
}

kid, ok := t.Header["kid"].(string)
if ok {
keys = set.Key(kid)
}

if len(keys) == 0 {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The JSON Web Token uses signing key with kid '%s', which could not be found.", kid))
}

for _, key := range keys {
if key.Use != "sig" {
continue
}
if expectsRSAKey {
if k, ok := key.Key.(*rsa.PublicKey); ok {
return k, nil
}
} else {
if k, ok := key.Key.(*ecdsa.PublicKey); ok {
return k, nil
}
}
}

if expectsRSAKey {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find RSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid))
} else {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid))
}
}

func clientCredentialsFromRequest(r *http.Request, form url.Values) (clientID, clientSecret string, err error) {
if id, secret, ok := r.BasicAuth(); !ok {
return clientCredentialsFromRequestBody(form, true)
Expand Down
10 changes: 4 additions & 6 deletions client_authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,15 @@ func TestAuthenticateClient(t *testing.T) {
ClientSecretsHasher: hasher,
TokenURL: "token-url",
HTTPClient: retryablehttp.NewClient(),
JWTStrategy: jwt.NewDefaultStrategy(
func(ctx context.Context, context *jwt.KeyContext) (interface{}, error) {
return encKey, nil
}),
}

f := &Fosite{
Store: storage.NewMemoryStore(),
Config: config,
JWTHelper: &JWTHelper{
Config: config,
JWTStrategy: jwt.NewDefaultStrategy(func(ctx context.Context, context *jwt.KeyContext) (interface{}, error) {
return encKey, nil
}),
},
}

barSecret, err := hasher.Hash(context.TODO(), []byte("bar"))
Expand Down
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,9 @@ type PushedAuthorizeRequestConfigProvider interface {
// must contain the PAR request_uri.
EnforcePushedAuthorize(ctx context.Context) bool
}

// JWTStrategyProvider returns the provider for configuring the JWT strategy.
type JWTStrategyProvider interface {
// GetJWTStrategy returns the JWT strategy.
GetJWTStrategy(ctx context.Context) jwt.Strategy
}
9 changes: 9 additions & 0 deletions config_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ var (
_ RevocationHandlersProvider = (*Config)(nil)
_ PushedAuthorizeRequestHandlersProvider = (*Config)(nil)
_ PushedAuthorizeRequestConfigProvider = (*Config)(nil)
_ JWTStrategyProvider = (*Config)(nil)
)

type Config struct {
Expand Down Expand Up @@ -212,6 +213,9 @@ type Config struct {

// IsPushedAuthorizeEnforced enforces pushed authorization request for /authorize
IsPushedAuthorizeEnforced bool

// JWTStrategy is used to provide additional JWT encrypt/decrypt/sign/verify capabilities
JWTStrategy jwt.Strategy
}

func (c *Config) GetGlobalSecret(ctx context.Context) ([]byte, error) {
Expand Down Expand Up @@ -488,3 +492,8 @@ func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) time.Dur
func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool {
return c.IsPushedAuthorizeEnforced
}

// GetJWTStrategy returns the JWT strategy.
func (c *Config) GetJWTStrategy(ctx context.Context) jwt.Strategy {
return c.JWTStrategy
}
3 changes: 3 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ const (
AuthorizeResponseContextKey = ContextKey("authorizeResponse")
// PushedAuthorizeResponseContextKey is the response context
PushedAuthorizeResponseContextKey = ContextKey("pushedAuthorizeResponse")

AssertionTypeContextKey = ContextKey("assertionType")
BaseErrorContextKey = ContextKey("baseError")
)
2 changes: 0 additions & 2 deletions fosite.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ type Fosite struct {
Store Storage

Config Configurator

*JWTHelper
}

// GetMinParameterEntropy returns MinParameterEntropy if set. Defaults to fosite.MinParameterEntropy.
Expand Down
Loading

0 comments on commit 6e3faec

Please sign in to comment.