Skip to content

Commit

Permalink
Provide a basic authentication token when running Spark Connect serve…
Browse files Browse the repository at this point in the history
…r locally
  • Loading branch information
HyukjinKwon committed Feb 12, 2025
1 parent 42ecabf commit 66336ac
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 34 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""

_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN")

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -637,6 +639,8 @@ def __init__(
if isinstance(connection, ChannelBuilder)
else DefaultChannelBuilder(connection, channel_options)
)
if SparkConnectClient._local_auth_token is not None:
self._builder.set(ChannelBuilder.PARAM_TOKEN, SparkConnectClient._local_auth_token)
self._user_id = None
self._retry_policies: List[RetryPolicy] = []

Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TYPE_CHECKING,
ClassVar,
)
import uuid

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -893,13 +894,20 @@ def stop(self) -> None:
if self is getattr(SparkSession._active_session, "session", None):
SparkSession._active_session.session = None

# It should be `None` always on stop.
SparkConnectClient._local_auth_token = None

if "SPARK_LOCAL_REMOTE" in os.environ:
# When local mode is in use, follow the regular Spark session's
# behavior by terminating the Spark Connect server,
# meaning that you can stop local mode, and restart the Spark Connect
# client with a different remote address.
if PySparkSession._activeSession is not None:
try:
getattr(
PySparkSession._activeSession._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(None)
PySparkSession._activeSession.stop()
except Exception as e:
logger.warn(
Expand Down Expand Up @@ -1060,6 +1068,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
overwrite_conf["spark.connect.grpc.binding.port"] = "0"

origin_remote = os.environ.get("SPARK_REMOTE", None)
local_auth_token = str(uuid.uuid4())
SparkConnectClient._local_auth_token = local_auth_token
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token
try:
if origin_remote is not None:
# So SparkSubmit thinks no remote is set in order to
Expand All @@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
PySparkSession(SparkContext.getOrCreate(conf))

# In Python local mode, session.stop does not terminate JVM itself
# so we can't control it via environment variable.
getattr(
SparkContext._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(local_auth_token)

# Lastly only keep runtime configurations because other configurations are
# disallowed to set in the regular Spark Connect session.
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
Expand All @@ -1084,6 +1102,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -627,6 +628,19 @@ class SparkSession private[sql] (
}
allocator.close()
SparkSession.onSessionClose(this)
SparkSession.server.synchronized {
// It should be `None` always on stop.
ConnectCommon.setLocalAuthToken(null)
if (SparkSession.server.isDefined) {
// When local mode is in use, follow the regular Spark session's
// behavior by terminating the Spark Connect server,
// meaning that you can stop local mode, and restart the Spark Connect
// client with a different remote address.
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
.start()
SparkSession.server = None
}
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private val maybeConnectStartScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
private val maybeConnectStopScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
Expand Down Expand Up @@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
server.synchronized {
if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
val localAuthToken = java.util.UUID.randomUUID().toString()
server = Some {
ConnectCommon.setLocalAuthToken(localAuthToken)
val args =
Seq(
maybeConnectStartScript.get.toString,
"--master",
remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment()
.put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken)
pb.start()
}
})
// scalastyle:on runtimeaddshutdownhook

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectStopScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
}
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ object SparkConnectClient {
private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA"

private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
ConnectCommon.AUTH_TOKEN_META_DATA_KEY

private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String =
"Authentication token cannot be passed over insecure connections. " +
Expand All @@ -422,7 +422,13 @@ object SparkConnectClient {
* port or a NameResolver-compliant URI connection string.
*/
class Builder(private var _configuration: Configuration) {
def this() = this(Configuration())
def this() = this {
ConnectCommon.getLocalAuthToken
.map { _ =>
Configuration(token = ConnectCommon.getLocalAuthToken, isSslEnabled = Some(true))
}
.getOrElse(Configuration())
}

def configuration: Configuration = _configuration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,23 @@
*/
package org.apache.spark.sql.connect.common.config

import io.grpc.Metadata

private[sql] object ConnectCommon {
val CONNECT_GRPC_BINDING_PORT: Int = 15002
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)

val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option(
System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t =>
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined)
assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}

import org.apache.spark.sql.connect.common.config.ConnectCommon

/**
* A gRPC interceptor to check if the header contains token for authentication.
*/
class LocalAuthInterceptor extends ServerInterceptor {
override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
ConnectCommon.assertLocalAuthToken(
Option(headers.get(ConnectCommon.AUTH_TOKEN_META_DATA_KEY)))
next.startCall(call, headers)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect._
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
Expand Down Expand Up @@ -366,10 +367,14 @@ object SparkConnectService extends Logging {
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)
ConnectCommon.setLocalAuthToken(localAuthToken)
val sparkConnectService = new SparkConnectService(debugMode)
val protoReflectionService =
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
val configuredInterceptors =
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil)

val startServiceFn = (port: Int) => {
val sb = bindAddress match {
Expand Down

0 comments on commit 66336ac

Please sign in to comment.