Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -9304,6 +9304,11 @@
"No enough memory for aggregation"
]
},
"_LEGACY_ERROR_TEMP_3303" : {
"message" : [
"Invalid token."
]
},
"_LEGACY_ERROR_USER_RAISED_EXCEPTION" : {
"message" : [
"<errorMessage>"
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
"SparkConnectClient",
]

import atexit

from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

import atexit
import logging
import threading
import os
Expand Down Expand Up @@ -125,6 +124,7 @@ class ChannelBuilder:
PARAM_USER_ID = "user_id"
PARAM_USER_AGENT = "user_agent"
PARAM_SESSION_ID = "session_id"
CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unfortunate that the grpc client tries to force you to use TLS if you want to use call credentials when there's so many workarounds like simply using a different header. Though in this case you could theoretically use local_channel_credentials at least on the Python side to use the built-in token mechanism

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kimahriman are you suggesting to use a different header than Authentication? If so, my issue with that is that folks will invariably use this to 'secure' remote connections as well, and in that case having TLS might not be such a bad thing.

We are planning to add UDS support in a follow-up. That should make this less of an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more the opposite, I assume the custom local_token header was used here to avoid the TLS requirement, but it's probably ok to allow non-TLS uses of the authorization bearer token header for a local connections, which the Python gRPC client even has support for. Adding the UDS supported would just improve that further. The local_channel_credentials even supports both of those cases, so you could use that with the authorization header to avoid the TLS requirement for this use case, while still requiring it for remote connections.

The "workarounds" are simply that you can still use the authorization header without TLS by using a custom interceptor that injects it in the metadata after the fact. Figured that out while trying to build a custom dynamic proxy for launching cluster deploy mode connect sessions to replace something like Livy.


GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024

Expand Down Expand Up @@ -591,6 +591,9 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""

_local_auth_token = os.environ.get("SPARK_CONNECT_LOCAL_AUTH_TOKEN", str(uuid.uuid4()))
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = _local_auth_token

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -637,6 +640,9 @@ def __init__(
if isinstance(connection, ChannelBuilder)
else DefaultChannelBuilder(connection, channel_options)
)
self._builder.set(
ChannelBuilder.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, SparkConnectClient._local_auth_token
)
self._user_id = None
self._retry_policies: List[RetryPolicy] = []

Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,23 @@ def test_config(self):
self.assertEqual(self.spark.conf.get("integer"), "1")


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectLocalAuthTests(unittest.TestCase):
def test_auth_failure(self):
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = "invalid"
try:
(
PySparkSession.builder.appName(self.__class__.__name__)
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)
except PySparkException as e:
assert e.getCondition() == "_LEGACY_ERROR_TEMP_3303"
finally:
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
self.fail("Exception should occur.")


if should_test_connect:

class TestError(grpc.RpcError, Exception):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -627,6 +628,17 @@ class SparkSession private[sql] (
}
allocator.close()
SparkSession.onSessionClose(this)
SparkSession.server.synchronized {
if (SparkSession.server.isDefined) {
// When local mode is in use, follow the regular Spark session's
// behavior by terminating the Spark Connect server,
// meaning that you can stop local mode, and restart the Spark Connect
// client with a different remote address.
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
.start()
SparkSession.server = None
}
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -679,6 +691,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private val maybeConnectStartScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
private val maybeConnectStopScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
Expand Down Expand Up @@ -712,37 +728,43 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
server.synchronized {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added a synchronized.

if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(
maybeConnectStartScript.get.toString,
"--master",
remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment()
.put(
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME,
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN)
pb.start()
}
})
// scalastyle:on runtimeaddshutdownhook

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectStopScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
}
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ object SparkConnectClient {
* Configure the builder using the env SPARK_REMOTE environment variable.
*/
def loadFromEnvironment(): Builder = {
option(
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME,
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN)
lazy val isAPIModeConnect =
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
.getOrElse("classic")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ private[sql] object ConnectCommon {
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024

val CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME = "local_token"
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
val CONNECT_LOCAL_AUTH_TOKEN: String =
Option(System.getenv(CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
.getOrElse(java.util.UUID.randomUUID().toString())
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor}

import org.apache.spark.SparkSecurityException
import org.apache.spark.sql.connect.common.config.ConnectCommon

/**
* A gRPC interceptor to check if the header contains token for local authentication.
*/
class LocalAuthInterceptor(localToken: String) extends ServerInterceptor {
override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
val token = Option(
headers.get(Metadata.Key
.of(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, Metadata.ASCII_STRING_MARSHALLER)))
if (token.isEmpty || token.get != localToken) {
throw new SparkSecurityException(
errorClass = "_LEGACY_ERROR_TEMP_3303",
messageParameters = Map.empty)
}
next.startCall(call, headers)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
Expand Down Expand Up @@ -369,7 +370,14 @@ object SparkConnectService extends Logging {
val sparkConnectService = new SparkConnectService(debugMode)
val protoReflectionService =
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
val configuredInterceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors()
val serverToken =
Option(System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)).orElse {
if (Utils.isTesting) Some(SparkEnv.get.conf.get("spark.testing.token"))
else None
}
val configuredInterceptors =
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
serverToken.map(new LocalAuthInterceptor(_))

val startServiceFn = (port: Int) => {
val sb = bindAddress match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.service

import org.apache.spark.SparkException
import org.apache.spark.sql.connect.SparkConnectServerTest

class SparkConnectLocalAuthE2ESuite extends SparkConnectServerTest {
override def beforeAll(): Unit = {
spark.sparkContext.conf.set("spark.testing.token", "invalid")
super.beforeAll()
}

test("Test local authentication") {
val e = intercept[SparkException] {
withClient { _ => () }
}
e.getCondition == "_LEGACY_ERROR_TEMP_3303"
}
}