Skip to content

Commit 7e9547c

Browse files
[SPARK-51156][CONNECT] Static token authentication support in Spark Connect
### What changes were proposed in this pull request? Adds static token authentication support to Spark Connect, which is used by default for automatically started servers locally. ### Why are the changes needed? To add authentication support to Spark Connect so a connect server isn't started that could be accessible to other users inadvertently. ### Does this PR introduce _any_ user-facing change? The local authentication should be transparent to users, but adds the option for users manually starting connect servers to specify an authentication token. ### How was this patch tested? Existing UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #50006 from Kimahriman/spark-connect-local-auth. Lead-authored-by: Adam Binford <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 30f4f4e commit 7e9547c

File tree

15 files changed

+197
-33
lines changed

15 files changed

+197
-33
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def userId(self) -> Optional[str]:
220220

221221
@property
222222
def token(self) -> Optional[str]:
223-
return self._params.get(ChannelBuilder.PARAM_TOKEN, None)
223+
return self._params.get(
224+
ChannelBuilder.PARAM_TOKEN, os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN")
225+
)
224226

225227
def metadata(self) -> Iterable[Tuple[str, str]]:
226228
"""
@@ -410,10 +412,11 @@ def _extract_attributes(self) -> None:
410412

411413
@property
412414
def secure(self) -> bool:
413-
return (
414-
self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"
415-
or self.token is not None
416-
)
415+
return self.use_ssl or self.token is not None
416+
417+
@property
418+
def use_ssl(self) -> bool:
419+
return self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"
417420

418421
@property
419422
def host(self) -> str:
@@ -439,14 +442,20 @@ def toChannel(self) -> grpc.Channel:
439442

440443
if not self.secure:
441444
return self._insecure_channel(self.endpoint)
445+
elif not self.use_ssl and self._host == "localhost":
446+
creds = grpc.local_channel_credentials()
447+
448+
if self.token is not None:
449+
creds = grpc.composite_channel_credentials(
450+
creds, grpc.access_token_call_credentials(self.token)
451+
)
452+
return self._secure_channel(self.endpoint, creds)
442453
else:
443-
ssl_creds = grpc.ssl_channel_credentials()
454+
creds = grpc.ssl_channel_credentials()
444455

445-
if self.token is None:
446-
creds = ssl_creds
447-
else:
456+
if self.token is not None:
448457
creds = grpc.composite_channel_credentials(
449-
ssl_creds, grpc.access_token_call_credentials(self.token)
458+
creds, grpc.access_token_call_credentials(self.token)
450459
)
451460

452461
return self._secure_channel(self.endpoint, creds)

python/pyspark/sql/connect/session.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import uuid
1718
from pyspark.sql.connect.utils import check_dependencies
1819

1920
check_dependencies(__name__)
@@ -1030,6 +1031,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10301031
10311032
2. Starts a regular Spark session that automatically starts a Spark Connect server
10321033
via ``spark.plugins`` feature.
1034+
1035+
Returns the authentication token that should be used to connect to this session.
10331036
"""
10341037
from pyspark import SparkContext, SparkConf
10351038

@@ -1049,6 +1052,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10491052
if "spark.api.mode" in overwrite_conf:
10501053
del overwrite_conf["spark.api.mode"]
10511054

