-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-51156][CONNECT] Provide a basic authentication token when running Spark Connect server locally #49880
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ | |
TYPE_CHECKING, | ||
ClassVar, | ||
) | ||
import uuid | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
@@ -893,13 +894,20 @@ def stop(self) -> None: | |
if self is getattr(SparkSession._active_session, "session", None): | ||
SparkSession._active_session.session = None | ||
|
||
# It should be `None` always on stop. | ||
SparkConnectClient._local_auth_token = None | ||
|
||
if "SPARK_LOCAL_REMOTE" in os.environ: | ||
# When local mode is in use, follow the regular Spark session's | ||
# behavior by terminating the Spark Connect server, | ||
# meaning that you can stop local mode, and restart the Spark Connect | ||
# client with a different remote address. | ||
if PySparkSession._activeSession is not None: | ||
try: | ||
getattr( | ||
PySparkSession._activeSession._jvm, | ||
"org.apache.spark.sql.connect.common.config.ConnectCommon", | ||
).setLocalAuthToken(None) | ||
PySparkSession._activeSession.stop() | ||
except Exception as e: | ||
logger.warn( | ||
|
@@ -1060,6 +1068,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: | |
overwrite_conf["spark.connect.grpc.binding.port"] = "0" | ||
|
||
origin_remote = os.environ.get("SPARK_REMOTE", None) | ||
local_auth_token = str(uuid.uuid4()) | ||
SparkConnectClient._local_auth_token = local_auth_token | ||
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token | ||
try: | ||
if origin_remote is not None: | ||
# So SparkSubmit thinks no remote is set in order to | ||
|
@@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: | |
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items())) | ||
PySparkSession(SparkContext.getOrCreate(conf)) | ||
|
||
# In Python local mode, session.stop does not terminate JVM itself | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about you set the environment variable when we start spark? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also how do you ensure the LocalAuthInterceptor will installed if there is no token yet? |
||
# so we can't control it via environment variable. | ||
getattr( | ||
SparkContext._jvm, | ||
"org.apache.spark.sql.connect.common.config.ConnectCommon", | ||
).setLocalAuthToken(local_auth_token) | ||
|
||
# Lastly only keep runtime configurations because other configurations are | ||
# disallowed to set in the regular Spark Connect session. | ||
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr] | ||
|
@@ -1084,6 +1102,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: | |
if origin_remote is not None: | ||
os.environ["SPARK_REMOTE"] = origin_remote | ||
del os.environ["SPARK_LOCAL_CONNECT"] | ||
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] | ||
else: | ||
raise PySparkRuntimeError( | ||
errorClass="SESSION_OR_CONTEXT_EXISTS", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral | |
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} | ||
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration | ||
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer | ||
import org.apache.spark.sql.connect.common.config.ConnectCommon | ||
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf} | ||
import org.apache.spark.sql.sources.BaseRelation | ||
import org.apache.spark.sql.types.StructType | ||
|
@@ -627,6 +628,19 @@ class SparkSession private[sql] ( | |
} | ||
allocator.close() | ||
SparkSession.onSessionClose(this) | ||
SparkSession.server.synchronized { | ||
// It should be `None` always on stop. | ||
ConnectCommon.setLocalAuthToken(null) | ||
if (SparkSession.server.isDefined) { | ||
// When local mode is in use, follow the regular Spark session's | ||
// behavior by terminating the Spark Connect server, | ||
// meaning that you can stop local mode, and restart the Spark Connect | ||
// client with a different remote address. | ||
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString) | ||
.start() | ||
SparkSession.server = None | ||
} | ||
} | ||
} | ||
|
||
/** @inheritdoc */ | ||
|
@@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging { | |
private val MAX_CACHED_SESSIONS = 100 | ||
private val planIdGenerator = new AtomicLong | ||
private var server: Option[Process] = None | ||
private val maybeConnectStartScript = | ||
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) | ||
private val maybeConnectStopScript = | ||
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh")) | ||
private[sql] val sparkOptions = sys.props.filter { p => | ||
p._1.startsWith("spark.") && p._2.nonEmpty | ||
}.toMap | ||
|
@@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging { | |
} | ||
} | ||
|
||
val maybeConnectScript = | ||
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) | ||
|
||
if (server.isEmpty && | ||
(remoteString.exists(_.startsWith("local")) || | ||
(remoteString.isDefined && isAPIModeConnect)) && | ||
maybeConnectScript.exists(Files.exists(_))) { | ||
server = Some { | ||
val args = | ||
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions | ||
.filter(p => !p._1.startsWith("spark.remote")) | ||
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") } | ||
val pb = new ProcessBuilder(args: _*) | ||
// So don't exclude spark-sql jar in classpath | ||
pb.environment().remove(SparkConnectClient.SPARK_REMOTE) | ||
pb.start() | ||
} | ||
|
||
// Let the server start. We will directly request to set the configurations | ||
// and this sleep makes less noisy with retries. | ||
Thread.sleep(2000L) | ||
System.setProperty("spark.remote", "sc://localhost") | ||
|
||
// scalastyle:off runtimeaddshutdownhook | ||
Runtime.getRuntime.addShutdownHook(new Thread() { | ||
override def run(): Unit = if (server.isDefined) { | ||
new ProcessBuilder(maybeConnectScript.get.toString) | ||
.start() | ||
server.synchronized { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just added a |
||
if (server.isEmpty && | ||
(remoteString.exists(_.startsWith("local")) || | ||
(remoteString.isDefined && isAPIModeConnect)) && | ||
maybeConnectStartScript.exists(Files.exists(_))) { | ||
val localAuthToken = java.util.UUID.randomUUID().toString() | ||
server = Some { | ||
ConnectCommon.setLocalAuthToken(localAuthToken) | ||
val args = | ||
Seq( | ||
maybeConnectStartScript.get.toString, | ||
"--master", | ||
remoteString.get) ++ sparkOptions | ||
.filter(p => !p._1.startsWith("spark.remote")) | ||
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") } | ||
val pb = new ProcessBuilder(args: _*) | ||
// So don't exclude spark-sql jar in classpath | ||
pb.environment().remove(SparkConnectClient.SPARK_REMOTE) | ||
pb.environment() | ||
.put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken) | ||
pb.start() | ||
} | ||
}) | ||
// scalastyle:on runtimeaddshutdownhook | ||
|
||
// Let the server start. We will directly request to set the configurations | ||
// and this sleep makes less noisy with retries. | ||
Thread.sleep(2000L) | ||
System.setProperty("spark.remote", "sc://localhost") | ||
|
||
// scalastyle:off runtimeaddshutdownhook | ||
Runtime.getRuntime.addShutdownHook(new Thread() { | ||
override def run(): Unit = if (server.isDefined) { | ||
new ProcessBuilder(maybeConnectStopScript.get.toString) | ||
.start() | ||
} | ||
}) | ||
// scalastyle:on runtimeaddshutdownhook | ||
} | ||
} | ||
} | ||
f | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,15 @@ private[sql] object ConnectCommon { | |
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0 | ||
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024 | ||
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024 | ||
|
||
val CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token" | ||
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN" | ||
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't store this in some random local variable. There is no need for this. On the client side the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that local server can stop and start without turning off the JVM (for Python) unlike that we always stop/start JVM for Scala. So it has to be a variable. |
||
System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)) | ||
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token) | ||
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN | ||
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t => | ||
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, you CANNOT use asserts for this. They will get elided if you disable assertions. Please throw a proper gRPC exception in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For other places, I will fix it later. But here assert is correct because if |
||
assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.connect.service | ||
|
||
import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor} | ||
|
||
import org.apache.spark.sql.connect.common.config.ConnectCommon | ||
|
||
/** | ||
* A gRPC interceptor to check if the header contains token for authentication. | ||
*/ | ||
class LocalAuthInterceptor extends ServerInterceptor { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about we name this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 in general for making this more generic and usable beyond the local Spark Connect use case. Having a pre-shared secret capability built-in goes a long way in making Spark Connect more usable in shared computer clusters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will scope it down to local usage for now. Whole this is internal for now, and we don't need to generalize them at this moment. |
||
override def interceptCall[ReqT, RespT]( | ||
call: ServerCall[ReqT, RespT], | ||
headers: Metadata, | ||
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { | ||
ConnectCommon.assertLocalAuthToken(Option(headers.get(Metadata.Key | ||
.of(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, Metadata.ASCII_STRING_MARSHALLER)))) | ||
next.startCall(call, headers) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,18 +17,21 @@ | |
|
||
package org.apache.spark.sql.connect.service | ||
|
||
import java.io.ByteArrayInputStream | ||
import java.net.InetSocketAddress | ||
import java.nio.charset.StandardCharsets | ||
import java.util.concurrent.TimeUnit | ||
|
||
import scala.jdk.CollectionConverters._ | ||
|
||
import com.google.protobuf.Message | ||
import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition} | ||
import io.grpc.MethodDescriptor.PrototypeMarshaller | ||
import io.grpc.netty.NettyServerBuilder | ||
import io.grpc.netty.{GrpcSslContexts, NettyServerBuilder} | ||
import io.grpc.protobuf.ProtoUtils | ||
import io.grpc.protobuf.services.ProtoReflectionService | ||
import io.grpc.stub.StreamObserver | ||
import io.netty.handler.ssl.SslContext | ||
import org.apache.commons.lang3.StringUtils | ||
|
||
import org.apache.spark.{SparkContext, SparkEnv} | ||
|
@@ -39,7 +42,8 @@ import org.apache.spark.internal.{Logging, MDC} | |
import org.apache.spark.internal.LogKeys.HOST | ||
import org.apache.spark.internal.config.UI.UI_ENABLED | ||
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent} | ||
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES} | ||
import org.apache.spark.sql.connect.common.config.ConnectCommon | ||
import org.apache.spark.sql.connect.config.Connect._ | ||
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener | ||
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} | ||
import org.apache.spark.sql.connect.utils.ErrorUtils | ||
|
@@ -366,10 +370,14 @@ object SparkConnectService extends Logging { | |
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) | ||
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS) | ||
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT) | ||
val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME) | ||
ConnectCommon.setLocalAuthToken(localAuthToken) | ||
val sparkConnectService = new SparkConnectService(debugMode) | ||
val protoReflectionService = | ||
if (debugMode) Some(ProtoReflectionService.newInstance()) else None | ||
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors() | ||
val configuredInterceptors = | ||
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++ | ||
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just pass in the token as an argument to the |
||
|
||
val startServiceFn = (port: Int) => { | ||
val sb = bindAddress match { | ||
|
@@ -388,6 +396,11 @@ object SparkConnectService extends Logging { | |
// grpcurl can introspect the API for debugging. | ||
protoReflectionService.foreach(service => sb.addService(service)) | ||
|
||
Option(localAuthToken).foreach { t => | ||
val token = new ByteArrayInputStream(t.getBytes(StandardCharsets.UTF_8)) | ||
sb.sslContext(GrpcSslContexts.forServer(token, token).build()) | ||
} | ||
|
||
server = sb.build | ||
server.start() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unfortunate that the grpc client tries to force you to use TLS if you want to use call credentials when there's so many workarounds like simply using a different header. Though in this case you could theoretically use
local_channel_credentials
at least on the Python side to use the built-in token mechanism