Skip to content

Commit b86c198

Browse files
committed
fixup
1 parent d14401a commit b86c198

File tree

4 files changed

+59
-3
lines changed

4 files changed

+59
-3
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,8 @@ def __init__(
641641
else DefaultChannelBuilder(connection, channel_options)
642642
)
643643
self._builder.set(
644-
ChannelBuilder.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME,
645-
SparkConnectClient._local_auth_token)
644+
ChannelBuilder.CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME, SparkConnectClient._local_auth_token
645+
)
646646
self._user_id = None
647647
self._retry_policies: List[RetryPolicy] = []
648648

python/pyspark/sql/tests/connect/test_connect_session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,23 @@ def test_config(self):
313313
self.assertEqual(self.spark.conf.get("integer"), "1")
314314

315315

316+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
317+
class SparkConnectLocalAuthTests(unittest.TestCase):
318+
def test_auth_failure(self):
319+
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = "invalid"
320+
try:
321+
(
322+
PySparkSession.builder.appName(self.__class__.__name__)
323+
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
324+
.getOrCreate()
325+
)
326+
except PySparkException as e:
327+
assert e.getCondition() == "_LEGACY_ERROR_TEMP_3303"
328+
finally:
329+
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
330+
self.fail("Exception should occur.")
331+
332+
316333
if should_test_connect:
317334

318335
class TestError(grpc.RpcError, Exception):

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,14 @@ object SparkConnectService extends Logging {
370370
val sparkConnectService = new SparkConnectService(debugMode)
371371
val protoReflectionService =
372372
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
373+
val serverToken =
374+
Option(System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME)).orElse {
375+
if (Utils.isTesting) Some(SparkEnv.get.conf.get("spark.testing.token"))
376+
else None
377+
}
373378
val configuredInterceptors =
374379
SparkConnectInterceptorRegistry.createConfiguredInterceptors() ++
375-
ConnectCommon.localAuthToken.map(new LocalAuthInterceptor(_))
380+
serverToken.map(new LocalAuthInterceptor(_))
376381

377382
val startServiceFn = (port: Int) => {
378383
val sb = bindAddress match {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect.service
18+
19+
import org.apache.spark.SparkException
20+
import org.apache.spark.sql.connect.SparkConnectServerTest
21+
22+
class SparkConnectLocalAuthE2ESuite extends SparkConnectServerTest {
23+
override def beforeAll(): Unit = {
24+
spark.sparkContext.conf.set("spark.testing.token", "invalid")
25+
super.beforeAll()
26+
}
27+
28+
test("Test local authentication") {
29+
val e = intercept[SparkException] {
30+
withClient { _ => () }
31+
}
32+
e.getCondition == "_LEGACY_ERROR_TEMP_3303"
33+
}
34+
}

0 commit comments

Comments
 (0)