1055+
# Check for a user provided authentication token, creating a new one if not,
1056+
# and make sure it's set in the environment,
1057+
if "SPARK_CONNECT_AUTHENTICATE_TOKEN" not in os.environ:
1058+
os.environ["SPARK_CONNECT_AUTHENTICATE_TOKEN"] = opts.get(
1059+
"spark.connect.authenticate.token", str(uuid.uuid4())
1060+
)
1061+
10521062
# Configurations to be set if unset.
10531063
default_conf = {
10541064
"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin",
@@ -1081,7 +1091,6 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10811091
new_opts = {k: opts[k] for k in opts if k in runtime_conf_keys}
10821092
opts.clear()
10831093
opts.update(new_opts)
1084-
10851094
finally:
10861095
if origin_remote is not None:
10871096
os.environ["SPARK_REMOTE"] = origin_remote

python/pyspark/sql/tests/connect/test_connect_session.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from pyspark.errors.exceptions.connect import (
4444
AnalysisException,
4545
SparkConnectException,
46+
SparkConnectGrpcException,
4647
SparkUpgradeException,
4748
)
4849

@@ -237,7 +238,13 @@ def test_custom_channel_builder(self):
237238

238239
class CustomChannelBuilder(ChannelBuilder):
239240
def toChannel(self):
240-
return self._insecure_channel(endpoint)
241+
creds = grpc.local_channel_credentials()
242+
243+
if self.token is not None:
244+
creds = grpc.composite_channel_credentials(
245+
creds, grpc.access_token_call_credentials(self.token)
246+
)
247+
return self._secure_channel(endpoint, creds)
241248

242249
session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
243250
session.sql("select 1 + 1")
@@ -290,6 +297,15 @@ def test_api_mode(self):
290297
self.assertEqual(session.range(1).first()[0], 0)
291298
self.assertIsInstance(session, RemoteSparkSession)
292299

300+
def test_authentication(self):
301+
# All servers start with a default token of "deadbeef", so supply in invalid one
302+
session = RemoteSparkSession.builder.remote("sc://localhost/;token=invalid").create()
303+
304+
with self.assertRaises(SparkConnectGrpcException) as e:
305+
session.range(3).collect()
306+
307+
self.assertTrue("Invalid authentication token" in str(e.exception))
308+
293309

294310
@unittest.skipIf(not should_test_connect, connect_requirement_message)
295311
class SparkConnectSessionWithOptionsTest(unittest.TestCase):

python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import unittest
2525
from typing import cast
2626

27-
from pyspark import SparkConf
2827
from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
2928
from pyspark.sql.types import (
3029
LongType,
@@ -56,7 +55,7 @@
5655
class GroupedApplyInPandasWithStateTestsMixin:
5756
@classmethod
5857
def conf(cls):
59-
cfg = SparkConf()
58+
cfg = super().conf()
6059
cfg.set("spark.sql.shuffle.partitions", "5")
6160
return cfg
6261

python/pyspark/testing/connectutils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def conf(cls):
155155
conf._jconf.remove("spark.master")
156156
conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration", "1s")
157157
conf.set("spark.connect.execute.reattachable.senderMaxStreamSize", "123")
158+
# Set a static token for all tests so the parallelism doesn't overwrite each
159+
# tests' environment variables
160+
conf.set("spark.connect.authenticate.token", "deadbeef")
158161
return conf
159162

160163
@classmethod

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite {
125125
assert(builder.host === "localhost")
126126
assert(builder.port === 15002)
127127
assert(builder.userAgent.contains("_SPARK_CONNECT_SCALA"))
128-
assert(builder.sslEnabled)
128+
assert(!builder.sslEnabled)
129129
assert(builder.token.contains("thisismysecret"))
130130
assert(builder.userId.isEmpty)
131131
assert(builder.userName.isEmpty)

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
299299
TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", isCorrect = true),
300300
TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true),
301301
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true),
302-
TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false),
303-
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = false),
302+
TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = true),
303+
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = true),
304304
TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect = true),
305305
TestPackURI(
306306
"sc://SPARK-45486",

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
762762
(remoteString.exists(_.startsWith("local")) ||
763763
(remoteString.isDefined && isAPIModeConnect)) &&
764764
maybeConnectStartScript.exists(Files.exists(_))) {
765+
val token = java.util.UUID.randomUUID().toString()
765766
val serverId = UUID.randomUUID().toString
766767
server = Some {
767768
val args =
@@ -779,6 +780,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
779780
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
780781
pb.environment().put("SPARK_IDENT_STRING", serverId)
781782
pb.environment().put("HOSTNAME", "local")
783+
pb.environment().put("SPARK_CONNECT_AUTHENTICATE_TOKEN", token)
782784
pb.start()
783785
}
784786

@@ -800,7 +802,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
800802
}
801803
}
802804

