File tree Expand file tree Collapse file tree 4 files changed +60
-3
lines changed
main/scala/org/apache/spark/sql/connect/service
test/scala/org/apache/spark/sql/connect/service Expand file tree Collapse file tree 4 files changed +60
-3
lines changed Original file line number Diff line number Diff line change @@ -641,8 +641,8 @@ def __init__(
641641 else DefaultChannelBuilder (connection , channel_options )
642642 )
643643 self ._builder .set (
644- ChannelBuilder .CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME ,
645- SparkConnectClient . _local_auth_token )
644+ ChannelBuilder .CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME , SparkConnectClient . _local_auth_token
645+ )
646646 self ._user_id = None
647647 self ._retry_policies : List [RetryPolicy ] = []
648648
Original file line number Diff line number Diff line change @@ -313,6 +313,23 @@ def test_config(self):
313313 self .assertEqual (self .spark .conf .get ("integer" ), "1" )
314314
315315
316+ @unittest .skipIf (not should_test_connect , connect_requirement_message )
317+ class SparkConnectLocalAuthTests (unittest .TestCase ):
318+ def test_auth_failure (self ):
319+ os .environ ["SPARK_CONNECT_LOCAL_AUTH_TOKEN" ] = "invalid"
320+ try :
321+ (
322+ PySparkSession .builder .appName (self .__class__ .__name__ )
323+ .remote (os .environ .get ("SPARK_CONNECT_TESTING_REMOTE" , "local[4]" ))
324+ .getOrCreate ()
325+ )
326+ except PySparkException as e :
327+ assert e .getCondition () == "_LEGACY_ERROR_TEMP_3303"
328+ finally :
329+ del os .environ ["SPARK_CONNECT_LOCAL_AUTH_TOKEN" ]
330+ self .fail ("Exception should occur." )
331+
332+
316333if should_test_connect :
317334
318335 class TestError (grpc .RpcError , Exception ):
Original file line number Diff line number Diff line change @@ -370,9 +370,15 @@ object SparkConnectService extends Logging {
370370 val sparkConnectService = new SparkConnectService (debugMode)
371371 val protoReflectionService =
372372 if (debugMode) Some (ProtoReflectionService .newInstance()) else None
373+ val serverToken = Option (
374+ System .getenv(ConnectCommon .CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME )
375+ ).orElse {
376+ if (Utils .isTesting) Some (SparkEnv .get.conf.get(" spark.testing.token" ))
377+ else None
378+ }
373379 val configuredInterceptors =
374380 SparkConnectInterceptorRegistry .createConfiguredInterceptors() ++
375- ConnectCommon .localAuthToken .map(new LocalAuthInterceptor (_))
381+ serverToken .map(new LocalAuthInterceptor (_))
376382
377383 val startServiceFn = (port : Int ) => {
378384 val sb = bindAddress match {
Original file line number Diff line number Diff line change 1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+ package org .apache .spark .sql .connect .service
18+
19+ import org .apache .spark .SparkException
20+ import org .apache .spark .sql .connect .SparkConnectServerTest
21+
22+ class SparkConnectLocalAuthE2ESuite extends SparkConnectServerTest {
23+ override def beforeAll (): Unit = {
24+ spark.sparkContext.conf.set(" spark.testing.token" , " invalid" )
25+ super .beforeAll()
26+ }
27+
28+ test(" Test local authentication" ) {
29+ val e = intercept[SparkException ] {
30+ withClient { _ => ()}
31+ }
32+ e.getCondition == " _LEGACY_ERROR_TEMP_3303"
33+ }
34+ }
You can’t perform that action at this time.
0 commit comments