Skip to content

Commit d368742

Browse files
committed
Provide a basic authentication token when running Spark Connect server locally
1 parent 42ecabf commit d368742

File tree

6 files changed

+144
-35
lines changed

6 files changed

+144
-35
lines changed

python/pyspark/sql/connect/session.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TYPE_CHECKING,
4141
ClassVar,
4242
)
43+
import uuid
4344

4445
import numpy as np
4546
import pandas as pd
@@ -53,7 +54,7 @@
5354
from pyspark.sql.connect.dataframe import DataFrame
5455
from pyspark.sql.dataframe import DataFrame as ParentDataFrame
5556
from pyspark.sql.connect.logging import logger
56-
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder
57+
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder, ChannelBuilder
5758
from pyspark.sql.connect.conf import RuntimeConf
5859
from pyspark.sql.connect.plan import (
5960
SQL,
@@ -120,6 +121,7 @@ class SparkSession:
120121
# Reference to the root SparkSession
121122
_default_session: ClassVar[Optional["SparkSession"]] = None
122123
_lock: ClassVar[RLock] = RLock()
124+
_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN")
123125

124126
class Builder:
125127
"""Builder for :class:`SparkSession`."""
@@ -238,6 +240,11 @@ def create(self) -> "SparkSession":
238240
else:
239241
spark_remote = to_str(self._options.get("spark.remote"))
240242
assert spark_remote is not None
243+
if SparkSession._local_auth_token is not None:
244+
spark_remote = DefaultChannelBuilder(spark_remote) # type: ignore[assignment]
245+
spark_remote.set( # type: ignore[attr-defined]
246+
ChannelBuilder.PARAM_TOKEN, SparkSession._local_auth_token
247+
)
241248
session = SparkSession(connection=spark_remote)
242249

243250
SparkSession._set_default_and_active_session(session)
@@ -893,13 +900,20 @@ def stop(self) -> None:
893900
if self is getattr(SparkSession._active_session, "session", None):
894901
SparkSession._active_session.session = None
895902

903+
# It should be `None` always on stop.
904+
SparkSession._local_auth_token = None
905+
896906
if "SPARK_LOCAL_REMOTE" in os.environ:
897907
# When local mode is in use, follow the regular Spark session's
898908
# behavior by terminating the Spark Connect server,
899909
# meaning that you can stop local mode, and restart the Spark Connect
900910
# client with a different remote address.
901911
if PySparkSession._activeSession is not None:
902912
try:
913+
getattr(
914+
PySparkSession._activeSession._jvm,
915+
"org.apache.spark.sql.connect.common.config.ConnectCommon",
916+
).setLocalAuthToken(None)
903917
PySparkSession._activeSession.stop()
904918
except Exception as e:
905919
logger.warn(
@@ -1060,6 +1074,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10601074
overwrite_conf["spark.connect.grpc.binding.port"] = "0"
10611075

10621076
origin_remote = os.environ.get("SPARK_REMOTE", None)
1077+
local_auth_token = str(uuid.uuid4())
1078+
SparkSession._local_auth_token = local_auth_token
1079+
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token
10631080
try:
10641081
if origin_remote is not None:
10651082
# So SparkSubmit thinks no remote is set in order to
@@ -1072,6 +1089,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10721089
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
10731090
PySparkSession(SparkContext.getOrCreate(conf))
10741091

1092+
# In Python local mode, session.stop does not terminate JVM itself
1093+
# so we can't control it via environment variable.
1094+
getattr(
1095+
SparkContext._jvm,
1096+
"org.apache.spark.sql.connect.common.config.ConnectCommon",
1097+
).setLocalAuthToken(local_auth_token)
1098+
10751099
# Lastly only keep runtime configurations because other configurations are
10761100
# disallowed to set in the regular Spark Connect session.
10771101
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
@@ -1084,6 +1108,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10841108
if origin_remote is not None:
10851109
os.environ["SPARK_REMOTE"] = origin_remote
10861110
del os.environ["SPARK_LOCAL_CONNECT"]
1111+
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
10871112
else:
10881113
raise PySparkRuntimeError(
10891114
errorClass="SESSION_OR_CONTEXT_EXISTS",

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
4848
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
4949
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
5050
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
51+
import org.apache.spark.sql.connect.common.config.ConnectCommon
5152
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
5253
import org.apache.spark.sql.sources.BaseRelation
5354
import org.apache.spark.sql.types.StructType
@@ -627,6 +628,19 @@ class SparkSession private[sql] (
627628
}
628629
allocator.close()
629630
SparkSession.onSessionClose(this)
631+
SparkSession.server.synchronized {
632+
// It should be `None` always on stop.
633+
ConnectCommon.setLocalAuthToken(null)
634+
if (SparkSession.server.isDefined) {
635+
// When local mode is in use, follow the regular Spark session's
636+
// behavior by terminating the Spark Connect server,
637+
// meaning that you can stop local mode, and restart the Spark Connect
638+
// client with a different remote address.
639+
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
640+
.start()
641+
SparkSession.server = None
642+
}
643+
}
630644
}
631645

632646
/** @inheritdoc */
@@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
679693
private val MAX_CACHED_SESSIONS = 100
680694
private val planIdGenerator = new AtomicLong
681695
private var server: Option[Process] = None
696+
private val maybeConnectStartScript =
697+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
698+
private val maybeConnectStopScript =
699+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
682700
private[sql] val sparkOptions = sys.props.filter { p =>
683701
p._1.startsWith("spark.") && p._2.nonEmpty
684702
}.toMap
@@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
712730
}
713731
}
714732

