2323
2424import com .google .common .collect .ImmutableMap ;
2525import com .snowflake .kafka .connector .internal .BufferThreshold ;
26+ import com .snowflake .kafka .connector .internal .InternalUtils ;
2627import com .snowflake .kafka .connector .internal .KCLogger ;
28+ import com .snowflake .kafka .connector .internal .OAuthConstants ;
2729import com .snowflake .kafka .connector .internal .SnowflakeErrors ;
30+ import com .snowflake .kafka .connector .internal .SnowflakeInternalOperations ;
31+ import com .snowflake .kafka .connector .internal .SnowflakeURL ;
2832import com .snowflake .kafka .connector .internal .streaming .IngestionMethodConfig ;
2933import com .snowflake .kafka .connector .internal .streaming .StreamingUtils ;
3034import java .io .BufferedReader ;
3135import java .io .File ;
3236import java .io .InputStream ;
3337import java .io .InputStreamReader ;
38+ import java .io .UnsupportedEncodingException ;
3439import java .net .Authenticator ;
3540import java .net .PasswordAuthentication ;
41+ import java .net .URI ;
42+ import java .net .URISyntaxException ;
3643import java .net .URL ;
3744import java .net .URLConnection ;
45+ import java .net .URLEncoder ;
3846import java .util .Arrays ;
47+ import java .util .Base64 ;
3948import java .util .HashMap ;
4049import java .util .Map ;
4150import java .util .Objects ;
4251import java .util .Random ;
4352import java .util .regex .Matcher ;
4453import java .util .regex .Pattern ;
54+ import java .util .stream .Collectors ;
55+ import net .snowflake .client .jdbc .internal .apache .http .HttpHeaders ;
56+ import net .snowflake .client .jdbc .internal .apache .http .client .methods .CloseableHttpResponse ;
57+ import net .snowflake .client .jdbc .internal .apache .http .client .methods .HttpPost ;
58+ import net .snowflake .client .jdbc .internal .apache .http .client .utils .URIBuilder ;
59+ import net .snowflake .client .jdbc .internal .apache .http .entity .ContentType ;
60+ import net .snowflake .client .jdbc .internal .apache .http .entity .StringEntity ;
61+ import net .snowflake .client .jdbc .internal .apache .http .impl .client .CloseableHttpClient ;
62+ import net .snowflake .client .jdbc .internal .apache .http .impl .client .HttpClientBuilder ;
63+ import net .snowflake .client .jdbc .internal .apache .http .util .EntityUtils ;
64+ import net .snowflake .client .jdbc .internal .google .gson .JsonObject ;
65+ import net .snowflake .client .jdbc .internal .google .gson .JsonParser ;
4566import org .apache .kafka .common .config .Config ;
4667import org .apache .kafka .common .config .ConfigException ;
4768import org .apache .kafka .common .config .ConfigValue ;
@@ -62,6 +83,15 @@ public class Utils {
6283 public static final String SF_SSL = "sfssl" ; // for test only
6384 public static final String SF_WAREHOUSE = "sfwarehouse" ; // for test only
6485 public static final String PRIVATE_KEY_PASSPHRASE = "snowflake.private.key" + ".passphrase" ;
86+ public static final String SF_AUTHENTICATOR =
87+ "snowflake.authenticator" ; // TODO: SNOW-889748 change to enum
88+ public static final String SF_OAUTH_CLIENT_ID = "snowflake.oauth.client.id" ;
89+ public static final String SF_OAUTH_CLIENT_SECRET = "snowflake.oauth.client.secret" ;
90+ public static final String SF_OAUTH_REFRESH_TOKEN = "snowflake.oauth.refresh.token" ;
91+
92+ // authenticator type
93+ public static final String SNOWFLAKE_JWT = "snowflake_jwt" ;
94+ public static final String OAUTH = "oauth" ;
6595
6696 /**
6797 * This value should be present if ingestion method is {@link
@@ -440,11 +470,54 @@ && parseTopicToTableMap(config.get(SnowflakeSinkConnectorConfig.TOPICS_TABLES_MA
440470 Utils .formatString ("{} cannot be empty." , SnowflakeSinkConnectorConfig .SNOWFLAKE_SCHEMA ));
441471 }
442472
443- if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY )) {
444- invalidConfigParams .put (
445- SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
446- Utils .formatString (
447- "{} cannot be empty." , SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ));
473+ switch (config
474+ .getOrDefault (SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE , Utils .SNOWFLAKE_JWT )
475+ .toLowerCase ()) {
476+ // TODO: SNOW-889748 change to enum
477+ case Utils .SNOWFLAKE_JWT :
478+ if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY )) {
479+ invalidConfigParams .put (
480+ SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
481+ Utils .formatString (
482+ "{} cannot be empty when using {} authenticator." ,
483+ SnowflakeSinkConnectorConfig .SNOWFLAKE_PRIVATE_KEY ,
484+ Utils .SNOWFLAKE_JWT ));
485+ }
486+ break ;
487+ case Utils .OAUTH :
488+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID )) {
489+ invalidConfigParams .put (
490+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID ,
491+ Utils .formatString (
492+ "{} cannot be empty when using {} authenticator." ,
493+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_ID ,
494+ Utils .OAUTH ));
495+ }
496+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET )) {
497+ invalidConfigParams .put (
498+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET ,
499+ Utils .formatString (
500+ "{} cannot be empty when using {} authenticator." ,
501+ SnowflakeSinkConnectorConfig .OAUTH_CLIENT_SECRET ,
502+ Utils .OAUTH ));
503+ }
504+ if (!config .containsKey (SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN )) {
505+ invalidConfigParams .put (
506+ SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN ,
507+ Utils .formatString (
508+ "{} cannot be empty when using {} authenticator." ,
509+ SnowflakeSinkConnectorConfig .OAUTH_REFRESH_TOKEN ,
510+ Utils .OAUTH ));
511+ }
512+ break ;
513+ default :
514+ invalidConfigParams .put (
515+ SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE ,
516+ Utils .formatString (
517+ "{} should be one of {} or {}." ,
518+ SnowflakeSinkConnectorConfig .AUTHENTICATOR_TYPE ,
519+ Utils .SNOWFLAKE_JWT ,
520+ Utils .OAUTH ));
448521 }
449522
450523 if (!config .containsKey (SnowflakeSinkConnectorConfig .SNOWFLAKE_USER )) {
@@ -704,6 +777,133 @@ public static String formatString(String format, Object... vars) {
704777 return format ;
705778 }
706779
780+ /**
781+ * Get OAuth access token given refresh token
782+ *
783+ * @param url OAuth server url
784+ * @param clientId OAuth clientId
785+ * @param clientSecret OAuth clientSecret
786+ * @param refreshToken OAuth refresh token
787+ * @return OAuth access token
788+ */
789+ public static String getSnowflakeOAuthAccessToken (
790+ SnowflakeURL url , String clientId , String clientSecret , String refreshToken ) {
791+ return getSnowflakeOAuthToken (
792+ url ,
793+ clientId ,
794+ clientSecret ,
795+ refreshToken ,
796+ OAuthConstants .REFRESH_TOKEN ,
797+ OAuthConstants .REFRESH_TOKEN ,
798+ OAuthConstants .ACCESS_TOKEN );
799+ }
800+
801+ /**
802+ * Get OAuth token given integration info <a
803+ * href="https://docs.snowflake.com/en/user-guide/oauth-snowflake-overview">Snowflake OAuth
804+ * Overview</a>
805+ *
806+ * @param url OAuth server url
807+ * @param clientId OAuth clientId
808+ * @param clientSecret OAuth clientSecret
809+ * @param credential OAuth credential, either az code or refresh token
810+ * @param grantType OAuth grant type, either authorization_code or refresh_token
811+ * @param credentialType OAuth credential key, either code or refresh_token
812+ * @param tokenType type of OAuth token to get, either access_token or refresh_token
813+ * @return OAuth token
814+ */
815+ // TODO: SNOW-895296 Integrate OAuth utils with streaming ingest SDK
816+ public static String getSnowflakeOAuthToken (
817+ SnowflakeURL url ,
818+ String clientId ,
819+ String clientSecret ,
820+ String credential ,
821+ String grantType ,
822+ String credentialType ,
823+ String tokenType ) {
824+ Map <String , String > headers = new HashMap <>();
825+ headers .put (HttpHeaders .CONTENT_TYPE , OAuthConstants .OAUTH_CONTENT_TYPE_HEADER );
826+ headers .put (
827+ HttpHeaders .AUTHORIZATION ,
828+ OAuthConstants .BASIC_AUTH_HEADER_PREFIX
829+ + Base64 .getEncoder ().encodeToString ((clientId + ":" + clientSecret ).getBytes ()));
830+
831+ Map <String , String > payload = new HashMap <>();
832+ payload .put (OAuthConstants .GRANT_TYPE_PARAM , grantType );
833+ payload .put (credentialType , credential );
834+ payload .put (OAuthConstants .REDIRECT_URI , OAuthConstants .DEFAULT_REDIRECT_URI );
835+
836+ // Encode and convert payload into string entity
837+ String payloadString =
838+ payload .entrySet ().stream ()
839+ .map (
840+ e -> {
841+ try {
842+ return e .getKey () + "=" + URLEncoder .encode (e .getValue (), "UTF-8" );
843+ } catch (UnsupportedEncodingException ex ) {
844+ throw SnowflakeErrors .ERROR_1004 .getException (ex );
845+ }
846+ })
847+ .collect (Collectors .joining ("&" ));
848+ final StringEntity entity =
849+ new StringEntity (payloadString , ContentType .APPLICATION_FORM_URLENCODED );
850+
851+ HttpPost post =
852+ buildOAuthHttpPostRequest (url , OAuthConstants .TOKEN_REQUEST_ENDPOINT , headers , entity );
853+
854+ // Request access token
855+ CloseableHttpClient client = HttpClientBuilder .create ().build ();
856+ try {
857+ return InternalUtils .backoffAndRetry (
858+ null ,
859+ SnowflakeInternalOperations .FETCH_OAUTH_TOKEN ,
860+ () -> {
861+ try (CloseableHttpResponse httpResponse = client .execute (post )) {
862+ String respBodyString = EntityUtils .toString (httpResponse .getEntity ());
863+ JsonObject respBody = JsonParser .parseString (respBodyString ).getAsJsonObject ();
864+ // Trim surrounding quotation marks
865+ return respBody .get (tokenType ).toString ().replaceAll ("^\" |\" $" , "" );
866+ } catch (Exception e ) {
867+ throw SnowflakeErrors .ERROR_1004 .getException (
868+ "Failed to get Oauth access token after retries" );
869+ }
870+ })
871+ .toString ();
872+ } catch (Exception e ) {
873+ throw SnowflakeErrors .ERROR_1004 .getException (e );
874+ }
875+ }
876+
877+ /**
878+ * Build OAuth http post request base on headers and payload
879+ *
880+ * @param url target url
881+ * @param headers headers key value pairs
882+ * @param entity payload entity
883+ * @return HttpPost request for OAuth
884+ */
885+ public static HttpPost buildOAuthHttpPostRequest (
886+ SnowflakeURL url , String path , Map <String , String > headers , StringEntity entity ) {
887+ // Build post request
888+ URI uri ;
889+ try {
890+ uri =
891+ new URIBuilder ().setHost (url .toString ()).setScheme (url .getScheme ()).setPath (path ).build ();
892+ } catch (URISyntaxException e ) {
893+ throw SnowflakeErrors .ERROR_1004 .getException (e );
894+ }
895+
896+ // Add headers
897+ HttpPost post = new HttpPost (uri );
898+ for (Map .Entry <String , String > e : headers .entrySet ()) {
899+ post .addHeader (e .getKey (), e .getValue ());
900+ }
901+
902+ post .setEntity (entity );
903+
904+ return post ;
905+ }
906+
707907 /**
708908 * Get the message and cause of a missing exception, handling the null or empty cases of each
709909 *
0 commit comments