Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 85 additions & 7 deletions src/spark_history_mcp/models/mcp_types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,92 @@
from typing import Optional, Sequence
from datetime import datetime
from typing import List, Optional

from pydantic import BaseModel, ConfigDict, Field

from spark_history_mcp.models.spark_types import JobData, StageData


class JobSummary(BaseModel):
"""Summary of job execution counts for a SQL query."""
job_id: Optional[int] = Field(None, alias="jobId")
name: str
description: Optional[str] = None
status: str
submission_time: Optional[datetime] = Field(None, alias="submissionTime")
completion_time: Optional[datetime] = Field(None, alias="completionTime")
duration_seconds: Optional[float] = None
succeeded_stage_ids: List[int] = Field(
default_factory=list, alias="succeededStageIds"
)
failed_stage_ids: List[int] = Field(default_factory=list, alias="failedStageIds")
active_stage_ids: List[int] = Field(default_factory=list, alias="activeStageIds")
pending_stage_ids: List[int] = Field(default_factory=list, alias="pendingStageIds")
skipped_stage_ids: List[int] = Field(default_factory=list, alias="skippedStageIds")

success_job_ids: Sequence[int] = Field(..., alias="successJobsIds")
failed_job_ids: Sequence[int] = Field(..., alias="failedJobsIds")
running_job_ids: Sequence[int] = Field(..., alias="runningJobsIds")
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)

model_config = ConfigDict(populate_by_name=True)
@classmethod
def parse_datetime(cls, value):
if value is None:
return None
if isinstance(value, (int, float)):
return datetime.fromtimestamp(value / 1000)
if isinstance(value, str) and value.endswith("GMT"):
try:
dt_str = value.replace("GMT", "+0000")
return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S.%f%z")
except ValueError:
pass
return value

@classmethod
def from_job_data(
cls, job_data: JobData, stages: List[StageData] = None
) -> "JobSummary":
"""Create a JobSummary from full JobData and optional stage data."""
duration = None
if job_data.completion_time and job_data.submission_time:
duration = (
job_data.completion_time - job_data.submission_time
).total_seconds()

# Initialize stage ID lists
succeeded_stage_ids = []
failed_stage_ids = []
active_stage_ids = []
pending_stage_ids = []
skipped_stage_ids = []

# Group stage IDs by status if stage data is provided
if stages and job_data.stage_ids:
stage_status_map = {stage.stage_id: stage.status for stage in stages}

for stage_id in job_data.stage_ids:
stage_status = stage_status_map.get(stage_id, "UNKNOWN")
if stage_status == "COMPLETE":
succeeded_stage_ids.append(stage_id)
elif stage_status == "FAILED":
failed_stage_ids.append(stage_id)
elif stage_status == "ACTIVE":
active_stage_ids.append(stage_id)
elif stage_status == "PENDING":
pending_stage_ids.append(stage_id)
elif stage_status == "SKIPPED":
skipped_stage_ids.append(stage_id)

return cls(
job_id=job_data.job_id,
name=job_data.name,
description=job_data.description,
status=job_data.status,
submission_time=job_data.submission_time,
completion_time=job_data.completion_time,
duration_seconds=duration,
succeeded_stage_ids=succeeded_stage_ids,
failed_stage_ids=failed_stage_ids,
active_stage_ids=active_stage_ids,
pending_stage_ids=pending_stage_ids,
skipped_stage_ids=skipped_stage_ids,
)


class SqlQuerySummary(BaseModel):
Expand All @@ -22,6 +98,8 @@ class SqlQuerySummary(BaseModel):
status: str
submission_time: Optional[str] = Field(None, alias="submissionTime")
plan_description: str = Field(..., alias="planDescription")
job_summary: JobSummary = Field(..., alias="jobSummary")
success_job_ids: List[int] = Field(..., alias="successJobIds")
failed_job_ids: List[int] = Field(..., alias="failedJobIds")
running_job_ids: List[int] = Field(..., alias="runningJobIds")

