Skip to content

Commit 49bdde7

Browse files
committed
Provide a basic authentication token when running Spark Connect server locally
1 parent 53c326b commit 49bdde7

File tree

8 files changed

+123
-33
lines changed

8 files changed

+123
-33
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9304,6 +9304,11 @@
93049304
"No enough memory for aggregation"
93059305
]
93069306
},
9307+
"_LEGACY_ERROR_TEMP_3303" : {
9308+
"message" : [
9309+
"Invalid token."
9310+
]
9311+
},
93079312
"_LEGACY_ERROR_USER_RAISED_EXCEPTION" : {
93089313
"message" : [
93099314
"<errorMessage>"

python/pyspark/sql/connect/client/core.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
"SparkConnectClient",
2121
]
2222

23-
import atexit
24-
2523
from pyspark.sql.connect.utils import check_dependencies
2624

2725
check_dependencies(__name__)
2826

27+
import atexit
2928
import logging
3029
import threading
3130
import os
@@ -125,6 +124,7 @@ class ChannelBuilder:
125124
PARAM_USER_ID = "user_id"
126125
PARAM_USER_AGENT = "user_agent"
127126
PARAM_SESSION_ID = "session_id"
127+
CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
128128

129129
GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024
130130

@@ -591,6 +591,9 @@ class SparkConnectClient(object):
591591
Conceptually the remote spark session that communicates with the server
592592
"""
593593

594+
_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN", str(uuid.uuid4()))
595+
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = _local_auth_token
596+
594597
def __init__(
595598
self,
596599
connection: Union[str, ChannelBuilder],
@@ -637,6 +640,9 @@ def __init__(
637640
if isinstance(connection, ChannelBuilder)
638641
else DefaultChannelBuilder(connection, channel_options)
639642
)
643+
self._builder.set(
644+
ChannelBuilder.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME,
645+
SparkConnectClient._local_auth_token)
640646
self._user_id = None
641647
self._retry_policies: List[RetryPolicy] = []
642648

python/pyspark/sql/connect/session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TYPE_CHECKING,
4141
ClassVar,
4242
)
43+
import uuid
4344

4445
import numpy as np
4546
import pandas as pd

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
4848
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
4949
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
5050
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
51+
import org.apache.spark.sql.connect.common.config.ConnectCommon
5152
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
5253
import org.apache.spark.sql.sources.BaseRelation
5354
import org.apache.spark.sql.types.StructType
@@ -627,6 +628,17 @@ class SparkSession private[sql] (
627628
}
628629
allocator.close()
629630
SparkSession.onSessionClose(this)
631+
SparkSession.server.synchronized {
632+
if (SparkSession.server.isDefined) {
633+
// When local mode is in use, follow the regular Spark session's
634+
// behavior by terminating the Spark Connect server,
635+
// meaning that you can stop local mode, and restart the Spark Connect
636+
// client with a different remote address.
637+
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
638+
.start()
639+
SparkSession.server = None
640+
}
641+
}
630642
}
631643

632644
/** @inheritdoc */
@@ -679,6 +691,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
679691
private val MAX_CACHED_SESSIONS = 100
680692
private val planIdGenerator = new AtomicLong
681693
private var server: Option[Process] = None
694+
private val maybeConnectStartScript =
695+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
696+
private val maybeConnectStopScript =
697+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
682698
private[sql] val sparkOptions = sys.props.filter { p =>
683699
p._1.startsWith("spark.") && p._2.nonEmpty
684700
}.toMap
@@ -712,37 +728,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
712728
}
713729
}
714730

715-
val maybeConnectScript =
716-
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
717-
718-
if (server.isEmpty &&
719-
(remoteString.exists(_.startsWith("local")) ||
720-
(remoteString.isDefined && isAPIModeConnect)) &&
721-
maybeConnectScript.exists(Files.exists(_))) {
722-
server = Some {
723-
val args =
724-
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
725-
.filter(p => !p._1.startsWith("spark.remote"))
726-
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
727-
val pb = new ProcessBuilder(args: _*)
728-
// So don't exclude spark-sql jar in classpath
729-
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
730-
pb.start()
731-
}
732-
733-
// Let the server start. We will directly request to set the configurations
734-
// and this sleep makes less noisy with retries.
735-
Thread.sleep(2000L)
736-
System.setProperty("spark.remote", "sc://localhost")
737-
738-
// scalastyle:off runtimeaddshutdownhook
739-
Runtime.getRuntime.addShutdownHook(new Thread() {
740-
override def run(): Unit = if (server.isDefined) {
741-
new ProcessBuilder(maybeConnectScript.get.toString)
742-
.start()
731+
server.synchronized {
732+
if (server.isEmpty &&
733+
(remoteString.exists(_.startsWith("local")) ||
734+
(remoteString.isDefined && isAPIModeConnect)) &&
735+
maybeConnectStartScript.exists(Files.exists(_))) {
736+
server = Some {
737+
val args =
738+
Seq(
739+
maybeConnectStartScript.get.toString,
740+
"--master",
741+
remoteString.get) ++ sparkOptions
742+
.filter(p => !p._1.startsWith("spark.remote"))
743+
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
744+
val pb = new ProcessBuilder(args: _*)
745+
// So don't exclude spark-sql jar in classpath
746+
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
747+
pb.environment()
748+
.put(
749+
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME,
750+
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN)
751+
pb.start()
743752
}
744-
})
745-
// scalastyle:on runtimeaddshutdownhook
753+
754+
// Let the server start. We will directly request to set the configurations
755+
// and this sleep makes less noisy with retries.
756+
Thread.sleep(2000L)
757+
System.setProperty("spark.remote", "sc://localhost")
758+
759+
// scalastyle:off runtimeaddshutdownhook
760+
Runtime.getRuntime.addShutdownHook(new Thread() {
761+
override def run(): Unit = if (server.isDefined) {
762+
new ProcessBuilder(maybeConnectStopScript.get.toString)
763+
.start()
764+
}
765+
})
766+
// scalastyle:on runtimeaddshutdownhook
767+
}
746768
}
747769
}
748770
f

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,9 @@ object SparkConnectClient {
620620
* Configure the builder using the env SPARK_REMOTE environment variable.
621621
*/
622622
def loadFromEnvironment(): Builder = {
623+
option(
624+
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME,
625+
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN)
623626
lazy val isAPIModeConnect =
624627
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
625628
.getOrElse("classic")

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,10 @@ private[sql] object ConnectCommon {
2121
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
2222
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
2323
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
24+
25+
val CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
26+
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
27+
val CONNECT_LOCAL_AUTH_TOKEN: String =
28+
Option(System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
29+
.getOrElse(java.util.UUID.randomUUID().toString())
2430
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.service
19+
20+
import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}
21+
22+
import org.apache.spark.SparkSecurityException
23+
import org.apache.spark.sql.connect.common.config.ConnectCommon
24+
25+
/**
26+
* A gRPC interceptor to check if the header contains token for local authentication.
27+
*/
28+
class LocalAuthInterceptor(localToken: String) extends ServerInterceptor {
29+
override def interceptCall[ReqT, RespT](
30+
call: ServerCall[ReqT, RespT],
31+
headers: Metadata,
32+
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
33+
val t = Option(
34+
headers.get(Metadata.Key
35+
.of(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, Metadata.ASCII_STRING_MARSHALLER)))
36+
.map(_.substring("Bearer ".length))
37+
if (t.isEmpty || t.get != localToken) {
38+
throw new SparkSecurityException(
39+
errorClass = "_LEGACY_ERROR_TEMP_3303",
40+
messageParameters = Map.empty)
41+
}
42+
next.startCall(call, headers)
43+
}
44+
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import org.apache.spark.internal.{Logging, MDC}
3939
import org.apache.spark.internal.LogKeys.HOST
4040
import org.apache.spark.internal.config.UI.UI_ENABLED
4141
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
42+
import org.apache.spark.sql.connect.common.config.ConnectCommon
4243
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
4344
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
4445
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
@@ -369,7 +370,9 @@ object SparkConnectService extends Logging {
369370
val sparkConnectService = new SparkConnectService(debugMode)
370371
val protoReflectionService =
371372
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
372-
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
373+
val configuredInterceptors =
374+
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
375+
ConnectCommon.localAuthToken.map(new LocalAuthInterceptor(_))
373376

374377
val startServiceFn = (port: Int) => {
375378
val sb = bindAddress match {

0 commit comments

Comments
 (0)