Skip to content

Commit bd46da2

Browse files
committed
Provide a basic authentication token when running Spark Connect server locally
1 parent cea79dc commit bd46da2

File tree

12 files changed

+270
-179
lines changed

12 files changed

+270
-179
lines changed

python/pyspark/sql/connect/client/artifact.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,13 @@ def __init__(
174174
channel: grpc.Channel,
175175
metadata: Iterable[Tuple[str, str]],
176176
):
177+
from pyspark.sql.connect.client import SparkConnectClient
178+
177179
self._user_context = proto.UserContext()
178180
if user_id is not None:
179181
self._user_context.user_id = user_id
182+
if SparkConnectClient._local_auth_token is not None:
183+
self._user_context.local_auth_token = SparkConnectClient._local_auth_token
180184
self._stub = grpc_lib.SparkConnectServiceStub(channel)
181185
self._session_id = session_id
182186
self._metadata = metadata

python/pyspark/sql/connect/client/core.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ class SparkConnectClient(object):
591591
Conceptually the remote spark session that communicates with the server
592592
"""
593593

594+
_local_auth_token: Optional[str] = None
595+
594596
def __init__(
595597
self,
596598
connection: Union[str, ChannelBuilder],
@@ -1123,6 +1125,8 @@ def execute_command(
11231125
req = self._execute_plan_request_with_metadata()
11241126
if self._user_id:
11251127
req.user_context.user_id = self._user_id
1128+
if self._local_auth_token:
1129+
req.user_context.local_auth_token = self._local_auth_token
11261130
req.plan.command.CopyFrom(command)
11271131
data, _, metrics, observed_metrics, properties = self._execute_and_fetch(
11281132
req, observations or {}
@@ -1149,6 +1153,8 @@ def execute_command_as_iterator(
11491153
req = self._execute_plan_request_with_metadata()
11501154
if self._user_id:
11511155
req.user_context.user_id = self._user_id
1156+
if self._local_auth_token:
1157+
req.user_context.local_auth_token = self._local_auth_token
11521158
req.plan.command.CopyFrom(command)
11531159
for response in self._execute_and_fetch_as_iterator(req, observations or {}):
11541160
if isinstance(response, dict):
@@ -1217,6 +1223,8 @@ def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
12171223
req.client_observed_server_side_session_id = self._server_session_id
12181224
if self._user_id:
12191225
req.user_context.user_id = self._user_id
1226+
if self._local_auth_token:
1227+
req.user_context.local_auth_token = self._local_auth_token
12201228
return req
12211229

12221230
def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
@@ -1227,6 +1235,8 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
12271235
req.client_type = self._builder.userAgent
12281236
if self._user_id:
12291237
req.user_context.user_id = self._user_id
1238+
if self._local_auth_token:
1239+
req.user_context.local_auth_token = self._local_auth_token
12301240
return req
12311241

12321242
def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
@@ -1591,6 +1601,8 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
15911601
req.client_type = self._builder.userAgent
15921602
if self._user_id:
15931603
req.user_context.user_id = self._user_id
1604+
if self._local_auth_token:
1605+
req.user_context.local_auth_token = self._local_auth_token
15941606
return req
15951607

15961608
def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
@@ -1667,6 +1679,8 @@ def _interrupt_request(
16671679
)
16681680
if self._user_id:
16691681
req.user_context.user_id = self._user_id
1682+
if self._local_auth_token:
1683+
req.user_context.local_auth_token = self._local_auth_token
16701684
return req
16711685

16721686
def interrupt_all(self) -> Optional[List[str]]:
@@ -1711,6 +1725,8 @@ def release_session(self) -> None:
17111725
req.client_type = self._builder.userAgent
17121726
if self._user_id:
17131727
req.user_context.user_id = self._user_id
1728+
if self._local_auth_token:
1729+
req.user_context.local_auth_token = self._local_auth_token
17141730
try:
17151731
for attempt in self._retrying():
17161732
with attempt:
@@ -1810,6 +1826,8 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
18101826
req.client_observed_server_side_session_id = self._server_session_id
18111827
if self._user_id:
18121828
req.user_context.user_id = self._user_id
1829+
if self._local_auth_token:
1830+
req.user_context.local_auth_token = self._local_auth_token
18131831

18141832
try:
18151833
return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())

python/pyspark/sql/connect/plan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,8 @@ def __del__(self) -> None:
633633
req = session.client._execute_plan_request_with_metadata()
634634
if session.client._user_id:
635635
req.user_context.user_id = session.client._user_id
636+
if session.client._local_auth_token:
637+
req.user_context.local_auth_token = session.client._local_auth_token
636638
req.plan.command.CopyFrom(command)
637639

638640
for attempt in session.client._retrying():

python/pyspark/sql/connect/proto/base_pb2.py

Lines changed: 177 additions & 177 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/base_pb2.pyi

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,15 @@ class UserContext(google.protobuf.message.Message):
104104

105105
USER_ID_FIELD_NUMBER: builtins.int
106106
USER_NAME_FIELD_NUMBER: builtins.int
107+
LOCAL_AUTH_TOKEN_FIELD_NUMBER: builtins.int
107108
EXTENSIONS_FIELD_NUMBER: builtins.int
108109
user_id: builtins.str
109110
user_name: builtins.str
111+
local_auth_token: builtins.str
112+
"""(Optional)
113+
114+
Authentication token. This is used internally only for local execution.
115+
"""
110116
@property
111117
def extensions(
112118
self,
@@ -123,14 +129,33 @@ class UserContext(google.protobuf.message.Message):
123129
*,
124130
user_id: builtins.str = ...,
125131
user_name: builtins.str = ...,
132+
local_auth_token: builtins.str | None = ...,
126133
extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ...,
127134
) -> None: ...
135+
def HasField(
136+
self,
137+
field_name: typing_extensions.Literal[
138+
"_local_auth_token", b"_local_auth_token", "local_auth_token", b"local_auth_token"
139+
],
140+
) -> builtins.bool: ...
128141
def ClearField(
129142
self,
130143
field_name: typing_extensions.Literal[
131-
"extensions", b"extensions", "user_id", b"user_id", "user_name", b"user_name"
144+
"_local_auth_token",
145+
b"_local_auth_token",
146+
"extensions",
147+
b"extensions",
148+
"local_auth_token",
149+
b"local_auth_token",
150+
"user_id",
151+
b"user_id",
152+
"user_name",
153+
b"user_name",
132154
],
133155
) -> None: ...
156+
def WhichOneof(
157+
self, oneof_group: typing_extensions.Literal["_local_auth_token", b"_local_auth_token"]
158+
) -> typing_extensions.Literal["local_auth_token"] | None: ...
134159

135160
global___UserContext = UserContext
136161

python/pyspark/sql/connect/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TYPE_CHECKING,
4141
ClassVar,
4242
)
43+
import uuid
4344

4445
import numpy as np
4546
import pandas as pd
@@ -1060,6 +1061,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10601061
overwrite_conf["spark.connect.grpc.binding.port"] = "0"
10611062

10621063
origin_remote = os.environ.get("SPARK_REMOTE", None)
1064+
SparkConnectClient._local_auth_token = str(uuid.uuid4())
1065+
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = SparkConnectClient._local_auth_token
10631066
try:
10641067
if origin_remote is not None:
10651068
# So SparkSubmit thinks no remote is set in order to
@@ -1084,6 +1087,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10841087
if origin_remote is not None:
10851088
os.environ["SPARK_REMOTE"] = origin_remote
10861089
del os.environ["SPARK_LOCAL_CONNECT"]
1090+
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
10871091
else:
10881092
raise PySparkRuntimeError(
10891093
errorClass="SESSION_OR_CONTEXT_EXISTS",

sql/connect/common/src/main/protobuf/spark/connect/base.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ message Plan {
4949
message UserContext {
5050
string user_id = 1;
5151
string user_name = 2;
52+
// (Optional)
53+
//
54+
// Authentication token. This is used internally only for local execution.
55+
optional string local_auth_token = 3;
5256

5357
// To extend the existing user context message that is used to identify incoming requests,
5458
// Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
4848
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
4949
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
5050
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
51+
import org.apache.spark.sql.connect.common.config.ConnectCommon
5152
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
5253
import org.apache.spark.sql.sources.BaseRelation
5354
import org.apache.spark.sql.types.StructType
@@ -720,13 +721,18 @@ object SparkSession extends SparkSessionCompanion with Logging {
720721
(remoteString.isDefined && isAPIModeConnect)) &&
721722
maybeConnectScript.exists(Files.exists(_))) {
722723
server = Some {
724+
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN = Option(java.util.UUID.randomUUID().toString())
723725
val args =
724726
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
725727
.filter(p => !p._1.startsWith("spark.remote"))
726728
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
727729
val pb = new ProcessBuilder(args: _*)
728730
// So don't exclude spark-sql jar in classpath
729731
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
732+
pb.environment()
733+
.put(
734+
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME,
735+
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.get)
730736
pb.start()
731737
}
732738

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ object SparkConnectClient {
745745
if (userName != null) {
746746
builder.setUserName(userName)
747747
}
748+
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(builder.setLocalAuthToken)
748749
builder.build()
749750
}
750751

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,7 @@ private[sql] object ConnectCommon {
2121
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
2222
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
2323
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
24+
// Set only when we locally run Spark Connect server.
25+
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
26+
var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = None
2427
}

0 commit comments

Comments
 (0)