model_config = ConfigDict(populate_by_name=True)
72 changes: 57 additions & 15 deletions src/spark_history_mcp/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,21 @@ def get_application(app_id: str, server: Optional[str] = None) -> ApplicationInf

@mcp.tool()
def list_jobs(
app_id: str, server: Optional[str] = None, status: Optional[list[str]] = None
) -> list:
app_id: str,
server: Optional[str] = None,
status: Optional[list[str]] = None,
limit: int = 50,
) -> List[JobSummary]:
"""
Get a list of all jobs for a Spark application.

Get a list of jobs for a Spark application.
Args:
app_id: The Spark application ID
server: Optional server name to use (uses default if not specified)
status: Optional list of job status values to filter by
status: Optional list of job status values to filter by (running|succeeded|failed|unknown)
limit: Maximum number of jobs to return (default: 50)

Returns:
List of JobData objects for the application
List of JobSummary objects for the application
"""
ctx = mcp.get_context()
client = get_client_or_default(ctx, server)
Expand All @@ -136,7 +139,43 @@ def list_jobs(
if status:
job_statuses = [JobExecutionStatus.from_string(s) for s in status]

return client.list_jobs(app_id=app_id, status=job_statuses)
jobs = client.list_jobs(app_id=app_id, status=job_statuses)

stages = client.list_stages(app_id=app_id, details=False)

job_summaries = [JobSummary.from_job_data(job, stages) for job in jobs]

if limit > 0:
job_summaries = job_summaries[:limit]

return job_summaries


@mcp.tool()
def get_job(
app_id: str,
job_id: int,
server: Optional[str] = None,
) -> JobSummary:
"""
Get information about a specific job.

Args:
app_id: The Spark application ID
job_id: The job ID
server: Optional server name to use (uses default if not specified)

Returns:
JobSummary object containing job information with stage IDs grouped by status
"""
ctx = mcp.get_context()
client = get_client_or_default(ctx, server)

job_data = client.get_job(app_id, job_id)

stages = client.list_stages(app_id=app_id, details=False)

return JobSummary.from_job_data(job_data, stages)


@mcp.tool()
Expand Down Expand Up @@ -190,6 +229,7 @@ def list_stages(
server: Optional[str] = None,
status: Optional[list[str]] = None,
with_summaries: bool = False,
limit: int = 20,
) -> list:
"""
Get a list of all stages for a Spark application.
Expand All @@ -202,6 +242,7 @@ def list_stages(
server: Optional server name to use (uses default if not specified)
status: Optional list of stage status values to filter by
with_summaries: Whether to include summary metrics in the response
limit: Maximum number of stages to return (default: 20)

Returns:
List of StageData objects for the application
Expand All @@ -214,12 +255,17 @@ def list_stages(
if status:
stage_statuses = [StageStatus.from_string(s) for s in status]

return client.list_stages(
stages = client.list_stages(
app_id=app_id,
status=stage_statuses,
with_summaries=with_summaries,
)

if limit > 0:
stages = stages[:limit]

return stages


@mcp.tool()
def list_slowest_stages(
Expand Down Expand Up @@ -939,12 +985,6 @@ def list_slowest_sql_queries(
# Create simplified results without additional API calls. Raw object is too verbose.
simplified_results = []
for execution in slowest_executions:
job_summary = JobSummary(
success_job_ids=execution.success_job_ids,
failed_job_ids=execution.failed_job_ids,
running_job_ids=execution.running_job_ids,
)

# Handle plan description based on include_plan_description flag
plan_description = ""
if include_plan_description and execution.plan_description:
Expand All @@ -961,7 +1001,9 @@ def list_slowest_sql_queries(
if execution.submission_time
else None,
plan_description=plan_description,
job_summary=job_summary,
success_job_ids=execution.success_job_ids,
failed_job_ids=execution.failed_job_ids,
running_job_ids=execution.running_job_ids,
)

simplified_results.append(query_summary)
Expand Down
Loading
Loading