Skip to content

Commit 66336ac

Browse files
committed
Provide a basic authentication token when running Spark Connect server locally
1 parent 42ecabf commit 66336ac

File tree

7 files changed

+142
-34
lines changed

7 files changed

+142
-34
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ class SparkConnectClient(object):
591591
Conceptually the remote spark session that communicates with the server
592592
"""
593593

594+
_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN")
595+
594596
def __init__(
595597
self,
596598
connection: Union[str, ChannelBuilder],
@@ -637,6 +639,8 @@ def __init__(
637639
if isinstance(connection, ChannelBuilder)
638640
else DefaultChannelBuilder(connection, channel_options)
639641
)
642+
if SparkConnectClient._local_auth_token is not None:
643+
self._builder.set(ChannelBuilder.PARAM_TOKEN, SparkConnectClient._local_auth_token)
640644
self._user_id = None
641645
self._retry_policies: List[RetryPolicy] = []
642646

python/pyspark/sql/connect/session.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TYPE_CHECKING,
4141
ClassVar,
4242
)
43+
import uuid
4344

4445
import numpy as np
4546
import pandas as pd
@@ -893,13 +894,20 @@ def stop(self) -> None:
893894
if self is getattr(SparkSession._active_session, "session", None):
894895
SparkSession._active_session.session = None
895896

897+
# It should be `None` always on stop.
898+
SparkConnectClient._local_auth_token = None
899+
896900
if "SPARK_LOCAL_REMOTE" in os.environ:
897901
# When local mode is in use, follow the regular Spark session's
898902
# behavior by terminating the Spark Connect server,
899903
# meaning that you can stop local mode, and restart the Spark Connect
900904
# client with a different remote address.
901905
if PySparkSession._activeSession is not None:
902906
try:
907+
getattr(
908+
PySparkSession._activeSession._jvm,
909+
"org.apache.spark.sql.connect.common.config.ConnectCommon",
910+
).setLocalAuthToken(None)
903911
PySparkSession._activeSession.stop()
904912
except Exception as e:
905913
logger.warn(
@@ -1060,6 +1068,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10601068
overwrite_conf["spark.connect.grpc.binding.port"] = "0"
10611069

10621070
origin_remote = os.environ.get("SPARK_REMOTE", None)
1071+
local_auth_token = str(uuid.uuid4())
1072+
SparkConnectClient._local_auth_token = local_auth_token
1073+
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token
10631074
try:
10641075
if origin_remote is not None:
10651076
# So SparkSubmit thinks no remote is set in order to
@@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10721083
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
10731084
PySparkSession(SparkContext.getOrCreate(conf))
10741085

1086+
# In Python local mode, session.stop does not terminate JVM itself
1087+
# so we can't control it via environment variable.
1088+
getattr(
1089+
SparkContext._jvm,
1090+
"org.apache.spark.sql.connect.common.config.ConnectCommon",
1091+
).setLocalAuthToken(local_auth_token)
1092+
10751093
# Lastly only keep runtime configurations because other configurations are
10761094
# disallowed to set in the regular Spark Connect session.
10771095
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
@@ -1084,6 +1102,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10841102
if origin_remote is not None:
10851103
os.environ["SPARK_REMOTE"] = origin_remote
10861104
del os.environ["SPARK_LOCAL_CONNECT"]
1105+
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
10871106
else:
10881107
raise PySparkRuntimeError(
10891108
errorClass="SESSION_OR_CONTEXT_EXISTS",

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
4848
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
4949
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
5050
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
51+
import org.apache.spark.sql.connect.common.config.ConnectCommon
5152
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
5253
import org.apache.spark.sql.sources.BaseRelation
5354
import org.apache.spark.sql.types.StructType
@@ -627,6 +628,19 @@ class SparkSession private[sql] (
627628
}
628629
allocator.close()
629630
SparkSession.onSessionClose(this)
631+
SparkSession.server.synchronized {
632+
// It should be `None` always on stop.
633+
ConnectCommon.setLocalAuthToken(null)
634+
if (SparkSession.server.isDefined) {
635+
// When local mode is in use, follow the regular Spark session's
636+
// behavior by terminating the Spark Connect server,
637+
// meaning that you can stop local mode, and restart the Spark Connect
638+
// client with a different remote address.
639+
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
640+
.start()
641+
SparkSession.server = None
642+
}
643+
}
630644
}
631645

632646
/** @inheritdoc */
@@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
679693
private val MAX_CACHED_SESSIONS = 100
680694
private val planIdGenerator = new AtomicLong
681695
private var server: Option[Process] = None
696+
private val maybeConnectStartScript =
697+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
698+
private val maybeConnectStopScript =
699+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
682700
private[sql] val sparkOptions = sys.props.filter { p =>
683701
p._1.startsWith("spark.") && p._2.nonEmpty
684702
}.toMap
@@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
712730
}
713731
}
714732

715-
val maybeConnectScript =
716-
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
717-
718-
if (server.isEmpty &&
719-
(remoteString.exists(_.startsWith("local")) ||
720-
(remoteString.isDefined && isAPIModeConnect)) &&
721-
maybeConnectScript.exists(Files.exists(_))) {
722-
server = Some {
723-
val args =
724-
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
725-
.filter(p => !p._1.startsWith("spark.remote"))
726-
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
727-
val pb = new ProcessBuilder(args: _*)
728-
// So don't exclude spark-sql jar in classpath
729-
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
730-
pb.start()
731-
}
732-
733-
// Let the server start. We will directly request to set the configurations
734-
// and this sleep makes less noisy with retries.
735-
Thread.sleep(2000L)
736-
System.setProperty("spark.remote", "sc://localhost")
737-
738-
// scalastyle:off runtimeaddshutdownhook
739-
Runtime.getRuntime.addShutdownHook(new Thread() {
740-
override def run(): Unit = if (server.isDefined) {
741-
new ProcessBuilder(maybeConnectScript.get.toString)
742-
.start()
733+
server.synchronized {
734+
if (server.isEmpty &&
735+
(remoteString.exists(_.startsWith("local")) ||
736+
(remoteString.isDefined && isAPIModeConnect)) &&
737+
maybeConnectStartScript.exists(Files.exists(_))) {
738+
val localAuthToken = java.util.UUID.randomUUID().toString()
739+
server = Some {
740+
ConnectCommon.setLocalAuthToken(localAuthToken)
741+
val args =
742+
Seq(
743+
maybeConnectStartScript.get.toString,
744+
"--master",
745+
remoteString.get) ++ sparkOptions
746+
.filter(p => !p._1.startsWith("spark.remote"))
747+
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
748+
val pb = new ProcessBuilder(args: _*)
749+
// So don't exclude spark-sql jar in classpath
750+
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
751+
pb.environment()
752+
.put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken)
753+
pb.start()
743754
}
744-
})
745-
// scalastyle:on runtimeaddshutdownhook
755+
756+
// Let the server start. We will directly request to set the configurations
757+
// and this sleep makes less noisy with retries.
758+
Thread.sleep(2000L)
759+
System.setProperty("spark.remote", "sc://localhost")
760+
761+
// scalastyle:off runtimeaddshutdownhook
762+
Runtime.getRuntime.addShutdownHook(new Thread() {
763+
override def run(): Unit = if (server.isDefined) {
764+
new ProcessBuilder(maybeConnectStopScript.get.toString)
765+
.start()
766+
}
767+
})
768+
// scalastyle:on runtimeaddshutdownhook
769+
}
746770
}
747771
}
748772
f

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ object SparkConnectClient {
404404
private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA"
405405

406406
private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
407-
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
407+
ConnectCommon.AUTH_TOKEN_META_DATA_KEY
408408

409409
private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String =
410410
"Authentication token cannot be passed over insecure connections. " +
@@ -422,7 +422,13 @@ object SparkConnectClient {
422422
* port or a NameResolver-compliant URI connection string.
423423
*/
424424
class Builder(private var _configuration: Configuration) {
425-
def this() = this(Configuration())
425+
def this() = this {
426+
ConnectCommon.getLocalAuthToken
427+
.map { _ =>
428+
Configuration(token = ConnectCommon.getLocalAuthToken, isSslEnabled = Some(true))
429+
}
430+
.getOrElse(Configuration())
431+
}
426432

427433
def configuration: Configuration = _configuration
428434

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,23 @@
1616
*/
1717
package org.apache.spark.sql.connect.common.config
1818

19+
import io.grpc.Metadata
20+
1921
private[sql] object ConnectCommon {
2022
val CONNECT_GRPC_BINDING_PORT: Int = 15002
2123
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
2224
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
2325
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
26+
val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
27+
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
28+
29+
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
30+
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option(
31+
System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
32+
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
33+
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN
34+
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t =>
35+
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined)
36+
assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get)
37+
}
2438
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.service
19+
20+
import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}
21+
22+
import org.apache.spark.sql.connect.common.config.ConnectCommon
23+
24+
/**
25+
* A gRPC interceptor to check if the header contains token for authentication.
26+
*/
27+
class LocalAuthInterceptor extends ServerInterceptor {
28+
override def interceptCall[ReqT, RespT](
29+
call: ServerCall[ReqT, RespT],
30+
headers: Metadata,
31+
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
32+
ConnectCommon.assertLocalAuthToken(
33+
Option(headers.get(ConnectCommon.AUTH_TOKEN_META_DATA_KEY)))
34+
next.startCall(call, headers)
35+
}
36+
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ import org.apache.spark.internal.{Logging, MDC}
3939
import org.apache.spark.internal.LogKeys.HOST
4040
import org.apache.spark.internal.config.UI.UI_ENABLED
4141
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
42-
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
42+
import org.apache.spark.sql.connect.common.config.ConnectCommon
43+
import org.apache.spark.sql.connect.config.Connect._
4344
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
4445
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
4546
import org.apache.spark.sql.connect.utils.ErrorUtils
@@ -366,10 +367,14 @@ object SparkConnectService extends Logging {
366367
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
367368
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
368369
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
370+
val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)
371+
ConnectCommon.setLocalAuthToken(localAuthToken)
369372
val sparkConnectService = new SparkConnectService(debugMode)
370373
val protoReflectionService =
371374
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
372-
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
375+
val configuredInterceptors =
376+
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
377+
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil)
373378

374379
val startServiceFn = (port: Int) => {
375380
val sb = bindAddress match {

0 commit comments

Comments
 (0)