Skip to content

Commit 0558d14

Browse files
committed
Provide a basic authentication token when running Spark Connect server locally
1 parent 42ecabf commit 0558d14

File tree

7 files changed

+145
-33
lines changed

7 files changed

+145
-33
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class ChannelBuilder:
125125
PARAM_USER_ID = "user_id"
126126
PARAM_USER_AGENT = "user_agent"
127127
PARAM_SESSION_ID = "session_id"
128+
CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
128129

129130
GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024
130131

@@ -591,6 +592,8 @@ class SparkConnectClient(object):
591592
Conceptually the remote spark session that communicates with the server
592593
"""
593594

595+
_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN")
596+
594597
def __init__(
595598
self,
596599
connection: Union[str, ChannelBuilder],
@@ -637,6 +640,10 @@ def __init__(
637640
if isinstance(connection, ChannelBuilder)
638641
else DefaultChannelBuilder(connection, channel_options)
639642
)
643+
if SparkConnectClient._local_auth_token is not None:
644+
self._builder.set(
645+
ChannelBuilder.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME,
646+
SparkConnectClient._local_auth_token)
640647
self._user_id = None
641648
self._retry_policies: List[RetryPolicy] = []
642649

python/pyspark/sql/connect/session.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TYPE_CHECKING,
4141
ClassVar,
4242
)
43+
import uuid
4344

4445
import numpy as np
4546
import pandas as pd
@@ -893,13 +894,20 @@ def stop(self) -> None:
893894
if self is getattr(SparkSession._active_session, "session", None):
894895
SparkSession._active_session.session = None
895896

897+
# It should be `None` always on stop.
898+
SparkConnectClient._local_auth_token = None
899+
896900
if "SPARK_LOCAL_REMOTE" in os.environ:
897901
# When local mode is in use, follow the regular Spark session's
898902
# behavior by terminating the Spark Connect server,
899903
# meaning that you can stop local mode, and restart the Spark Connect
900904
# client with a different remote address.
901905
if PySparkSession._activeSession is not None:
902906
try:
907+
getattr(
908+
PySparkSession._activeSession._jvm,
909+
"org.apache.spark.sql.connect.common.config.ConnectCommon",
910+
).setLocalAuthToken(None)
903911
PySparkSession._activeSession.stop()
904912
except Exception as e:
905913
logger.warn(
@@ -1060,6 +1068,9 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10601068
overwrite_conf["spark.connect.grpc.binding.port"] = "0"
10611069

10621070
origin_remote = os.environ.get("SPARK_REMOTE", None)
1071+
local_auth_token = str(uuid.uuid4())
1072+
SparkConnectClient._local_auth_token = local_auth_token
1073+
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = local_auth_token
10631074
try:
10641075
if origin_remote is not None:
10651076
# So SparkSubmit thinks no remote is set in order to
@@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10721083
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
10731084
PySparkSession(SparkContext.getOrCreate(conf))
10741085

1086+
# In Python local mode, session.stop does not terminate JVM itself
1087+
# so we can't control it via environment variable.
1088+
getattr(
1089+
SparkContext._jvm,
1090+
"org.apache.spark.sql.connect.common.config.ConnectCommon",
1091+
).setLocalAuthToken(local_auth_token)
1092+
10751093
# Lastly only keep runtime configurations because other configurations are
10761094
# disallowed to set in the regular Spark Connect session.
10771095
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
@@ -1084,6 +1102,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10841102
if origin_remote is not None:
10851103
os.environ["SPARK_REMOTE"] = origin_remote
10861104
del os.environ["SPARK_LOCAL_CONNECT"]
1105+
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
10871106
else:
10881107
raise PySparkRuntimeError(
10891108
errorClass="SESSION_OR_CONTEXT_EXISTS",

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
4848
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
4949
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
5050
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
51+
import org.apache.spark.sql.connect.common.config.ConnectCommon
5152
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
5253
import org.apache.spark.sql.sources.BaseRelation
5354
import org.apache.spark.sql.types.StructType
@@ -627,6 +628,19 @@ class SparkSession private[sql] (
627628
}
628629
allocator.close()
629630
SparkSession.onSessionClose(this)
631+
SparkSession.server.synchronized {
632+
// It should be `None` always on stop.
633+
ConnectCommon.setLocalAuthToken(null)
634+
if (SparkSession.server.isDefined) {
635+
// When local mode is in use, follow the regular Spark session's
636+
// behavior by terminating the Spark Connect server,
637+
// meaning that you can stop local mode, and restart the Spark Connect
638+
// client with a different remote address.
639+
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
640+
.start()
641+
SparkSession.server = None
642+
}
643+
}
630644
}
631645

632646
/** @inheritdoc */
@@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
679693
private val MAX_CACHED_SESSIONS = 100
680694
private val planIdGenerator = new AtomicLong
681695
private var server: Option[Process] = None
696+
private val maybeConnectStartScript =
697+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
698+
private val maybeConnectStopScript =
699+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
682700
private[sql] val sparkOptions = sys.props.filter { p =>
683701
p._1.startsWith("spark.") && p._2.nonEmpty
684702
}.toMap
@@ -712,37 +730,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
712730
}
713731
}
714732

715-
val maybeConnectScript =
716-
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
717-
718-
if (server.isEmpty &&
719-
(remoteString.exists(_.startsWith("local")) ||
720-
(remoteString.isDefined && isAPIModeConnect)) &&
721-
maybeConnectScript.exists(Files.exists(_))) {
722-
server = Some {
723-
val args =
724-
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
725-
.filter(p => !p._1.startsWith("spark.remote"))
726-
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
727-
val pb = new ProcessBuilder(args: _*)
728-
// So don't exclude spark-sql jar in classpath
729-
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
730-
pb.start()
731-
}
732-
733-
// Let the server start. We will directly request to set the configurations
734-
// and this sleep makes less noisy with retries.
735-
Thread.sleep(2000L)
736-
System.setProperty("spark.remote", "sc://localhost")
737-
738-
// scalastyle:off runtimeaddshutdownhook
739-
Runtime.getRuntime.addShutdownHook(new Thread() {
740-
override def run(): Unit = if (server.isDefined) {
741-
new ProcessBuilder(maybeConnectScript.get.toString)
742-
.start()
733+
server.synchronized {
734+
if (server.isEmpty &&
735+
(remoteString.exists(_.startsWith("local")) ||
736+
(remoteString.isDefined && isAPIModeConnect)) &&
737+
maybeConnectStartScript.exists(Files.exists(_))) {
738+
val localAuthToken = java.util.UUID.randomUUID().toString()
739+
server = Some {
740+
ConnectCommon.setLocalAuthToken(localAuthToken)
741+
val args =
742+
Seq(
743+
maybeConnectStartScript.get.toString,
744+
"--master",
745+
remoteString.get) ++ sparkOptions
746+
.filter(p => !p._1.startsWith("spark.remote"))
747+
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
748+
val pb = new ProcessBuilder(args: _*)
749+
// So don't exclude spark-sql jar in classpath
750+
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
751+
pb.environment()
752+
.put(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME, localAuthToken)
753+
pb.start()
743754
}
744-
})
745-
// scalastyle:on runtimeaddshutdownhook
755+
756+
// Let the server start. We will directly request to set the configurations
757+
// and this sleep makes less noisy with retries.
758+
Thread.sleep(2000L)
759+
System.setProperty("spark.remote", "sc://localhost")
760+
761+
// scalastyle:off runtimeaddshutdownhook
762+
Runtime.getRuntime.addShutdownHook(new Thread() {
763+
override def run(): Unit = if (server.isDefined) {
764+
new ProcessBuilder(maybeConnectStopScript.get.toString)
765+
.start()
766+
}
767+
})
768+
// scalastyle:on runtimeaddshutdownhook
769+
}
746770
}
747771
}
748772
f

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,8 @@ object SparkConnectClient {
620620
* Configure the builder using the env SPARK_REMOTE environment variable.
621621
*/
622622
def loadFromEnvironment(): Builder = {
623+
ConnectCommon.getLocalAuthToken.foreach(t =>
624+
option(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, t))
623625
lazy val isAPIModeConnect =
624626
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
625627
.getOrElse("classic")

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,15 @@ private[sql] object ConnectCommon {
2121
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
2222
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
2323
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
24+
25+
val CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
26+
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
27+
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option(
28+
System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
29+
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
30+
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN
31+
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t =>
32+
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined)
33+
assert(t.substring("Bearer ".length) == CONNECT_LOCAL_AUTH_TOKEN.get)
34+
}
2435
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.service
19+
20+
import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}
21+
22+
import org.apache.spark.sql.connect.common.config.ConnectCommon
23+
24+
/**
25+
* A gRPC interceptor to check if the header contains token for authentication.
26+
*/
27+
class LocalAuthInterceptor extends ServerInterceptor {
28+
override def interceptCall[ReqT, RespT](
29+
call: ServerCall[ReqT, RespT],
30+
headers: Metadata,
31+
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
32+
ConnectCommon.assertLocalAuthToken(Option(headers.get(Metadata.Key
33+
.of(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, Metadata.ASCII_STRING_MARSHALLER))))
34+
next.startCall(call, headers)
35+
}
36+
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,21 @@
1717

1818
package org.apache.spark.sql.connect.service
1919

20+
import java.io.ByteArrayInputStream
2021
import java.net.InetSocketAddress
22+
import java.nio.charset.StandardCharsets
2123
import java.util.concurrent.TimeUnit
2224

2325
import scala.jdk.CollectionConverters._
2426

2527
import com.google.protobuf.Message
2628
import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition}
2729
import io.grpc.MethodDescriptor.PrototypeMarshaller
28-
import io.grpc.netty.NettyServerBuilder
30+
import io.grpc.netty.{GrpcSslContexts, NettyServerBuilder}
2931
import io.grpc.protobuf.ProtoUtils
3032
import io.grpc.protobuf.services.ProtoReflectionService
3133
import io.grpc.stub.StreamObserver
34+
import io.netty.handler.ssl.SslContext
3235
import org.apache.commons.lang3.StringUtils
3336

3437
import org.apache.spark.{SparkContext, SparkEnv}
@@ -39,7 +42,8 @@ import org.apache.spark.internal.{Logging, MDC}
3942
import org.apache.spark.internal.LogKeys.HOST
4043
import org.apache.spark.internal.config.UI.UI_ENABLED
4144
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
42-
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
45+
import org.apache.spark.sql.connect.common.config.ConnectCommon
46+
import org.apache.spark.sql.connect.config.Connect._
4347
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
4448
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
4549
import org.apache.spark.sql.connect.utils.ErrorUtils
@@ -366,10 +370,14 @@ object SparkConnectService extends Logging {
366370
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
367371
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
368372
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
373+
val localAuthToken = System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)
374+
ConnectCommon.setLocalAuthToken(localAuthToken)
369375
val sparkConnectService = new SparkConnectService(debugMode)
370376
val protoReflectionService =
371377
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
372-
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
378+
val configuredInterceptors =
379+
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
380+
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil)
373381

374382
val startServiceFn = (port: Int) => {
375383
val sb = bindAddress match {
@@ -388,6 +396,11 @@ object SparkConnectService extends Logging {
388396
// grpcurl can introspect the API for debugging.
389397
protoReflectionService.foreach(service => sb.addService(service))
390398

399+
Option(localAuthToken).foreach { t =>
400+
val token = new ByteArrayInputStream(t.getBytes(StandardCharsets.UTF_8))
401+
sb.sslContext(GrpcSslContexts.forServer(token, token).build())
402+
}
403+
391404
server = sb.build
392405
server.start()
393406

0 commit comments

Comments
 (0)