Skip to content
Closed
29 changes: 19 additions & 10 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def userId(self) -> Optional[str]:

@property
def token(self) -> Optional[str]:
return self._params.get(ChannelBuilder.PARAM_TOKEN, None)
return self._params.get(
ChannelBuilder.PARAM_TOKEN, os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN")
)

def metadata(self) -> Iterable[Tuple[str, str]]:
"""
Expand Down Expand Up @@ -410,10 +412,11 @@ def _extract_attributes(self) -> None:

@property
def secure(self) -> bool:
return (
self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"
or self.token is not None
)
return self.use_ssl or self.token is not None

@property
def use_ssl(self) -> bool:
return self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"

@property
def host(self) -> str:
Expand All @@ -439,14 +442,20 @@ def toChannel(self) -> grpc.Channel:

if not self.secure:
return self._insecure_channel(self.endpoint)
elif not self.use_ssl and self._host == "localhost":
creds = grpc.local_channel_credentials()

if self.token is not None:
creds = grpc.composite_channel_credentials(
creds, grpc.access_token_call_credentials(self.token)
)
return self._secure_channel(self.endpoint, creds)
else:
ssl_creds = grpc.ssl_channel_credentials()
creds = grpc.ssl_channel_credentials()

if self.token is None:
creds = ssl_creds
else:
if self.token is not None:
creds = grpc.composite_channel_credentials(
ssl_creds, grpc.access_token_call_credentials(self.token)
creds, grpc.access_token_call_credentials(self.token)
)

return self._secure_channel(self.endpoint, creds)
Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)
Expand Down Expand Up @@ -1030,6 +1031,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:

2. Starts a regular Spark session that automatically starts a Spark Connect server
via ``spark.plugins`` feature.

Returns the authentication token that should be used to connect to this session.
"""
from pyspark import SparkContext, SparkConf

Expand All @@ -1049,6 +1052,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
if "spark.api.mode" in overwrite_conf:
del overwrite_conf["spark.api.mode"]

# Check for a user provided authentication token, creating a new one if not,
# and make sure it's set in the environment,
if "SPARK_CONNECT_AUTHENTICATE_TOKEN" not in os.environ:
os.environ["SPARK_CONNECT_AUTHENTICATE_TOKEN"] = opts.get(
"spark.connect.authenticate.token", str(uuid.uuid4())
)

# Configurations to be set if unset.
default_conf = {
"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin",
Expand Down Expand Up @@ -1081,7 +1091,6 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
new_opts = {k: opts[k] for k in opts if k in runtime_conf_keys}
opts.clear()
opts.update(new_opts)

finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
Expand Down
18 changes: 17 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pyspark.errors.exceptions.connect import (
AnalysisException,
SparkConnectException,
SparkConnectGrpcException,
SparkUpgradeException,
)

Expand Down Expand Up @@ -237,7 +238,13 @@ def test_custom_channel_builder(self):

class CustomChannelBuilder(ChannelBuilder):
def toChannel(self):
return self._insecure_channel(endpoint)
creds = grpc.local_channel_credentials()

if self.token is not None:
creds = grpc.composite_channel_credentials(
creds, grpc.access_token_call_credentials(self.token)
)
return self._secure_channel(endpoint, creds)

session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
session.sql("select 1 + 1")
Expand Down Expand Up @@ -290,6 +297,15 @@ def test_api_mode(self):
self.assertEqual(session.range(1).first()[0], 0)
self.assertIsInstance(session, RemoteSparkSession)

def test_authentication(self):
# All servers start with a default token of "deadbeef", so supply in invalid one
session = RemoteSparkSession.builder.remote("sc://localhost/;token=invalid").create()

with self.assertRaises(SparkConnectGrpcException) as e:
session.range(3).collect()

self.assertTrue("Invalid authentication token" in str(e.exception))


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectSessionWithOptionsTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import unittest
from typing import cast

from pyspark import SparkConf
from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
from pyspark.sql.types import (
LongType,
Expand Down Expand Up @@ -56,7 +55,7 @@
class GroupedApplyInPandasWithStateTestsMixin:
@classmethod
def conf(cls):
cfg = SparkConf()
cfg = super().conf()
cfg.set("spark.sql.shuffle.partitions", "5")
return cfg

Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def conf(cls):
conf._jconf.remove("spark.master")
conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration", "1s")
conf.set("spark.connect.execute.reattachable.senderMaxStreamSize", "123")
# Set a static token for all tests so the parallelism doesn't overwrite each
# tests' environment variables
conf.set("spark.connect.authenticate.token", "deadbeef")
return conf

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite {
assert(builder.host === "localhost")
assert(builder.port === 15002)
assert(builder.userAgent.contains("_SPARK_CONNECT_SCALA"))
assert(builder.sslEnabled)
assert(!builder.sslEnabled)
assert(builder.token.contains("thisismysecret"))
assert(builder.userId.isEmpty)
assert(builder.userName.isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", isCorrect = true),
TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true),
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true),
TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false),
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = false),
TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = true),
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = true),
TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect = true),
TestPackURI(
"sc://SPARK-45486",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
val token = java.util.UUID.randomUUID().toString()
val serverId = UUID.randomUUID().toString
server = Some {
val args =
Expand All @@ -779,6 +780,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment().put("SPARK_IDENT_STRING", serverId)
pb.environment().put("HOSTNAME", "local")
pb.environment().put("SPARK_CONNECT_AUTHENTICATE_TOKEN", token)
pb.start()
}

Expand All @@ -800,7 +802,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}

