diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 1bd9725a9a7f0..83982d78ab9cf 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -125,6 +125,7 @@ class ChannelBuilder: PARAM_USER_ID = "user_id" PARAM_USER_AGENT = "user_agent" PARAM_SESSION_ID = "session_id" + CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token" GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024 @@ -591,6 +592,8 @@ class SparkConnectClient(object): Conceptually the remote spark session that communicates with the server """ + _local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN") + def __init__( self, connection: Union[str, ChannelBuilder], @@ -637,6 +640,10 @@ def __init__( if isinstance(connection, ChannelBuilder) else DefaultChannelBuilder(connection, channel_options) ) + if SparkConnectClient._local_auth_token is not None: + self._builder.set( + ChannelBuilder.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, + SparkConnectClient._local_auth_token) self._user_id = None self._retry_policies: List[RetryPolicy] = [] diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index c01c1e42a3185..8718834ca05e5 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -40,6 +40,7 @@ TYPE_CHECKING, ClassVar, ) +import uuid import numpy as np import pandas as pd @@ -893,6 +894,9 @@ def stop(self) -> None: if self is getattr(SparkSession._active_session, "session", None): SparkSession._active_session.session = None + # It should be `None` always on stop. + SparkConnectClient._local_auth_token = None + if "SPARK_LOCAL_REMOTE" in os.environ: # When local mode is in use, follow the regular Spark session's # behavior by terminating the Spark Connect server, @@ -900,6 +904,10 @@ def stop(self) -> None: # client with a different remote address. if PySparkSession._activeSession is not None: try: + getattr( + PySparkSession._activeSession._jvm, + "org.apache.spark.sql.connect.common.config.ConnectCommon", + ).setLocalAuthToken(None) PySparkSession._activeSession.stop() except Exception as e: logger.warn( @@ -1060,6 +1068,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: overwrite_conf["spark.connect.grpc.binding.port"] = "0" origin_remote = os.environ.get("SPARK_REMOTE", None) + local_auth_token = str(uuid.uuid4()) + SparkConnectClient._local_auth_token = local_auth_token + os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token try: if origin_remote is not None: # So SparkSubmit thinks no remote is set in order to @@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items())) PySparkSession(SparkContext.getOrCreate(conf)) + # In Python local mode, session.stop does not terminate JVM itself + # so we can't control it via environment variable. + getattr( + SparkContext._jvm, + "org.apache.spark.sql.connect.common.config.ConnectCommon", + ).setLocalAuthToken(local_auth_token) + # Lastly only keep runtime configurations because other configurations are # disallowed to set in the regular Spark Connect session. utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr] @@ -1084,6 +1102,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: if origin_remote is not None: os.environ["SPARK_REMOTE"] = origin_remote del os.environ["SPARK_LOCAL_CONNECT"] + del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] else: raise PySparkRuntimeError( errorClass="SESSION_OR_CONTEXT_EXISTS", diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index 7e7b1a3632087..12dcf0b36a1d8 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer +import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -627,6 +628,19 @@ class SparkSession private[sql] ( } allocator.close() SparkSession.onSessionClose(this) + SparkSession.server.synchronized { + // It should be `None` always on stop. + ConnectCommon.setLocalAuthToken(null) + if (SparkSession.server.isDefined) { + // When local mode is in use, follow the regular Spark session's + // behavior by terminating the Spark Connect server, + // meaning that you can stop local mode, and restart the Spark Connect + // client with a different remote address. + new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString) + .start() + SparkSession.server = None + } + } } /** @inheritdoc */ @@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging { private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong private var server: Option[Process] = None + private val maybeConnectStartScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) + private val maybeConnectStopScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh")) private[sql] val sparkOptions = sys.props.filter { p => p._1.startsWith("spark.") && p._2.nonEmpty }.toMap @@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging { } } - val maybeConnectScript = - Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) - - if (server.isEmpty && - (remoteString.exists(_.startsWith("local")) || - (remoteString.isDefined && isAPIModeConnect)) && - maybeConnectScript.exists(Files.exists(_))) { - server = Some { - val args = - Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions - .filter(p => !p._1.startsWith("spark.remote")) - .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } - val pb = new ProcessBuilder(args: _*) - // So don't exclude spark-sql jar in classpath - pb.environment().remove(SparkConnectClient.SPARK_REMOTE) - pb.start() - } - - // Let the server start. We will directly request to set the configurations - // and this sleep makes less noisy with retries. - Thread.sleep(2000L) - System.setProperty("spark.remote", "sc://localhost") - - // scalastyle:off runtimeaddshutdownhook - Runtime.getRuntime.addShutdownHook(new Thread() { - override def run(): Unit = if (server.isDefined) { - new ProcessBuilder(maybeConnectScript.get.toString) - .start() + server.synchronized { + if (server.isEmpty && + (remoteString.exists(_.startsWith("local")) || + (remoteString.isDefined && isAPIModeConnect)) && + maybeConnectStartScript.exists(Files.exists(_))) { + val localAuthToken = java.util.UUID.randomUUID().toString() + server = Some { + ConnectCommon.setLocalAuthToken(localAuthToken) + val args = + Seq( + maybeConnectStartScript.get.toString, + "--master", + remoteString.get) ++ sparkOptions + .filter(p => !p._1.startsWith("spark.remote")) + .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } + val pb = new ProcessBuilder(args: _*) + // So don't exclude spark-sql jar in classpath + pb.environment().remove(SparkConnectClient.SPARK_REMOTE) + pb.environment() + .put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken) + pb.start() } - }) - // scalastyle:on runtimeaddshutdownhook + + // Let the server start. We will directly request to set the configurations + // and this sleep makes less noisy with retries. + Thread.sleep(2000L) + System.setProperty("spark.remote", "sc://localhost") + + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = if (server.isDefined) { + new ProcessBuilder(maybeConnectStopScript.get.toString) + .start() + } + }) + // scalastyle:on runtimeaddshutdownhook + } } } f diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index dd241c50c9340..6a9c021f12526 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -620,6 +620,8 @@ object SparkConnectClient { * Configure the builder using the env SPARK_REMOTE environment variable. */ def loadFromEnvironment(): Builder = { + ConnectCommon.getLocalAuthToken.foreach(t => + option(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, t)) lazy val isAPIModeConnect = Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY)) .getOrElse("classic") diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala index e244fd13595b2..7b9bf65eb37ff 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala @@ -21,4 +21,15 @@ private[sql] object ConnectCommon { val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0 val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024 val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024 + + val CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token" + val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN" + private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option( + System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)) + def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token) + def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN + def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t => + assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined) + assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get) + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala new file mode 100644 index 0000000000000..d4a26af1b1545 --- /dev/null +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor} + +import org.apache.spark.sql.connect.common.config.ConnectCommon + +/** + * A gRPC interceptor to check if the header contains token for authentication. + */ +class LocalAuthInterceptor extends ServerInterceptor { + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + ConnectCommon.assertLocalAuthToken(Option(headers.get(Metadata.Key + .of(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, Metadata.ASCII_STRING_MARSHALLER)))) + next.startCall(call, headers) + } +} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index e62c19b66c8e5..a242a427d72ca 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.connect.service +import java.io.ByteArrayInputStream import java.net.InetSocketAddress +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import scala.jdk.CollectionConverters._ @@ -25,10 +27,11 @@ import scala.jdk.CollectionConverters._ import com.google.protobuf.Message import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition} import io.grpc.MethodDescriptor.PrototypeMarshaller -import io.grpc.netty.NettyServerBuilder +import io.grpc.netty.{GrpcSslContexts, NettyServerBuilder} import io.grpc.protobuf.ProtoUtils import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver +import io.netty.handler.ssl.SslContext import org.apache.commons.lang3.StringUtils import org.apache.spark.{SparkContext, SparkEnv} @@ -39,7 +42,8 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.HOST import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent} -import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES} +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.connect.config.Connect._ import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils @@ -366,10 +370,14 @@ object SparkConnectService extends Logging { val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS) val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT) + val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME) + ConnectCommon.setLocalAuthToken(localAuthToken) val sparkConnectService = new SparkConnectService(debugMode) val protoReflectionService = if (debugMode) Some(ProtoReflectionService.newInstance()) else None - val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors() + val configuredInterceptors = + SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++ + (if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil) val startServiceFn = (port: Int) => { val sb = bindAddress match { @@ -388,6 +396,11 @@ object SparkConnectService extends Logging { // grpcurl can introspect the API for debugging. protoReflectionService.foreach(service => sb.addService(service)) + Option(localAuthToken).foreach { t => + val token = new ByteArrayInputStream(t.getBytes(StandardCharsets.UTF_8)) + sb.sslContext(GrpcSslContexts.forServer(token, token).build()) + } + server = sb.build server.start()