Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move token resolving logic from the request ctx to separate method #3606

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion adapter/config/default_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ var defaultConfig = &Config{
DropConsoleTestHeaders: true,
},
APIKeyConfig: apiKeyConfig{
InternalAPIKeyHeader: "Choreo-API-Key",
InternalAPIKeyHeader: "choreo-api-key",
OAuthAgentURL: "https://localhost:9443",
},
PATConfig: patConfig{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
import org.wso2.choreo.connect.enforcer.commons.model.AuthenticationContext;
import org.wso2.choreo.connect.enforcer.commons.model.RequestContext;
import org.wso2.choreo.connect.enforcer.config.ConfigHolder;
import org.wso2.choreo.connect.enforcer.constants.APIConstants;
import org.wso2.choreo.connect.enforcer.exception.APISecurityException;
import org.wso2.choreo.connect.enforcer.security.jwt.validator.JWTConstants;

import java.util.Base64;
import java.util.Map;
Expand Down Expand Up @@ -65,6 +63,18 @@ public boolean canAuthenticate(RequestContext requestContext) {
@Override
public AuthenticationContext authenticate(RequestContext requestContext) throws APISecurityException {

return super.authenticate(requestContext);
}

private String getAPIKeyFromRequest(RequestContext requestContext) {
Map<String, String> headers = requestContext.getHeaders();
return headers.get(ConfigHolder.getInstance().getConfig().getApiKeyConfig()
.getApiKeyInternalHeader().toLowerCase());
}

@Override
protected String retrieveTokenFromRequestCtx(RequestContext requestContext) {

String apiKeyHeaderValue = getAPIKeyFromRequest(requestContext);
// Skipping the prefix(`chk_`) and checksum.
String apiKeyData = apiKeyHeaderValue.substring(4, apiKeyHeaderValue.length() - 6);
Expand All @@ -73,16 +83,7 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws
// Convert data into JSON.
JSONObject jsonObject = (JSONObject) JSONValue.parse(decodedKeyData);
// Extracting the jwt token.
String jwtToken = jsonObject.getAsString(APIKeyConstants.API_KEY_JSON_KEY);
// Add the JWT as the Authorization header to authenticate the request.
requestContext.getHeaders().put(APIConstants.AUTHORIZATION_HEADER_DEFAULT,
JWTConstants.BEARER + " " + jwtToken);
return super.authenticate(requestContext);
}

private String getAPIKeyFromRequest(RequestContext requestContext) {
Map<String, String> headers = requestContext.getHeaders();
return headers.get(ConfigHolder.getInstance().getConfig().getApiKeyConfig().getApiKeyInternalHeader());
return jsonObject.getAsString(APIKeyConstants.API_KEY_JSON_KEY);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,30 +171,8 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws
Utils.setTag(jwtAuthenticatorInfoSpan, APIConstants.LOG_TRACE_ID,
ThreadContext.get(APIConstants.LOG_TRACE_ID));
}
String authHeaderVal = retrieveAuthHeaderValue(requestContext);

if (authHeaderVal == null
&& requestContext.getMatchedAPI().getApiType().equalsIgnoreCase(APIConstants.ApiType.WEB_SOCKET)) {
String tokenValue = extractJWTInWSProtocolHeader(requestContext);
if (StringUtils.isNotEmpty(tokenValue)) {
authHeaderVal = JWTConstants.BEARER + " " + tokenValue;
}
}

if (authHeaderVal == null || !authHeaderVal.toLowerCase().contains(JWTConstants.BEARER)) {
throw new APISecurityException(APIConstants.StatusCodes.UNAUTHENTICATED.getCode(),
APISecurityConstants.API_AUTH_MISSING_CREDENTIALS, "Missing Credentials");
}
String[] splitToken = authHeaderVal.split("\\s");
String token = authHeaderVal;
// Extract the token when it is sent as bearer token. i.e Authorization: Bearer <token>
if (splitToken.length > 1) {
token = splitToken[1];
}
// Handle PAT logic
if (isPATEnabled && token.startsWith(APIKeyConstants.PAT_PREFIX)) {
token = exchangeJWTForPAT(requestContext, token);
}
String token = retrieveTokenFromRequestCtx(requestContext);
String context = requestContext.getMatchedAPI().getBasePath();
String name = requestContext.getMatchedAPI().getName();
String version = requestContext.getMatchedAPI().getVersion();
Expand Down Expand Up @@ -266,7 +244,7 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws
ThreadContext.get(APIConstants.LOG_TRACE_ID));
}
// if the token is self contained, validation subscription from `subscribedApis` claim
JSONObject api = validateSubscriptionFromClaim(name, version, claims, splitToken,
JSONObject api = validateSubscriptionFromClaim(name, version, claims, token,
apiKeyValidationInfoDTO, true);
if (api == null) {
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -527,6 +505,40 @@ private String retrieveAuthHeaderValue(RequestContext requestContext) {
return headers.get(FilterUtils.getAuthHeaderName(requestContext));
}

/**
* Extract the JWT token from the request context.
*
* @param requestContext Request context
* @return JWT token
* @throws APISecurityException If an error occurs while extracting the JWT token
*/
protected String retrieveTokenFromRequestCtx(RequestContext requestContext) throws APISecurityException {

String authHeaderVal = retrieveAuthHeaderValue(requestContext);
if (authHeaderVal == null
&& requestContext.getMatchedAPI().getApiType().equalsIgnoreCase(APIConstants.ApiType.WEB_SOCKET)) {
String tokenValue = extractJWTInWSProtocolHeader(requestContext);
if (StringUtils.isNotEmpty(tokenValue)) {
authHeaderVal = JWTConstants.BEARER + " " + tokenValue;
}
}
if (authHeaderVal == null || !authHeaderVal.toLowerCase().contains(JWTConstants.BEARER)) {
throw new APISecurityException(APIConstants.StatusCodes.UNAUTHENTICATED.getCode(),
APISecurityConstants.API_AUTH_MISSING_CREDENTIALS, "Missing Credentials");
}
String[] splitToken = authHeaderVal.split("\\s");
String token = authHeaderVal;
// Extract the token when it is sent as bearer token. i.e Authorization: Bearer <token>
if (splitToken.length > 1) {
token = splitToken[1];
}
// Handle PAT logic
if (isPATEnabled && token.startsWith(APIKeyConstants.PAT_PREFIX)) {
token = exchangeJWTForPAT(requestContext, token);
}
return token;
}

@Override
public int getPriority() {
return 10;
Expand Down Expand Up @@ -612,9 +624,9 @@ private APIKeyValidationInfoDTO validateSubscriptionUsingKeyManager(RequestConte
* If the subscription information is not found, return a null object.
* @throws APISecurityException if the user is not subscribed to the API
*/
private JSONObject validateSubscriptionFromClaim(String name, String version, JWTClaimsSet payload,
String[] splitToken, APIKeyValidationInfoDTO validationInfo,
boolean isOauth) throws APISecurityException {
private JSONObject validateSubscriptionFromClaim(String name, String version, JWTClaimsSet payload, String token,
APIKeyValidationInfoDTO validationInfo, boolean isOauth)
throws APISecurityException {
JSONObject api = null;
try {
validationInfo.setEndUserName(payload.getSubject());
Expand Down Expand Up @@ -678,15 +690,15 @@ private JSONObject validateSubscriptionFromClaim(String name, String version, JW
}
if (log.isDebugEnabled()) {
log.debug("User is subscribed to the API: " + name + ", " +
"version: " + version + ". Token: " + FilterUtils.getMaskedToken(splitToken[0]));
"version: " + version + ". Token: " + FilterUtils.getMaskedToken(token));
}
break;
}
}
if (api == null) {
if (log.isDebugEnabled()) {
log.debug("User is not subscribed to access the API: " + name +
", version: " + version + ". Token: " + FilterUtils.getMaskedToken(splitToken[0]));
", version: " + version + ". Token: " + FilterUtils.getMaskedToken(token));
}
log.error("User is not subscribed to access the API.");
throw new APISecurityException(APIConstants.StatusCodes.UNAUTHORIZED.getCode(),
Expand Down
Loading