System.setProperty("spark.remote", "sc://localhost")
System.setProperty("spark.remote", s"sc://localhost/;token=$token")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,20 +468,14 @@ object SparkConnectClient {
* sc://localhost/;token=aaa;use_ssl=true
* }}}
*
* Throws exception if the token is set but use_ssl=false.
*
* @param inputToken
* the user token.
* @return
* this builder.
*/
def token(inputToken: String): Builder = {
require(inputToken != null && inputToken.nonEmpty)
if (_configuration.isSslEnabled.contains(false)) {
throw new IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
}
_configuration =
_configuration.copy(token = Option(inputToken), isSslEnabled = Option(true))
_configuration = _configuration.copy(token = Option(inputToken))
this
}

Expand All @@ -499,7 +493,6 @@ object SparkConnectClient {
* this builder.
*/
def disableSsl(): Builder = {
require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
_configuration = _configuration.copy(isSslEnabled = Option(false))
this
}
Expand Down Expand Up @@ -737,6 +730,8 @@ object SparkConnectClient {
grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {

private def isLocal = host.equals("localhost")

def userContext: proto.UserContext = {
val builder = proto.UserContext.newBuilder()
if (userId != null) {
Expand All @@ -749,7 +744,7 @@ object SparkConnectClient {
}

def credentials: ChannelCredentials = {
if (isSslEnabled.contains(true)) {
if (isSslEnabled.contains(true) || (token.isDefined && !isLocal)) {
token match {
case Some(t) =>
// With access token added in the http header.
Expand All @@ -765,10 +760,18 @@ object SparkConnectClient {
}

def createChannel(): ManagedChannel = {
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, credentials)
val creds = credentials
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds)

// Workaround LocalChannelCredentials are added in
// https://github.com/grpc/grpc-java/issues/9900
var metadataWithOptionalToken = metadata
if (!isSslEnabled.contains(true) && isLocal && token.isDefined) {
metadataWithOptionalToken = metadata + (("Authorization", s"Bearer ${token.get}"))
}

if (metadata.nonEmpty) {
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
if (metadataWithOptionalToken.nonEmpty) {
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadataWithOptionalToken))
}

interceptors.foreach(channelBuilder.intercept(_))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config

import java.util.concurrent.TimeUnit

import org.apache.spark.SparkEnv
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -313,4 +314,21 @@ object Connect {
.internal()
.booleanConf
.createWithDefault(true)

val CONNECT_AUTHENTICATE_TOKEN =
buildStaticConf("spark.connect.authenticate.token")
Copy link
Member

@HyukjinKwon HyukjinKwon Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#49880 (comment)

I also still don't like it to be as a conf. Are we sure that this is not shown in ps? Those will be passed to Spark Submit through Py4J server launcher to start Spark Connect server.

This is different with other cases because it will be always down in ps command vs other configurations are set in spark.conf file in general.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am being difficult particularly on this because we're adding this mainly to block the access from an arbitrary user .. but it will be pretty useless if this can be just seen with simple ps aux..

Copy link
Contributor Author

@Kimahriman Kimahriman Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I see what you mean, I was not expecting the conf vars set on the Python SparkConf to end up on the command line, I thought it was just passed in memory through py4j. Let me think about that.

There are a lot of use cases to consider here though that I'm trying to encompass, which is why the dual config/env var setup.

  • pyspark --remote local / spark-shell --remote local: For the most part I assume this is for testing purposes. I'm not sure what the use case is for multiple users being on the same machine but they need to prevent others from connecting to their own session. This would be the case where seeing the ps output could be thought of as a security hole. Having the authentication at least prevents users on other servers from remotely accessing this connect server
  • pyspark --conf spark.api.mode=connect / spark-shell --conf spark.api.mode-connect: I think this is effectively the same thing as the previous use case
  • spark-submit --deploy-mode client --conf spark.api.mode=connect: Kinda similar to the previous two. multiple users on the same machine where a job is submitted from, but you don't want them to access your own sessions. I guess if multiple users can remotely start sessions this way on a dedicated server, you could see the ps output from the Spark driver. I don't think this method would show the token on the command line, but I would need to verify
  • spark-submit --deploy-mode cluster --conf spark.api.mode=connect: This is the case I am most worried about from a security perspective. You are launching a driver in a shared compute cluster, so anyone else on that cluster would be able to access your Spark Connect server without any authentication (the reason I brought up the security issue in the beginning). I also don't think this would show the token in the command line, but would need to verify

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason the config is designed to be used for more than just the local use case is for something like

  • spark-submit --deploy-mode cluster --conf spark.connect.authenticate.token=<token> --class org.apache.spark.sql.connect.service.SparkConnectServer: So you can launch a remote Spark session on behalf of a user and then only let that user authenticate to it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okie. I am fine with going ahead with this first, and following up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated so that the tests specify a constant value via the config, but normal usage would only set the env var for the automatically launched server

.doc("A pre-shared token that will be used to authenticate clients. This secret must be" +
" passed as a bearer token by for clients to connect.")
.version("4.0.0")
.internal()
.stringConf
.createOptional

val CONNECT_AUTHENTICATE_TOKEN_ENV = "SPARK_CONNECT_AUTHENTICATE_TOKEN"

def getAuthenticateToken: Option[String] = {
SparkEnv.get.conf.get(CONNECT_AUTHENTICATE_TOKEN).orElse {
Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, PYTHON_EXEC, QUERY_ID, R
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
import org.apache.spark.sql.connect.common.ForeachWriterPacket
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.streaming.StreamingQuery
Expand Down Expand Up @@ -135,7 +136,10 @@ object StreamingForeachBatchHelper extends Logging {
sessionHolder: SessionHolder): (ForeachBatchFnType, AutoCloseable) = {

val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
Connect.getAuthenticateToken.foreach { token =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: If you like reusing code then factor this into a helper function in sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala.

def authenticationTokenParam: String = getAuthenticateToken.map(token => ";token=" + token).getOrElse("")

connectUrl = s"$connectUrl;token=$token"
}
val runner = StreamingPythonRunner(
pythonFn,
connectUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.FUNCTION_NAME
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService}
import org.apache.spark.sql.streaming.StreamingQueryListener

Expand All @@ -36,7 +37,10 @@ class PythonStreamingQueryListener(listener: SimplePythonFunction, sessionHolder
with Logging {

private val port = SparkConnectService.localPort
private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
private var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
Connect.getAuthenticateToken.foreach { token =>
connectUrl = s"$connectUrl;token=$token"
}
// Scoped for testing
private[connect] val runner = StreamingPythonRunner(
listener,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor, Status}

class PreSharedKeyAuthenticationInterceptor(token: String) extends ServerInterceptor {

val authorizationMetadataKey =
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)

val expectedValue = s"Bearer $token"

override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
metadata: Metadata,
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
val authHeaderValue = metadata.get(authorizationMetadataKey)

if (authHeaderValue == null) {
val status = Status.UNAUTHENTICATED.withDescription("No authentication token provided")
call.close(status, new Metadata())
new ServerCall.Listener[ReqT]() {}
} else if (authHeaderValue != expectedValue) {
val status = Status.UNAUTHENTICATED.withDescription("Invalid authentication token")
call.close(status, new Metadata())
new ServerCall.Listener[ReqT]() {}
} else {
next.startCall(call, metadata)
}
}
}
Loading