Skip to content

Commit

Permalink
Provide a basic authentication token when running Spark Connect serve…
Browse files Browse the repository at this point in the history
…r locally
  • Loading branch information
HyukjinKwon committed Feb 11, 2025
1 parent cea79dc commit 018c25a
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 179 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/sql/connect/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,13 @@ def __init__(
channel: grpc.Channel,
metadata: Iterable[Tuple[str, str]],
):
from pyspark.sql.connect.client import SparkConnectClient

self._user_context = proto.UserContext()
if user_id is not None:
self._user_context.user_id = user_id
if SparkConnectClient._local_auth_token is not None:
self._user_context.local_auth_token = SparkConnectClient._local_auth_token
self._stub = grpc_lib.SparkConnectServiceStub(channel)
self._session_id = session_id
self._metadata = metadata
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""

_local_auth_token = None

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -1123,6 +1125,8 @@ def execute_command(
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
req.plan.command.CopyFrom(command)
data, _, metrics, observed_metrics, properties = self._execute_and_fetch(
req, observations or {}
Expand All @@ -1149,6 +1153,8 @@ def execute_command_as_iterator(
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
req.plan.command.CopyFrom(command)
for response in self._execute_and_fetch_as_iterator(req, observations or {}):
if isinstance(response, dict):
Expand Down Expand Up @@ -1217,6 +1223,8 @@ def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
req.client_observed_server_side_session_id = self._server_session_id
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
return req

def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
Expand All @@ -1227,6 +1235,8 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
return req

def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
Expand Down Expand Up @@ -1591,6 +1601,8 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
return req

def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
Expand Down Expand Up @@ -1667,6 +1679,8 @@ def _interrupt_request(
)
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
return req

def interrupt_all(self) -> Optional[List[str]]:
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@ def __del__(self) -> None:
req = session.client._execute_plan_request_with_metadata()
if session.client._user_id:
req.user_context.user_id = session.client._user_id
if session.client._local_auth_token:
req.user_context.local_auth_token = session.client._local_auth_token
req.plan.command.CopyFrom(command)

for attempt in session.client._retrying():
Expand Down
354 changes: 177 additions & 177 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

27 changes: 26 additions & 1 deletion python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,15 @@ class UserContext(google.protobuf.message.Message):

USER_ID_FIELD_NUMBER: builtins.int
USER_NAME_FIELD_NUMBER: builtins.int
LOCAL_AUTH_TOKEN_FIELD_NUMBER: builtins.int
EXTENSIONS_FIELD_NUMBER: builtins.int
user_id: builtins.str
user_name: builtins.str
local_auth_token: builtins.str
"""(Optional)
Authentication token. This is used internally only for local execution.
"""
@property
def extensions(
self,
Expand All @@ -123,14 +129,33 @@ class UserContext(google.protobuf.message.Message):
*,
user_id: builtins.str = ...,
user_name: builtins.str = ...,
local_auth_token: builtins.str | None = ...,
extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_local_auth_token", b"_local_auth_token", "local_auth_token", b"local_auth_token"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"extensions", b"extensions", "user_id", b"user_id", "user_name", b"user_name"
"_local_auth_token",
b"_local_auth_token",
"extensions",
b"extensions",
"local_auth_token",
b"local_auth_token",
"user_id",
b"user_id",
"user_name",
b"user_name",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_local_auth_token", b"_local_auth_token"]
) -> typing_extensions.Literal["local_auth_token"] | None: ...

global___UserContext = UserContext

Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TYPE_CHECKING,
ClassVar,
)
import uuid

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1060,6 +1061,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
overwrite_conf["spark.connect.grpc.binding.port"] = "0"

origin_remote = os.environ.get("SPARK_REMOTE", None)
SparkConnectClient._local_auth_token = str(uuid.uuid4())
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = SparkConnectClient._local_auth_token
try:
if origin_remote is not None:
# So SparkSubmit thinks no remote is set in order to
Expand All @@ -1084,6 +1087,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down
4 changes: 4 additions & 0 deletions sql/connect/common/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ message Plan {
message UserContext {
string user_id = 1;
string user_name = 2;
// (Optional)
//
// Authentication token. This is used internally only for local execution.
optional string local_auth_token = 3;

// To extend the existing user context message that is used to identify incoming requests,
// Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -720,13 +721,18 @@ object SparkSession extends SparkSessionCompanion with Logging {
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN = Option(java.util.UUID.randomUUID().toString())
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment()
.put(
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME,
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.get)
pb.start()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ object SparkConnectClient {
if (userName != null) {
builder.setUserName(userName)
}
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(builder.setLocalAuthToken)
builder.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ private[sql] object ConnectCommon {
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
// Set only when we locally run Spark Connect server.
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = None
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.sql.util.ArtifactUtils
import org.apache.spark.util.Utils
Expand All @@ -52,6 +53,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
private var holder: SessionHolder = _

override def onNext(req: AddArtifactsRequest): Unit = try {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == req.getUserContext.getLocalAuthToken))
if (this.holder == null) {
val previousSessionId = req.hasClientObservedServerSideSessionId match {
case true => Some(req.getClientObservedServerSideSessionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect._
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
Expand Down Expand Up @@ -69,6 +70,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def executePlan(
request: proto.ExecutePlanRequest,
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectExecutePlanHandler(responseObserver).handle(request)
} catch {
Expand All @@ -95,6 +98,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def analyzePlan(
request: proto.AnalyzePlanRequest,
responseObserver: StreamObserver[proto.AnalyzePlanResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectAnalyzeHandler(responseObserver).handle(request)
} catch {
Expand All @@ -116,6 +121,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def config(
request: proto.ConfigRequest,
responseObserver: StreamObserver[proto.ConfigResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectConfigHandler(responseObserver).handle(request)
} catch {
Expand Down Expand Up @@ -143,6 +150,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def artifactStatus(
request: proto.ArtifactStatusesRequest,
responseObserver: StreamObserver[proto.ArtifactStatusesResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectArtifactStatusesHandler(responseObserver).handle(request)
} catch
Expand All @@ -159,6 +168,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def interrupt(
request: proto.InterruptRequest,
responseObserver: StreamObserver[proto.InterruptResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectInterruptHandler(responseObserver).handle(request)
} catch
Expand All @@ -175,6 +186,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def reattachExecute(
request: proto.ReattachExecuteRequest,
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectReattachExecuteHandler(responseObserver).handle(request)
} catch
Expand All @@ -191,6 +204,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def releaseExecute(
request: proto.ReleaseExecuteRequest,
responseObserver: StreamObserver[proto.ReleaseExecuteResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectReleaseExecuteHandler(responseObserver).handle(request)
} catch
Expand All @@ -207,6 +222,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def releaseSession(
request: proto.ReleaseSessionRequest,
responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectReleaseSessionHandler(responseObserver).handle(request)
} catch
Expand All @@ -220,6 +237,8 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
override def fetchErrorDetails(
request: proto.FetchErrorDetailsRequest,
responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit = {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == request.getUserContext.getLocalAuthToken))
try {
new SparkConnectFetchErrorDetailsHandler(responseObserver).handle(request)
} catch {
Expand Down Expand Up @@ -366,6 +385,8 @@ object SparkConnectService extends Logging {
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN = Option(
System.getenv(ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME))
val sparkConnectService = new SparkConnectService(debugMode)
val protoReflectionService =
if (debugMode) Some(ProtoReflectionService.newInstance()) else None
Expand Down

0 comments on commit 018c25a

Please sign in to comment.