Skip to content

Commit

Permalink
Provide a basic authentication token when running Spark Connect serve…
Browse files Browse the repository at this point in the history
…r locally
  • Loading branch information
HyukjinKwon committed Feb 12, 2025
1 parent 42ecabf commit d368742
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 35 deletions.
27 changes: 26 additions & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TYPE_CHECKING,
ClassVar,
)
import uuid

import numpy as np
import pandas as pd
Expand All @@ -53,7 +54,7 @@
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.dataframe import DataFrame as ParentDataFrame
from pyspark.sql.connect.logging import logger
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder, ChannelBuilder
from pyspark.sql.connect.conf import RuntimeConf
from pyspark.sql.connect.plan import (
SQL,
Expand Down Expand Up @@ -120,6 +121,7 @@ class SparkSession:
# Reference to the root SparkSession
_default_session: ClassVar[Optional["SparkSession"]] = None
_lock: ClassVar[RLock] = RLock()
_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN")

class Builder:
"""Builder for :class:`SparkSession`."""
Expand Down Expand Up @@ -238,6 +240,11 @@ def create(self) -> "SparkSession":
else:
spark_remote = to_str(self._options.get("spark.remote"))
assert spark_remote is not None
if SparkSession._local_auth_token is not None:
spark_remote = DefaultChannelBuilder(spark_remote) # type: ignore[assignment]
spark_remote.set( # type: ignore[attr-defined]
ChannelBuilder.PARAM_TOKEN, SparkSession._local_auth_token
)
session = SparkSession(connection=spark_remote)

SparkSession._set_default_and_active_session(session)
Expand Down Expand Up @@ -893,13 +900,20 @@ def stop(self) -> None:
if self is getattr(SparkSession._active_session, "session", None):
SparkSession._active_session.session = None

# It should be `None` always on stop.
SparkSession._local_auth_token = None

if "SPARK_LOCAL_REMOTE" in os.environ:
# When local mode is in use, follow the regular Spark session's
# behavior by terminating the Spark Connect server,
# meaning that you can stop local mode, and restart the Spark Connect
# client with a different remote address.
if PySparkSession._activeSession is not None:
try:
getattr(
PySparkSession._activeSession._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(None)
PySparkSession._activeSession.stop()
except Exception as e:
logger.warn(
Expand Down Expand Up @@ -1060,6 +1074,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
overwrite_conf["spark.connect.grpc.binding.port"] = "0"

origin_remote = os.environ.get("SPARK_REMOTE", None)
local_auth_token = str(uuid.uuid4())
SparkSession._local_auth_token = local_auth_token
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token
try:
if origin_remote is not None:
# So SparkSubmit thinks no remote is set in order to
Expand All @@ -1072,6 +1089,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
PySparkSession(SparkContext.getOrCreate(conf))

# In Python local mode, session.stop does not terminate JVM itself
# so we can't control it via environment variable.
getattr(
SparkContext._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(local_auth_token)

# Lastly only keep runtime configurations because other configurations are
# disallowed to set in the regular Spark Connect session.
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
Expand All @@ -1084,6 +1108,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -627,6 +628,19 @@ class SparkSession private[sql] (
}
allocator.close()
SparkSession.onSessionClose(this)
SparkSession.server.synchronized {
// It should be `None` always on stop.
ConnectCommon.setLocalAuthToken(null)
if (SparkSession.server.isDefined) {
// When local mode is in use, follow the regular Spark session's
// behavior by terminating the Spark Connect server,
// meaning that you can stop local mode, and restart the Spark Connect
// client with a different remote address.
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
.start()
SparkSession.server = None
}
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private val maybeConnectStartScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
private val maybeConnectStopScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
Expand Down Expand Up @@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
server.synchronized {
if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
val localAuthToken = java.util.UUID.randomUUID().toString()
server = Some {
ConnectCommon.setLocalAuthToken(localAuthToken)
val args =
Seq(
maybeConnectStartScript.get.toString,
"--master",
remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment()
.put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken)
pb.start()
}
})
// scalastyle:on runtimeaddshutdownhook

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectStopScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
}
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ object SparkConnectClient {
private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA"

private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
ConnectCommon.AUTH_TOKEN_META_DATA_KEY

private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String =
"Authentication token cannot be passed over insecure connections. " +
Expand All @@ -422,7 +422,13 @@ object SparkConnectClient {
* port or a NameResolver-compliant URI connection string.
*/
class Builder(private var _configuration: Configuration) {
def this() = this(Configuration())
def this() = this {
ConnectCommon.getLocalAuthToken
.map { _ =>
Configuration(token = ConnectCommon.getLocalAuthToken, isSslEnabled = Some(true))
}
.getOrElse(Configuration())
}

def configuration: Configuration = _configuration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,22 @@
*/
package org.apache.spark.sql.connect.common.config

import io.grpc.Metadata

private[sql] object ConnectCommon {
val CONNECT_GRPC_BINDING_PORT: Int = 15002
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)

val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = None
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t =>
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined)
assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}

import org.apache.spark.sql.connect.common.config.ConnectCommon

/**
* A gRPC interceptor to check if the header contains token for authentication.
*/
class LocalAuthInterceptor extends ServerInterceptor {
override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
ConnectCommon.assertLocalAuthToken(
Option(headers.get(ConnectCommon.AUTH_TOKEN_META_DATA_KEY)))
next.startCall(call, headers)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect._
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
Expand Down Expand Up @@ -366,10 +367,14 @@ object SparkConnectService extends Logging {
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)
ConnectCommon.setLocalAuthToken(localAuthToken)
val sparkConnectService = new SparkConnectService(debugMode)
val protoReflectionService =
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
val configuredInterceptors =
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil)

val startServiceFn = (port: Int) => {
val sb = bindAddress match {
Expand Down

0 comments on commit d368742

Please sign in to comment.