-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-51156][CONNECT] Provide a basic authentication token when running Spark Connect server locally #49880
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's great to have the security feature. Are you going to add test cases?
It's enabled by default .. so I think it's fine ... |
(and backward compat will be tested in the scheduled build ) |
018c25a
to
bd46da2
Compare
Can |
This can secure individual connections themselves if I am not wrong .. but the problem is that any user can make a connection to the running server |
bd46da2
to
1edae37
Compare
override def run(): Unit = if (server.isDefined) { | ||
new ProcessBuilder(maybeConnectScript.get.toString) | ||
.start() | ||
server.synchronized { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added a synchronized
.
@@ -52,6 +53,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr | |||
private var holder: SessionHolder = _ | |||
|
|||
override def onNext(req: AddArtifactsRequest): Unit = try { | |||
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: would prefer this check to be centralised in some companion object
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can aslo log any auth failure
@@ -52,6 +53,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr | |||
private var holder: SessionHolder = _ | |||
|
|||
override def onNext(req: AddArtifactsRequest): Unit = try { | |||
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k => | |||
assert(k == req.getUserContext.getLocalAuthToken)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to throw a more descriptive message in these cases of auth failures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Further, assertions can technically be disabled in the JVM. Doing so might lead to accidental removal of auth
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN = Option( | ||
System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would we need to do any special handling when the env var is set to ""
i.e empty string?
...connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
Outdated
Show resolved
Hide resolved
@@ -21,4 +21,10 @@ private[sql] object ConnectCommon { | |||
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0 | |||
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024 | |||
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024 | |||
// Set only when we locally run Spark Connect server. | |||
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN" | |||
var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use a var for storage? Use the conf system instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using confs can be shown when other users ps
😢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't that true for any security related config if you pass it on the command line? And for this local use case won't it just be in memory when creating the SparkContext through py4j?
@HyukjinKwon can you make sure we are use SSL/TLS in this case. Otherwise it will be kind of easy to intercept the token. |
let me take a look |
a935c0b
to
d368742
Compare
d368742
to
66336ac
Compare
Some tests might fail ... need to go sleep .. but should be able to review this. I addressed most of major comments. |
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token) | ||
def getLocalAuthToken: Option[String] = CONNECT_LOCAL_AUTH_TOKEN | ||
def assertLocalAuthToken(token: Option[String]): Unit = token.foreach { t => | ||
assert(CONNECT_LOCAL_AUTH_TOKEN.isDefined) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, you CANNOT use asserts for this. They will get elided if you disable assertions. Please throw a proper gRPC exception in the LocalAuthInterceptor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For other places, I will fix it later. But here assert is correct because if token
is set, CONNECT_LOCAL_AUTH_TOKEN
must be set to for local usage.
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors() | ||
val configuredInterceptors = | ||
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++ | ||
(if (localAuthToken != null) Seq(new LocalAuthInterceptor()) else Nil) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pass in the token as an argument to the LocalAuthInterceptor
, there is absolute no reason for putting this in some global variable...
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER) | ||
|
||
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN" | ||
private var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = Option( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't store this in some random local variable. There is no need for this. On the client side the SparkConnectClient
will store the token. On the server the LocalAuthInterceptor
should just hold on to the token.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that local server can stop and start without turning off the JVM (for Python) unlike that we always stop/start JVM for Scala. So it has to be a variable.
/** | ||
* A gRPC interceptor to check if the header contains token for authentication. | ||
*/ | ||
class LocalAuthInterceptor extends ServerInterceptor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we name this PreSharedKeyAuthenticationInterceptor
? It is not a Local interceptor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 in general for making this more generic and usable beyond the local Spark Connect use case. Having a pre-shared secret capability built-in goes a long way in making Spark Connect more usable in shared computer clusters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will scope it down to local usage for now. Whole this is internal for now, and we don't need to generalize them at this moment.
@@ -422,7 +422,13 @@ object SparkConnectClient { | |||
* port or a NameResolver-compliant URI connection string. | |||
*/ | |||
class Builder(private var _configuration: Configuration) { | |||
def this() = this(Configuration()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do this in SparkConnectClient.loadFromEnvironment()
@@ -1072,6 +1083,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: | |||
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items())) | |||
PySparkSession(SparkContext.getOrCreate(conf)) | |||
|
|||
# In Python local mode, session.stop does not terminate JVM itself |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about you set the environment variable when we start spark?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also how do you ensure the LocalAuthInterceptor will installed if there is no token yet?
def this() = this { | ||
ConnectCommon.getLocalAuthToken | ||
.map { _ => | ||
Configuration(token = ConnectCommon.getLocalAuthToken, isSslEnabled = Some(true)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting SSL enabled does not mean the server enforces SSL only.
I think I can't just enable SSL by default. We should expose the certificate or use insecure connection. The access token API cannot be used with SSL it seems so I can't reuse this existing token either. |
I think we will likely miss RC1 - I will have to be away from keyboard like 3 days. Since technically CVE isn't filed yet, and this is an optional distribution, I think we can go ahead with RC 1. I will try to target RC 2. |
66336ac
to
0558d14
Compare
@@ -125,6 +125,7 @@ class ChannelBuilder: | |||
PARAM_USER_ID = "user_id" | |||
PARAM_USER_AGENT = "user_agent" | |||
PARAM_SESSION_ID = "session_id" | |||
CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unfortunate that the grpc client tries to force you to use TLS if you want to use call credentials when there's so many workarounds like simply using a different header. Though in this case you could theoretically use local_channel_credentials
at least on the Python side to use the built-in token mechanism
What changes were proposed in this pull request?
This PR implements a simple authentication when running Spark Connect server locally.
Why are the changes needed?
To prevent security issues.
Does this PR introduce any user-facing change?
Yes. It requires the authentication token to access to the Spark Connect server.
How was this patch tested?
Enabled by default, and will be tested in CI.
Was this patch authored or co-authored using generative AI tooling?
No.