-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-51156][CONNECT] Static token authentication support in Spark Connect #50006
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2dae591
1c77476
2fba2e7
619a2f2
ad2ebff
ae2abe9
81f5e53
bd4323d
5374049
0b7ccf8
42de37c
f0948f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config | |
|
|
||
| import java.util.concurrent.TimeUnit | ||
|
|
||
| import org.apache.spark.SparkEnv | ||
| import org.apache.spark.network.util.ByteUnit | ||
| import org.apache.spark.sql.connect.common.config.ConnectCommon | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
@@ -313,4 +314,21 @@ object Connect { | |
| .internal() | ||
| .booleanConf | ||
| .createWithDefault(true) | ||
|
|
||
| val CONNECT_AUTHENTICATE_TOKEN = | ||
| buildStaticConf("spark.connect.authenticate.token") | ||
|
||
| .doc("A pre-shared token that will be used to authenticate clients. This secret must be" + | ||
| " passed as a bearer token by for clients to connect.") | ||
| .version("4.0.0") | ||
| .internal() | ||
| .stringConf | ||
| .createOptional | ||
|
|
||
| val CONNECT_AUTHENTICATE_TOKEN_ENV = "SPARK_CONNECT_AUTHENTICATE_TOKEN" | ||
|
|
||
| def getAuthenticateToken: Option[String] = { | ||
| SparkEnv.get.conf.get(CONNECT_AUTHENTICATE_TOKEN).orElse { | ||
| Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV)) | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, PYTHON_EXEC, QUERY_ID, R | |
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} | ||
| import org.apache.spark.sql.connect.common.ForeachWriterPacket | ||
| import org.apache.spark.sql.connect.config.Connect | ||
| import org.apache.spark.sql.connect.service.SessionHolder | ||
| import org.apache.spark.sql.connect.service.SparkConnectService | ||
| import org.apache.spark.sql.streaming.StreamingQuery | ||
|
|
@@ -135,7 +136,10 @@ object StreamingForeachBatchHelper extends Logging { | |
| sessionHolder: SessionHolder): (ForeachBatchFnType, AutoCloseable) = { | ||
|
|
||
| val port = SparkConnectService.localPort | ||
| val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" | ||
| var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" | ||
| Connect.getAuthenticateToken.foreach { token => | ||
|
||
| connectUrl = s"$connectUrl;token=$token" | ||
| } | ||
| val runner = StreamingPythonRunner( | ||
| pythonFn, | ||
| connectUrl, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.connect.service | ||
|
|
||
| import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor, Status} | ||
|
|
||
| class PreSharedKeyAuthenticationInterceptor(token: String) extends ServerInterceptor { | ||
HyukjinKwon marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| val authorizationMetadataKey = | ||
| Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER) | ||
|
|
||
| val expectedValue = s"Bearer $token" | ||
|
|
||
| override def interceptCall[ReqT, RespT]( | ||
| call: ServerCall[ReqT, RespT], | ||
| metadata: Metadata, | ||
| next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { | ||
| val authHeaderValue = metadata.get(authorizationMetadataKey) | ||
|
|
||
| if (authHeaderValue == null) { | ||
| val status = Status.UNAUTHENTICATED.withDescription("No authentication token provided") | ||
| call.close(status, new Metadata()) | ||
| new ServerCall.Listener[ReqT]() {} | ||
| } else if (authHeaderValue != expectedValue) { | ||
| val status = Status.UNAUTHENTICATED.withDescription("Invalid authentication token") | ||
| call.close(status, new Metadata()) | ||
| new ServerCall.Listener[ReqT]() {} | ||
| } else { | ||
| next.startCall(call, metadata) | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.