Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alhuang committed Aug 9, 2023
1 parent 9f6d4dd commit 033a4e3
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.snowflake.kafka.connector;

import com.snowflake.kafka.connector.internal.KCLogger;
import com.snowflake.kafka.connector.internal.OAuthConstants;
import com.snowflake.kafka.connector.internal.SnowflakeConnectionService;
import com.snowflake.kafka.connector.internal.SnowflakeConnectionServiceFactory;
import com.snowflake.kafka.connector.internal.SnowflakeErrors;
Expand Down Expand Up @@ -212,8 +213,8 @@ public Config validate(Map<String, String> connectorConfigs) {
// If using snowflake_jwt and authentication, and private key or private key passphrase is
// provided through file, skip validation
if (connectorConfigs
.getOrDefault(Utils.SF_AUTHENTICATOR, Utils.SNOWFLAKE_JWT)
.equals(Utils.SNOWFLAKE_JWT)
.getOrDefault(Utils.SF_AUTHENTICATOR, OAuthConstants.SNOWFLAKE_JWT)
.equals(OAuthConstants.SNOWFLAKE_JWT)
&& (connectorConfigs.getOrDefault(Utils.SF_PRIVATE_KEY, "").contains("${file:")
|| connectorConfigs.getOrDefault(Utils.PRIVATE_KEY_PASSPHRASE, "").contains("${file:")))
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSet;
import com.snowflake.kafka.connector.internal.KCLogger;
import com.snowflake.kafka.connector.internal.OAuthConstants;
import com.snowflake.kafka.connector.internal.streaming.IngestionMethodConfig;
import com.snowflake.kafka.connector.internal.streaming.StreamingUtils;
import java.util.Arrays;
Expand Down Expand Up @@ -307,8 +308,8 @@ static ConfigDef newConfigDef() {
SNOWFLAKE_ROLE)
.define(
AUTHENTICATOR_TYPE,
Type.STRING,
Utils.SNOWFLAKE_JWT,
Type.STRING, // TODO: SNOW-889748 change to enum and add validator
OAuthConstants.SNOWFLAKE_JWT,
Importance.LOW,
"Authenticator for JDBC and streaming ingest sdk",
SNOWFLAKE_LOGIN_INFO,
Expand Down
57 changes: 27 additions & 30 deletions src/main/java/com/snowflake/kafka/connector/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.ImmutableMap;
import com.snowflake.kafka.connector.internal.BufferThreshold;
import com.snowflake.kafka.connector.internal.KCLogger;
import com.snowflake.kafka.connector.internal.OAuthConstants;
import com.snowflake.kafka.connector.internal.SnowflakeErrors;
import com.snowflake.kafka.connector.internal.SnowflakeURL;
import com.snowflake.kafka.connector.internal.streaming.IngestionMethodConfig;
Expand Down Expand Up @@ -81,7 +82,8 @@ public class Utils {
public static final String SF_SSL = "sfssl"; // for test only
public static final String SF_WAREHOUSE = "sfwarehouse"; // for test only
public static final String PRIVATE_KEY_PASSPHRASE = "snowflake.private.key" + ".passphrase";
public static final String SF_AUTHENTICATOR = "snowflake.authenticator";
public static final String SF_AUTHENTICATOR =
"snowflake.authenticator"; // TODO: SNOW-889748 change to enum
public static final String SF_OAUTH_CLIENT_ID = "snowflake.oauth.client.id";
public static final String SF_OAUTH_CLIENT_SECRET = "snowflake.oauth.client.secret";
public static final String SF_OAUTH_REFRESH_TOKEN = "snowflake.oauth.refresh.token";
Expand Down Expand Up @@ -128,19 +130,6 @@ public class Utils {
public static final String GET_EXCEPTION_MISSING_MESSAGE = "missing exception message";
public static final String GET_EXCEPTION_MISSING_CAUSE = "missing exception cause";

// OAuth
public static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";
public static final String OAUTH_CONTENT_TYPE_HEADER = "application/x-www-form-urlencoded";
public static final String BASIC_AUTH_HEADER_PREFIX = "Basic ";
public static final String GRANT_TYPE_PARAM = "grant_type";
public static final String REFRESH_TOKEN = "refresh_token";
public static final String ACCESS_TOKEN = "access_token";
public static final String SNOWFLAKE_JWT = "snowflake_jwt";
public static final String OAUTH = "oauth";
public static final String REDIRECT_URI = "redirect_uri";
public static final String DEFAULT_REDIRECT_URI = "https://localhost.com/oauth";
public static final int OAUTH_MAX_RETRY = 5;

private static final KCLogger LOGGER = new KCLogger(Utils.class.getName());

/**
Expand Down Expand Up @@ -476,41 +465,43 @@ && parseTopicToTableMap(config.get(SnowflakeSinkConnectorConfig.TOPICS_TABLES_MA
Utils.formatString("{} cannot be empty.", SnowflakeSinkConnectorConfig.SNOWFLAKE_SCHEMA));
}

switch (config.getOrDefault(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, SNOWFLAKE_JWT)) {
case SNOWFLAKE_JWT:
switch (config.getOrDefault(
SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAuthConstants.SNOWFLAKE_JWT)) {
// TODO: SNOW-889748 change to enum
case OAuthConstants.SNOWFLAKE_JWT:
if (!config.containsKey(SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY,
Utils.formatString(
"{} cannot be empty when using {} authenticator.",
SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY,
SNOWFLAKE_JWT));
OAuthConstants.SNOWFLAKE_JWT));
}
break;
case OAUTH:
case OAuthConstants.OAUTH:
if (!config.containsKey(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID,
Utils.formatString(
"{} cannot be empty when using {} authenticator.",
SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID,
OAUTH));
OAuthConstants.OAUTH));
}
if (!config.containsKey(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET,
Utils.formatString(
"{} cannot be empty when using {} authenticator.",
SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET,
OAUTH));
OAuthConstants.OAUTH));
}
if (!config.containsKey(SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN,
Utils.formatString(
"{} cannot be empty when using {} authenticator.",
SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN,
OAUTH));
OAuthConstants.OAUTH));
}
break;
default:
Expand All @@ -519,8 +510,8 @@ && parseTopicToTableMap(config.get(SnowflakeSinkConnectorConfig.TOPICS_TABLES_MA
Utils.formatString(
"{} should be one of {} or {}.",
SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE,
SNOWFLAKE_JWT,
OAUTH));
OAuthConstants.SNOWFLAKE_JWT,
OAuthConstants.OAUTH));
}

if (!config.containsKey(SnowflakeSinkConnectorConfig.SNOWFLAKE_USER)) {
Expand Down Expand Up @@ -792,7 +783,13 @@ public static String formatString(String format, Object... vars) {
public static String getSnowflakeOAuthAccessToken(
SnowflakeURL url, String clientId, String clientSecret, String refreshToken) {
return getSnowflakeOAuthToken(
url, clientId, clientSecret, refreshToken, REFRESH_TOKEN, REFRESH_TOKEN, ACCESS_TOKEN);
url,
clientId,
clientSecret,
refreshToken,
OAuthConstants.REFRESH_TOKEN,
OAuthConstants.REFRESH_TOKEN,
OAuthConstants.ACCESS_TOKEN);
}

/**
Expand All @@ -818,16 +815,16 @@ public static String getSnowflakeOAuthToken(
String credentialType,
String tokenType) {
Map<String, String> headers = new HashMap<>();
headers.put(HttpHeaders.CONTENT_TYPE, OAUTH_CONTENT_TYPE_HEADER);
headers.put(HttpHeaders.CONTENT_TYPE, OAuthConstants.OAUTH_CONTENT_TYPE_HEADER);
headers.put(
HttpHeaders.AUTHORIZATION,
BASIC_AUTH_HEADER_PREFIX
OAuthConstants.BASIC_AUTH_HEADER_PREFIX
+ Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()));

Map<String, String> payload = new HashMap<>();
payload.put(GRANT_TYPE_PARAM, grantType);
payload.put(OAuthConstants.GRANT_TYPE_PARAM, grantType);
payload.put(credentialType, credential);
payload.put(REDIRECT_URI, DEFAULT_REDIRECT_URI);
payload.put(OAuthConstants.REDIRECT_URI, OAuthConstants.DEFAULT_REDIRECT_URI);

// Encode and convert payload into string entity
String payloadString =
Expand All @@ -844,11 +841,11 @@ public static String getSnowflakeOAuthToken(
final StringEntity entity =
new StringEntity(payloadString, ContentType.APPLICATION_FORM_URLENCODED);

HttpPost post = makeOAuthHttpPost(url, TOKEN_REQUEST_ENDPOINT, headers, entity);
HttpPost post = makeOAuthHttpPost(url, OAuthConstants.TOKEN_REQUEST_ENDPOINT, headers, entity);

// Request access token
CloseableHttpClient client = HttpClientBuilder.create().build();
for (int retries = 0; retries < OAUTH_MAX_RETRY; retries++) {
for (int retries = 0; retries < OAuthConstants.OAUTH_MAX_RETRY; retries++) {
try (CloseableHttpResponse httpResponse = client.execute(post)) {
String respBodyString = EntityUtils.toString(httpResponse.getEntity());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ static Properties createProperties(

// Set credential
if (!properties.containsKey(JDBC_AUTHENTICATOR)) {
properties.put(JDBC_AUTHENTICATOR, Utils.SNOWFLAKE_JWT);
properties.put(JDBC_AUTHENTICATOR, OAuthConstants.SNOWFLAKE_JWT);
}
if (properties.getProperty(JDBC_AUTHENTICATOR).equals(Utils.SNOWFLAKE_JWT)) {
if (properties.getProperty(JDBC_AUTHENTICATOR).equals(OAuthConstants.SNOWFLAKE_JWT)) {
// JWT key pair auth
if (!privateKeyPassphrase.isEmpty()) {
properties.put(
Expand All @@ -195,7 +195,7 @@ static Properties createProperties(
} else if (!privateKey.isEmpty()) {
properties.put(JDBC_PRIVATE_KEY, parsePrivateKey(privateKey));
}
} else if (properties.getProperty(JDBC_AUTHENTICATOR).equals(Utils.OAUTH)) {
} else if (properties.getProperty(JDBC_AUTHENTICATOR).equals(OAuthConstants.OAUTH)) {
// OAuth auth
if (oAuthClientId.isEmpty()) {
throw SnowflakeErrors.ERROR_0026.getException();
Expand Down Expand Up @@ -243,7 +243,7 @@ static Properties createProperties(
}

// required parameter check, the OAuth parameter is already checked when fetching access token
if (properties.getProperty(JDBC_AUTHENTICATOR).equals(Utils.SNOWFLAKE_JWT)
if (properties.getProperty(JDBC_AUTHENTICATOR).equals(OAuthConstants.SNOWFLAKE_JWT)
&& !properties.containsKey(JDBC_PRIVATE_KEY)) {
throw SnowflakeErrors.ERROR_0013.getException();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.snowflake.kafka.connector.internal;

/**
* This class contains constants for OAuth request.
*
* @see <a
* href="https://github.com/snowflakedb/snowflake/blob/4fdb96cd5849f266cda430c5d49a13c29e866af5/GlobalServices/src/main/java/com/snowflake/resources/ResourceConstants.java">ResourceConstants</a>
*/
public class OAuthConstants {
public static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";
public static final String OAUTH_CONTENT_TYPE_HEADER = "application/x-www-form-urlencoded";
public static final String BASIC_AUTH_HEADER_PREFIX = "Basic ";
public static final String GRANT_TYPE_PARAM = "grant_type";
public static final String REFRESH_TOKEN = "refresh_token";
public static final String ACCESS_TOKEN = "access_token";
public static final String SNOWFLAKE_JWT = "snowflake_jwt";
public static final String OAUTH = "oauth";
public static final String REDIRECT_URI = "redirect_uri";
public static final String DEFAULT_REDIRECT_URI = "https://localhost.com/oauth";
public static final int OAUTH_MAX_RETRY = 5;
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig;
import com.snowflake.kafka.connector.Utils;
import com.snowflake.kafka.connector.internal.BufferThreshold;
import com.snowflake.kafka.connector.internal.OAuthConstants;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -101,11 +102,11 @@ public static Map<String, String> convertConfigForStreamingClient(
connectorConfig.computeIfPresent(
Utils.SF_AUTHENTICATOR,
(key, value) -> {
if (value.equals(Utils.SNOWFLAKE_JWT)) {
if (value.equals(OAuthConstants.SNOWFLAKE_JWT)) {
streamingPropertiesMap.put(
STREAMING_CONSTANT_AUTHORIZATION_TYPE, STREAMING_CONSTANT_JWT);
}
if (value.equals(Utils.OAUTH)) {
if (value.equals(OAuthConstants.OAUTH)) {
streamingPropertiesMap.put(
STREAMING_CONSTANT_AUTHORIZATION_TYPE, STREAMING_CONSTANT_OAUTH);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.ERRORS_TOLERANCE_CONFIG;
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.NAME;
import static com.snowflake.kafka.connector.Utils.HTTP_NON_PROXY_HOSTS;
import static com.snowflake.kafka.connector.Utils.OAUTH;
import static com.snowflake.kafka.connector.internal.TestUtils.getConfig;
import static org.junit.Assert.assertEquals;

import com.snowflake.kafka.connector.internal.OAuthConstants;
import com.snowflake.kafka.connector.internal.SnowflakeErrors;
import com.snowflake.kafka.connector.internal.SnowflakeKafkaConnectorException;
import com.snowflake.kafka.connector.internal.streaming.IngestionMethodConfig;
Expand Down Expand Up @@ -902,7 +902,7 @@ public void testMultipleInvalidConfigs() {
@Test
public void testOAuthAuthenticator() {
Map<String, String> config = getConfig();
config.put(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAUTH);
config.put(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAuthConstants.OAUTH);
config.put(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID, "client_id");
config.put(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET, "client_secret");
config.put(SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN, "refresh_token");
Expand All @@ -924,7 +924,7 @@ public void testInvalidAuthenticator() {
public void testEmptyClientId() {
try {
Map<String, String> config = getConfig();
config.put(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAUTH);
config.put(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAuthConstants.OAUTH);
config.put(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET, "client_secret");
config.put(SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN, "refresh_token");
Utils.validateConfig(config);
Expand All @@ -937,7 +937,7 @@ public void testEmptyClientId() {
public void testEmptyClientSecret() {
try {
Map<String, String> config = getConfig();
config.put(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAUTH);
config.put(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAuthConstants.OAUTH);
config.put(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID, "client_id");
config.put(SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN, "refresh_token");
Utils.validateConfig(config);
Expand All @@ -950,7 +950,7 @@ public void testEmptyClientSecret() {
public void testEmptyRefreshToken() {
try {
Map<String, String> config = getConfig();
config.put(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAUTH);
config.put(SnowflakeSinkConnectorConfig.AUTHENTICATOR_TYPE, OAuthConstants.OAUTH);
config.put(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID, "client_id");
config.put(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET, "client_secret");
Utils.validateConfig(config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import static com.snowflake.kafka.connector.Utils.HTTP_PROXY_USER;
import static com.snowflake.kafka.connector.Utils.HTTP_USE_PROXY;
import static com.snowflake.kafka.connector.Utils.JDK_HTTP_AUTH_TUNNELING;
import static com.snowflake.kafka.connector.Utils.OAUTH;
import static com.snowflake.kafka.connector.Utils.REFRESH_TOKEN;
import static com.snowflake.kafka.connector.Utils.SF_DATABASE;
import static com.snowflake.kafka.connector.Utils.SF_SCHEMA;
import static com.snowflake.kafka.connector.Utils.SF_URL;
Expand Down Expand Up @@ -365,7 +363,7 @@ public static Map<String, String> getConfWithOAuth() {
if (!confWithOAuth.containsKey(Utils.SF_OAUTH_REFRESH_TOKEN)) {
confWithOAuth.put(Utils.SF_OAUTH_REFRESH_TOKEN, getRefreshToken(confWithOAuth));
}
confWithOAuth.put(Utils.SF_AUTHENTICATOR, OAUTH);
confWithOAuth.put(Utils.SF_AUTHENTICATOR, OAuthConstants.OAUTH);
confWithOAuth.remove(Utils.SF_PRIVATE_KEY);
confWithOAuth.put(Utils.SF_ROLE, getProfile(PROFILE_PATH).get(ROLE).asText());
}
Expand Down Expand Up @@ -875,7 +873,7 @@ public static String getRefreshToken(Map<String, String> config) {
getAZCode(config),
AZ_GRANT_TYPE,
AZ_CREDENTIAL_TYPE_CODE,
REFRESH_TOKEN);
OAuthConstants.REFRESH_TOKEN);
}

private static String getAZCode(Map<String, String> config) {
Expand Down

0 comments on commit 033a4e3

Please sign in to comment.