715-
val maybeConnectScript =
716-
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
717-
718-
if (server.isEmpty &&
719-
(remoteString.exists(_.startsWith("local")) ||
720-
(remoteString.isDefined && isAPIModeConnect)) &&
721-
maybeConnectScript.exists(Files.exists(_))) {
722-
server = Some {
723-
val args =
724-
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
725-
.filter(p => !p._1.startsWith("spark.remote"))
726-
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
727-
val pb = new ProcessBuilder(args: _*)
728-
// So don't exclude spark-sql jar in classpath
729-
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
730-
pb.start()
731-
}
732-
733-
// Let the server start. We will directly request to set the configurations
734-
// and this sleep makes less noisy with retries.
735-
Thread.sleep(2000L)
736-
System.setProperty("spark.remote", "sc://localhost")
737-
738-
// scalastyle:off runtimeaddshutdownhook
739-
Runtime.getRuntime.addShutdownHook(new Thread() {
740-
override def run(): Unit = if (server.isDefined) {
741-
new ProcessBuilder(maybeConnectScript.get.toString)
742-
.start()
733+
server.synchronized {
734+
if (server.isEmpty &&
735+
(remoteString.exists(_.startsWith("local")) ||
736+
(remoteString.isDefined && isAPIModeConnect)) &&
737+
maybeConnectStartScript.exists(Files.exists(_))) {
738+
val localAuthToken = java.util.UUID.randomUUID().toString()
739+
server = Some {
740+
ConnectCommon.setLocalAuthToken(localAuthToken)
741+
val args =
742+
Seq(
743+
maybeConnectStartScript.get.toString,
744+
"--master",
745+
remoteString.get) ++ sparkOptions
746+
.filter(p => !p._1.startsWith("spark.remote"))
747+
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
748+
val pb = new ProcessBuilder(args: _*)
749+
// So don't exclude spark-sql jar in classpath
750+
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
751+
pb.environment()
752+
.put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken)
753+
pb.start()
743754
}
744-
})
745-
// scalastyle:on runtimeaddshutdownhook
755+
756+
// Let the server start. We will directly request to set the configurations
757+
// and this sleep makes less noisy with retries.
758+
Thread.sleep(2000L)
759+
System.setProperty("spark.remote", "sc://localhost")
760+
761+
// scalastyle:off runtimeaddshutdownhook
762+
Runtime.getRuntime.addShutdownHook(new Thread() {
763+
override def run(): Unit = if (server.isDefined) {
764+
new ProcessBuilder(maybeConnectStopScript.get.toString)
765+
.start()
766+
}
767+
})
768+
// scalastyle:on runtimeaddshutdownhook
769+
}
746770
}
747771
}
748772
f

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ object SparkConnectClient {
404404
private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA"
405405

406406
private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
407-
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
407+
ConnectCommon.AUTH_TOKEN_META_DATA_KEY
408408

409409
private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String =
410410
"Authentication token cannot be passed over insecure connections. " +
@@ -422,7 +422,13 @@ object SparkConnectClient {
422422
* port or a NameResolver-compliant URI connection string.
423423
*/
424424
class Builder(private var _configuration: Configuration) {
425-
def this() = this(Configuration())
425+
def this() = this {
426+
ConnectCommon.getLocalAuthToken
427+
.map { _ =>
428+
Configuration(token = ConnectCommon.getLocalAuthToken, isSslEnabled = Some(true))
429+
}
430+
.getOrElse(Configuration())
431+
}
426432

427433
def configuration: Configuration = _configuration
428434

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,22 @@
1616
*/
1717
package org.apache.spark.sql.connect.common.config
1818

19+
import io.grpc.Metadata
20+
1921
private[sql] object ConnectCommon {
2022
val CONNECT_GRPC_BINDING_PORT: Int = 15002
2123
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
2224
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
2325
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
26+
val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
27+
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
28+
29+
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
30+
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = None
31+
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
32+
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN
33+
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t =>
34+
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined)
35+
assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get)
36+
}
2437
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.service
19+
20+
import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}
21+
22+
import org.apache.spark.sql.connect.common.config.ConnectCommon
23+
24+
/**
25+
* A gRPC interceptor to check if the header contains token for authentication.
26+
*/
27+
class LocalAuthInterceptor extends ServerInterceptor {
28+
override def interceptCall[ReqT, RespT](
29+
call: ServerCall[ReqT, RespT],
30+
headers: Metadata,
31+
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
32+
ConnectCommon.assertLocalAuthToken(
33+
Option(headers.get(ConnectCommon.AUTH_TOKEN_META_DATA_KEY)))
34+
next.startCall(call, headers)
35+
}
36+
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ import org.apache.spark.internal.{Logging, MDC}
3939
import org.apache.spark.internal.LogKeys.HOST
4040
import org.apache.spark.internal.config.UI.UI_ENABLED
4141
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
42-
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
42+
import org.apache.spark.sql.connect.common.config.ConnectCommon
43+
import org.apache.spark.sql.connect.config.Connect._
4344
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
4445
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
4546
import org.apache.spark.sql.connect.utils.ErrorUtils
@@ -366,10 +367,14 @@ object SparkConnectService extends Logging {
366367
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
367368
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
368369
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
370+
val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)
371+
ConnectCommon.setLocalAuthToken(localAuthToken)
369372
val sparkConnectService = new SparkConnectService(debugMode)
370373
val protoReflectionService =
371374
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
372-
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
375+
val configuredInterceptors =
376+
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
377+
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil)
373378

374379
val startServiceFn = (port: Int) => {
375380
val sb = bindAddress match {

0 commit comments

Comments
 (0)