803-
System.setProperty("spark.remote", "sc://localhost")
805+
System.setProperty("spark.remote", s"sc://localhost/;token=$token")
804806

805807
// scalastyle:off runtimeaddshutdownhook
806808
Runtime.getRuntime.addShutdownHook(new Thread() {

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -468,20 +468,14 @@ object SparkConnectClient {
468468
* sc://localhost/;token=aaa;use_ssl=true
469469
* }}}
470470
*
471-
* Throws exception if the token is set but use_ssl=false.
472-
*
473471
* @param inputToken
474472
* the user token.
475473
* @return
476474
* this builder.
477475
*/
478476
def token(inputToken: String): Builder = {
479477
require(inputToken != null && inputToken.nonEmpty)
480-
if (_configuration.isSslEnabled.contains(false)) {
481-
throw new IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
482-
}
483-
_configuration =
484-
_configuration.copy(token = Option(inputToken), isSslEnabled = Option(true))
478+
_configuration = _configuration.copy(token = Option(inputToken))
485479
this
486480
}
487481

@@ -499,7 +493,6 @@ object SparkConnectClient {
499493
* this builder.
500494
*/
501495
def disableSsl(): Builder = {
502-
require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
503496
_configuration = _configuration.copy(isSslEnabled = Option(false))
504497
this
505498
}
@@ -737,6 +730,8 @@ object SparkConnectClient {
737730
grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
738731
grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {
739732

733+
private def isLocal = host.equals("localhost")
734+
740735
def userContext: proto.UserContext = {
741736
val builder = proto.UserContext.newBuilder()
742737
if (userId != null) {
@@ -749,7 +744,7 @@ object SparkConnectClient {
749744
}
750745

751746
def credentials: ChannelCredentials = {
752-
if (isSslEnabled.contains(true)) {
747+
if (isSslEnabled.contains(true) || (token.isDefined && !isLocal)) {
753748
token match {
754749
case Some(t) =>
755750
// With access token added in the http header.
@@ -765,10 +760,18 @@ object SparkConnectClient {
765760
}
766761

767762
def createChannel(): ManagedChannel = {
768-
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, credentials)
763+
val creds = credentials
764+
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds)
765+
766+
// Workaround LocalChannelCredentials are added in
767+
// https://github.com/grpc/grpc-java/issues/9900
768+
var metadataWithOptionalToken = metadata
769+
if (!isSslEnabled.contains(true) && isLocal && token.isDefined) {
770+
metadataWithOptionalToken = metadata + (("Authorization", s"Bearer ${token.get}"))
771+
}
769772

770-
if (metadata.nonEmpty) {
771-
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
773+
if (metadataWithOptionalToken.nonEmpty) {
774+
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadataWithOptionalToken))
772775
}
773776

774777
interceptors.foreach(channelBuilder.intercept(_))

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config
1818

1919
import java.util.concurrent.TimeUnit
2020

21+
import org.apache.spark.SparkEnv
2122
import org.apache.spark.network.util.ByteUnit
2223
import org.apache.spark.sql.connect.common.config.ConnectCommon
2324
import org.apache.spark.sql.internal.SQLConf
@@ -313,4 +314,21 @@ object Connect {
313314
.internal()
314315
.booleanConf
315316
.createWithDefault(true)
317+
318+
val CONNECT_AUTHENTICATE_TOKEN =
319+
buildStaticConf("spark.connect.authenticate.token")
320+
.doc("A pre-shared token that will be used to authenticate clients. This secret must be" +
321+
" passed as a bearer token by for clients to connect.")
322+
.version("4.0.0")
323+
.internal()
324+
.stringConf
325+
.createOptional
326+
327+
val CONNECT_AUTHENTICATE_TOKEN_ENV = "SPARK_CONNECT_AUTHENTICATE_TOKEN"
328+
329+
def getAuthenticateToken: Option[String] = {
330+
SparkEnv.get.conf.get(CONNECT_AUTHENTICATE_TOKEN).orElse {
331+
Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV))
332+
}
333+
}
316334
}

0 commit comments

Comments
 (0)