2323
2424import com .google .common .collect .ImmutableMap ;
2525import com .snowflake .kafka .connector .internal .BufferThreshold ;
26+ import com .snowflake .kafka .connector .internal .InternalUtils ;
2627import com .snowflake .kafka .connector .internal .KCLogger ;
28+ import com .snowflake .kafka .connector .internal .OAuthConstants ;
2729import com .snowflake .kafka .connector .internal .SnowflakeErrors ;
30+ import com .snowflake .kafka .connector .internal .SnowflakeInternalOperations ;
31+ import com .snowflake .kafka .connector .internal .SnowflakeURL ;
2832import com .snowflake .kafka .connector .internal .streaming .IngestionMethodConfig ;
2933import com .snowflake .kafka .connector .internal .streaming .StreamingUtils ;
3034import java .io .BufferedReader ;
3135import java .io .File ;
3236import java .io .InputStream ;
3337import java .io .InputStreamReader ;
38+ import java .io .UnsupportedEncodingException ;
3439import java .net .Authenticator ;
3540import java .net .PasswordAuthentication ;
41+ import java .net .URI ;
42+ import java .net .URISyntaxException ;
3643import java .net .URL ;
3744import java .net .URLConnection ;
45+ import java .net .URLEncoder ;
3846import java .util .Arrays ;
47+ import java .util .Base64 ;
3948import java .util .HashMap ;
4049import java .util .Map ;
4150import java .util .Objects ;
4251import java .util .Random ;
4352import java .util .regex .Matcher ;
4453import java .util .regex .Pattern ;
54+ import java .util .stream .Collectors ;
55+ import net .snowflake .client .jdbc .internal .apache .http .HttpHeaders ;
56+ import net .snowflake .client .jdbc .internal .apache .http .client .methods .CloseableHttpResponse ;
57+ import net .snowflake .client .jdbc .internal .apache .http .client .methods .HttpPost ;
58+ import net .snowflake .client .jdbc .internal .apache .http .client .utils .URIBuilder ;
59+ import net .snowflake .client .jdbc .internal .apache .http .entity .ContentType ;
60+ import net .snowflake .client .jdbc .internal .apache .http .entity .StringEntity ;
61+ import net .snowflake .client .jdbc .internal .apache .http .impl .client .CloseableHttpClient ;
62+ import net .snowflake .client .jdbc .internal .apache .http .impl .client .HttpClientBuilder ;
63+ import net .snowflake .client .jdbc .internal .apache .http .util .EntityUtils ;
64+ import net .snowflake .client .jdbc .internal .google .gson .JsonObject ;
65+ import net .snowflake .client .jdbc .internal .google .gson .JsonParser ;
4566import org .apache .kafka .common .config .Config ;
4667import org .apache .kafka .common .config .ConfigException ;
4768import org .apache .kafka .common .config .ConfigValue ;
@@ -62,6 +83,15 @@ public class Utils {
6283 public static final String SF_SSL = "sfssl" ; // for test only
6384 public static final String SF_WAREHOUSE = "sfwarehouse" ; // for test only
6485 public static final String PRIVATE_KEY_PASSPHRASE = "snowflake.private.key" + ".passphrase" ;
86+ public static final String SF_AUTHENTICATOR =
87+ "snowflake.authenticator" ; // TODO: SNOW-889748 change to enum
88+ public static final String SF_OAUTH_CLIENT_ID = "snowflake.oauth.client.id" ;
89+ public static final String SF_OAUTH_CLIENT_SECRET = "snowflake.oauth.client.secret" ;
90+ public static final String SF_OAUTH_REFRESH_TOKEN = "snowflake.oauth.refresh.token" ;
91+
92+ // authenticator type
93+ public static final String SNOWFLAKE_JWT = "snowflake_jwt" ;
94+ public static final String OAUTH = "oauth" ;
6595
6696 /**
6797 * This value should be present if ingestion method is {@link
@@ -441,11 +471,54 @@ && parseTopicToTableMap(config.get(SnowflakeSinkConnectorConfig.TOPICS_TABLES_MA
441471 Utils .formatString ("{} cannot be empty." , SnowflakeSinkConnectorConfig .SNOWFLAKE_SCHEMA ));
442472 }
443473
444- if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY )) {
445- invalidConfigParams .put (
446- SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
447- Utils .formatString (
448- "{} cannot be empty." , SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ));
474+ switch (config
475+ .getOrDefault (SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE , Utils .SNOWFLAKE_JWT )
476+ .toLowerCase ()) {
477+ // TODO: SNOW-889748 change to enum
478+ case Utils .SNOWFLAKE_JWT :
479+ if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY )) {
480+ invalidConfigParams .put (
481+ SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
482+ Utils .formatString (
483+ "{} cannot be empty when using {} authenticator." ,
484+ SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
485+ Utils .SNOWFLAKE_JWT ));
486+ }
487+ break ;
488+ case Utils .OAUTH :
489+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID )) {
490+ invalidConfigParams .put (
491+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID ,
492+ Utils .formatString (
493+ "{} cannot be empty when using {} authenticator." ,
494+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID ,
495+ Utils .OAUTH ));
496+ }
497+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET )) {
498+ invalidConfigParams .put (
499+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET ,
500+ Utils .formatString (
501+ "{} cannot be empty when using {} authenticator." ,
502+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET ,
503+ Utils .OAUTH ));
504+ }
505+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN )) {
506+ invalidConfigParams .put (
507+ SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN ,
508+ Utils .formatString (
509+ "{} cannot be empty when using {} authenticator." ,
510+ SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN ,
511+ Utils .OAUTH ));
512+ }
513+ break ;
514+ default :
515+ invalidConfigParams .put (
516+ SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE ,
517+ Utils .formatString (
518+ "{} should be one of {} or {}." ,
519+ SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE ,
520+ Utils .SNOWFLAKE_JWT ,
521+ Utils .OAUTH ));
449522 }
450523
451524 if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_USER )) {
@@ -705,6 +778,133 @@ public static String formatString(String format, Object... vars) {
705778 return format ;
706779 }
707780
781+ /**
782+ * Get OAuth access token given refresh token
783+ *
784+ * @param url OAuth server url
785+ * @param clientId OAuth clientId
786+ * @param clientSecret OAuth clientSecret
787+ * @param refreshToken OAuth refresh token
788+ * @return OAuth access token
789+ */
790+ public static String getSnowflakeOAuthAccessToken (
791+ SnowflakeURL url , String clientId , String clientSecret , String refreshToken ) {
792+ return getSnowflakeOAuthToken (
793+ url ,
794+ clientId ,
795+ clientSecret ,
796+ refreshToken ,
797+ OAuthConstants .REFRESH_TOKEN ,
798+ OAuthConstants .REFRESH_TOKEN ,
799+ OAuthConstants .ACCESS_TOKEN );
800+ }
801+
802+ /**
803+ * Get OAuth token given integration info <a
804+ * href="https://docs.snowflake.com/en/user-guide/oauth-snowflake-overview">Snowflake OAuth
805+ * Overview</a>
806+ *
807+ * @param url OAuth server url
808+ * @param clientId OAuth clientId
809+ * @param clientSecret OAuth clientSecret
810+ * @param credential OAuth credential, either az code or refresh token
811+ * @param grantType OAuth grant type, either authorization_code or refresh_token
812+ * @param credentialType OAuth credential key, either code or refresh_token
813+ * @param tokenType type of OAuth token to get, either access_token or refresh_token
814+ * @return OAuth token
815+ */
816+ // TODO: SNOW-895296 Integrate OAuth utils with streaming ingest SDK
817+ public static String getSnowflakeOAuthToken (
818+ SnowflakeURL url ,
819+ String clientId ,
820+ String clientSecret ,
821+ String credential ,
822+ String grantType ,
823+ String credentialType ,
824+ String tokenType ) {
825+ Map <String , String > headers = new HashMap <>();
826+ headers .put (HttpHeaders .CONTENT_TYPE , OAuthConstants .OAUTH_CONTENT_TYPE_HEADER );
827+ headers .put (
828+ HttpHeaders .AUTHORIZATION ,
829+ OAuthConstants .BASIC_AUTH_HEADER_PREFIX
830+ + Base64 .getEncoder ().encodeToString ((clientId + ":" + clientSecret ).getBytes ()));
831+
832+ Map <String , String > payload = new HashMap <>();
833+ payload .put (OAuthConstants .GRANT_TYPE_PARAM , grantType );
834+ payload .put (credentialType , credential );
835+ payload .put (OAuthConstants .REDIRECT_URI , OAuthConstants .DEFAULT_REDIRECT_URI );
836+
837+ // Encode and convert payload into string entity
838+ String payloadString =
839+ payload .entrySet ().stream ()
840+ .map (
841+ e -> {
842+ try {
843+ return e .getKey () + "=" + URLEncoder .encode (e .getValue (), "UTF-8" );
844+ } catch (UnsupportedEncodingException ex ) {
845+ throw SnowflakeErrors .ERROR_1004 .getException (ex );
846+ }
847+ })
848+ .collect (Collectors .joining ("&" ));
849+ final StringEntity entity =
850+ new StringEntity (payloadString , ContentType .APPLICATION_FORM_URLENCODED );
851+
852+ HttpPost post =
853+ buildOAuthHttpPostRequest (url , OAuthConstants .TOKEN_REQUEST_ENDPOINT , headers , entity );
854+
855+ // Request access token
856+ CloseableHttpClient client = HttpClientBuilder .create ().build ();
857+ try {
858+ return InternalUtils .backoffAndRetry (
859+ null ,
860+ SnowflakeInternalOperations .FETCH_OAUTH_TOKEN ,
861+ () -> {
862+ try (CloseableHttpResponse httpResponse = client .execute (post )) {
863+ String respBodyString = EntityUtils .toString (httpResponse .getEntity ());
864+ JsonObject respBody = JsonParser .parseString (respBodyString ).getAsJsonObject ();
865+ // Trim surrounding quotation marks
866+ return respBody .get (tokenType ).toString ().replaceAll ("^\" |\" $" , "" );
867+ } catch (Exception e ) {
868+ throw SnowflakeErrors .ERROR_1004 .getException (
869+ "Failed to get Oauth access token after retries" );
870+ }
871+ })
872+ .toString ();
873+ } catch (Exception e ) {
874+ throw SnowflakeErrors .ERROR_1004 .getException (e );
875+ }
876+ }
877+
878+ /**
879+ * Build OAuth http post request base on headers and payload
880+ *
881+ * @param url target url
882+ * @param headers headers key value pairs
883+ * @param entity payload entity
884+ * @return HttpPost request for OAuth
885+ */
886+ public static HttpPost buildOAuthHttpPostRequest (
887+ SnowflakeURL url , String path , Map <String , String > headers , StringEntity entity ) {
888+ // Build post request
889+ URI uri ;
890+ try {
891+ uri =
892+ new URIBuilder ().setHost (url .toString ()).setScheme (url .getScheme ()).setPath (path ).build ();
893+ } catch (URISyntaxException e ) {
894+ throw SnowflakeErrors .ERROR_1004 .getException (e );
895+ }
896+
897+ // Add headers
898+ HttpPost post = new HttpPost (uri );
899+ for (Map .Entry <String , String > e : headers .entrySet ()) {
900+ post .addHeader (e .getKey (), e .getValue ());
901+ }
902+
903+ post .setEntity (entity );
904+
905+ return post ;
906+ }
907+
708908 /**
709909 * Get the message and cause of a missing exception, handling the null or empty cases of each
710910 *
0 commit comments