diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 6f07fdd275..442046762a 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -301,6 +301,7 @@ jobs: - flytekit-data-fsspec - flytekit-dbt - flytekit-deck-standard + - flytekit-dgxc-lepton # TODO: remove dolt plugin - https://github.com/flyteorg/flyte/issues/5350 # flytekit-dolt - flytekit-duckdb diff --git a/plugins/flytekit-dgxc-lepton/.dockerignore b/plugins/flytekit-dgxc-lepton/.dockerignore new file mode 100644 index 0000000000..a9f064d303 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/.dockerignore @@ -0,0 +1,14 @@ +# Ignore deployment files from Docker build +deployment/ +examples/ +tests/ +*.md +.git +.gitignore +__pycache__/ +*.pyc +*.pyo +*.egg-info/ +build/ +dist/ +.pytest_cache/ diff --git a/plugins/flytekit-dgxc-lepton/Dockerfile.connector b/plugins/flytekit-dgxc-lepton/Dockerfile.connector new file mode 100644 index 0000000000..7a22f947f5 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/Dockerfile.connector @@ -0,0 +1,30 @@ +# Dockerfile for Lepton Agent +# This creates a standalone agent service for handling Lepton endpoint operations + +FROM ghcr.io/flyteorg/flytekit:py3.12-latest + +# Switch to root to handle file permissions +USER root + +# Install git for leptonai installation +RUN apt-get update && apt-get install -y git && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Copy and install our plugin +COPY --chown=flytekit:flytekit . /home/flytekit/dgxc-lepton-plugin +WORKDIR /home/flytekit/dgxc-lepton-plugin + +# Clean any existing build artifacts +RUN rm -rf *.egg-info build dist + +# Install the plugin +RUN pip install . + +# Switch back to flytekit user +USER flytekit + +# Expose agent service port +EXPOSE 8000 + +# Start the connector service +# The connector service will handle Lepton endpoint lifecycle operations via gRPC +CMD ["pyflyte", "serve", "connector", "--port", "8000", "--modules", "flytekitplugins.dgxc_lepton"] diff --git a/plugins/flytekit-dgxc-lepton/README.md b/plugins/flytekit-dgxc-lepton/README.md new file mode 100644 index 0000000000..b3fd08a5bb --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/README.md @@ -0,0 +1,375 @@ +# Flytekit DGXC Lepton Plugin + +A professional Flytekit plugin that enables seamless deployment and management of AI inference endpoints using Lepton AI infrastructure within Flyte workflows. + +## Overview + +This plugin provides: +- **Unified Task API** for deployment and management of Lepton AI endpoints +- **Type-safe configuration** with consolidated dataclasses and IDE support +- **Multiple endpoint engines**: VLLM, SGLang, NIM, and custom containers +- **Unified configuration classes** for scaling, environment, and mounts + +## Installation + +```bash +pip install flytekitplugins-dgxc-lepton +``` + +## Quick Start + +```python +from flytekit import workflow +from flytekitplugins.dgxc_lepton import ( + lepton_endpoint_deployment_task, lepton_endpoint_deletion_task, LeptonEndpointConfig, + EndpointEngineConfig, EnvironmentConfig, ScalingConfig +) + +@workflow +def inference_workflow() -> str: + """Deploy Llama model using VLLM and return endpoint URL.""" + + # Complete configuration in one place + config = LeptonEndpointConfig( + endpoint_name="my-llama-endpoint", + resource_shape="gpu.1xh200", + node_group="your-node-group", + endpoint_config=EndpointEngineConfig.vllm( + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + served_model_name="llama-3.1-8b-instruct", + ), + environment=EnvironmentConfig.create( + LOG_LEVEL="INFO", + secrets={"HF_TOKEN": "hf-secret"} + ), + scaling=ScalingConfig.traffic(min_replicas=1, max_replicas=2), + ) + + # Deploy endpoint and return URL + return lepton_endpoint_deployment_task(config=config) +``` + +## API Reference + +### Core Components + +#### `lepton_endpoint_deployment_task(config: LeptonEndpointConfig) -> str` +Main function for Lepton AI endpoint deployment. + +**Parameters:** +- `config`: Complete endpoint configuration +- `task_name`: Optional custom task name + +**Returns:** +- Endpoint URL for successful deployment + +#### `lepton_endpoint_deletion_task(endpoint_name: str, ...) -> str` +Function for Lepton AI endpoint deletion. + +**Parameters:** +- `endpoint_name`: Name of the endpoint to delete +- `task_name`: Optional custom task name + +**Returns:** +- Success message confirming deletion + +#### `LeptonEndpointConfig` +Unified configuration for all Lepton endpoint operations. + +**Required Fields:** +- `endpoint_name`: Name of the endpoint +- `resource_shape`: Hardware resource specification (e.g., "gpu.1xh200") +- `node_group`: Target node group for deployment +- `endpoint_config`: Engine-specific configuration + +**Optional Fields:** +- `scaling`: Auto-scaling configuration +- `environment`: Environment variables and secrets +- `mounts`: Storage mount configurations +- `api_token`/`api_token_secret`: Authentication +- `image_pull_secrets`: Container registry secrets +- `endpoint_readiness_timeout`: Deployment timeout + +### Endpoint Engine Configuration + +#### `EndpointEngineConfig` +Unified configuration for different inference engines. + + +##### VLLM Engine +```python +EndpointEngineConfig.vllm( + image="vllm/vllm-openai:latest", + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + served_model_name="default-model", + tensor_parallel_size=1, + pipeline_parallel_size=1, + data_parallel_size=1, + extra_args="--max-model-len 4096", + port=8000 +) +``` + +##### SGLang Engine +```python +EndpointEngineConfig.sglang( + image="lmsysorg/sglang:latest", + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + tensor_parallel_size=1, + data_parallel_size=1, + extra_args="--context-length 4096", + port=30000 +) +``` + +##### NVIDIA NIM +```python +EndpointEngineConfig.nim( + image="nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest", + port=8000 +) +``` + +##### Custom Container +```python +EndpointEngineConfig.custom( + image="python:3.11-slim", + command=["/bin/bash", "-c", "python3 -m http.server 8080"], + port=8080 +) +``` + +### Scaling Configuration + +#### `ScalingConfig` +Unified auto-scaling configuration with enforced single strategy. + + +##### Traffic-based Scaling +```python +ScalingConfig.traffic( + min_replicas=1, + max_replicas=5, + timeout=1800 # Scale down after 30 min of no traffic +) +``` + +##### GPU Utilization Scaling +```python +ScalingConfig.gpu( + target_utilization=80, # Target 80% GPU utilization + min_replicas=1, + max_replicas=10 +) +``` + +##### QPM (Queries Per Minute) Scaling +```python +ScalingConfig.qpm( + target_qpm=100.5, # Target queries per minute + min_replicas=2, + max_replicas=8 +) +``` + +### Environment Configuration + +#### `EnvironmentConfig` +Unified configuration for environment variables and secrets. + +**Factory Methods:** + +##### Environment Variables Only +```python +EnvironmentConfig.from_env( + LOG_LEVEL="DEBUG", + MODEL_PATH="/models", + CUDA_VISIBLE_DEVICES="0,1" +) +``` + +##### Secrets Only +```python +EnvironmentConfig.from_secrets( + HF_TOKEN="hf-secret", + NGC_API_KEY="ngc-secret" +) +``` + +##### Mixed Configuration +```python +EnvironmentConfig.create( + LOG_LEVEL="INFO", + MODEL_PATH="/models", + secrets={ + "HF_TOKEN": "hf-secret", + "NGC_API_KEY": "ngc-secret" + } +) +``` + +### Mount Configuration + +#### `MountReader` +Simplified NFS mount configuration. + +```python +MountReader.node_nfs( + ("/shared-storage/models", "/opt/models"), + ("/shared-storage/data", "/opt/data"), + ("/shared-storage/logs", "/opt/logs", False), # Disabled mount + storage_name="production-nfs" # Custom storage name +) +``` + +## Complete Examples + +### VLLM Deployment with Auto-scaling + +```python +from flytekit import workflow +from flytekitplugins.dgxc_lepton import ( + lepton_endpoint_deployment_task, LeptonEndpointConfig, + EndpointEngineConfig, EnvironmentConfig, ScalingConfig, MountReader +) + +@workflow +def deploy_vllm_with_scaling() -> str: + """Deploy VLLM with traffic-based auto-scaling.""" + + config = LeptonEndpointConfig( + endpoint_name="vllm-llama-3.1-8b", + resource_shape="gpu.1xh200", + node_group="inference-nodes", + endpoint_config=EndpointEngineConfig.vllm( + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + served_model_name="llama-3.1-8b-instruct", + tensor_parallel_size=1, + extra_args="--max-model-len 8192 --enable-chunked-prefill" + ), + environment=EnvironmentConfig.create( + LOG_LEVEL="INFO", + CUDA_VISIBLE_DEVICES="0", + secrets={"HF_TOKEN": "hf-secret"} + ), + scaling=ScalingConfig.traffic( + min_replicas=1, + max_replicas=3, + timeout=1800 + ), + mounts=MountReader.node_nfs( + ("/shared-storage/models", "/opt/models"), + ("/shared-storage/cache", "/root/.cache") + ), + api_token_secret="lepton-api-token", + image_pull_secrets=["hf-secret"], + endpoint_readiness_timeout=600 + ) + + return lepton_endpoint_deployment_task(config=config) +``` + +### NIM Deployment with QPM Scaling + +```python +@workflow +def deploy_nim_with_qpm_scaling() -> str: + """Deploy NVIDIA NIM with QPM-based scaling.""" + + config = LeptonEndpointConfig( + endpoint_name="nemotron-super-reasoning", + resource_shape="gpu.1xh200", + node_group="nim-nodes", + endpoint_config=EndpointEngineConfig.nim( + image="nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest" + ), + environment=EnvironmentConfig.create( + OMPI_ALLOW_RUN_AS_ROOT="1", + secrets={"NGC_API_KEY": "ngc-secret"} + ), + scaling=ScalingConfig.qpm( + target_qpm=2.5, + min_replicas=1, + max_replicas=3 + ), + image_pull_secrets=["ngc-secret"], + api_token="UNIQUE_ENDPOINT_TOKEN" + ) + + return lepton_endpoint_deployment_task(config=config) +``` + +### Custom Container Deployment + +```python +@workflow +def deploy_custom_service() -> str: + """Deploy custom inference service.""" + + config = LeptonEndpointConfig( + endpoint_name="custom-inference-api", + resource_shape="cpu.large", + node_group="cpu-nodes", + endpoint_config=EndpointEngineConfig.custom( + image="my-registry/inference-api:v1.0", + command=["python", "app.py"], + port=8080 + ), + environment=EnvironmentConfig.from_env( + LOG_LEVEL="DEBUG", + API_VERSION="v1", + WORKERS="4" + ), + scaling=ScalingConfig.gpu( + target_utilization=70, + min_replicas=2, + max_replicas=6 + ) + ) + + return lepton_endpoint_deployment_task(config=config) +``` +## Configuration Requirements + +Replace these placeholders with your actual values: +- ``: Your Kubernetes node group for GPU workloads +- ``: Your NGC registry pull secret name +- `/shared-storage/model-cache/*`: Your shared storage paths for model caching +- `NGC_API_KEY`: Your NGC API key secret name +- `HUGGING_FACE_HUB_TOKEN_read`: Your HuggingFace token secret name + +## Monitoring & Debugging + +```bash +# Monitor connector logs +kubectl logs -n flyte deployment/lepton-connector --follow + +# Check Lepton console (URLs auto-generated in Flyte execution view) + +# List recent executions +pyflyte get executions -p flytesnacks -d development --limit 5 + +## Development + +### Running Tests + +```bash +pytest tests/test_lepton.py -v +``` + +### Plugin Registration + +The plugin automatically registers with Flytekit's dynamic plugin loading system: + +```python +# Automatic registration enables this usage pattern +task = LeptonEndpointDeploymentTask(config=config) +``` + +## Support + +For issues, questions, or contributions, please refer to the Flytekit documentation and Lepton AI platform documentation. + +## License + +This plugin follows the same license as Flytekit. diff --git a/plugins/flytekit-dgxc-lepton/examples/basic_inference.py b/plugins/flytekit-dgxc-lepton/examples/basic_inference.py new file mode 100644 index 0000000000..e07ec17955 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/examples/basic_inference.py @@ -0,0 +1,62 @@ +""" +Example: Basic Inference using Task API + +This example demonstrates the new task-based API for Lepton AI +endpoint lifecycle management, showing different workflow patterns. + +""" + +from flytekitplugins.dgxc_lepton import ( + EndpointEngineConfig, + EnvironmentConfig, + LeptonEndpointConfig, + ScalingConfig, + lepton_endpoint_deletion_task, + lepton_endpoint_deployment_task, +) + +from flytekit import workflow + + +@workflow +def basic_inference_workflow() -> str: + """Simple workflow using the unified task API.""" + + # Complete configuration in one place + config = LeptonEndpointConfig( + endpoint_name="basic-inference-endpoint", + resource_shape="cpu.small", + node_group="", # Replace with your actual node group name + endpoint_config=EndpointEngineConfig.custom( + image="python:3.11-slim", + command=["/bin/bash", "-c", "python3 -m http.server 8080 --bind 0.0.0.0"], + port=8080, + ), + api_token="BASIC_ENDPOINT_TOKEN", + scaling=ScalingConfig.traffic(min_replicas=1, max_replicas=1, timeout=1800), + environment=EnvironmentConfig.from_env( + LOG_LEVEL="INFO", + SERVER_MODE="production", + ), + endpoint_readiness_timeout=300, + ) + + return lepton_endpoint_deployment_task(config=config, task_name="basic-inference-endpoint-v3") + + +@workflow +def basic_inference_cleanup_workflow() -> str: + """Cleanup workflow to delete the deployed endpoint.""" + + # Much simpler deletion - just the endpoint name! + return lepton_endpoint_deletion_task( + endpoint_name="basic-inference-endpoint", task_name="basic-inference-cleanup-v2" + ) + + +if __name__ == "__main__": + # Local execution examples + + # Example: Deploy endpoint + result = basic_inference_workflow() + print(f"Endpoint URL: {result}") diff --git a/plugins/flytekit-dgxc-lepton/examples/nemotron_nim_example.py b/plugins/flytekit-dgxc-lepton/examples/nemotron_nim_example.py new file mode 100644 index 0000000000..da08b4b376 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/examples/nemotron_nim_example.py @@ -0,0 +1,52 @@ +""" +Example: NVIDIA NIM using Task API + +This example demonstrates deploying NVIDIA NIM models using the +task API for maximum flexibility and type safety. +""" + +from flytekitplugins.dgxc_lepton import ( + EndpointEngineConfig, + EnvironmentConfig, + LeptonEndpointConfig, + MountReader, + ScalingConfig, + lepton_endpoint_deployment_task, +) + +from flytekit import workflow + + +@workflow +def nemotron_super_workflow() -> str: + """Deploy Nemotron Super for advanced AI reasoning.""" + + # Complete configuration in one place + config = LeptonEndpointConfig( + endpoint_name="nemotron-super-reasoning", + resource_shape="gpu.1xh200", + node_group="", # Replace with your actual GPU node group + endpoint_config=EndpointEngineConfig.nim( + image="nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest", + ), + api_token="UNIQUE_ENDPOINT_TOKEN", + scaling=ScalingConfig.qpm(target_qpm=2.5, min_replicas=1, max_replicas=3), + environment=EnvironmentConfig.create( + OMPI_ALLOW_RUN_AS_ROOT="1", + OMPI_ALLOW_RUN_AS_ROOT_CONFIRM="1", + NIM_MODEL_NAME="nvidia/llama-3_3-nemotron-super-49b-v1_5", + SERVED_MODEL_NAME="nvidia/llama-3_3-nemotron-super-49b-v1_5", + secrets={ + "NGC_API_KEY": "", + "HF_TOKEN": "HUGGING_FACE_HUB_TOKEN_read", + }, + ), + initial_delay_seconds=5000, + image_pull_secrets=[""], + mounts=MountReader.node_nfs( + ("/nim-cache", "/opt/nim/.cache"), + ("/test-datasets", "/opt/nim/datasets"), + ), + ) + + return lepton_endpoint_deployment_task(config=config, task_name="nemotron_super_workflow") diff --git a/plugins/flytekit-dgxc-lepton/examples/vllm_example.py b/plugins/flytekit-dgxc-lepton/examples/vllm_example.py new file mode 100644 index 0000000000..cfe9132cbb --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/examples/vllm_example.py @@ -0,0 +1,92 @@ +""" +Example: vLLM Deployment using Task API with Parameterized Functions + +This example demonstrates how the task API enables clean, reusable +vLLM deployments with parameterized functions for different environments. +""" + +from typing import Optional + +from flytekitplugins.dgxc_lepton import ( + EndpointEngineConfig, + EnvironmentConfig, + LeptonEndpointConfig, + MountReader, + ScalingConfig, + lepton_endpoint_deployment_task, +) + +from flytekit import workflow + + +def deploy_vllm_model( + endpoint_name: str, + resource_shape: str, + node_group: str, + mounts: MountReader, + checkpoint_path: str = "meta-llama/Llama-3.1-8B-Instruct", + served_model_name: Optional[str] = None, + enable_secrets: bool = True, +) -> str: + """Parameterized function to deploy vLLM models with configurable infrastructure. + + Args: + endpoint_name (str): Name for the deployed endpoint + resource_shape (str): GPU configuration (e.g., "gpu.1xh200", "gpu.1xa10") + node_group (str): Kubernetes node group for deployment + mounts (MountReader): Mount configuration for model storage + checkpoint_path (str): HuggingFace model path + served_model_name (Optional[str]): Name to serve the model as (defaults to endpoint_name) + enable_secrets (bool): Whether to include HuggingFace token secret + + Returns: + str: URL of the deployed vLLM endpoint + """ + # Use endpoint_name as served_model_name if not provided + if served_model_name is None: + served_model_name = endpoint_name + + # Complete configuration in one place + config = LeptonEndpointConfig( + endpoint_name=endpoint_name, + resource_shape=resource_shape, + node_group=node_group, + endpoint_config=EndpointEngineConfig.vllm( + image="vllm/vllm-openai:latest", + checkpoint_path=checkpoint_path, + served_model_name=served_model_name, + tensor_parallel_size=1, + pipeline_parallel_size=1, + data_parallel_size=1, + extra_args="--gpu-memory-utilization 0.95 --trust-remote-code", + ), + api_token="VLLM_ENDPOINT_TOKEN", + scaling=ScalingConfig.traffic(min_replicas=1, max_replicas=2, timeout=3600, scale_to_zero=True), + environment=EnvironmentConfig.create( + HF_HOME="/opt/vllm/.cache", + VLLM_WORKER_MULTIPROC_METHOD="spawn", + secrets={"HF_TOKEN": "HUGGING_FACE_HUB_TOKEN_read_updated"} if enable_secrets else None, + ), + mounts=mounts, + ) + + return lepton_endpoint_deployment_task(config=config, task_name="deploy_vllm_model_v3") + + +@workflow +def vllm_inference_workflow() -> str: + """Deploy Llama 3.1 8B model using vLLM with production configuration.""" + + # Configure production mounts + model_cache_mounts = MountReader.node_nfs( + ("/vllm-models-cache", "/opt/vllm/.cache"), + ) + + return deploy_vllm_model( + endpoint_name="vllm-llama-inference", + resource_shape="gpu.1xh200", + node_group="", # Replace with your actual GPU node group + mounts=model_cache_mounts, + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + served_model_name="llama-3.1-8b-instruct", + ) diff --git a/plugins/flytekit-dgxc-lepton/flytekitplugins/__init__.py b/plugins/flytekit-dgxc-lepton/flytekitplugins/__init__.py new file mode 100644 index 0000000000..9874acdb41 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/flytekitplugins/__init__.py @@ -0,0 +1 @@ +# Namespace package diff --git a/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/__init__.py b/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/__init__.py new file mode 100644 index 0000000000..3917144e1b --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/__init__.py @@ -0,0 +1,56 @@ +""" +.. currentmodule:: flytekitplugins.dgxc_lepton + +This package contains things that are useful when extending Flytekit for Lepton AI integration. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + lepton_endpoint_deployment_task + lepton_endpoint_deletion_task + LeptonEndpointConfig + LeptonEndpointDeploymentTask + LeptonEndpointDeletionTask + EndpointType + EnvironmentConfig + MountReader + ScalingConfig + ScalingType + EndpointEngineConfig +""" + +# Clean imports with consolidated classes +# Import connector module to trigger connector registration (connectors are not part of public API) +from . import connector # noqa: F401 +from .config import ( + EndpointEngineConfig, + EndpointType, + EnvironmentConfig, + LeptonEndpointConfig, + MountReader, + ScalingConfig, + ScalingType, +) +from .task import ( + LeptonEndpointDeletionTask, + LeptonEndpointDeploymentTask, + lepton_endpoint_deletion_task, + lepton_endpoint_deployment_task, +) + +__all__ = [ + # Task API + "lepton_endpoint_deployment_task", + "lepton_endpoint_deletion_task", + "LeptonEndpointConfig", + "LeptonEndpointDeploymentTask", + "LeptonEndpointDeletionTask", + "EndpointType", + # Configuration classes + "EnvironmentConfig", + "MountReader", + "ScalingConfig", + "ScalingType", + "EndpointEngineConfig", +] diff --git a/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/config.py b/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/config.py new file mode 100644 index 0000000000..e79a64ee57 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/config.py @@ -0,0 +1,503 @@ +"""Configuration classes for Lepton AI inference endpoints.""" + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Union + + +class EndpointType(Enum): + """Supported endpoint types for Lepton AI.""" + + CUSTOM = "custom" + VLLM = "vllm" + NIM = "nim" + SGLANG = "sglang" + + +class ScalingType(Enum): + """Supported scaling types for Lepton AI.""" + + TRAFFIC = "traffic" + GPU = "gpu" + QPM = "qpm" + + +@dataclass(frozen=True) +class ScalingConfig: + """Unified scaling configuration that enforces only one scaling type. + + This ensures that only one scaling strategy can be used at a time, + preventing conflicting scaling configurations. + + Attributes: + scaling_type (ScalingType): The type of scaling strategy to use + min_replicas (int): Minimum number of replicas + max_replicas (int): Maximum number of replicas + timeout (Optional[int]): Timeout in seconds for traffic-based scaling + scale_to_zero (Optional[bool]): Whether to allow scaling to zero replicas + target_utilization (Optional[int]): Target GPU utilization percentage for GPU-based scaling + target_qpm (Optional[float]): Target queries per minute for QPM-based scaling + + Examples: + # Traffic-based scaling + ScalingConfig.traffic(min_replicas=1, max_replicas=3, timeout=1800) + + # GPU-based scaling + ScalingConfig.gpu(target_utilization=80, min_replicas=1, max_replicas=5) + + # QPM-based scaling + ScalingConfig.qpm(target_qpm=100.0, min_replicas=2, max_replicas=4) + """ + + scaling_type: ScalingType + min_replicas: int = 1 + max_replicas: int = 2 + + # Traffic scaling parameters + timeout: Optional[int] = None + scale_to_zero: Optional[bool] = None + + # GPU scaling parameters + target_utilization: Optional[int] = None + + # QPM scaling parameters + target_qpm: Optional[float] = None + + def __post_init__(self): + """Validate scaling configuration parameters.""" + if self.scaling_type == ScalingType.TRAFFIC: + if self.timeout is None: + object.__setattr__(self, "timeout", 3600) # Default timeout + if self.scale_to_zero is None: + object.__setattr__(self, "scale_to_zero", False) + elif self.scaling_type == ScalingType.GPU: + if self.target_utilization is None: + raise ValueError("target_utilization is required for GPU scaling") + if not 1 <= self.target_utilization <= 100: + raise ValueError("target_utilization must be between 1 and 100") + elif self.scaling_type == ScalingType.QPM: + if self.target_qpm is None: + raise ValueError("target_qpm is required for QPM scaling") + if self.target_qpm <= 0: + raise ValueError("target_qpm must be positive") + + def to_dict(self) -> Dict[str, Any]: + """Convert to Lepton scaling configuration.""" + if self.scaling_type == ScalingType.TRAFFIC: + return { + "scale_down": {"no_traffic_timeout": self.timeout, "scale_from_zero": self.scale_to_zero}, + "target_gpu_utilization_percentage": 0, # Disabled + "target_throughput": {"qpm": 0, "paths": [], "methods": []}, + } + elif self.scaling_type == ScalingType.GPU: + return { + "scale_down": {"no_traffic_timeout": 0, "scale_from_zero": False}, + "target_gpu_utilization_percentage": self.target_utilization, + "target_throughput": {"qpm": 0, "paths": [], "methods": []}, + } + elif self.scaling_type == ScalingType.QPM: + return { + "scale_down": {"no_traffic_timeout": 0, "scale_from_zero": False}, + "target_gpu_utilization_percentage": 0, # Disabled + "target_throughput": {"qpm": self.target_qpm, "paths": [], "methods": []}, + } + else: + raise ValueError(f"Unknown scaling type: {self.scaling_type}") + + def get_replica_config(self) -> Dict[str, int]: + """Get replica configuration.""" + return { + "min_replicas": self.min_replicas, + "max_replicas": self.max_replicas, + } + + @classmethod + def traffic( + cls, min_replicas: int = 1, max_replicas: int = 2, timeout: int = 3600, scale_to_zero: bool = False + ) -> "ScalingConfig": + """Create traffic-based scaling configuration.""" + return cls( + scaling_type=ScalingType.TRAFFIC, + min_replicas=min_replicas, + max_replicas=max_replicas, + timeout=timeout, + scale_to_zero=scale_to_zero, + ) + + @classmethod + def gpu(cls, target_utilization: int, min_replicas: int = 1, max_replicas: int = 3) -> "ScalingConfig": + """Create GPU utilization-based scaling configuration.""" + return cls( + scaling_type=ScalingType.GPU, + min_replicas=min_replicas, + max_replicas=max_replicas, + target_utilization=target_utilization, + ) + + @classmethod + def qpm(cls, target_qpm: float, min_replicas: int = 1, max_replicas: int = 3) -> "ScalingConfig": + """Create queries-per-minute based scaling configuration.""" + return cls( + scaling_type=ScalingType.QPM, min_replicas=min_replicas, max_replicas=max_replicas, target_qpm=target_qpm + ) + + +@dataclass(frozen=True) +class EnvironmentConfig: + """Unified environment variable configuration for Lepton deployments. + + Handles both regular environment variables and secret references. + """ + + env_vars: Dict[str, str] = None + secrets: Dict[str, str] = None + + def __post_init__(self): + if self.env_vars is None: + object.__setattr__(self, "env_vars", {}) + if self.secrets is None: + object.__setattr__(self, "secrets", {}) + + def to_dict(self) -> Dict[str, Any]: + """Convert to Lepton-compatible environment variables.""" + result = {} + + # Add regular environment variables + result.update(self.env_vars) + + # Add secret references + for env_var, secret_name in self.secrets.items(): + result[env_var] = {"value_from": {"secret_name_ref": secret_name}} + + return result + + @classmethod + def create( + cls, env_vars: Optional[Dict[str, str]] = None, secrets: Optional[Dict[str, str]] = None, **kwargs + ) -> "EnvironmentConfig": + """Create EnvironmentConfig with environment variables and secrets. + + Args: + env_vars (Optional[Dict[str, str]]): Regular environment variables + secrets (Optional[Dict[str, str]]): Environment variables from secrets (env_var -> secret_name) + **kwargs: Additional environment variables (treated as regular env vars) + + Returns: + EnvironmentConfig: Configured environment instance + """ + final_env_vars = {} + if env_vars: + final_env_vars.update(env_vars) + final_env_vars.update(kwargs) + + return cls(env_vars=final_env_vars, secrets=secrets or {}) + + @classmethod + def from_secrets(cls, **secret_mapping: str) -> "EnvironmentConfig": + """Create EnvironmentConfig with only secrets.""" + return cls(env_vars={}, secrets=secret_mapping) + + @classmethod + def from_env(cls, **env_vars: str) -> "EnvironmentConfig": + """Create EnvironmentConfig with only environment variables.""" + return cls(env_vars=env_vars, secrets={}) + + +@dataclass(frozen=True) +class EndpointEngineConfig: + """Unified endpoint configuration that enforces only one endpoint type. + + This class consolidates all endpoint-specific configurations into a single, + type-safe configuration that ensures only one endpoint type can be specified. + """ + + endpoint_type: EndpointType + image: str + port: int = 8000 + + # VLLM-specific parameters + checkpoint_path: Optional[str] = None + served_model_name: Optional[str] = None + tensor_parallel_size: Optional[int] = None + pipeline_parallel_size: Optional[int] = None + data_parallel_size: Optional[int] = None + extra_args: Optional[str] = None + + # Custom container parameters + command: Optional[List[str]] = None + + def __post_init__(self): + """Validate configuration based on endpoint type.""" + if self.endpoint_type == EndpointType.VLLM: + # Set VLLM defaults if not provided + if self.checkpoint_path is None: + object.__setattr__(self, "checkpoint_path", "meta-llama/Llama-3.1-8B-Instruct") + if self.served_model_name is None: + object.__setattr__(self, "served_model_name", "default-model") + if self.tensor_parallel_size is None: + object.__setattr__(self, "tensor_parallel_size", 1) + if self.pipeline_parallel_size is None: + object.__setattr__(self, "pipeline_parallel_size", 1) + if self.data_parallel_size is None: + object.__setattr__(self, "data_parallel_size", 1) + if self.extra_args is None: + object.__setattr__(self, "extra_args", "") + + elif self.endpoint_type == EndpointType.SGLANG: + # Set SGLang defaults if not provided + if self.checkpoint_path is None: + object.__setattr__(self, "checkpoint_path", "meta-llama/Llama-3.1-8B-Instruct") + if self.served_model_name is None: + object.__setattr__(self, "served_model_name", "default-model") + if self.tensor_parallel_size is None: + object.__setattr__(self, "tensor_parallel_size", 1) + if self.data_parallel_size is None: + object.__setattr__(self, "data_parallel_size", 1) + if self.extra_args is None: + object.__setattr__(self, "extra_args", "") + + elif self.endpoint_type == EndpointType.CUSTOM: + # Custom containers use different default port + if self.port == 8000: # Only change if using default + object.__setattr__(self, "port", 8080) + + def to_dict(self) -> Dict[str, Any]: + """Convert to endpoint parameters based on endpoint type.""" + base_dict = {"image": self.image, "port": self.port} + + if self.endpoint_type in [EndpointType.VLLM, EndpointType.SGLANG]: + base_dict.update( + { + "checkpoint_path": self.checkpoint_path, + "served_model_name": self.served_model_name, + "tensor_parallel_size": self.tensor_parallel_size, + "data_parallel_size": self.data_parallel_size, + "extra_args": self.extra_args, + } + ) + + # VLLM-specific parameter + if self.endpoint_type == EndpointType.VLLM: + base_dict["pipeline_parallel_size"] = self.pipeline_parallel_size + + elif self.endpoint_type == EndpointType.CUSTOM: + if self.command: + base_dict["command"] = self.command + + # NIM uses base configuration as-is + + return base_dict + + @classmethod + def vllm( + cls, + image: str = "vllm/vllm-openai:latest", + checkpoint_path: str = "meta-llama/Llama-3.1-8B-Instruct", + served_model_name: str = "default-model", + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + data_parallel_size: int = 1, + extra_args: str = "", + port: int = 8000, + ) -> "EndpointEngineConfig": + """Create VLLM endpoint configuration.""" + return cls( + endpoint_type=EndpointType.VLLM, + image=image, + port=port, + checkpoint_path=checkpoint_path, + served_model_name=served_model_name, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + data_parallel_size=data_parallel_size, + extra_args=extra_args, + ) + + @classmethod + def sglang( + cls, + image: str = "lmsysorg/sglang:latest", + checkpoint_path: str = "meta-llama/Llama-3.1-8B-Instruct", + served_model_name: str = "default-model", + tensor_parallel_size: int = 1, + data_parallel_size: int = 1, + extra_args: str = "", + port: int = 8000, + ) -> "EndpointEngineConfig": + """Create SGLang endpoint configuration.""" + return cls( + endpoint_type=EndpointType.SGLANG, + image=image, + port=port, + checkpoint_path=checkpoint_path, + served_model_name=served_model_name, + tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, + extra_args=extra_args, + ) + + @classmethod + def nim(cls, image: str, port: int = 8000) -> "EndpointEngineConfig": + """Create NIM endpoint configuration.""" + return cls(endpoint_type=EndpointType.NIM, image=image, port=port) + + @classmethod + def custom(cls, image: str, command: Optional[List[str]] = None, port: int = 8080) -> "EndpointEngineConfig": + """Create custom container endpoint configuration.""" + return cls(endpoint_type=EndpointType.CUSTOM, image=image, port=port, command=command) + + +@dataclass(frozen=True) +class MountReader: + """Mount configuration for Lepton deployments.""" + + mounts: List[Dict[str, Any]] + + def to_dict(self) -> List[Dict[str, Any]]: + """Convert to Lepton-compatible mount configuration.""" + return self.mounts + + @classmethod + def node_nfs(cls, *path_pairs, storage_name: str = "lepton-shared-fs") -> "MountReader": + """Create NFS mounts from path pairs. + + Args: + *path_pairs: Tuples of (cache_path, mount_path) or (cache_path, mount_path, enabled) + storage_name (str): NFS storage identifier + + Returns: + MountReader: Configured instance + + Raises: + ValueError: If path_pair format is invalid + """ + mounts = [] + for path_pair in path_pairs: + if len(path_pair) == 2: + cache_path, mount_path = path_pair + enabled = True + elif len(path_pair) == 3: + cache_path, mount_path, enabled = path_pair + else: + raise ValueError( + f"Path pair must be (cache_path, mount_path) or (cache_path, mount_path, enabled), got {path_pair}" + ) + + mounts.append( + { + "enabled": enabled, + "cache_path": cache_path, + "mount_path": mount_path, + "storage_source": f"node-nfs:{storage_name}", + } + ) + + return cls(mounts=mounts) + + +@dataclass(frozen=True) +class LeptonEndpointConfig: + """Complete configuration for Lepton AI endpoint deployment. + + This single configuration class follows standard Flytekit plugin patterns + by containing all parameters needed for endpoint deployment in one place. + + """ + + # Required parameters + endpoint_name: str + resource_shape: str + node_group: str + + # Unified endpoint engine configuration + endpoint_config: EndpointEngineConfig + + # Optional configurations + scaling: Optional[ScalingConfig] = None + environment: Optional[Union[EnvironmentConfig, Dict[str, str]]] = None + mounts: Optional[Union[MountReader, List[Dict[str, Any]]]] = None + + # Authentication + api_token: Optional[str] = None + api_token_secret: Optional[str] = None + + # Health configuration + initial_delay_seconds: Optional[int] = None + liveness_timeout: Optional[int] = None + readiness_delay: Optional[int] = None + + # Container and deployment configuration + image_pull_secrets: Optional[List[str]] = None + endpoint_readiness_timeout: Optional[int] = None + + def __post_init__(self): + """Validate the configuration after initialization.""" + # Validate required fields for deployment operations + if not self.endpoint_name or not self.resource_shape or not self.node_group: + raise ValueError("endpoint_name, resource_shape, and node_group are required for deployment") + + # Validate API token configuration + if self.api_token and self.api_token_secret: + raise ValueError("Cannot specify both api_token and api_token_secret") + + # Validate that endpoint_config type matches the endpoint_config.endpoint_type + # (The EndpointConfig class handles its own internal validation) + + def to_dict(self) -> Dict[str, Any]: + """Convert the configuration to a dictionary for the connector.""" + config = { + "deployment_type": self.endpoint_config.endpoint_type.value, + "endpoint_name": self.endpoint_name, + "resource_shape": self.resource_shape, + "node_group": self.node_group, + "operation": "deploy", + } + + # Add endpoint configuration + config.update(self.endpoint_config.to_dict()) + + # Build environment variables from unified environment configuration + if self.environment: + env_vars_dict = ( + self.environment.to_dict() if isinstance(self.environment, EnvironmentConfig) else self.environment + ) + if env_vars_dict: + config["env_vars"] = env_vars_dict + + # Add mounts + if self.mounts: + config["mounts"] = self.mounts.to_dict() if hasattr(self.mounts, "to_dict") else self.mounts + + # Add API tokens + if self.api_token: + config["api_tokens"] = [{"value": self.api_token}] + elif self.api_token_secret: + config["api_tokens"] = [{"value_from": {"secret_name_ref": self.api_token_secret}}] + + # Add scaling configuration + if self.scaling: + config["auto_scaler"] = self.scaling.to_dict() + config.update(self.scaling.get_replica_config()) + + # Add health configuration + health_fields = [self.initial_delay_seconds, self.liveness_timeout, self.readiness_delay] + if any(health_fields): + health_config = {"enable_collection": True} + if self.initial_delay_seconds or self.liveness_timeout: + liveness = {} + if self.initial_delay_seconds: + liveness["initial_delay_seconds"] = self.initial_delay_seconds + if self.liveness_timeout: + liveness["timeout_seconds"] = self.liveness_timeout + health_config["liveness"] = liveness + if self.readiness_delay: + health_config["readiness"] = {"initial_delay_seconds": self.readiness_delay} + config["health_config"] = health_config + + # Add other fields + if self.image_pull_secrets: + config["image_pull_secrets"] = self.image_pull_secrets + if self.endpoint_readiness_timeout: + config["endpoint_readiness_timeout"] = self.endpoint_readiness_timeout + + return config diff --git a/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/connector.py b/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/connector.py new file mode 100644 index 0000000000..a2bccf3f2d --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/connector.py @@ -0,0 +1,594 @@ +"""Optimized Lepton AI connector with improved performance and reduced redundancy.""" + +import asyncio +import re +import time +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict, Optional + +from flyteidl.admin.agent_pb2 import GetTaskLogsResponse +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from leptonai.api.v1.deployment import DeploymentAPI +from leptonai.api.v1.log import LogAPI +from leptonai.api.v1.types.common import Metadata +from leptonai.api.v1.types.deployment import LeptonDeployment, LeptonDeploymentUserSpec +from leptonai.api.v2.client import APIClient + +from flytekit import current_context +from flytekit.extend.backend.base_connector import AsyncConnectorBase, ConnectorRegistry, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.loggers import logger +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +def log_info(message: str) -> None: + """Optimized unified logging helper.""" + logger.info(message) + print(message, flush=True) # More efficient than separate flush call + + +@dataclass +class LeptonMetadata(ResourceMeta): + deployment_name: str + + +class BaseLeptonConnector(AsyncConnectorBase): + """Base class for Lepton connectors with shared functionality.""" + + # Optimized class constants + ORIGIN_URL = "https://gateway.dgxc-lepton.nvidia.com" + + # Consolidated status mapping with frozenset for faster lookup + _PENDING_STATES = frozenset(["starting", "scaling", "updating", "not ready", "notready"]) + _SUCCESS_STATES = frozenset(["ready", "deleting", "stopping", "terminating", "stopped", "terminated"]) + _FAILED_STATES = frozenset(["unknown", "unk"]) + + # Pre-compiled regex for better performance + NAME_SANITIZE_REGEX = re.compile(r"[^a-z0-9\-]") + + @staticmethod + def _get_deployment_command(deployment_type: str, config: Dict[str, Any]) -> list: + """Generate deployment command based on type and config.""" + if deployment_type == "vllm": + return [ + "vllm", + "serve", + config.get("checkpoint_path", "meta-llama/Llama-3.1-8B-Instruct"), + f"--tensor-parallel-size={int(config.get('tensor_parallel_size', 1))}", + f"--pipeline-parallel-size={int(config.get('pipeline_parallel_size', 1))}", + f"--data-parallel-size={int(config.get('data_parallel_size', 1))}", + f"--port={int(config.get('port', 8000))}", + f"--served-model-name={config.get('served_model_name', 'default-model')}", + ] + (config.get("extra_args", "").split() if config.get("extra_args") else []) + elif deployment_type == "sglang": + return [ + "python3", + "-m", + "sglang.launch_server", + f"--model-path={config.get('checkpoint_path', 'meta-llama/Llama-3.1-8B-Instruct')}", + "--host=0.0.0.0", + f"--port={int(config.get('port', 8000))}", + f"--served-model-name={config.get('served_model_name', 'default-model')}", + f"--tp={int(config.get('tensor_parallel_size', 1))}", + f"--dp={int(config.get('data_parallel_size', 1))}", + ] + (config.get("extra_args", "").split() if config.get("extra_args") else []) + elif deployment_type == "custom": + return config.get( + "command", ["/bin/bash", "-c", f"python3 -m http.server {int(config.get('port', 8080))} --bind 0.0.0.0"] + ) + else: # nim or unknown + return [] + + def __init__(self, task_type_name: str, connector_name: str): + super().__init__(task_type_name=task_type_name, metadata_type=LeptonMetadata) + self.name = connector_name + # Single client with lazy-loaded API surfaces + self._client: Optional[APIClient] = None + self._deployment_api: Optional[DeploymentAPI] = None + self._log_api: Optional[LogAPI] = None + self._dashboard_base_url: Optional[str] = None + + @lru_cache(maxsize=128) + def _sanitize_endpoint_name(self, name: str) -> str: + """Optimized endpoint name sanitization with caching.""" + sanitized = self.NAME_SANITIZE_REGEX.sub("-", name.lower()) + + # Ensure valid start/end characters + if not sanitized or not sanitized[0].isalpha(): + sanitized = "ep-" + sanitized + sanitized = sanitized.rstrip("-") + if not sanitized or not sanitized[-1].isalnum(): + sanitized = sanitized + "x" + + # Efficient truncation with deterministic suffix + if len(sanitized) > 36: + return sanitized[:33] + f"{hash(name) % 1000:03d}" + return sanitized + + def _get_dashboard_base_url(self) -> str: + """Get dashboard base URL dynamically from workspace origin URL""" + if self._dashboard_base_url is not None: + return self._dashboard_base_url + + # Get origin URL from client + try: + origin_url = ( + self.client.workspace_origin_url if hasattr(self.client, "workspace_origin_url") else self.ORIGIN_URL + ) + except Exception: + origin_url = self.ORIGIN_URL + + # Convert gateway URL to dashboard URL + if "gateway.dgxc-lepton.nvidia.com" in origin_url: + self._dashboard_base_url = origin_url.replace( + "gateway.dgxc-lepton.nvidia.com", "dashboard.dgxc-lepton.nvidia.com" + ) + elif "gateway.lepton.ai" in origin_url: + self._dashboard_base_url = origin_url.replace("gateway.lepton.ai", "dashboard.lepton.ai") + else: + # Fallback to hardcoded URL if pattern doesn't match + self._dashboard_base_url = "https://dashboard.dgxc-lepton.nvidia.com" + log_info(f"Warning: Using fallback dashboard URL for unknown origin: {origin_url}") + + log_info(f"Dynamic dashboard base URL: {self._dashboard_base_url}") + return self._dashboard_base_url + + def _build_logs_url(self, deployment_name: str, workspace_id: str) -> str: + """Build dynamic logs URL for deployment""" + dashboard_base = self._get_dashboard_base_url() + return f"{dashboard_base}/workspace/{workspace_id}/compute/deployments/detail/{deployment_name}/replicas/list#/deployment/{deployment_name}/logs" + + @property + def client(self) -> APIClient: + """Lazy-loaded, cached API client""" + if self._client is not None: + return self._client + + import os + + workspace_id = os.environ.get("LEPTON_WORKSPACE_ID") + token = os.environ.get("LEPTON_TOKEN") + + # Fallback to Flyte secrets + if not workspace_id or not token: + try: + secrets = current_context().secrets + workspace_id = workspace_id or secrets.get("lepton_workspace_id") + token = token or secrets.get("lepton_token") + except Exception: + pass + + if not workspace_id or not token: + raise ValueError("Missing Lepton credentials: LEPTON_WORKSPACE_ID and LEPTON_TOKEN") + + log_info(f"Creating Lepton API client for workspace: {workspace_id}") + self._client = APIClient(workspace_id=workspace_id, auth_token=token, workspace_origin_url=self.ORIGIN_URL) + return self._client + + @property + def deployment_api(self) -> DeploymentAPI: + """Lazy-loaded deployment API""" + if self._deployment_api is None: + self._deployment_api = DeploymentAPI(self.client) + return self._deployment_api + + @property + def log_api(self) -> LogAPI: + """Lazy-loaded log API""" + if self._log_api is None: + self._log_api = LogAPI(self.client) + return self._log_api + + def _map_status_to_phase(self, status: str) -> TaskExecution.Phase: + """Optimized status mapping using frozensets for faster lookup.""" + if not status: + return convert_to_flyte_phase("pending") + + status_lower = status.lower() + if status_lower in self._PENDING_STATES: + return convert_to_flyte_phase("pending") + elif status_lower in self._SUCCESS_STATES: + return convert_to_flyte_phase("succeeded") + elif status_lower in self._FAILED_STATES: + return convert_to_flyte_phase("failed") + else: + return convert_to_flyte_phase("pending") # Default fallback + + def _build_spec(self, config: Dict[str, Any]) -> Dict[str, Any]: + """Optimized spec generation with reduced complexity.""" + deployment_type = config.get("deployment_type", "custom") + port = config.get("port", 8080 if deployment_type == "custom" else 8000) + + log_info(f"Building spec for {deployment_type}, port: {port}") + + # Base container spec + container = {"image": config.get("image", "python:3.11-slim"), "ports": [{"container_port": port}]} + + # Add command using optimized method + command = self._get_deployment_command(deployment_type, config) + if command: + container["command"] = command + log_info(f"Generated command: {' '.join(command[:3])}...") + + # Build environment variables efficiently + envs = [] + + # Auto-derived environment variables + if config.get("served_model_name"): + envs.append({"name": "SERVED_MODEL_NAME", "value": config["served_model_name"]}) + if config.get("port"): + envs.append({"name": "MODEL_PORT", "value": str(port)}) + if deployment_type in ("vllm", "sglang") and config.get("checkpoint_path"): + envs.append({"name": "MODEL_PATH", "value": config["checkpoint_path"]}) + if deployment_type in ("vllm", "sglang"): + envs.append({"name": "TENSOR_PARALLEL_SIZE", "value": str(config.get("tensor_parallel_size", 1))}) + if deployment_type == "nim" and config.get("served_model_name"): + envs.append({"name": "NIM_MODEL_NAME", "value": config["served_model_name"]}) + + # User-defined environment variables + for key, value in config.get("env_vars", {}).items(): + env_var = {"name": key} + if isinstance(value, dict) and "value_from" in value: + env_var["value_from"] = value["value_from"] + else: + env_var["value"] = str(value) + envs.append(env_var) + + # Build API tokens + api_tokens = [] + for token_config in config.get("api_tokens", []): + if isinstance(token_config, dict): + if "value" in token_config: + api_tokens.append({"value": str(token_config["value"])}) + elif "value_from" in token_config: + api_tokens.append({"value_from": token_config["value_from"]}) + else: + api_tokens.append({"value": str(token_config)}) + + # Assemble final spec more efficiently + spec = { + "container": container, + "resource_requirement": { + "resource_shape": config.get("resource_shape", "cpu.small"), + "min_replicas": config.get("min_replicas", 1), + "max_replicas": config.get("max_replicas", 1), + }, + "envs": envs, + "health": config.get("health_config", {}), + "log": {"enable_collection": config.get("log_config", {}).get("enable_collection", True)}, + "metrics": {}, + "routing_policy": {}, + "enable_rdma": config.get("enable_rdma", False), + "user_security_context": {}, + } + + # Add optional configurations + if config.get("auto_scaler"): + spec["auto_scaler"] = config["auto_scaler"] + if config.get("queue_config"): + spec["queue_config"] = config["queue_config"] + if config.get("image_pull_secrets"): + spec["image_pull_secrets"] = config["image_pull_secrets"] + if api_tokens: + spec["api_tokens"] = api_tokens + + # Node group affinity + if config.get("node_group"): + spec["resource_requirement"]["affinity"] = {"allowed_dedicated_node_groups": [config["node_group"]]} + + # Storage mounts - simplified handling + mounts_config = config.get("mounts") + if mounts_config: + mounts_list = ( + mounts_config.get("mounts") + if isinstance(mounts_config, dict) and "mounts" in mounts_config + else mounts_config + if isinstance(mounts_config, list) + else [] + ) + + if mounts_list: + spec["mounts"] = [ + { + "path": mount["cache_path"], + "from": mount.get("storage_source", "node-nfs:lepton-shared-fs"), + "mount_path": mount["mount_path"], + "mount_options": mount.get("mount_options", {}), + } + for mount in mounts_list + if isinstance(mount, dict) and mount.get("enabled", True) + ] + + return spec + + def _extract_config_from_inputs(self, task_template: TaskTemplate, inputs: Optional[LiteralMap]) -> Dict[str, Any]: + """Extract config from inputs using Flytekit's proper TypeEngine approach""" + if not inputs or not inputs.literals: + return {} + + try: + # Use the proper Flytekit pattern from Perian and Snowflake connectors + from flytekit import FlyteContextManager + from flytekit.core.type_engine import TypeEngine + + ctx = FlyteContextManager.current_context() + literal_types = task_template.interface.inputs + + # Convert literals to native Python objects using TypeEngine + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) + log_info(f"Native inputs from TypeEngine: {list(native_inputs.keys())}") + + config = {} + + # Handle the request dataclass specially + if "request" in native_inputs: + request_data = native_inputs["request"] + log_info(f"Found request data: {type(request_data)}") + + # If it's a dataclass instance, extract its fields + if hasattr(request_data, "__dataclass_fields__"): + log_info(f"Request is a dataclass with fields: {list(request_data.__dataclass_fields__.keys())}") + # Extract all fields from the dataclass instance + for field_name in request_data.__dataclass_fields__.keys(): + value = getattr(request_data, field_name) + if value is not None: + config[field_name] = value + log_info(f"Extracted {field_name} from dataclass: {value}") + + log_info( + f"Successfully extracted {len([v for v in config.values() if v is not None])} parameters from dataclass" + ) + elif isinstance(request_data, dict): + log_info(f"Request is a dict with keys: {list(request_data.keys())}") + # Extract all fields from the dict + for param, value in request_data.items(): + if value is not None: + config[param] = value + log_info(f"Extracted {param} from dict: {value}") + else: + # Store as-is if it's neither a dataclass nor dict + config["request"] = request_data + log_info(f"Stored request as-is: {type(request_data)}") + + # Add any other top-level inputs + for key, value in native_inputs.items(): + if key != "request" and value is not None: + config[key] = value + log_info(f"Added top-level input {key}: {value}") + + log_info(f"Final config keys: {list(config.keys())}") + return config + + except Exception as e: + log_info(f"Error extracting config from inputs: {e}") + import traceback + + log_info(f"Traceback: {traceback.format_exc()}") + # Fallback to empty config + return {} + + async def create( + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> LeptonMetadata: + """Handle endpoint operations - to be overridden by subclasses""" + log_info(f"Processing task: {task_template.id}") + + # Efficient config merging + task_config = task_template.custom if isinstance(task_template.custom, dict) else {} + input_config = self._extract_config_from_inputs(task_template, inputs) + config = {**task_config, **input_config} + + log_info(f"Config keys: {list(config.keys())}") + + # Delegate to specific implementation + return await self._handle_operation(config, task_template) + + async def _handle_operation(self, config: Dict[str, Any], task_template: TaskTemplate) -> LeptonMetadata: + """To be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement _handle_operation") + + async def _handle_deployment(self, config: Dict[str, Any], task_template: TaskTemplate) -> LeptonMetadata: + """Handle endpoint deployment""" + # Generate name and spec + endpoint_name = config.get("endpoint_name") or f"{task_template.id.name}-{int(time.time())}" + endpoint_name = self._sanitize_endpoint_name(endpoint_name) + + log_info(f"Creating: {endpoint_name}") + + try: + spec = self._build_spec(config) + + # Debug logging for node group + log_info(f"Node group in config: {config.get('node_group')}") + log_info(f"Resource shape in config: {config.get('resource_shape')}") + if "resource_requirement" in spec and "affinity" in spec["resource_requirement"]: + log_info(f"Affinity in spec: {spec['resource_requirement']['affinity']}") + else: + log_info("No affinity found in spec") + log_info(f"Full spec resource_requirement: {spec.get('resource_requirement', {})}") + + # Create deployment object + metadata = Metadata(id_=endpoint_name, name=endpoint_name) + user_spec = LeptonDeploymentUserSpec(**spec) + deployment_obj = LeptonDeployment(metadata=metadata, spec=user_spec) + + # Execute in thread pool (SDK is sync) + await asyncio.get_event_loop().run_in_executor(None, lambda: self.deployment_api.create(deployment_obj)) + + log_info(f"Created: {endpoint_name}") + return LeptonMetadata(deployment_name=endpoint_name) + + except Exception as e: + log_info(f"Failed: {e}") + raise RuntimeError(f"Failed to create Lepton endpoint: {e}") + + async def _handle_deletion(self, config: Dict[str, Any], task_template: TaskTemplate) -> LeptonMetadata: + """Handle endpoint deletion""" + endpoint_name = config.get("endpoint_name") or task_template.id.name + endpoint_name = self._sanitize_endpoint_name(endpoint_name) + + log_info(f"Deleting: {endpoint_name}") + + try: + # Execute deletion in thread pool (SDK is sync) + await asyncio.get_event_loop().run_in_executor(None, lambda: self.deployment_api.delete(endpoint_name)) + + log_info(f"Deleted: {endpoint_name}") + return LeptonMetadata(deployment_name=endpoint_name) + + except Exception as e: + log_info(f"Failed to delete: {e}") + raise RuntimeError(f"Failed to delete Lepton endpoint: {e}") + + async def get(self, resource_meta: LeptonMetadata, **kwargs) -> Resource: + """Optimized status checking""" + deployment_name = resource_meta.deployment_name + log_info(f"Checking: {deployment_name}") + + try: + deployment = await asyncio.get_event_loop().run_in_executor( + None, lambda: self.deployment_api.get(deployment_name) + ) + + if not deployment: + return Resource(phase=convert_to_flyte_phase("failed")) + + # Extract state efficiently + state = "unknown" + if deployment.status and deployment.status.state: + state_str = str(deployment.status.state) + state = state_str.split(".")[-1].lower() if "LeptonDeploymentState." in state_str else state_str.lower() + + log_info(f"State: {state}") + flyte_phase = self._map_status_to_phase(state) + + # Build dynamic log links + workspace_id = self.client.workspace_id + logs_url = self._build_logs_url(deployment_name, workspace_id) + log_links = [TaskLog(uri=logs_url, name="Lepton Console")] + + # Get recent logs to embed in message for better visibility + log_message = f"Status: {state}" + try: + # Call our existing get_logs method to get recent logs + log_response = self.get_logs(resource_meta, **kwargs) + if log_response.logs: + # Get last few log lines for preview + recent_logs = log_response.logs[-3:] # Last 3 lines + if recent_logs: + log_preview = " | ".join([line.strip() for line in recent_logs if line.strip()]) + log_message = f"Status: {state} | Recent: {log_preview[:200]}..." # Truncate for UI + except Exception: + # If log retrieval fails, just use basic status message + pass + + # Return appropriate response + if flyte_phase == TaskExecution.SUCCEEDED: + # Get endpoint URL + external_endpoint = "Unknown URL" + if deployment.status and deployment.status.endpoint and deployment.status.endpoint.external_endpoint: + external_endpoint = deployment.status.endpoint.external_endpoint + + log_info(f"Ready: {external_endpoint}") + return Resource( + phase=flyte_phase, + message=f"Ready: {external_endpoint} | {log_message}", + log_links=log_links, + outputs={"o0": external_endpoint}, + ) + else: + log_info(f"Not ready: {state}") + return Resource( + phase=flyte_phase, + message=log_message, + log_links=log_links, + ) + + except Exception as e: + log_info(f"Error: {e}") + + # For deletion operations, a 404 error means successful deletion + if "404" in str(e) and "not found" in str(e).lower(): + log_info(f"Deletion confirmed: {deployment_name} not found (404)") + return Resource( + phase=TaskExecution.SUCCEEDED, + message=f"Successfully deleted: {deployment_name}", + outputs={"o0": f"Successfully deleted endpoint: {deployment_name}"}, + ) + + return Resource(phase=convert_to_flyte_phase("failed")) + + async def delete(self, resource_meta: LeptonMetadata, **kwargs): + """Optimized deletion with idempotency""" + deployment_name = resource_meta.deployment_name + + try: + log_info(f"Deleting: {deployment_name}") + + await asyncio.get_event_loop().run_in_executor(None, lambda: self.deployment_api.delete(deployment_name)) + + log_info(f"Deleted: {deployment_name}") + + except Exception as e: + error_msg = str(e).lower() + # Idempotent deletion + if any(term in error_msg for term in ["not found", "does not exist", "404"]): + log_info(f"Already deleted: {deployment_name}") + else: + logger.error(f"Delete failed: {e}") + raise RuntimeError(f"Failed to delete: {e}") + + def get_logs(self, resource_meta: LeptonMetadata, **kwargs) -> GetTaskLogsResponse: + """Optimized log retrieval""" + deployment_name = resource_meta.deployment_name + + try: + log_info(f"Getting logs: {deployment_name}") + + logs_result = self.log_api.get_log(deployment=deployment_name, num=1000) + log_lines = logs_result.split("\n") if isinstance(logs_result, str) else [] + + # Build consistent console URL using same logic as get() method + workspace_id = self.client.workspace_id + console_url = self._build_logs_url(deployment_name, workspace_id) + + return GetTaskLogsResponse( + logs=log_lines, token="", log_links=[TaskLog(name="Lepton Console", uri=console_url)] + ) + + except Exception as e: + logger.error(f"Log retrieval failed: {e}") + return GetTaskLogsResponse(logs=[], token="", log_links=[]) + + +class LeptonEndpointDeploymentConnector(BaseLeptonConnector): + """Connector for Lepton endpoint deployment operations.""" + + def __init__(self): + super().__init__( + task_type_name="lepton_endpoint_deployment_task", connector_name="Lepton Endpoint Deployment Connector" + ) + + async def _handle_operation(self, config: Dict[str, Any], task_template: TaskTemplate) -> LeptonMetadata: + """Handle deployment operations""" + return await self._handle_deployment(config, task_template) + + +class LeptonEndpointDeletionConnector(BaseLeptonConnector): + """Connector for Lepton endpoint deletion operations.""" + + def __init__(self): + super().__init__( + task_type_name="lepton_endpoint_deletion_task", connector_name="Lepton Endpoint Deletion Connector" + ) + + async def _handle_operation(self, config: Dict[str, Any], task_template: TaskTemplate) -> LeptonMetadata: + """Handle deletion operations""" + return await self._handle_deletion(config, task_template) + + +# Register both connectors +ConnectorRegistry.register(LeptonEndpointDeploymentConnector()) +ConnectorRegistry.register(LeptonEndpointDeletionConnector()) diff --git a/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/task.py b/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/task.py new file mode 100644 index 0000000000..63750412bc --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/flytekitplugins/dgxc_lepton/task.py @@ -0,0 +1,176 @@ +"""Lepton AI endpoint task implementation.""" + +import os +import warnings +from typing import Any, Dict, Optional + +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend import TaskPlugins +from flytekit.extend.backend.base_connector import AsyncConnectorExecutorMixin + +from .config import LeptonEndpointConfig + + +class LeptonEndpointDeploymentTask(AsyncConnectorExecutorMixin, PythonTask): + """Task for Lepton AI endpoint deployment operations. + + This task follows standard Flytekit patterns by using a single configuration + class that contains all necessary parameters. + + Args: + config (LeptonEndpointConfig): Complete Lepton configuration + **kwargs: Additional task parameters + """ + + _TASK_TYPE = "lepton_endpoint_deployment_task" + + def __init__(self, config: LeptonEndpointConfig, **kwargs): + # Validate that we're running on Lepton platform + self._validate_lepton_platform() + + # Create interface - no inputs needed since config has everything + interface = Interface( + inputs={}, + outputs={"o0": str}, + ) + + # Use config directly + task_name = kwargs.pop("name", "lepton_deployment_task") + + super().__init__( + name=task_name, + task_type=self._TASK_TYPE, + task_config=config.to_dict(), + interface=interface, + **kwargs, + ) + + def _validate_lepton_platform(self): + """Validate that this task can only run on Lepton platform.""" + # This could be enhanced to check environment variables or other platform indicators + # For now, we'll add a basic check that can be expanded + import os + + # Check for Lepton-specific environment indicators + lepton_indicators = [ + "LEPTON_WORKSPACE_ID", + "LEPTON_API_TOKEN", + "LEPTON_PLATFORM", + "DGXC_LEPTON_PLATFORM", # Custom indicator + ] + + # In production, you might want to be more strict + # For development/testing, we'll be more permissive + has_lepton_indicator = any(os.getenv(indicator) for indicator in lepton_indicators) + + # Allow override for development + if os.getenv("LEPTON_PLATFORM_VALIDATION", "true").lower() == "false": + return + + if not has_lepton_indicator: + # This is a warning rather than hard error to allow development + # In production, you might want to raise an exception + import warnings + + warnings.warn( + "Lepton endpoint tasks are designed to run on Lepton platform. " + "Set DGXC_LEPTON_PLATFORM=true or disable validation with " + "LEPTON_PLATFORM_VALIDATION=false for development.", + UserWarning, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + """Return custom task attributes for serialization.""" + return self.task_config + + +def lepton_endpoint_deployment_task(config: LeptonEndpointConfig, task_name: Optional[str] = None) -> str: + """Function for Lepton AI endpoint deployment. + + Args: + config (LeptonEndpointConfig): Complete Lepton configuration including endpoint details + task_name (Optional[str]): Optional custom task name + + Returns: + str: Endpoint URL for successful deployment + """ + # Create and execute the task + task = LeptonEndpointDeploymentTask(config=config, name=task_name) + return task() + + +class LeptonEndpointDeletionTask(AsyncConnectorExecutorMixin, PythonTask): + """Task for Lepton AI endpoint deletion operations. + + This task only requires an endpoint name for simple deletions. + + Args: + endpoint_name (str): Name of the endpoint to delete + **kwargs: Additional task parameters + """ + + _TASK_TYPE = "lepton_endpoint_deletion_task" + + def __init__(self, endpoint_name: str, **kwargs): + self._validate_lepton_platform() + + # Build minimal config for deletion - only endpoint name needed + task_config = { + "endpoint_name": endpoint_name, + } + + interface = Interface( + inputs={}, + outputs={"o0": str}, + ) + + task_name = kwargs.pop("name", "lepton_deletion_task") + + super().__init__( + name=task_name, + task_type=self._TASK_TYPE, + task_config=task_config, + interface=interface, + **kwargs, + ) + + def _validate_lepton_platform(self): + """Validate that this task can only run on Lepton platform.""" + lepton_indicators = ["LEPTON_WORKSPACE_ID", "LEPTON_API_TOKEN", "LEPTON_PLATFORM", "DGXC_LEPTON_PLATFORM"] + + has_lepton_indicator = any(os.getenv(indicator) for indicator in lepton_indicators) + + if os.getenv("LEPTON_PLATFORM_VALIDATION", "true").lower() == "false": + return + + if not has_lepton_indicator: + warnings.warn( + "Lepton tasks are designed to run on Lepton platform. " + "Set DGXC_LEPTON_PLATFORM=true or disable validation with " + "LEPTON_PLATFORM_VALIDATION=false for development.", + UserWarning, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + """Return custom task attributes for serialization.""" + return self.task_config + + +def lepton_endpoint_deletion_task(endpoint_name: str, task_name: Optional[str] = None) -> str: + """Function for Lepton AI endpoint deletion. + + Args: + endpoint_name (str): Name of the endpoint to delete + task_name (Optional[str]): Optional custom task name + + Returns: + str: Success message confirming deletion + """ + task = LeptonEndpointDeletionTask(endpoint_name=endpoint_name, name=task_name) + return task() + + +# Register the Lepton endpoint plugins with Flytekit's dynamic plugin loading system +TaskPlugins.register_pythontask_plugin(LeptonEndpointConfig, LeptonEndpointDeploymentTask) diff --git a/plugins/flytekit-dgxc-lepton/setup.py b/plugins/flytekit-dgxc-lepton/setup.py new file mode 100644 index 0000000000..8a507b1ca3 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/setup.py @@ -0,0 +1,35 @@ +from setuptools import setup + +PLUGIN_NAME = "dgxc-lepton" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.9.1,<2.0.0", "leptonai"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="Anshul Jindal", + author_email="ansjindal@nvidia.com", + description="DGXC Lepton Flytekit plugin for inference endpoints", + long_description="Flytekit DGXC Lepton Plugin - AI inference endpoints using Lepton AI infrastructure", + long_description_content_type="text/markdown", + packages=["flytekitplugins.dgxc_lepton"], + install_requires=plugin_requires, + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": ["dgxc_lepton=flytekitplugins.dgxc_lepton"]}, +) diff --git a/plugins/flytekit-dgxc-lepton/tests/__init__.py b/plugins/flytekit-dgxc-lepton/tests/__init__.py new file mode 100644 index 0000000000..66173aec46 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/tests/__init__.py @@ -0,0 +1 @@ +# Test package diff --git a/plugins/flytekit-dgxc-lepton/tests/test_lepton.py b/plugins/flytekit-dgxc-lepton/tests/test_lepton.py new file mode 100644 index 0000000000..935a0ef482 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/tests/test_lepton.py @@ -0,0 +1,392 @@ +import asyncio +import pytest +from collections import OrderedDict +from unittest.mock import MagicMock, patch + +import grpc +from flyteidl.core.execution_pb2 import TaskExecution +from flytekitplugins.dgxc_lepton import ( + LeptonEndpointConfig, + LeptonEndpointDeploymentTask, + LeptonEndpointDeletionTask, + EndpointType, + EndpointEngineConfig, + EnvironmentConfig, + ScalingConfig, + MountReader, + lepton_endpoint_deployment_task, + lepton_endpoint_deletion_task, +) + +from flytekit import Resources +from flytekit.configuration import DefaultImages, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable +from flytekit.extend.backend.base_connector import ConnectorRegistry +from flytekit.models.literals import Literal, LiteralMap, Primitive, Scalar + + +class TestLeptonEndpointConfig: + """Test the unified LeptonEndpointConfig class.""" + + def test_basic_vllm_config(self): + """Test basic VLLM endpoint configuration.""" + config = LeptonEndpointConfig( + endpoint_name="test-vllm", + resource_shape="gpu.1xh100", + node_group="test-group", + endpoint_config=EndpointEngineConfig.vllm( + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct" + ), + api_token="test-token", + ) + + config_dict = config.to_dict() + assert config_dict["endpoint_name"] == "test-vllm" + assert config_dict["resource_shape"] == "gpu.1xh100" + assert config_dict["node_group"] == "test-group" + assert config_dict["deployment_type"] == "vllm" + assert config_dict["checkpoint_path"] == "meta-llama/Llama-3.1-8B-Instruct" + assert config_dict["api_tokens"] == [{"value": "test-token"}] + + def test_nim_config_with_secrets(self): + """Test NIM endpoint with secrets.""" + config = LeptonEndpointConfig( + endpoint_name="test-nim", + resource_shape="gpu.1xh200", + node_group="test-group", + endpoint_config=EndpointEngineConfig.nim( + image="nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest" + ), + environment=EnvironmentConfig.create( + OMPI_ALLOW_RUN_AS_ROOT="1", + secrets={"NGC_API_KEY": "ngc-secret"} + ), + scaling=ScalingConfig.qpm(target_qpm=2.5, min_replicas=1, max_replicas=3), + image_pull_secrets=["ngc-secret"], + ) + + config_dict = config.to_dict() + assert config_dict["deployment_type"] == "nim" + assert config_dict["image"] == "nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest" + assert config_dict["image_pull_secrets"] == ["ngc-secret"] + assert "env_vars" in config_dict + assert config_dict["env_vars"]["OMPI_ALLOW_RUN_AS_ROOT"] == "1" + assert config_dict["env_vars"]["NGC_API_KEY"]["value_from"]["secret_name_ref"] == "ngc-secret" + + def test_custom_config_with_scaling(self): + """Test custom endpoint with traffic scaling.""" + config = LeptonEndpointConfig( + endpoint_name="test-custom", + resource_shape="cpu.small", + node_group="test-group", + endpoint_config=EndpointEngineConfig.custom( + image="python:3.11-slim", + command=["/bin/bash", "-c", "python3 -m http.server 8080"], + port=8080, + ), + scaling=ScalingConfig.traffic(min_replicas=1, max_replicas=2, timeout=1800), + environment=EnvironmentConfig.from_env(LOG_LEVEL="INFO"), + ) + + config_dict = config.to_dict() + assert config_dict["deployment_type"] == "custom" + assert config_dict["image"] == "python:3.11-slim" + assert config_dict["command"] == ["/bin/bash", "-c", "python3 -m http.server 8080"] + assert config_dict["port"] == 8080 + assert "auto_scaler" in config_dict + assert config_dict["min_replicas"] == 1 + assert config_dict["max_replicas"] == 2 + + def test_deletion_config(self): + """Test deletion configuration - now handled by separate deletion task.""" + # Note: Deletion is now handled by lepton_endpoint_deletion_task() + # which only needs endpoint_name, not a full config + + # This test now just verifies we can create a minimal config for testing + config = LeptonEndpointConfig( + endpoint_name="test-delete", + resource_shape="cpu.small", + node_group="test-group", + endpoint_config=EndpointEngineConfig.custom(image="dummy"), + ) + + config_dict = config.to_dict() + assert config_dict["operation"] == "deploy" # Always deploy for LeptonEndpointConfig + assert config_dict["endpoint_name"] == "test-delete" + + def test_validation_errors(self): + """Test configuration validation.""" + # Test missing required fields for deployment + with pytest.raises(ValueError, match="endpoint_name, resource_shape, and node_group are required"): + LeptonEndpointConfig( + endpoint_name="", + resource_shape="gpu.1xh100", + node_group="test-group", + endpoint_config=EndpointEngineConfig.vllm(), + ) + + # Test both api_token and api_token_secret specified + with pytest.raises(ValueError, match="Cannot specify both api_token and api_token_secret"): + LeptonEndpointConfig( + endpoint_name="test", + resource_shape="gpu.1xh100", + node_group="test-group", + endpoint_config=EndpointEngineConfig.vllm(), + api_token="token", + api_token_secret="secret", + ) + + +class TestEndpointEngineConfig: + """Test the unified EndpointEngineConfig class.""" + + def test_vllm_factory_method(self): + """Test VLLM factory method.""" + config = EndpointEngineConfig.vllm( + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + tensor_parallel_size=2, + ) + + assert config.endpoint_type == EndpointType.VLLM + assert config.checkpoint_path == "meta-llama/Llama-3.1-8B-Instruct" + assert config.tensor_parallel_size == 2 + assert config.pipeline_parallel_size == 1 # default + + config_dict = config.to_dict() + assert config_dict["checkpoint_path"] == "meta-llama/Llama-3.1-8B-Instruct" + assert config_dict["tensor_parallel_size"] == 2 + assert config_dict["pipeline_parallel_size"] == 1 + assert config_dict["data_parallel_size"] == 1 + + def test_sglang_factory_method(self): + """Test SGLang factory method.""" + config = EndpointEngineConfig.sglang( + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + ) + + assert config.endpoint_type == EndpointType.SGLANG + config_dict = config.to_dict() + assert "pipeline_parallel_size" not in config_dict # SGLang doesn't use this + assert config_dict["data_parallel_size"] == 1 + + def test_nim_factory_method(self): + """Test NIM factory method.""" + config = EndpointEngineConfig.nim( + image="nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest" + ) + + assert config.endpoint_type == EndpointType.NIM + assert config.image == "nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest" + + def test_custom_factory_method(self): + """Test custom factory method.""" + config = EndpointEngineConfig.custom( + image="python:3.11-slim", + command=["/bin/bash", "-c", "python3 -m http.server 8080"], + port=8080, + ) + + assert config.endpoint_type == EndpointType.CUSTOM + assert config.image == "python:3.11-slim" + assert config.command == ["/bin/bash", "-c", "python3 -m http.server 8080"] + assert config.port == 8080 + + +class TestScalingConfig: + """Test the unified ScalingConfig class.""" + + def test_traffic_scaling(self): + """Test traffic-based scaling.""" + config = ScalingConfig.traffic(min_replicas=1, max_replicas=3, timeout=1800) + + assert config.scaling_type.value == "traffic" + assert config.min_replicas == 1 + assert config.max_replicas == 3 + assert config.timeout == 1800 + + config_dict = config.to_dict() + assert config_dict["scale_down"]["no_traffic_timeout"] == 1800 + assert config_dict["target_gpu_utilization_percentage"] == 0 + + def test_gpu_scaling(self): + """Test GPU utilization scaling.""" + config = ScalingConfig.gpu(target_utilization=80, min_replicas=1, max_replicas=5) + + assert config.scaling_type.value == "gpu" + assert config.target_utilization == 80 + + config_dict = config.to_dict() + assert config_dict["target_gpu_utilization_percentage"] == 80 + + def test_qpm_scaling(self): + """Test QPM-based scaling.""" + config = ScalingConfig.qpm(target_qpm=100.5, min_replicas=2, max_replicas=4) + + assert config.scaling_type.value == "qpm" + assert config.target_qpm == 100.5 + + config_dict = config.to_dict() + assert config_dict["target_throughput"]["qpm"] == 100.5 + + def test_scaling_validation(self): + """Test scaling configuration validation.""" + # Test invalid GPU utilization + with pytest.raises(ValueError, match="target_utilization must be between 1 and 100"): + ScalingConfig.gpu(target_utilization=150) + + # Test negative QPM + with pytest.raises(ValueError, match="target_qpm must be positive"): + ScalingConfig.qpm(target_qpm=-5.0) + + +class TestEnvironmentConfig: + """Test the unified EnvironmentConfig class.""" + + def test_env_vars_only(self): + """Test environment variables only.""" + config = EnvironmentConfig.from_env(LOG_LEVEL="DEBUG", MODEL_PATH="/models") + + config_dict = config.to_dict() + assert config_dict["LOG_LEVEL"] == "DEBUG" + assert config_dict["MODEL_PATH"] == "/models" + + def test_secrets_only(self): + """Test secrets only.""" + config = EnvironmentConfig.from_secrets(HF_TOKEN="hf-secret", NGC_API_KEY="ngc-secret") + + config_dict = config.to_dict() + assert config_dict["HF_TOKEN"]["value_from"]["secret_name_ref"] == "hf-secret" + assert config_dict["NGC_API_KEY"]["value_from"]["secret_name_ref"] == "ngc-secret" + + def test_mixed_config(self): + """Test mixed environment variables and secrets.""" + config = EnvironmentConfig.create( + LOG_LEVEL="DEBUG", + MODEL_PATH="/models", + secrets={"HF_TOKEN": "hf-secret"} + ) + + config_dict = config.to_dict() + assert config_dict["LOG_LEVEL"] == "DEBUG" + assert config_dict["MODEL_PATH"] == "/models" + assert config_dict["HF_TOKEN"]["value_from"]["secret_name_ref"] == "hf-secret" + + +class TestMountReader: + """Test the simplified MountReader class.""" + + def test_node_nfs_basic(self): + """Test basic NFS mount configuration.""" + mounts = MountReader.node_nfs( + ("/shared-storage/models", "/opt/models"), + ("/shared-storage/data", "/opt/data"), + ) + + mount_list = mounts.to_dict() + assert len(mount_list) == 2 + assert mount_list[0]["cache_path"] == "/shared-storage/models" + assert mount_list[0]["mount_path"] == "/opt/models" + assert mount_list[0]["enabled"] is True + assert mount_list[0]["storage_source"] == "node-nfs:lepton-shared-fs" + + def test_node_nfs_with_disabled(self): + """Test NFS mount with disabled mount.""" + mounts = MountReader.node_nfs( + ("/shared-storage/models", "/opt/models"), + ("/shared-storage/logs", "/opt/logs", False), # Disabled + ) + + mount_list = mounts.to_dict() + assert mount_list[0]["enabled"] is True + assert mount_list[1]["enabled"] is False + + def test_node_nfs_custom_storage(self): + """Test NFS mount with custom storage name.""" + mounts = MountReader.node_nfs( + ("/prod-storage/models", "/opt/models"), + storage_name="production-nfs" + ) + + mount_list = mounts.to_dict() + assert mount_list[0]["storage_source"] == "node-nfs:production-nfs" + + +class TestLeptonEndpointDeploymentTask: + """Test the unified LeptonEndpointDeploymentTask class.""" + + def test_task_creation(self): + """Test task creation with configuration.""" + config = LeptonEndpointConfig( + endpoint_name="test-endpoint", + resource_shape="gpu.1xh100", + node_group="test-group", + endpoint_config=EndpointEngineConfig.vllm(), + ) + + task = LeptonEndpointDeploymentTask(config=config, name="test-task") + + assert task.name == "test-task" + assert task._TASK_TYPE == "lepton_endpoint_deployment_task" + assert task.task_config == config.to_dict() + + def test_lepton_endpoint_deployment_task_function(self): + """Test the lepton_endpoint_deployment_task convenience function.""" + config = LeptonEndpointConfig( + endpoint_name="test-endpoint", + resource_shape="gpu.1xh100", + node_group="test-group", + endpoint_config=EndpointEngineConfig.custom(image="test-image"), + ) + + # Test that the function creates a task (without executing it) + task = LeptonEndpointDeploymentTask(config=config, name="test-function") + + # Verify the task was created correctly + assert task.name == "test-function" + assert task._TASK_TYPE == "lepton_endpoint_deployment_task" + assert task.task_config == config.to_dict() + + def test_lepton_endpoint_deletion_task_function(self): + """Test the lepton_endpoint_deletion_task convenience function.""" + # Test that the function creates a deletion task (without executing it) + task = LeptonEndpointDeletionTask(endpoint_name="test-endpoint", name="test-deletion") + + # Verify the task was created correctly + assert task.name == "test-deletion" + assert task._TASK_TYPE == "lepton_endpoint_deletion_task" + assert task.task_config == {"endpoint_name": "test-endpoint"} + + +class TestConnectorIntegration: + """Test connector integration with new configuration.""" + + @patch('flytekitplugins.dgxc_lepton.connector.APIClient') + def test_connector_with_new_config(self, mock_api_client): + """Test that connector works with new configuration format.""" + # Mock the API client + mock_client = MagicMock() + mock_api_client.return_value = mock_client + + # Create configuration + config = LeptonEndpointConfig( + endpoint_name="test-connector", + resource_shape="gpu.1xh100", + node_group="test-group", + endpoint_config=EndpointEngineConfig.vllm( + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct" + ), + api_token="test-token", + ) + + # Test that config converts properly + config_dict = config.to_dict() + + # Verify expected structure + assert config_dict["deployment_type"] == "vllm" + assert config_dict["endpoint_name"] == "test-connector" + assert config_dict["api_tokens"] == [{"value": "test-token"}] + assert config_dict["checkpoint_path"] == "meta-llama/Llama-3.1-8B-Instruct" + + +if __name__ == "__main__": + pytest.main([__file__])