Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51156][CONNECT] Provide a basic authentication token when running Spark Connect server locally #49880

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

HyukjinKwon
Copy link
Member

What changes were proposed in this pull request?

This PR implements a simple authentication when running Spark Connect server locally.

Why are the changes needed?

To prevent security issues.

Does this PR introduce any user-facing change?

Yes. It requires the authentication token to access to the Spark Connect server.

How was this patch tested?

Enabled by default, and will be tested in CI.

Was this patch authored or co-authored using generative AI tooling?

No.

Copy link
Member

@dongjoon-hyun dongjoon-hyun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to have the security feature. Are you going to add test cases?

@HyukjinKwon
Copy link
Member Author

It's enabled by default .. so I think it's fine ...

@HyukjinKwon
Copy link
Member Author

(and backward compat will be tested in the scheduled build )

@HyukjinKwon HyukjinKwon force-pushed the localauth branch 5 times, most recently from 018c25a to bd46da2 Compare February 11, 2025 08:26
@pan3793
Copy link
Member

pan3793 commented Feb 11, 2025

Can AccessTokenCallCredentials added in SPARK-42533 protect the spark local connect server?

@HyukjinKwon
Copy link
Member Author

This can secure individual connections themselves if I am not wrong .. but the problem is that any user can make a connection to the running server

override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
server.synchronized {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added a synchronized.

@HyukjinKwon HyukjinKwon marked this pull request as ready for review February 11, 2025 11:44
@@ -52,6 +53,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
private var holder: SessionHolder = _

override def onNext(req: AddArtifactsRequest): Unit = try {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: would prefer this check to be centralised in some companion object

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can aslo log any auth failure

@@ -52,6 +53,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
private var holder: SessionHolder = _

override def onNext(req: AddArtifactsRequest): Unit = try {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == req.getUserContext.getLocalAuthToken))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to throw a more descriptive message in these cases of auth failures?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further, assertions can technically be disabled in the JVM. Doing so might lead to accidental removal of auth

Comment on lines 388 to 389
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN = Option(
System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we need to do any special handling when the env var is set to "" i.e empty string?

@@ -21,4 +21,10 @@ private[sql] object ConnectCommon {
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
// Set only when we locally run Spark Connect server.
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a var for storage? Use the conf system instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using confs can be shown when other users ps 😢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that true for any security related config if you pass it on the command line? And for this local use case won't it just be in memory when creating the SparkContext through py4j?

@hvanhovell
Copy link
Contributor

@HyukjinKwon can you make sure we are use SSL/TLS in this case. Otherwise it will be kind of easy to intercept the token.

@HyukjinKwon
Copy link
Member Author

let me take a look

@HyukjinKwon HyukjinKwon force-pushed the localauth branch 7 times, most recently from a935c0b to d368742 Compare February 12, 2025 11:47
@HyukjinKwon
Copy link
Member Author

Some tests might fail ... need to go sleep .. but should be able to review this. I addressed most of major comments.

def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t =>
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, you CANNOT use asserts for this. They will get elided if you disable assertions. Please throw a proper gRPC exception in the LocalAuthInterceptor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other places, I will fix it later. But here assert is correct because if token is set, CONNECT_LOCAL_AUTH_TOKEN must be set to for local usage.

val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
val configuredInterceptors =
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pass in the token as an argument to the LocalAuthInterceptor, there is absolute no reason for putting this in some global variable...

Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)

val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't store this in some random local variable. There is no need for this. On the client side the SparkConnectClient will store the token. On the server the LocalAuthInterceptor should just hold on to the token.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that local server can stop and start without turning off the JVM (for Python) unlike that we always stop/start JVM for Scala. So it has to be a variable.

/**
* A gRPC interceptor to check if the header contains token for authentication.
*/
class LocalAuthInterceptor extends ServerInterceptor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we name this PreSharedKeyAuthenticationInterceptor? It is not a Local interceptor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 in general for making this more generic and usable beyond the local Spark Connect use case. Having a pre-shared secret capability built-in goes a long way in making Spark Connect more usable in shared computer clusters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will scope it down to local usage for now. Whole this is internal for now, and we don't need to generalize them at this moment.

@@ -422,7 +422,13 @@ object SparkConnectClient {
* port or a NameResolver-compliant URI connection string.
*/
class Builder(private var _configuration: Configuration) {
def this() = this(Configuration())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do this in SparkConnectClient.loadFromEnvironment()

@@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
PySparkSession(SparkContext.getOrCreate(conf))

# In Python local mode, session.stop does not terminate JVM itself
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about you set the environment variable when we start spark?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also how do you ensure the LocalAuthInterceptor will installed if there is no token yet?

def this() = this {
ConnectCommon.getLocalAuthToken
.map { _ =>
Configuration(token = ConnectCommon.getLocalAuthToken, isSslEnabled = Some(true))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting SSL enabled does not mean the server enforces SSL only.

@HyukjinKwon
Copy link
Member Author

I think I can't just enable SSL by default. We should expose the certificate or use insecure connection. The access token API cannot be used with SSL it seems so I can't reuse this existing token either.

@HyukjinKwon
Copy link
Member Author

I think we will likely miss RC1 - I will have to be away from keyboard like 3 days. Since technically CVE isn't filed yet, and this is an optional distribution, I think we can go ahead with RC 1. I will try to target RC 2.

@@ -125,6 +125,7 @@ class ChannelBuilder:
PARAM_USER_ID = "user_id"
PARAM_USER_AGENT = "user_agent"
PARAM_SESSION_ID = "session_id"
CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unfortunate that the grpc client tries to force you to use TLS if you want to use call credentials when there's so many workarounds like simply using a different header. Though in this case you could theoretically use local_channel_credentials at least on the Python side to use the built-in token mechanism

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants