diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 1bd9725a9a7f0..eb42aa9554062 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -591,6 +591,8 @@ class SparkConnectClient(object): Conceptually the remote spark session that communicates with the server """ + _local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN") + def __init__( self, connection: Union[str, ChannelBuilder], @@ -637,6 +639,8 @@ def __init__( if isinstance(connection, ChannelBuilder) else DefaultChannelBuilder(connection, channel_options) ) + if SparkConnectClient._local_auth_token is not None: + self._builder.set(ChannelBuilder.PARAM_TOKEN, SparkConnectClient._local_auth_token) self._user_id = None self._retry_policies: List[RetryPolicy] = [] diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index c01c1e42a3185..8718834ca05e5 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -40,6 +40,7 @@ TYPE_CHECKING, ClassVar, ) +import uuid import numpy as np import pandas as pd @@ -893,6 +894,9 @@ def stop(self) -> None: if self is getattr(SparkSession._active_session, "session", None): SparkSession._active_session.session = None + # It should be `None` always on stop. + SparkConnectClient._local_auth_token = None + if "SPARK_LOCAL_REMOTE" in os.environ: # When local mode is in use, follow the regular Spark session's # behavior by terminating the Spark Connect server, @@ -900,6 +904,10 @@ def stop(self) -> None: # client with a different remote address. if PySparkSession._activeSession is not None: try: + getattr( + PySparkSession._activeSession._jvm, + "org.apache.spark.sql.connect.common.config.ConnectCommon", + ).setLocalAuthToken(None) PySparkSession._activeSession.stop() except Exception as e: logger.warn( @@ -1060,6 +1068,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: overwrite_conf["spark.connect.grpc.binding.port"] = "0" origin_remote = os.environ.get("SPARK_REMOTE", None) + local_auth_token = str(uuid.uuid4()) + SparkConnectClient._local_auth_token = local_auth_token + os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token try: if origin_remote is not None: # So SparkSubmit thinks no remote is set in order to @@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items())) PySparkSession(SparkContext.getOrCreate(conf)) + # In Python local mode, session.stop does not terminate JVM itself + # so we can't control it via environment variable. + getattr( + SparkContext._jvm, + "org.apache.spark.sql.connect.common.config.ConnectCommon", + ).setLocalAuthToken(local_auth_token) + # Lastly only keep runtime configurations because other configurations are # disallowed to set in the regular Spark Connect session. utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr] @@ -1084,6 +1102,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: if origin_remote is not None: os.environ["SPARK_REMOTE"] = origin_remote del os.environ["SPARK_LOCAL_CONNECT"] + del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] else: raise PySparkRuntimeError( errorClass="SESSION_OR_CONTEXT_EXISTS", diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index 7e7b1a3632087..12dcf0b36a1d8 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer +import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -627,6 +628,19 @@ class SparkSession private[sql] ( } allocator.close() SparkSession.onSessionClose(this) + SparkSession.server.synchronized { + // It should be `None` always on stop. + ConnectCommon.setLocalAuthToken(null) + if (SparkSession.server.isDefined) { + // When local mode is in use, follow the regular Spark session's + // behavior by terminating the Spark Connect server, + // meaning that you can stop local mode, and restart the Spark Connect + // client with a different remote address. + new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString) + .start() + SparkSession.server = None + } + } } /** @inheritdoc */ @@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging { private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong private var server: Option[Process] = None + private val maybeConnectStartScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) + private val maybeConnectStopScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh")) private[sql] val sparkOptions = sys.props.filter { p => p._1.startsWith("spark.") && p._2.nonEmpty }.toMap @@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging { } } - val maybeConnectScript = - Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) - - if (server.isEmpty && - (remoteString.exists(_.startsWith("local")) || - (remoteString.isDefined && isAPIModeConnect)) && - maybeConnectScript.exists(Files.exists(_))) { - server = Some { - val args = - Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions - .filter(p => !p._1.startsWith("spark.remote")) - .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } - val pb = new ProcessBuilder(args: _*) - // So don't exclude spark-sql jar in classpath - pb.environment().remove(SparkConnectClient.SPARK_REMOTE) - pb.start() - } - - // Let the server start. We will directly request to set the configurations - // and this sleep makes less noisy with retries. - Thread.sleep(2000L) - System.setProperty("spark.remote", "sc://localhost") - - // scalastyle:off runtimeaddshutdownhook - Runtime.getRuntime.addShutdownHook(new Thread() { - override def run(): Unit = if (server.isDefined) { - new ProcessBuilder(maybeConnectScript.get.toString) - .start() + server.synchronized { + if (server.isEmpty && + (remoteString.exists(_.startsWith("local")) || + (remoteString.isDefined && isAPIModeConnect)) && + maybeConnectStartScript.exists(Files.exists(_))) { + val localAuthToken = java.util.UUID.randomUUID().toString() + server = Some { + ConnectCommon.setLocalAuthToken(localAuthToken) + val args = + Seq( + maybeConnectStartScript.get.toString, + "--master", + remoteString.get) ++ sparkOptions + .filter(p => !p._1.startsWith("spark.remote")) + .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } + val pb = new ProcessBuilder(args: _*) + // So don't exclude spark-sql jar in classpath + pb.environment().remove(SparkConnectClient.SPARK_REMOTE) + pb.environment() + .put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken) + pb.start() } - }) - // scalastyle:on runtimeaddshutdownhook + + // Let the server start. We will directly request to set the configurations + // and this sleep makes less noisy with retries. + Thread.sleep(2000L) + System.setProperty("spark.remote", "sc://localhost") + + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = if (server.isDefined) { + new ProcessBuilder(maybeConnectStopScript.get.toString) + .start() + } + }) + // scalastyle:on runtimeaddshutdownhook + } } } f diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index dd241c50c9340..23a0764e7d63a 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -404,7 +404,7 @@ object SparkConnectClient { private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA" private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] = - Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER) + ConnectCommon.AUTH_TOKEN_META_DATA_KEY private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String = "Authentication token cannot be passed over insecure connections. " + @@ -422,7 +422,13 @@ object SparkConnectClient { * port or a NameResolver-compliant URI connection string. */ class Builder(private var _configuration: Configuration) { - def this() = this(Configuration()) + def this() = this { + ConnectCommon.getLocalAuthToken + .map { _ => + Configuration(token = ConnectCommon.getLocalAuthToken, isSslEnabled = Some(true)) + } + .getOrElse(Configuration()) + } def configuration: Configuration = _configuration diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala index e244fd13595b2..d67aebf0e94c3 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala @@ -16,9 +16,23 @@ */ package org.apache.spark.sql.connect.common.config +import io.grpc.Metadata + private[sql] object ConnectCommon { val CONNECT_GRPC_BINDING_PORT: Int = 15002 val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0 val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024 val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024 + val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] = + Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER) + + val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN" + private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option( + System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)) + def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token) + def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN + def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t => + assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined) + assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get) + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala new file mode 100644 index 0000000000000..a1a5b6d63f776 --- /dev/null +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor} + +import org.apache.spark.sql.connect.common.config.ConnectCommon + +/** + * A gRPC interceptor to check if the header contains token for authentication. + */ +class LocalAuthInterceptor extends ServerInterceptor { + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + ConnectCommon.assertLocalAuthToken( + Option(headers.get(ConnectCommon.AUTH_TOKEN_META_DATA_KEY))) + next.startCall(call, headers) + } +} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index e62c19b66c8e5..739be91085f0f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -39,7 +39,8 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.HOST import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent} -import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES} +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.connect.config.Connect._ import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils @@ -366,10 +367,14 @@ object SparkConnectService extends Logging { val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS) val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT) + val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME) + ConnectCommon.setLocalAuthToken(localAuthToken) val sparkConnectService = new SparkConnectService(debugMode) val protoReflectionService = if (debugMode) Some(ProtoReflectionService.newInstance()) else None - val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors() + val configuredInterceptors = + SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++ + (if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil) val startServiceFn = (port: Int) => { val sb = bindAddress match {