Skip to content

Commit

Permalink
Provide a basic authentication token when running Spark Connect serve…
Browse files Browse the repository at this point in the history
…r locally
  • Loading branch information
HyukjinKwon committed Feb 11, 2025
1 parent cea79dc commit 1edae37
Show file tree
Hide file tree
Showing 12 changed files with 337 additions and 209 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/sql/connect/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,13 @@ def __init__(
channel: grpc.Channel,
metadata: Iterable[Tuple[str, str]],
):
from pyspark.sql.connect.client import SparkConnectClient

self._user_context = proto.UserContext()
if user_id is not None:
self._user_context.user_id = user_id
if SparkConnectClient._local_auth_token is not None:
self._user_context.local_auth_token = SparkConnectClient._local_auth_token
self._stub = grpc_lib.SparkConnectServiceStub(channel)
self._session_id = session_id
self._metadata = metadata
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""

_local_auth_token: Optional[str] = None

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -1123,6 +1125,8 @@ def execute_command(
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
req.plan.command.CopyFrom(command)
data, _, metrics, observed_metrics, properties = self._execute_and_fetch(
req, observations or {}
Expand All @@ -1149,6 +1153,8 @@ def execute_command_as_iterator(
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
req.plan.command.CopyFrom(command)
for response in self._execute_and_fetch_as_iterator(req, observations or {}):
if isinstance(response, dict):
Expand Down Expand Up @@ -1217,6 +1223,8 @@ def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
req.client_observed_server_side_session_id = self._server_session_id
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
return req

def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
Expand All @@ -1227,6 +1235,8 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
return req

def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
Expand Down Expand Up @@ -1591,6 +1601,8 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
return req

def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
Expand Down Expand Up @@ -1667,6 +1679,8 @@ def _interrupt_request(
)
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
return req

def interrupt_all(self) -> Optional[List[str]]:
Expand Down Expand Up @@ -1711,6 +1725,8 @@ def release_session(self) -> None:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token
try:
for attempt in self._retrying():
with attempt:
Expand Down Expand Up @@ -1810,6 +1826,8 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
req.client_observed_server_side_session_id = self._server_session_id
if self._user_id:
req.user_context.user_id = self._user_id
if self._local_auth_token:
req.user_context.local_auth_token = self._local_auth_token

try:
return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@ def __del__(self) -> None:
req = session.client._execute_plan_request_with_metadata()
if session.client._user_id:
req.user_context.user_id = session.client._user_id
if session.client._local_auth_token:
req.user_context.local_auth_token = session.client._local_auth_token
req.plan.command.CopyFrom(command)

for attempt in session.client._retrying():
Expand Down
354 changes: 177 additions & 177 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

27 changes: 26 additions & 1 deletion python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,15 @@ class UserContext(google.protobuf.message.Message):

USER_ID_FIELD_NUMBER: builtins.int
USER_NAME_FIELD_NUMBER: builtins.int
LOCAL_AUTH_TOKEN_FIELD_NUMBER: builtins.int
EXTENSIONS_FIELD_NUMBER: builtins.int
user_id: builtins.str
user_name: builtins.str
local_auth_token: builtins.str
"""(Optional)
Authentication token. This is used internally only for local execution.
"""
@property
def extensions(
self,
Expand All @@ -123,14 +129,33 @@ class UserContext(google.protobuf.message.Message):
*,
user_id: builtins.str = ...,
user_name: builtins.str = ...,
local_auth_token: builtins.str | None = ...,
extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_local_auth_token", b"_local_auth_token", "local_auth_token", b"local_auth_token"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"extensions", b"extensions", "user_id", b"user_id", "user_name", b"user_name"
"_local_auth_token",
b"_local_auth_token",
"extensions",
b"extensions",
"local_auth_token",
b"local_auth_token",
"user_id",
b"user_id",
"user_name",
b"user_name",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_local_auth_token", b"_local_auth_token"]
) -> typing_extensions.Literal["local_auth_token"] | None: ...

global___UserContext = UserContext

Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TYPE_CHECKING,
ClassVar,
)
import uuid

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -893,13 +894,20 @@ def stop(self) -> None:
if self is getattr(SparkSession._active_session, "session", None):
SparkSession._active_session.session = None

# It should be `None` always on stop.
SparkConnectClient._local_auth_token = None

if "SPARK_LOCAL_REMOTE" in os.environ:
# When local mode is in use, follow the regular Spark session's
# behavior by terminating the Spark Connect server,
# meaning that you can stop local mode, and restart the Spark Connect
# client with a different remote address.
if PySparkSession._activeSession is not None:
try:
getattr(
PySparkSession._activeSession._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(None)
PySparkSession._activeSession.stop()
except Exception as e:
logger.warn(
Expand Down Expand Up @@ -1060,6 +1068,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
overwrite_conf["spark.connect.grpc.binding.port"] = "0"

origin_remote = os.environ.get("SPARK_REMOTE", None)
SparkConnectClient._local_auth_token = str(uuid.uuid4())
os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"] = SparkConnectClient._local_auth_token
try:
if origin_remote is not None:
# So SparkSubmit thinks no remote is set in order to
Expand All @@ -1072,6 +1082,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items()))
PySparkSession(SparkContext.getOrCreate(conf))

# In Python local mode, session.stop does not terminate JVM itself
# so we can't control it via environment variable.
getattr(
SparkContext._jvm,
"org.apache.spark.sql.connect.common.config.ConnectCommon",
).setLocalAuthToken(SparkConnectClient._local_auth_token)

# Lastly only keep runtime configurations because other configurations are
# disallowed to set in the regular Spark Connect session.
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
Expand All @@ -1084,6 +1101,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
del os.environ["SPARK_CONNECT_LOCAL_AUTH_TOKEN"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down
4 changes: 4 additions & 0 deletions sql/connect/common/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ message Plan {
message UserContext {
string user_id = 1;
string user_name = 2;
// (Optional)
//
// Authentication token. This is used internally only for local execution.
optional string local_auth_token = 3;

// To extend the existing user context message that is used to identify incoming requests,
// Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -627,6 +628,19 @@ class SparkSession private[sql] (
}
allocator.close()
SparkSession.onSessionClose(this)
SparkSession.server.synchronized {
// It should be `None` always on stop.
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN = None
if (SparkSession.server.isDefined) {
// When local mode is in use, follow the regular Spark session's
// behavior by terminating the Spark Connect server,
// meaning that you can stop local mode, and restart the Spark Connect
// client with a different remote address.
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
.start()
SparkSession.server = None
}
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -679,6 +693,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private val maybeConnectStartScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
private val maybeConnectStopScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
Expand Down Expand Up @@ -712,37 +730,45 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
server.synchronized {
if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
server = Some {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN =
Option(java.util.UUID.randomUUID().toString())
val args =
Seq(
maybeConnectStartScript.get.toString,
"--master",
remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment()
.put(
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME,
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.get)
pb.start()
}
})
// scalastyle:on runtimeaddshutdownhook

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectStopScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
}
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ object SparkConnectClient {
if (userName != null) {
builder.setUserName(userName)
}
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(builder.setLocalAuthToken)
builder.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ private[sql] object ConnectCommon {
val CONNECT_GRPC_PORT_MAX_RETRIES: Int = 0
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024
val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT: Int = 1024
// Set only when we locally run Spark Connect server.
val CONNECT_LOCAL_AUTH_TOKEN_ENV_NAME = "SPARK_CONNECT_LOCAL_AUTH_TOKEN"
var CONNECT_LOCAL_AUTH_TOKEN: Option[String] = None

// For Python testing.
def setLocalAuthToken(token: String): Unit = CONNECT_LOCAL_AUTH_TOKEN = Option(token)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.sql.util.ArtifactUtils
import org.apache.spark.util.Utils
Expand All @@ -52,6 +53,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
private var holder: SessionHolder = _

override def onNext(req: AddArtifactsRequest): Unit = try {
ConnectCommon.CONNECT_LOCAL_AUTH_TOKEN.foreach(k =>
assert(k == req.getUserContext.getLocalAuthToken))
if (this.holder == null) {
val previousSessionId = req.hasClientObservedServerSideSessionId match {
case true => Some(req.getClientObservedServerSideSessionId)
Expand Down
Loading

0 comments on commit 1edae37

Please sign in to comment.