diff --git a/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala b/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala index fb3b1711a..bdc1b0136 100644 --- a/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala +++ b/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala @@ -28,19 +28,23 @@ import org.apache.commons.io.IOUtils import org.apache.commons.io.input.BoundedInputStream import org.apache.hadoop.conf.Configuration import org.apache.hadoop.util.VersionInfo -import org.apache.http.{HttpHeaders, HttpHost, HttpStatus} +import org.apache.http.{HttpHeaders, HttpHost, HttpRequest, HttpStatus} import org.apache.http.client.config.RequestConfig import org.apache.http.client.methods.{HttpGet, HttpPost, HttpRequestBase} import org.apache.http.client.protocol.HttpClientContext +import org.apache.http.conn.routing.HttpRoute import org.apache.http.conn.ssl.{SSLConnectionSocketFactory, SSLContextBuilder, TrustSelfSignedStrategy} import org.apache.http.entity.StringEntity import org.apache.http.impl.client.{HttpClientBuilder, HttpClients} +import org.apache.http.impl.conn.{DefaultRoutePlanner, DefaultSchemePortResolver} +import org.apache.http.protocol.HttpContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import io.delta.sharing.client.auth.{AuthConfig, AuthCredentialProviderFactory} import io.delta.sharing.client.model._ import io.delta.sharing.client.util.{ConfUtils, JsonUtils, RetryUtils, UnexpectedHttpStatus} +import io.delta.sharing.client.util.ConfUtils.ProxyConfig import io.delta.sharing.spark.MissingEndStreamActionException /** An interface to fetch Delta metadata from remote server. */ @@ -198,7 +202,8 @@ class DeltaSharingRestClient( asyncQueryMaxDuration: Long = 600000L, tokenExchangeMaxRetries: Int = 5, tokenExchangeMaxRetryDurationInSeconds: Int = 60, - tokenRenewalThresholdInSeconds: Int = 600 + tokenRenewalThresholdInSeconds: Int = 600, + proxyConfigOpt: Option[ProxyConfig] = None ) extends DeltaSharingClient with Logging { logInfo(s"DeltaSharingRestClient with endStreamActionEnabled: $endStreamActionEnabled, " + @@ -211,7 +216,7 @@ class DeltaSharingRestClient( // Convert the responseFormat to a Seq to be used later. private val responseFormatSet = responseFormat.split(",").toSet - private lazy val client = { + private[sharing] lazy val client = { val clientBuilder: HttpClientBuilder = if (sslTrustAll) { val sslBuilder = new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()) @@ -227,6 +232,31 @@ class DeltaSharingRestClient( .setConnectTimeout(timeoutInSeconds * 1000) .setConnectionRequestTimeout(timeoutInSeconds * 1000) .setSocketTimeout(timeoutInSeconds * 1000).build() + proxyConfigOpt.foreach { proxyConfig => + if (sslTrustAll) { + throw new IllegalStateException( + "Proxy configuration is not supported when sslTrustAll is enabled.") + } + val proxy = new HttpHost(proxyConfig.host, proxyConfig.port) + clientBuilder.setProxy(proxy) + + if (proxyConfig.noProxyHosts.nonEmpty) { + val routePlanner = new DefaultRoutePlanner(DefaultSchemePortResolver.INSTANCE) { + override def determineRoute(target: HttpHost, + request: HttpRequest, + context: HttpContext): HttpRoute = { + if (proxyConfig.noProxyHosts.contains(target.getHostName)) { + // Direct route (no proxy) + new HttpRoute(target) + } else { + // Route via proxy + new HttpRoute(target, proxy) + } + } + } + clientBuilder.setRoutePlanner(routePlanner) + } + } val client = clientBuilder // Disable the default retry behavior because we have our own retry logic. // See `RetryUtils.runWithExponentialBackoff`. @@ -1401,6 +1431,7 @@ object DeltaSharingRestClient extends Logging { val endStreamActionEnabled = ConfUtils.includeEndStreamAction(sqlConf) val asyncQueryMaxDurationMillis = ConfUtils.asyncQueryTimeout(sqlConf) val asyncQueryPollDurationMillis = ConfUtils.asyncQueryPollIntervalMillis(sqlConf) + val proxyConfig = ConfUtils.getClientProxyConfig(sqlConf) val tokenExchangeMaxRetries = ConfUtils.tokenExchangeMaxRetries(sqlConf) val tokenExchangeMaxRetryDurationInSeconds = @@ -1427,7 +1458,8 @@ object DeltaSharingRestClient extends Logging { classOf[Long], classOf[Int], classOf[Int], - classOf[Int] + classOf[Int], + classOf[Option[ProxyConfig]] ).newInstance(profileProvider, java.lang.Integer.valueOf(timeoutInSeconds), java.lang.Integer.valueOf(numRetries), @@ -1445,7 +1477,8 @@ object DeltaSharingRestClient extends Logging { java.lang.Long.valueOf(asyncQueryMaxDurationMillis), java.lang.Integer.valueOf(tokenExchangeMaxRetries), java.lang.Integer.valueOf(tokenExchangeMaxRetryDurationInSeconds), - java.lang.Integer.valueOf(tokenRenewalThresholdInSeconds) + java.lang.Integer.valueOf(tokenRenewalThresholdInSeconds), + proxyConfig ).asInstanceOf[DeltaSharingClient] } } diff --git a/client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala b/client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala index a6b9b5cfa..ef3487822 100644 --- a/client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala +++ b/client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala @@ -84,6 +84,10 @@ object ConfUtils { val PROXY_PORT = "spark.delta.sharing.network.proxyPort" val NO_PROXY_HOSTS = "spark.delta.sharing.network.noProxyHosts" + val CLIENT_PROXY_HOST = "spark.delta.sharing.client.network.proxyHost" + val CLIENT_PROXY_PORT = "spark.delta.sharing.client.network.proxyPort" + val CLIENT_NO_PROXY_HOSTS = "spark.delta.sharing.client.network.noProxyHosts" + val OAUTH_RETRIES_CONF = "spark.delta.sharing.oauth.tokenExchangeMaxRetries" val OAUTH_RETRIES_DEFAULT = 5 @@ -118,6 +122,23 @@ object ConfUtils { Some(ProxyConfig(proxyHost, proxyPort, noProxyHosts = noProxyList)) } + def getClientProxyConfig(conf: SQLConf): Option[ProxyConfig] = { + val proxyHost = conf.getConfString(CLIENT_PROXY_HOST, null) + val proxyPortAsString = conf.getConfString(CLIENT_PROXY_PORT, null) + + if (proxyHost == null && proxyPortAsString == null) { + return None + } + + validateNonEmpty(proxyHost, CLIENT_PROXY_HOST) + validateNonEmpty(proxyPortAsString, CLIENT_PROXY_PORT) + val proxyPort = proxyPortAsString.toInt + validatePortNumber(proxyPort, CLIENT_PROXY_PORT) + + val noProxyList = conf.getConfString(CLIENT_NO_PROXY_HOSTS, "").split(",").map(_.trim).toSeq + Some(ProxyConfig(proxyHost, proxyPort, noProxyHosts = noProxyList)) + } + def getNeverUseHttps(conf: Configuration): Boolean = { conf.getBoolean(NEVER_USE_HTTPS, NEVER_USE_HTTPS_DEFAULT.toBoolean) } diff --git a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala index 8e25692c2..ac1e9bbdd 100644 --- a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala @@ -17,24 +17,17 @@ package io.delta.sharing.client import java.sql.Timestamp +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import org.apache.http.HttpHeaders import org.apache.http.client.methods.{HttpGet, HttpRequestBase} +import org.apache.http.util.EntityUtils +import org.sparkproject.jetty.server.Server +import org.sparkproject.jetty.servlet.{ServletHandler, ServletHolder} -import io.delta.sharing.client.model.{ - AddCDCFile, - AddFile, - AddFileForCDF, - DeltaTableFiles, - EndStreamAction, - Format, - Metadata, - Protocol, - RemoveFile, - Table -} -import io.delta.sharing.client.util.JsonUtils -import io.delta.sharing.client.util.UnexpectedHttpStatus +import io.delta.sharing.client.model.{AddCDCFile, AddFile, AddFileForCDF, DeltaTableFiles, EndStreamAction, Format, Metadata, Protocol, RemoveFile, Table} +import io.delta.sharing.client.util.{JsonUtils, ProxyServer, UnexpectedHttpStatus} +import io.delta.sharing.client.util.ConfUtils.ProxyConfig import io.delta.sharing.spark.MissingEndStreamActionException // scalastyle:off maxLineLength @@ -1330,4 +1323,178 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { } checkErrorMessage(e, s"and 0 lines, and last line as [Empty_Seq_in_checkEndStreamAction].") } + + integrationTest("traffic goes through a proxy when a proxy configured") { + // Create a local HTTP server. + val server = new Server(0) + val handler = new ServletHandler() + server.setHandler(handler) + handler.addServletWithMapping(new ServletHolder(new HttpServlet { + override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = { + resp.setContentType("text/plain") + resp.setStatus(HttpServletResponse.SC_OK) + + // scalastyle:off println + resp.getWriter.println("Hello, World!") + // scalastyle:on println + } + }), "/*") + server.start() + do { + Thread.sleep(100) + } while (!server.isStarted()) + + // Create a local HTTP proxy server. + val proxyServer = new ProxyServer(0) + proxyServer.initialize() + + try { + val dsClient = new DeltaSharingRestClient( + testProfileProvider, + sslTrustAll = false, + proxyConfigOpt = Some( + ProxyConfig( + host = proxyServer.getHost(), + port = proxyServer.getPort(), + noProxyHosts = Seq(server.getURI.getHost) + ) + ) + ) + + // Send a request to the local server through the httpClient. + val response = dsClient.client.execute(new HttpGet(server.getURI.toString)) + + // Assert that the request is successful. + assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK) + val content = EntityUtils.toString(response.getEntity) + assert(content.trim == "Hello, World!") + + // Assert that the request is passed through proxy. + assert(proxyServer.getCapturedRequests().size == 1) + } finally { + server.stop() + proxyServer.stop() + } + } + + integrationTest("traffic skips the proxy when a noProxyHosts configured") { + // Create a local HTTP server. + val server = new Server(0) + val handler = new ServletHandler() + server.setHandler(handler) + handler.addServletWithMapping(new ServletHolder(new HttpServlet { + override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = { + resp.setContentType("text/plain") + resp.setStatus(HttpServletResponse.SC_OK) + + // scalastyle:off println + resp.getWriter.println("Hello, World!") + // scalastyle:on println + } + }), "/*") + server.start() + do { + Thread.sleep(100) + } while (!server.isStarted()) + + // Create a local HTTP proxy server. + val proxyServer = new ProxyServer(0) + proxyServer.initialize() + try { + val dsClient = new DeltaSharingRestClient( + testProfileProvider, + sslTrustAll = false, + proxyConfigOpt = Some( + ProxyConfig( + host = proxyServer.getHost(), + port = proxyServer.getPort(), + noProxyHosts = Seq(server.getURI.getHost) + ) + ) + ) + + // Send a request to the local server through the httpClient. + val response = dsClient.client.execute(new HttpGet(server.getURI.toString)) + + // Assert that the request is successful. + assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK) + val content = EntityUtils.toString(response.getEntity) + assert(content.trim == "Hello, World!") + + // Assert that the request is not passed through proxy. + assert(proxyServer.getCapturedRequests().isEmpty) + } finally { + server.stop() + proxyServer.stop() + } + } + + integrationTest("traffic goes through the proxy when noProxyHosts does not include destination") { + // Create a local HTTP server. + val server = new Server(0) + val handler = new ServletHandler() + server.setHandler(handler) + handler.addServletWithMapping(new ServletHolder(new HttpServlet { + override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = { + resp.setContentType("text/plain") + resp.setStatus(HttpServletResponse.SC_OK) + + // scalastyle:off println + resp.getWriter.println("Hello, World!") + // scalastyle:on println + } + }), "/*") + server.start() + do { + Thread.sleep(100) + } while (!server.isStarted()) + + // Create a local HTTP proxy server. + val proxyServer = new ProxyServer(0) + proxyServer.initialize() + try { + val dsClient = new DeltaSharingRestClient( + testProfileProvider, + sslTrustAll = false, + proxyConfigOpt = Some( + ProxyConfig( + host = proxyServer.getHost(), + port = proxyServer.getPort(), + noProxyHosts = Seq(server.getURI.getHost) + ) + ) + ) + + // Send a request to the local server through the httpClient. + val response = dsClient.client.execute(new HttpGet(server.getURI.toString)) + + // Assert that the request is successful. + assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK) + val content = EntityUtils.toString(response.getEntity) + assert(content.trim == "Hello, World!") + + // Assert that the request is not passed through proxy. + assert(proxyServer.getCapturedRequests().size == 1) + } finally { + server.stop() + proxyServer.stop() + } + } + + integrationTest("sslTrustAll cannot be true if proxy configured") { + val e = intercept[IllegalStateException] { + new DeltaSharingRestClient( + testProfileProvider, + sslTrustAll = true, + proxyConfigOpt = Some( + ProxyConfig( + host = "localhost", + port = 8080, + noProxyHosts = Seq() + ) + ) + ).client + } + assert(e.getMessage.contains("Proxy configuration is not supported when sslTrustAll is enabled.")) + } } diff --git a/spark/src/test/scala/io/delta/sharing/spark/TestDeltaSharingClient.scala b/spark/src/test/scala/io/delta/sharing/spark/TestDeltaSharingClient.scala index 11475139b..d04ebbae0 100644 --- a/spark/src/test/scala/io/delta/sharing/spark/TestDeltaSharingClient.scala +++ b/spark/src/test/scala/io/delta/sharing/spark/TestDeltaSharingClient.scala @@ -16,23 +16,9 @@ package io.delta.sharing.spark -import io.delta.sharing.client.{ - DeltaSharingClient, - DeltaSharingProfile, - DeltaSharingProfileProvider -} -import io.delta.sharing.client.model.{ - AddCDCFile, - AddFile, - AddFileForCDF, - DeltaTableFiles, - DeltaTableMetadata, - Metadata, - Protocol, - RemoveFile, - SingleAction, - Table -} +import io.delta.sharing.client.{DeltaSharingClient, DeltaSharingProfile, DeltaSharingProfileProvider} +import io.delta.sharing.client.model.{AddCDCFile, AddFile, AddFileForCDF, DeltaTableFiles, DeltaTableMetadata, Metadata, Protocol, RemoveFile, SingleAction, Table} +import io.delta.sharing.client.util.ConfUtils.ProxyConfig import io.delta.sharing.client.util.JsonUtils import io.delta.sharing.spark.TestDeltaSharingClient.TESTING_TIMESTAMP @@ -54,7 +40,8 @@ class TestDeltaSharingClient( asyncQueryMaxDuration: Long = Long.MaxValue, tokenExchangeMaxRetries: Int = 5, tokenExchangeMaxRetryDurationInSeconds: Int = 60, - tokenRenewalThresholdInSeconds: Int = 600 + tokenRenewalThresholdInSeconds: Int = 600, + proxyConfigOpt: Option[ProxyConfig] = None ) extends DeltaSharingClient { import DeltaSharingOptions.RESPONSE_FORMAT_PARQUET