Skip to content

Commit

Permalink
Provide a basic authentication token when running Spark Connect serve…
Browse files Browse the repository at this point in the history
…r locally
  • Loading branch information
HyukjinKwon committed Feb 13, 2025
1 parent 42ecabf commit 0558d14
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 33 deletions.
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class ChannelBuilder:
PARAM_USER_ID = "user_id"
PARAM_USER_AGENT = "user_agent"
PARAM_SESSION_ID = "session_id"
CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"

GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024

Expand Down Expand Up @@ -591,6 +592,8 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""

_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN")

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -637,6 +640,10 @@ def __init__(
if isinstance(connection, ChannelBuilder)
else DefaultChannelBuilder(connection, channel_options)
)
if SparkConnectClient._local_auth_token is not None:
self._builder.set(
ChannelBuilder.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME,
SparkConnectClient._local_auth_token)
self._user_id = None
self._retry_policies: List[RetryPolicy] = []

Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TYPE_CHECKING,
ClassVar,
)
import uuid

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -893,13 +894,20 @@ def stop(self) -> None:
if self is getattr(SparkSession._active_session, "session", None):
SparkSession._active_session.session = None

# It should be `None` always on stop.
SparkConnectClient._local_auth_token = None

if "SPARK_LOCAL_REMOTE" in os.environ:
# When local mode is in use, follow the regular Spark session's
# behavior by terminating the Spark Connect server,
# meaning that you can stop local mode, and restart the Spark Connect
# client with a different remote address.
if PySparkSession._activeSession is not None:
try:
getattr(
PySparkSession._activeSession._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(None)
PySparkSession._activeSession.stop()
except Exception as e:
logger.warn(
Expand Down Expand Up @@ -1060,6 +1068,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
overwrite_conf["spark.connect.grpc.binding.port"] = "0"

origin_remote = os.environ.get("SPARK_REMOTE", None)
local_auth_token = str(uuid.uuid4())
SparkConnectClient._local_auth_token = local_auth_token
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token
try:
if origin_remote is not None:
# So SparkSubmit thinks no remote is set in order to
Expand All @@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
PySparkSession(SparkContext.getOrCreate(conf))

# In Python local mode, session.stop does not terminate JVM itself
# so we can't control it via environment variable.
getattr(
SparkContext._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(local_auth_token)

# Lastly only keep runtime configurations because other configurations are
# disallowed to set in the regular Spark Connect session.
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
Expand All @@ -1084,6 +1102,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -627,6 +628,19 @@ class SparkSession private[sql] (
}
allocator.close()
SparkSession.onSessionClose(this)
SparkSession.server.synchronized {
// It should be `None` always on stop.
ConnectCommon.setLocalAuthToken(null)
if (SparkSession.server.isDefined) {
// When local mode is in use, follow the regular Spark session's
// behavior by terminating the Spark Connect server,
// meaning that you can stop local mode, and restart the Spark Connect
// client with a different remote address.
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
.start()
SparkSession.server = None
}
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private val maybeConnectStartScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
private val maybeConnectStopScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
Expand Down Expand Up @@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
server.synchronized {
if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
val localAuthToken = java.util.UUID.randomUUID().toString()
server = Some {
ConnectCommon.setLocalAuthToken(localAuthToken)
val args =
Seq(
maybeConnectStartScript.get.toString,
"--master",
remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment()
.put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken)
pb.start()
}
})
// scalastyle:on runtimeaddshutdownhook

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectStopScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
}
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,8 @@ object SparkConnectClient {
* Configure the builder using the env SPARK_REMOTE environment variable.
*/
def loadFromEnvironment(): Builder = {
ConnectCommon.getLocalAuthToken.foreach(t =>
option(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, t))
lazy val isAPIModeConnect =
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
.getOrElse("classic")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,15 @@ private[sql] object ConnectCommon {
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024

val CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option(
System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t =>
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined)
assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}

import org.apache.spark.sql.connect.common.config.ConnectCommon

/**
* A gRPC interceptor to check if the header contains token for authentication.
*/
class LocalAuthInterceptor extends ServerInterceptor {
override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
ConnectCommon.assertLocalAuthToken(Option(headers.get(Metadata.Key
.of(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, Metadata.ASCII_STRING_MARSHALLER))))
next.startCall(call, headers)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@

package org.apache.spark.sql.connect.service

import java.io.ByteArrayInputStream
import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit

import scala.jdk.CollectionConverters._

import com.google.protobuf.Message
import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition}
import io.grpc.MethodDescriptor.PrototypeMarshaller
import io.grpc.netty.NettyServerBuilder
import io.grpc.netty.{GrpcSslContexts, NettyServerBuilder}
import io.grpc.protobuf.ProtoUtils
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import io.netty.handler.ssl.SslContext
import org.apache.commons.lang3.StringUtils

import org.apache.spark.{SparkContext, SparkEnv}
Expand All @@ -39,7 +42,8 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect._
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
Expand Down Expand Up @@ -366,10 +370,14 @@ object SparkConnectService extends Logging {
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)
ConnectCommon.setLocalAuthToken(localAuthToken)
val sparkConnectService = new SparkConnectService(debugMode)
val protoReflectionService =
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
val configuredInterceptors =
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil)

val startServiceFn = (port: Int) => {
val sb = bindAddress match {
Expand All @@ -388,6 +396,11 @@ object SparkConnectService extends Logging {
// grpcurl can introspect the API for debugging.
protoReflectionService.foreach(service => sb.addService(service))

Option(localAuthToken).foreach { t =>
val token = new ByteArrayInputStream(t.getBytes(StandardCharsets.UTF_8))
sb.sslContext(GrpcSslContexts.forServer(token, token).build())
}

server = sb.build
server.start()

Expand Down

0 comments on commit 0558d14

Please sign in to comment.