File tree Expand file tree Collapse file tree 5 files changed +61
-6
lines changed 
main/scala/org/apache/spark/sql/connect/service 
test/scala/org/apache/spark/sql/connect/service Expand file tree Collapse file tree 5 files changed +61
-6
lines changed Original file line number Diff line number Diff line change @@ -641,8 +641,8 @@ def __init__(
641641            else  DefaultChannelBuilder (connection , channel_options )
642642        )
643643        self ._builder .set (
644-             ChannelBuilder .CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME ,
645-              SparkConnectClient . _local_auth_token )
644+             ChannelBuilder .CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME ,  SparkConnectClient . _local_auth_token 
645+         )
646646        self ._user_id  =  None 
647647        self ._retry_policies : List [RetryPolicy ] =  []
648648
Original file line number Diff line number Diff line change @@ -313,6 +313,23 @@ def test_config(self):
313313        self .assertEqual (self .spark .conf .get ("integer" ), "1" )
314314
315315
316+ @unittest .skipIf (not  should_test_connect , connect_requirement_message ) 
317+ class  SparkConnectLocalAuthTests (unittest .TestCase ):
318+     def  test_auth_failure (self ):
319+         os .environ ["SPARK_CONNECT_LOCAL_AUTH_TOKEN" ] =  "invalid" 
320+         try :
321+             (
322+                 PySparkSession .builder .appName (self .__class__ .__name__ )
323+                 .remote (os .environ .get ("SPARK_CONNECT_TESTING_REMOTE" , "local[4]" ))
324+                 .getOrCreate ()
325+             )
326+         except  PySparkException  as  e :
327+             assert  e .getCondition () ==  "_LEGACY_ERROR_TEMP_3303" 
328+         finally :
329+             del  os .environ ["SPARK_CONNECT_LOCAL_AUTH_TOKEN" ]
330+         self .fail ("Exception should occur." )
331+ 
332+ 
316333if  should_test_connect :
317334
318335    class  TestError (grpc .RpcError , Exception ):
Original file line number Diff line number Diff line change @@ -30,11 +30,10 @@ class LocalAuthInterceptor(localToken: String) extends ServerInterceptor {
3030      call : ServerCall [ReqT , RespT ],
3131      headers : Metadata ,
3232      next : ServerCallHandler [ReqT , RespT ]):  ServerCall .Listener [ReqT ] =  {
33-     val  t  =  Option (
33+     val  token  =  Option (
3434      headers.get(Metadata .Key 
3535        .of(ConnectCommon .CONNECT_LOCAL_AUTH_TOKEN_PARAM_NAME , Metadata .ASCII_STRING_MARSHALLER )))
36-       .map(_.substring(" Bearer "  .length))
37-     if  (t.isEmpty ||  t.get !=  localToken) {
36+     if  (token.isEmpty ||  token.get !=  localToken) {
3837      throw  new  SparkSecurityException (
3938        errorClass =  " _LEGACY_ERROR_TEMP_3303"  ,
4039        messageParameters =  Map .empty)
Original file line number Diff line number Diff line change @@ -370,9 +370,14 @@ object SparkConnectService extends Logging {
370370    val  sparkConnectService  =  new  SparkConnectService (debugMode)
371371    val  protoReflectionService  = 
372372      if  (debugMode) Some (ProtoReflectionService .newInstance()) else  None 
373+     val  serverToken  = 
374+       Option (System .getenv(ConnectCommon .CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME )).orElse {
375+         if  (Utils .isTesting) Some (SparkEnv .get.conf.get(" spark.testing.token"  ))
376+         else  None 
377+       }
373378    val  configuredInterceptors  = 
374379      SparkConnectInterceptorRegistry .createConfiguredInterceptors() ++ 
375-         ConnectCommon .localAuthToken .map(new  LocalAuthInterceptor (_))
380+         serverToken .map(new  LocalAuthInterceptor (_))
376381
377382    val  startServiceFn  =  (port : Int ) =>  {
378383      val  sb  =  bindAddress match  {
Original file line number Diff line number Diff line change 1+ /* 
2+  * Licensed to the Apache Software Foundation (ASF) under one or more 
3+  * contributor license agreements.  See the NOTICE file distributed with 
4+  * this work for additional information regarding copyright ownership. 
5+  * The ASF licenses this file to You under the Apache License, Version 2.0 
6+  * (the "License"); you may not use this file except in compliance with 
7+  * the License.  You may obtain a copy of the License at 
8+  * 
9+  *    http://www.apache.org/licenses/LICENSE-2.0 
10+  * 
11+  * Unless required by applicable law or agreed to in writing, software 
12+  * distributed under the License is distributed on an "AS IS" BASIS, 
13+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
14+  * See the License for the specific language governing permissions and 
15+  * limitations under the License. 
16+  */  
17+ package  org .apache .spark .sql .connect .service 
18+ 
19+ import  org .apache .spark .SparkException 
20+ import  org .apache .spark .sql .connect .SparkConnectServerTest 
21+ 
22+ class  SparkConnectLocalAuthE2ESuite  extends  SparkConnectServerTest  {
23+   override  def  beforeAll ():  Unit  =  {
24+     spark.sparkContext.conf.set(" spark.testing.token"  , " invalid"  )
25+     super .beforeAll()
26+   }
27+ 
28+   test(" Test local authentication"  ) {
29+     val  e  =  intercept[SparkException ] {
30+       withClient { _ =>  () }
31+     }
32+     e.getCondition ==  " _LEGACY_ERROR_TEMP_3303" 
33+   }
34+ }
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments