2323
2424import com .google .common .collect .ImmutableMap ;
2525import com .snowflake .kafka .connector .internal .BufferThreshold ;
26+ import com .snowflake .kafka .connector .internal .InternalUtils ;
2627import com .snowflake .kafka .connector .internal .KCLogger ;
28+ import com .snowflake .kafka .connector .internal .OAuthConstants ;
2729import com .snowflake .kafka .connector .internal .SnowflakeErrors ;
30+ import com .snowflake .kafka .connector .internal .SnowflakeInternalOperations ;
31+ import com .snowflake .kafka .connector .internal .SnowflakeURL ;
2832import com .snowflake .kafka .connector .internal .streaming .IngestionMethodConfig ;
2933import com .snowflake .kafka .connector .internal .streaming .StreamingUtils ;
3034import java .io .BufferedReader ;
3135import java .io .File ;
3236import java .io .InputStream ;
3337import java .io .InputStreamReader ;
38+ import java .io .UnsupportedEncodingException ;
3439import java .net .Authenticator ;
3540import java .net .PasswordAuthentication ;
41+ import java .net .URI ;
42+ import java .net .URISyntaxException ;
3643import java .net .URL ;
3744import java .net .URLConnection ;
45+ import java .net .URLEncoder ;
3846import java .util .Arrays ;
47+ import java .util .Base64 ;
3948import java .util .HashMap ;
4049import java .util .Map ;
4150import java .util .Objects ;
4251import java .util .Random ;
4352import java .util .regex .Matcher ;
4453import java .util .regex .Pattern ;
54+ import java .util .stream .Collectors ;
55+ import net .snowflake .client .jdbc .internal .apache .http .HttpHeaders ;
56+ import net .snowflake .client .jdbc .internal .apache .http .client .methods .CloseableHttpResponse ;
57+ import net .snowflake .client .jdbc .internal .apache .http .client .methods .HttpPost ;
58+ import net .snowflake .client .jdbc .internal .apache .http .client .utils .URIBuilder ;
59+ import net .snowflake .client .jdbc .internal .apache .http .entity .ContentType ;
60+ import net .snowflake .client .jdbc .internal .apache .http .entity .StringEntity ;
61+ import net .snowflake .client .jdbc .internal .apache .http .impl .client .CloseableHttpClient ;
62+ import net .snowflake .client .jdbc .internal .apache .http .impl .client .HttpClientBuilder ;
63+ import net .snowflake .client .jdbc .internal .apache .http .util .EntityUtils ;
64+ import net .snowflake .client .jdbc .internal .google .gson .JsonObject ;
65+ import net .snowflake .client .jdbc .internal .google .gson .JsonParser ;
4566import org .apache .kafka .common .config .Config ;
4667import org .apache .kafka .common .config .ConfigException ;
4768import org .apache .kafka .common .config .ConfigValue ;
@@ -62,6 +83,15 @@ public class Utils {
6283 public static final String SF_SSL = "sfssl" ; // for test only
6384 public static final String SF_WAREHOUSE = "sfwarehouse" ; // for test only
6485 public static final String PRIVATE_KEY_PASSPHRASE = "snowflake.private.key" + ".passphrase" ;
86+ public static final String SF_AUTHENTICATOR =
87+ "snowflake.authenticator" ; // TODO: SNOW-889748 change to enum
88+ public static final String SF_OAUTH_CLIENT_ID = "snowflake.oauth.client.id" ;
89+ public static final String SF_OAUTH_CLIENT_SECRET = "snowflake.oauth.client.secret" ;
90+ public static final String SF_OAUTH_REFRESH_TOKEN = "snowflake.oauth.refresh.token" ;
91+
92+ // authenticator type
93+ public static final String SNOWFLAKE_JWT = "snowflake_jwt" ;
94+ public static final String OAUTH = "oauth" ;
6595
6696 /**
6797 * This value should be present if ingestion method is {@link
@@ -440,11 +470,54 @@ && parseTopicToTableMap(config.get(SnowflakeSinkConnectorConfig.TOPICS_TABLES_MA
440470 Utils .formatString ("{} cannot be empty." , SnowflakeSinkConnectorConfig .SNOWFLAKE_SCHEMA ));
441471 }
442472
443- if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY )) {
444- invalidConfigParams .put (
445- SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
446- Utils .formatString (
447- "{} cannot be empty." , SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ));
473+ switch (config
474+ .getOrDefault (SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE , Utils .SNOWFLAKE_JWT )
475+ .toLowerCase ()) {
476+ // TODO: SNOW-889748 change to enum
477+ case Utils .SNOWFLAKE_JWT :
478+ if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY )) {
479+ invalidConfigParams .put (
480+ SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
481+ Utils .formatString (
482+ "{} cannot be empty when using {} authenticator." ,
483+ SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
484+ Utils .SNOWFLAKE_JWT ));
485+ }
486+ break ;
487+ case Utils .OAUTH :
488+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID )) {
489+ invalidConfigParams .put (
490+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID ,
491+ Utils .formatString (
492+ "{} cannot be empty when using {} authenticator." ,
493+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID ,
494+ Utils .OAUTH ));
495+ }
496+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET )) {
497+ invalidConfigParams .put (
498+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET ,
499+ Utils .formatString (
500+ "{} cannot be empty when using {} authenticator." ,
501+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET ,
502+ Utils .OAUTH ));
503+ }
504+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN )) {
505+ invalidConfigParams .put (
506+ SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN ,
507+ Utils .formatString (
508+ "{} cannot be empty when using {} authenticator." ,
509+ SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN ,
510+ Utils .OAUTH ));
511+ }
512+ break ;
513+ default :
514+ invalidConfigParams .put (
515+ SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE ,
516+ Utils .formatString (
517+ "{} should be one of {} or {}." ,
518+ SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE ,
519+ Utils .SNOWFLAKE_JWT ,
520+ Utils .OAUTH ));
448521 }
449522
450523 if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_USER )) {
@@ -719,6 +792,133 @@ public static String formatString(String format, Object... vars) {
719792 return format ;
720793 }
721794
795+ /**
796+ * Get OAuth access token given refresh token
797+ *
798+ * @param url OAuth server url
799+ * @param clientId OAuth clientId
800+ * @param clientSecret OAuth clientSecret
801+ * @param refreshToken OAuth refresh token
802+ * @return OAuth access token
803+ */
804+ public static String getSnowflakeOAuthAccessToken (
805+ SnowflakeURL url , String clientId , String clientSecret , String refreshToken ) {
806+ return getSnowflakeOAuthToken (
807+ url ,
808+ clientId ,
809+ clientSecret ,
810+ refreshToken ,
811+ OAuthConstants .REFRESH_TOKEN ,
812+ OAuthConstants .REFRESH_TOKEN ,
813+ OAuthConstants .ACCESS_TOKEN );
814+ }
815+
816+ /**
817+ * Get OAuth token given integration info <a
818+ * href="https://docs.snowflake.com/en/user-guide/oauth-snowflake-overview">Snowflake OAuth
819+ * Overview</a>
820+ *
821+ * @param url OAuth server url
822+ * @param clientId OAuth clientId
823+ * @param clientSecret OAuth clientSecret
824+ * @param credential OAuth credential, either az code or refresh token
825+ * @param grantType OAuth grant type, either authorization_code or refresh_token
826+ * @param credentialType OAuth credential key, either code or refresh_token
827+ * @param tokenType type of OAuth token to get, either access_token or refresh_token
828+ * @return OAuth token
829+ */
830+ // TODO: SNOW-895296 Integrate OAuth utils with streaming ingest SDK
831+ public static String getSnowflakeOAuthToken (
832+ SnowflakeURL url ,
833+ String clientId ,
834+ String clientSecret ,
835+ String credential ,
836+ String grantType ,
837+ String credentialType ,
838+ String tokenType ) {
839+ Map <String , String > headers = new HashMap <>();
840+ headers .put (HttpHeaders .CONTENT_TYPE , OAuthConstants .OAUTH_CONTENT_TYPE_HEADER );
841+ headers .put (
842+ HttpHeaders .AUTHORIZATION ,
843+ OAuthConstants .BASIC_AUTH_HEADER_PREFIX
844+ + Base64 .getEncoder ().encodeToString ((clientId + ":" + clientSecret ).getBytes ()));
845+
846+ Map <String , String > payload = new HashMap <>();
847+ payload .put (OAuthConstants .GRANT_TYPE_PARAM , grantType );
848+ payload .put (credentialType , credential );
849+ payload .put (OAuthConstants .REDIRECT_URI , OAuthConstants .DEFAULT_REDIRECT_URI );
850+
851+ // Encode and convert payload into string entity
852+ String payloadString =
853+ payload .entrySet ().stream ()
854+ .map (
855+ e -> {
856+ try {
857+ return e .getKey () + "=" + URLEncoder .encode (e .getValue (), "UTF-8" );
858+ } catch (UnsupportedEncodingException ex ) {
859+ throw SnowflakeErrors .ERROR_1004 .getException (ex );
860+ }
861+ })
862+ .collect (Collectors .joining ("&" ));
863+ final StringEntity entity =
864+ new StringEntity (payloadString , ContentType .APPLICATION_FORM_URLENCODED );
865+
866+ HttpPost post =
867+ buildOAuthHttpPostRequest (url , OAuthConstants .TOKEN_REQUEST_ENDPOINT , headers , entity );
868+
869+ // Request access token
870+ CloseableHttpClient client = HttpClientBuilder .create ().build ();
871+ try {
872+ return InternalUtils .backoffAndRetry (
873+ null ,
874+ SnowflakeInternalOperations .FETCH_OAUTH_TOKEN ,
875+ () -> {
876+ try (CloseableHttpResponse httpResponse = client .execute (post )) {
877+ String respBodyString = EntityUtils .toString (httpResponse .getEntity ());
878+ JsonObject respBody = JsonParser .parseString (respBodyString ).getAsJsonObject ();
879+ // Trim surrounding quotation marks
880+ return respBody .get (tokenType ).toString ().replaceAll ("^\" |\" $" , "" );
881+ } catch (Exception e ) {
882+ throw SnowflakeErrors .ERROR_1004 .getException (
883+ "Failed to get Oauth access token after retries" );
884+ }
885+ })
886+ .toString ();
887+ } catch (Exception e ) {
888+ throw SnowflakeErrors .ERROR_1004 .getException (e );
889+ }
890+ }
891+
892+ /**
893+ * Build OAuth http post request base on headers and payload
894+ *
895+ * @param url target url
896+ * @param headers headers key value pairs
897+ * @param entity payload entity
898+ * @return HttpPost request for OAuth
899+ */
900+ public static HttpPost buildOAuthHttpPostRequest (
901+ SnowflakeURL url , String path , Map <String , String > headers , StringEntity entity ) {
902+ // Build post request
903+ URI uri ;
904+ try {
905+ uri =
906+ new URIBuilder ().setHost (url .toString ()).setScheme (url .getScheme ()).setPath (path ).build ();
907+ } catch (URISyntaxException e ) {
908+ throw SnowflakeErrors .ERROR_1004 .getException (e );
909+ }
910+
911+ // Add headers
912+ HttpPost post = new HttpPost (uri );
913+ for (Map .Entry <String , String > e : headers .entrySet ()) {
914+ post .addHeader (e .getKey (), e .getValue ());
915+ }
916+
917+ post .setEntity (entity );
918+
919+ return post ;
920+ }
921+
722922 /**
723923 * Get the message and cause of a missing exception, handling the null or empty cases of each
724924 *
0 commit comments