From a935c0b83f974ff5d927477250bf7d8fb07cf95d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 11 Feb 2025 15:07:05 +0900 Subject: [PATCH] Provide a basic authentication token when running Spark Connect server locally --- python/pyspark/sql/connect/session.py | 27 +++++- .../spark/sql/connect/SparkSession.scala | 88 ++++++++++++------- .../connect/client/SparkConnectClient.scala | 6 +- .../connect/common/config/ConnectCommon.scala | 13 +++ .../service/LocalAuthInterceptor.scala | 36 ++++++++ .../connect/service/SparkConnectService.scala | 9 +- 6 files changed, 143 insertions(+), 36 deletions(-) create mode 100644 sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index c01c1e42a3185..8328ffdbc4c2b 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -40,6 +40,7 @@ TYPE_CHECKING, ClassVar, ) +import uuid import numpy as np import pandas as pd @@ -53,7 +54,7 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame as ParentDataFrame from pyspark.sql.connect.logging import logger -from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder +from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder, ChannelBuilder from pyspark.sql.connect.conf import RuntimeConf from pyspark.sql.connect.plan import ( SQL, @@ -120,6 +121,7 @@ class SparkSession: # Reference to the root SparkSession _default_session: ClassVar[Optional["SparkSession"]] = None _lock: ClassVar[RLock] = RLock() + _local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN") class Builder: """Builder for :class:`SparkSession`.""" @@ -238,6 +240,11 @@ def create(self) -> "SparkSession": else: spark_remote = to_str(self._options.get("spark.remote")) assert spark_remote is not None + if SparkSession._local_auth_token is not None: + spark_remote = DefaultChannelBuilder(spark_remote) # type: ignore[assignment] + spark_remote.set( # type: ignore[attr-defined] + ChannelBuilder.PARAM_TOKEN, SparkSession._local_auth_token + ) session = SparkSession(connection=spark_remote) SparkSession._set_default_and_active_session(session) @@ -893,6 +900,9 @@ def stop(self) -> None: if self is getattr(SparkSession._active_session, "session", None): SparkSession._active_session.session = None + # It should be `None` always on stop. + SparkSession._local_auth_token = None + if "SPARK_LOCAL_REMOTE" in os.environ: # When local mode is in use, follow the regular Spark session's # behavior by terminating the Spark Connect server, @@ -900,6 +910,10 @@ def stop(self) -> None: # client with a different remote address. if PySparkSession._activeSession is not None: try: + getattr( + PySparkSession._activeSession._jvm, + "org.apache.spark.sql.connect.common.config.ConnectCommon", + ).setLocalAuthToken(None) PySparkSession._activeSession.stop() except Exception as e: logger.warn( @@ -1060,6 +1074,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: overwrite_conf["spark.connect.grpc.binding.port"] = "0" origin_remote = os.environ.get("SPARK_REMOTE", None) + local_auth_token = str(uuid.uuid4()) + SparkSession._local_auth_token = local_auth_token + os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token try: if origin_remote is not None: # So SparkSubmit thinks no remote is set in order to @@ -1072,6 +1089,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items())) PySparkSession(SparkContext.getOrCreate(conf)) + # In Python local mode, session.stop does not terminate JVM itself + # so we can't control it via environment variable. + getattr( + SparkContext._jvm, + "org.apache.spark.sql.connect.common.config.ConnectCommon", + ).setLocalAuthToken(local_auth_token) + # Lastly only keep runtime configurations because other configurations are # disallowed to set in the regular Spark Connect session. utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr] @@ -1084,6 +1108,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: if origin_remote is not None: os.environ["SPARK_REMOTE"] = origin_remote del os.environ["SPARK_LOCAL_CONNECT"] + del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] else: raise PySparkRuntimeError( errorClass="SESSION_OR_CONTEXT_EXISTS", diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index 7e7b1a3632087..1c342801ce3c1 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer +import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -627,6 +628,19 @@ class SparkSession private[sql] ( } allocator.close() SparkSession.onSessionClose(this) + SparkSession.server.synchronized { + // It should be `None` always on stop. + ConnectCommon.setLocalAuthToken(null) + if (SparkSession.server.isDefined) { + // When local mode is in use, follow the regular Spark session's + // behavior by terminating the Spark Connect server, + // meaning that you can stop local mode, and restart the Spark Connect + // client with a different remote address. + new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString) + .start() + SparkSession.server = None + } + } } /** @inheritdoc */ @@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging { private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong private var server: Option[Process] = None + private val maybeConnectStartScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) + private val maybeConnectStopScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh")) private[sql] val sparkOptions = sys.props.filter { p => p._1.startsWith("spark.") && p._2.nonEmpty }.toMap @@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging { } } - val maybeConnectScript = - Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) - - if (server.isEmpty && - (remoteString.exists(_.startsWith("local")) || - (remoteString.isDefined && isAPIModeConnect)) && - maybeConnectScript.exists(Files.exists(_))) { - server = Some { - val args = - Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions - .filter(p => !p._1.startsWith("spark.remote")) - .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } - val pb = new ProcessBuilder(args: _*) - // So don't exclude spark-sql jar in classpath - pb.environment().remove(SparkConnectClient.SPARK_REMOTE) - pb.start() - } - - // Let the server start. We will directly request to set the configurations - // and this sleep makes less noisy with retries. - Thread.sleep(2000L) - System.setProperty("spark.remote", "sc://localhost") - - // scalastyle:off runtimeaddshutdownhook - Runtime.getRuntime.addShutdownHook(new Thread() { - override def run(): Unit = if (server.isDefined) { - new ProcessBuilder(maybeConnectScript.get.toString) - .start() + server.synchronized { + if (server.isEmpty && + (remoteString.exists(_.startsWith("local")) || + (remoteString.isDefined && isAPIModeConnect)) && + maybeConnectStartScript.exists(Files.exists(_))) { + val localAuthToken = java.util.UUID.randomUUID().toString() + server = Some { + ConnectCommon.setLocalAuthToken(localAuthToken) + val args = + Seq( + maybeConnectStartScript.get.toString, + "--master", + remoteString.get) ++ sparkOptions + .filter(p => !p._1.startsWith("spark.remote")) + .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } + val pb = new ProcessBuilder(args: _*) + // So don't exclude spark-sql jar in classpath + pb.environment().remove(SparkConnectClient.SPARK_REMOTE) + pb.environment() + .put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken) + pb.start() } - }) - // scalastyle:on runtimeaddshutdownhook + + // Let the server start. We will directly request to set the configurations + // and this sleep makes less noisy with retries. + Thread.sleep(2000L) + System.setProperty("spark.remote", "sc://localhost") + + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = if (server.isDefined) { + new ProcessBuilder(maybeConnectStopScript.get.toString) + .start() + } + }) + // scalastyle:on runtimeaddshutdownhook + } } } f @@ -752,6 +776,10 @@ object SparkSession extends SparkSessionCompanion with Logging { * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ private[sql] def create(configuration: Configuration): SparkSession = { + ConnectCommon.getLocalAuthToken.foreach { _ => + configuration.token = ConnectCommon.getLocalAuthToken + configuration.isSslEnabled = Some(true) + } new SparkSession(configuration.toSparkConnectClient, planIdGenerator) } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index dd241c50c9340..c52f20d36ad7f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -404,7 +404,7 @@ object SparkConnectClient { private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA" private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] = - Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER) + ConnectCommon.AUTH_TOKEN_META_DATA_KEY private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String = "Authentication token cannot be passed over insecure connections. " + @@ -725,8 +725,8 @@ object SparkConnectClient { userName: String = null, host: String = "localhost", port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT, - token: Option[String] = None, - isSslEnabled: Option[Boolean] = None, + var token: Option[String] = None, + var isSslEnabled: Option[Boolean] = None, metadata: Map[String, String] = Map.empty, userAgent: String = genUserAgent( sys.env.getOrElse("SPARK_CONNECT_USER_AGENT", DEFAULT_USER_AGENT)), diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala index e244fd13595b2..ee9639a99bd5c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala @@ -16,9 +16,22 @@ */ package org.apache.spark.sql.connect.common.config +import io.grpc.Metadata + private[sql] object ConnectCommon { val CONNECT_GRPC_BINDING_PORT: Int = 15002 val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0 val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024 val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024 + val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] = + Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER) + + val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN" + private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = None + def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token) + def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN + def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t => + assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined) + assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get) + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala new file mode 100644 index 0000000000000..a1a5b6d63f776 --- /dev/null +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalAuthInterceptor.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor} + +import org.apache.spark.sql.connect.common.config.ConnectCommon + +/** + * A gRPC interceptor to check if the header contains token for authentication. + */ +class LocalAuthInterceptor extends ServerInterceptor { + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + ConnectCommon.assertLocalAuthToken( + Option(headers.get(ConnectCommon.AUTH_TOKEN_META_DATA_KEY))) + next.startCall(call, headers) + } +} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index e62c19b66c8e5..739be91085f0f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -39,7 +39,8 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.HOST import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent} -import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES} +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.connect.config.Connect._ import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils @@ -366,10 +367,14 @@ object SparkConnectService extends Logging { val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS) val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT) + val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME) + ConnectCommon.setLocalAuthToken(localAuthToken) val sparkConnectService = new SparkConnectService(debugMode) val protoReflectionService = if (debugMode) Some(ProtoReflectionService.newInstance()) else None - val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors() + val configuredInterceptors = + SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++ + (if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil) val startServiceFn = (port: Int) => { val sb = bindAddress match {