Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51156][CONNECT] Provide a basic authentication token when running Spark Connect server locally #49880

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class ChannelBuilder:
PARAM_USER_ID = "user_id"
PARAM_USER_AGENT = "user_agent"
PARAM_SESSION_ID = "session_id"
CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unfortunate that the grpc client tries to force you to use TLS if you want to use call credentials when there's so many workarounds like simply using a different header. Though in this case you could theoretically use local_channel_credentials at least on the Python side to use the built-in token mechanism


GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024

Expand Down Expand Up @@ -591,6 +592,8 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""

_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN")

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -637,6 +640,10 @@ def __init__(
if isinstance(connection, ChannelBuilder)
else DefaultChannelBuilder(connection, channel_options)
)
if SparkConnectClient._local_auth_token is not None:
self._builder.set(
ChannelBuilder.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME,
SparkConnectClient._local_auth_token)
self._user_id = None
self._retry_policies: List[RetryPolicy] = []

Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TYPE_CHECKING,
ClassVar,
)
import uuid

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -893,13 +894,20 @@ def stop(self) -> None:
if self is getattr(SparkSession._active_session, "session", None):
SparkSession._active_session.session = None

# It should be `None` always on stop.
SparkConnectClient._local_auth_token = None

if "SPARK_LOCAL_REMOTE" in os.environ:
# When local mode is in use, follow the regular Spark session's
# behavior by terminating the Spark Connect server,
# meaning that you can stop local mode, and restart the Spark Connect
# client with a different remote address.
if PySparkSession._activeSession is not None:
try:
getattr(
PySparkSession._activeSession._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(None)
PySparkSession._activeSession.stop()
except Exception as e:
logger.warn(
Expand Down Expand Up @@ -1060,6 +1068,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
overwrite_conf["spark.connect.grpc.binding.port"] = "0"

origin_remote = os.environ.get("SPARK_REMOTE", None)
local_auth_token = str(uuid.uuid4())
SparkConnectClient._local_auth_token = local_auth_token
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token
try:
if origin_remote is not None:
# So SparkSubmit thinks no remote is set in order to
Expand All @@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
PySparkSession(SparkContext.getOrCreate(conf))

# In Python local mode, session.stop does not terminate JVM itself
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about you set the environment variable when we start spark?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also how do you ensure the LocalAuthInterceptor will installed if there is no token yet?

# so we can't control it via environment variable.
getattr(
SparkContext._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(local_auth_token)

# Lastly only keep runtime configurations because other configurations are
# disallowed to set in the regular Spark Connect session.
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
Expand All @@ -1084,6 +1102,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -627,6 +628,19 @@ class SparkSession private[sql] (
}
allocator.close()
SparkSession.onSessionClose(this)
SparkSession.server.synchronized {
// It should be `None` always on stop.
ConnectCommon.setLocalAuthToken(null)
if (SparkSession.server.isDefined) {
// When local mode is in use, follow the regular Spark session's
// behavior by terminating the Spark Connect server,
// meaning that you can stop local mode, and restart the Spark Connect
// client with a different remote address.
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
.start()
SparkSession.server = None
}
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private val maybeConnectStartScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
private val maybeConnectStopScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
Expand Down Expand Up @@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
server.synchronized {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added a synchronized.

if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
val localAuthToken = java.util.UUID.randomUUID().toString()
server = Some {
ConnectCommon.setLocalAuthToken(localAuthToken)
val args =
Seq(
maybeConnectStartScript.get.toString,
"--master",
remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment()
.put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken)
pb.start()
}
})
// scalastyle:on runtimeaddshutdownhook

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectStopScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
}
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,8 @@ object SparkConnectClient {
* Configure the builder using the env SPARK_REMOTE environment variable.
*/
def loadFromEnvironment(): Builder = {
ConnectCommon.getLocalAuthToken.foreach(t =>
option(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, t))
lazy val isAPIModeConnect =
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
.getOrElse("classic")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,15 @@ private[sql] object ConnectCommon {
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024

val CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't store this in some random local variable. There is no need for this. On the client side the SparkConnectClient will store the token. On the server the LocalAuthInterceptor should just hold on to the token.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that local server can stop and start without turning off the JVM (for Python) unlike that we always stop/start JVM for Scala. So it has to be a variable.

System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t =>
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, you CANNOT use asserts for this. They will get elided if you disable assertions. Please throw a proper gRPC exception in the LocalAuthInterceptor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other places, I will fix it later. But here assert is correct because if token is set, CONNECT_LOCAL_AUTH_TOKEN must be set to for local usage.

assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}

import org.apache.spark.sql.connect.common.config.ConnectCommon

/**
* A gRPC interceptor to check if the header contains token for authentication.
*/
class LocalAuthInterceptor extends ServerInterceptor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we name this PreSharedKeyAuthenticationInterceptor? It is not a Local interceptor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 in general for making this more generic and usable beyond the local Spark Connect use case. Having a pre-shared secret capability built-in goes a long way in making Spark Connect more usable in shared computer clusters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will scope it down to local usage for now. Whole this is internal for now, and we don't need to generalize them at this moment.

override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
ConnectCommon.assertLocalAuthToken(Option(headers.get(Metadata.Key
.of(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, Metadata.ASCII_STRING_MARSHALLER))))
next.startCall(call, headers)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@

package org.apache.spark.sql.connect.service

import java.io.ByteArrayInputStream
import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit

import scala.jdk.CollectionConverters._

import com.google.protobuf.Message
import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition}
import io.grpc.MethodDescriptor.PrototypeMarshaller
import io.grpc.netty.NettyServerBuilder
import io.grpc.netty.{GrpcSslContexts, NettyServerBuilder}
import io.grpc.protobuf.ProtoUtils
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import io.netty.handler.ssl.SslContext
import org.apache.commons.lang3.StringUtils

import org.apache.spark.{SparkContext, SparkEnv}
Expand All @@ -39,7 +42,8 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect._
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
Expand Down Expand Up @@ -366,10 +370,14 @@ object SparkConnectService extends Logging {
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)
ConnectCommon.setLocalAuthToken(localAuthToken)
val sparkConnectService = new SparkConnectService(debugMode)
val protoReflectionService =
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
val configuredInterceptors =
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pass in the token as an argument to the LocalAuthInterceptor, there is absolute no reason for putting this in some global variable...


val startServiceFn = (port: Int) => {
val sb = bindAddress match {
Expand All @@ -388,6 +396,11 @@ object SparkConnectService extends Logging {
// grpcurl can introspect the API for debugging.
protoReflectionService.foreach(service => sb.addService(service))

Option(localAuthToken).foreach { t =>
val token = new ByteArrayInputStream(t.getBytes(StandardCharsets.UTF_8))
sb.sslContext(GrpcSslContexts.forServer(token, token).build())
}

server = sb.build
server.start()

Expand Down