diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index a492e1cba8..a3c02745b4 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -66,8 +66,10 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: int): if config: ctx.obj[CTX_CONFIG_FILE] = config cfg = configuration.ConfigFile(config) + # Temporarily commented out to ensure proper output format when using --quiet flag in pyflyte register # Set here so that if someone has Config.auto() in their user code, the config here will get used. - if FLYTECTL_CONFIG_ENV_VAR in os.environ: + if FLYTECTL_CONFIG_ENV_VAR in os.environ and verbose > 0: + # Log when verbose > 0 to prevent breaking output format for pyflyte register's quiet or summamry-format flag logger.info( f"Config file arg {config} will override env var {FLYTECTL_CONFIG_ENV_VAR}: {os.environ[FLYTECTL_CONFIG_ENV_VAR]}" ) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 0bf56fed69..ebdea57abe 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -34,6 +34,9 @@ the root of your project, it finds the first folder that does not have a ``__init__.py`` file. """ +_original_secho = click.secho +_original_log_level = logger.level + @click.command("register", help=_register_help) @project_option_dec @@ -159,6 +162,20 @@ help="Skip errors during registration. This is useful when registering multiple packages and you want to skip " "errors for some packages.", ) +@click.option( + "--summary-format", + "-f", + required=False, + type=click.Choice(["json", "yaml"], case_sensitive=False), + default=None, + help="Output format for registration summary. Lists registered workflows, tasks, and launch plans. 'json' and 'yaml' supported.", +) +@click.option( + "--quiet", + is_flag=True, + default=False, + help="Suppress output messages, only displaying errors.", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -181,12 +198,25 @@ def register( resource_requests: typing.Optional[Resources], resource_limits: typing.Optional[Resources], skip_errors: bool, + summary_format: typing.Optional[str], + quiet: bool, ): """ see help """ + + if summary_format is not None: + quiet = True + + if quiet: + # Mute all secho output through monkey patching + click.secho = lambda *args, **kw: None + # Output only log at ERROR or CRITICAL level + logger.setLevel("ERROR") + # Set the relevant copy option if non_fast is set, this enables the individual file listing behavior # that the copy flag uses. + if non_fast: click.secho("The --non-fast flag is deprecated, please use --copy none instead", fg="yellow") if "--copy" in sys.argv: @@ -214,42 +244,49 @@ def register( "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", ) - # Use extra images in the config file if that file exists - config_file = ctx.obj.get(constants.CTX_CONFIG_FILE) - if config_file: - image_config = patch_image_config(config_file, image_config) - - click.secho( - f"Running pyflyte register from {os.getcwd()} " - f"with images {image_config} " - f"and image destination folder {destination_dir} " - f"on {len(package_or_module)} package(s) {package_or_module}", - dim=True, - ) - - # Create and save FlyteRemote, - remote = get_and_save_remote_with_click_context(ctx, project, domain, data_upload_location="flyte://data") - click.secho(f"Registering against {remote.config.platform.endpoint}") - repo.register( - project, - domain, - image_config, - output, - destination_dir, - service_account, - raw_data_prefix, - version, - deref_symlinks, - copy_style=copy, - package_or_module=package_or_module, - remote=remote, - env=env, - default_resources=ResourceSpec( - requests=resource_requests or Resources(), limits=resource_limits or Resources() - ), - dry_run=dry_run, - activate_launchplans=activate_launchplans, - skip_errors=skip_errors, - show_files=show_files, - verbosity=ctx.obj[constants.CTX_VERBOSE], - ) + try: + # Use extra images in the config file if that file exists + config_file = ctx.obj.get(constants.CTX_CONFIG_FILE) + if config_file: + image_config = patch_image_config(config_file, image_config) + + click.secho( + f"Running pyflyte register from {os.getcwd()} " + f"with images {image_config} " + f"and image destination folder {destination_dir} " + f"on {len(package_or_module)} package(s) {package_or_module}", + dim=True, + ) + + # Create and save FlyteRemote, + remote = get_and_save_remote_with_click_context(ctx, project, domain, data_upload_location="flyte://data") + click.secho(f"Registering against {remote.config.platform.endpoint}") + repo.register( + project, + domain, + image_config, + output, + destination_dir, + service_account, + raw_data_prefix, + version, + deref_symlinks, + copy_style=copy, + package_or_module=package_or_module, + remote=remote, + env=env, + summary_format=summary_format, + quiet=quiet, + default_resources=ResourceSpec( + requests=resource_requests or Resources(), limits=resource_limits or Resources() + ), + dry_run=dry_run, + activate_launchplans=activate_launchplans, + skip_errors=skip_errors, + show_files=show_files, + verbosity=ctx.obj[constants.CTX_VERBOSE], + ) + finally: + # Restore original secho + click.secho = _original_secho + logger.setLevel(_original_log_level) diff --git a/flytekit/clis/sdk_in_container/utils.py b/flytekit/clis/sdk_in_container/utils.py index fe906d7e5d..641f8f43fa 100644 --- a/flytekit/clis/sdk_in_container/utils.py +++ b/flytekit/clis/sdk_in_container/utils.py @@ -74,7 +74,10 @@ def validate_package(ctx, param, values): pkgs.extend(val.split(",")) else: pkgs.append(val) - logger.debug(f"Using packages: {pkgs}") + + if ctx.params.get("verbose", 0) > 0: + # Log when verbose > 0 to prevent breaking output format for pyflyte register's quiet or summamry-format flag + logger.debug(f"Using packages: {pkgs}") return pkgs diff --git a/flytekit/image_spec/noop_builder 2.py b/flytekit/image_spec/noop_builder 2.py new file mode 100644 index 0000000000..ee2dc0c664 --- /dev/null +++ b/flytekit/image_spec/noop_builder 2.py @@ -0,0 +1,31 @@ +from flytekit.image_spec.image_spec import ImageSpec, ImageSpecBuilder + + +class NoOpBuilder(ImageSpecBuilder): + """Noop image builder.""" + + builder_type = "noop" + + def should_build(self, image_spec: ImageSpec) -> bool: + """ + The build_image function of NoOpBuilder does not actually build a Docker image. + Since no Docker build process occurs, we do not need to check for Docker daemon + or existing images. Therefore, should_build should always return True. + + Args: + image_spec (ImageSpec): Image specification + + Returns: + bool: Always returns True + """ + return True + + def build_image(self, image_spec: ImageSpec) -> str: + if not isinstance(image_spec.base_image, str): + msg = "base_image must be a string to use the noop image builder" + raise ValueError(msg) + + import click + + click.secho(f"Using image: {image_spec.base_image}", fg="blue") + return image_spec.base_image diff --git a/flytekit/models/concurrency 2.py b/flytekit/models/concurrency 2.py new file mode 100644 index 0000000000..75e44bdbeb --- /dev/null +++ b/flytekit/models/concurrency 2.py @@ -0,0 +1,62 @@ +from flyteidl.admin import launch_plan_pb2 as _launch_plan_idl + +from flytekit.models import common as _common + + +class ConcurrencyLimitBehavior(object): + SKIP = _launch_plan_idl.CONCURRENCY_LIMIT_BEHAVIOR_SKIP + + @classmethod + def enum_to_string(cls, val): + """ + :param int val: + :rtype: Text + """ + if val == cls.SKIP: + return "SKIP" + else: + return "" + + +class ConcurrencyPolicy(_common.FlyteIdlEntity): + """ + Defines the concurrency policy for a launch plan. + """ + + def __init__(self, max_concurrency: int, behavior: ConcurrencyLimitBehavior = None): + self._max_concurrency = max_concurrency + self._behavior = behavior if behavior is not None else ConcurrencyLimitBehavior.SKIP + + @property + def max_concurrency(self) -> int: + """ + Maximum number of concurrent workflows allowed. + """ + return self._max_concurrency + + @property + def behavior(self) -> ConcurrencyLimitBehavior: + """ + Policy behavior when concurrency limit is reached. + """ + return self._behavior + + def to_flyte_idl(self) -> _launch_plan_idl.ConcurrencyPolicy: + """ + :rtype: flyteidl.admin.launch_plan_pb2.ConcurrencyPolicy + """ + return _launch_plan_idl.ConcurrencyPolicy( + max=self.max_concurrency, + behavior=self.behavior, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _launch_plan_idl.ConcurrencyPolicy) -> "ConcurrencyPolicy": + """ + :param flyteidl.admin.launch_plan_pb2.ConcurrencyPolicy pb2_object: + :rtype: ConcurrencyPolicy + """ + return cls( + max_concurrency=pb2_object.max, + behavior=pb2_object.behavior, + ) diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 975afdb445..79fcb20b14 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -1,12 +1,15 @@ import asyncio import functools +import json import os import tarfile import tempfile import typing +from contextlib import contextmanager from pathlib import Path import click +import yaml from rich import print as rprint from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings @@ -24,6 +27,9 @@ from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities from flytekit.tools.translator import FlyteControlPlaneEntity, Options +_original_secho = click.secho +_original_log_level = logger.level + class NoSerializableEntitiesError(Exception): pass @@ -239,6 +245,20 @@ def print_registration_status( rprint(f"[{color}]{state_ind} {name}: {i.name} (Failed)") +@contextmanager +def temporary_secho(): + """ + Temporarily restores the original click.secho function. + Useful when you need to temporarily disable quiet mode. + """ + current_secho = click.secho + try: + click.secho = _original_secho + yield + finally: + click.secho = current_secho + + def register( project: str, domain: str, @@ -253,7 +273,9 @@ def register( remote: FlyteRemote, copy_style: CopyFileDetection, env: typing.Optional[typing.Dict[str, str]], + summary_format: typing.Optional[str], default_resources: typing.Optional[ResourceSpec], + quiet: bool = False, dry_run: bool = False, activate_launchplans: bool = False, skip_errors: bool = False, @@ -264,127 +286,170 @@ def register( Temporarily, for fast register, specify both the fast arg as well as copy_style. fast == True with copy_style == None means use the old fast register tar'ring method. """ - detected_root = find_common_root(package_or_module) - click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") - - # Create serialization settings - # Todo: Rely on default Python interpreter for now, this will break custom Spark containers - serialization_settings = SerializationSettings( - project=project, - domain=domain, - version=version, - image_config=image_config, - fast_serialization_settings=None, # should probably add incomplete fast settings - env=env, - default_resources=default_resources, - ) - if not version and copy_style == CopyFileDetection.NO_COPY: - click.secho("Version is required.", fg="red") - return + # Mute all secho output through monkey patching + if quiet: + click.secho = lambda *args, **kw: None + logger.setLevel("ERROR") + + try: + detected_root = find_common_root(package_or_module) + click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") + + # Create serialization settings + # Todo: Rely on default Python interpreter for now, this will break custom Spark containers + serialization_settings = SerializationSettings( + project=project, + domain=domain, + version=version, + image_config=image_config, + fast_serialization_settings=None, # should probably add incomplete fast settings + env=env, + default_resources=default_resources, + ) + + if not version and copy_style == CopyFileDetection.NO_COPY: + click.secho("Version is required.", fg="red") + return - b = serialization_settings.new_builder() - serialization_settings = b.build() + b = serialization_settings.new_builder() + serialization_settings = b.build() - options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) + options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) - # Load all the entities - FlyteContextManager.push_context(remote.context) - serialization_settings.git_repo = _get_git_repo_url(str(detected_root)) - pkgs_and_modules = list_packages_and_modules(detected_root, list(package_or_module)) + # Load all the entities + FlyteContextManager.push_context(remote.context) + serialization_settings.git_repo = _get_git_repo_url(str(detected_root)) + pkgs_and_modules = list_packages_and_modules(detected_root, list(package_or_module)) - # NB: The change here is that the loading of user code _cannot_ depend on fast register information (the computed - # version, upload native url, hash digest, etc.). - serialize_load_only(pkgs_and_modules, serialization_settings, str(detected_root)) + # NB: The change here is that the loading of user code _cannot_ depend on fast register information (the computed + # version, upload native url, hash digest, etc.). + serialize_load_only(pkgs_and_modules, serialization_settings, str(detected_root)) - # Fast registration is handled after module loading - if copy_style != CopyFileDetection.NO_COPY: - md5_bytes, native_url = remote.fast_package( - detected_root, - deref_symlinks, - output, - options=fast_registration.FastPackageOptions([], copy_style=copy_style, show_files=show_files), + # Fast registration is handled after module loading + if copy_style != CopyFileDetection.NO_COPY: + md5_bytes, native_url = remote.fast_package( + detected_root, + deref_symlinks, + output, + options=fast_registration.FastPackageOptions([], copy_style=copy_style, show_files=show_files), + ) + # update serialization settings from fast register output + fast_serialization_settings = FastSerializationSettings( + enabled=True, + destination_dir=destination_dir, + distribution_location=native_url, + ) + serialization_settings.fast_serialization_settings = fast_serialization_settings + if not version: + images, pod_templates = [], [] + for entity in FlyteEntities.entities.copy(): + if isinstance(entity, PythonTask): + images.extend(FlyteRemote._get_image_names(entity)) + images.extend(FlyteRemote._get_pod_template_hash(entity)) + version = remote._version_from_hash( + md5_bytes, serialization_settings, service_account, raw_data_prefix, *images, *pod_templates + ) # noqa + serialization_settings.version = version + click.secho(f"Computed version is {version}", fg="yellow") + + registrable_entities = serialize_get_control_plane_entities( + serialization_settings, str(detected_root), options, is_registration=True ) - # update serialization settings from fast register output - fast_serialization_settings = FastSerializationSettings( - enabled=True, - destination_dir=destination_dir, - distribution_location=native_url, + + click.secho( + f"Serializing and registering {len(registrable_entities)} flyte entities", + fg="green", ) - serialization_settings.fast_serialization_settings = fast_serialization_settings - if not version: - images, pod_templates = [], [] - for entity in FlyteEntities.entities.copy(): - if isinstance(entity, PythonTask): - images.extend(FlyteRemote._get_image_names(entity)) - images.extend(FlyteRemote._get_pod_template_hash(entity)) - version = remote._version_from_hash( - md5_bytes, serialization_settings, service_account, raw_data_prefix, *images, *pod_templates - ) # noqa - serialization_settings.version = version - click.secho(f"Computed version is {version}", fg="yellow") - - registrable_entities = serialize_get_control_plane_entities( - serialization_settings, str(detected_root), options, is_registration=True - ) - click.secho( - f"Serializing and registering {len(registrable_entities)} flyte entities", - fg="green", - ) - FlyteContextManager.pop_context() - if len(registrable_entities) == 0: - click.secho("No Flyte entities were detected. Aborting!", fg="red") - return + FlyteContextManager.pop_context() + if len(registrable_entities) == 0: + click.secho("No Flyte entities were detected. Aborting!", fg="red") + return - def _raw_register(cp_entity: FlyteControlPlaneEntity): - is_lp = False - if isinstance(cp_entity, launch_plan.LaunchPlan): - og_id = cp_entity.id - is_lp = True - else: - og_id = cp_entity.template.id - try: - if not dry_run: - try: - i = remote.raw_register( - cp_entity, serialization_settings, version=version, create_default_launchplan=False - ) - console_url = remote.generate_console_url(i) - print_activation_message = False - if is_lp: - if activate_launchplans: - remote.activate_launchplan(i) - print_activation_message = True - if cp_entity.should_auto_activate: - print_activation_message = True - print_registration_status( - i, console_url=console_url, verbosity=verbosity, activation=print_activation_message - ) - - except Exception as e: - if not skip_errors: - raise e - print_registration_status(og_id, success=False) + def _raw_register(cp_entity: FlyteControlPlaneEntity): + is_lp = False + if isinstance(cp_entity, launch_plan.LaunchPlan): + og_id = cp_entity.id + is_lp = True else: - print_registration_status(og_id, dry_run=True) - except RegistrationSkipped: - print_registration_status(og_id, success=False) - - async def _register(entities: typing.List[task.TaskSpec]): - loop = asyncio.get_running_loop() - tasks = [] - for entity in entities: - tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) - await asyncio.gather(*tasks) - return - - # concurrent register - cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities)) - asyncio.run(_register(cp_task_entities)) - # serial register - cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities)) - for entity in cp_other_entities: - _raw_register(entity) - - click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green") + og_id = cp_entity.template.id + + result = { + "id": og_id.name, + "type": og_id.resource_type_name(), + "version": og_id.version, + "status": "skipped", # default status + } + + try: + if not dry_run: + try: + i = remote.raw_register( + cp_entity, serialization_settings, version=version, create_default_launchplan=False + ) + console_url = remote.generate_console_url(i) + print_activation_message = False + if is_lp: + if activate_launchplans: + remote.activate_launchplan(i) + print_activation_message = True + if cp_entity.should_auto_activate: + print_activation_message = True + if not quiet: + print_registration_status( + i, console_url=console_url, verbosity=verbosity, activation=print_activation_message + ) + result["status"] = "success" + + except Exception as e: + if not skip_errors: + raise e + if not quiet: + print_registration_status(og_id, success=False) + result["status"] = "failed" + else: + if not quiet: + print_registration_status(og_id, dry_run=True) + except RegistrationSkipped: + if not quiet: + print_registration_status(og_id, success=False) + result["status"] = "skipped" + + return result + + async def _register(entities: typing.List[task.TaskSpec]): + loop = asyncio.get_running_loop() + tasks = [] + for entity in entities: + tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) + results = await asyncio.gather(*tasks) + return results + + # concurrent register + cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities)) + task_results = asyncio.run(_register(cp_task_entities)) + # serial register + cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities)) + other_results = [] + for entity in cp_other_entities: + other_results.append(_raw_register(entity)) + + all_results = task_results + other_results + + click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green") + + if summary_format is not None: + supported_format = {"json", "yaml"} + if summary_format not in supported_format: + raise ValueError(f"Unsupported file format: {summary_format}") + + with temporary_secho(): + if summary_format == "json": + click.secho(json.dumps(all_results, indent=2)) + elif summary_format == "yaml": + click.secho(yaml.dump(all_results)) + finally: + # Restore original secho + click.secho = _original_secho + logger.setLevel(_original_log_level) diff --git a/plugins/community/__init__ 2.py b/plugins/community/__init__ 2.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent 2.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent 2.py new file mode 100644 index 0000000000..d254ec5960 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent 2.py @@ -0,0 +1,143 @@ +import re +from typing import Optional + +from flyteidl.core.execution_pb2 import TaskExecution +from typing_extensions import Annotated + +from flytekit import FlyteContextManager, kwtypes +from flytekit.core import context_manager +from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import ( + AgentRegistry, + Resource, + SyncAgentBase, +) +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +from .boto3_mixin import Boto3AgentMixin, CustomException + + +# https://github.com/flyteorg/flyte/issues/4505 +def convert_floats_with_no_fraction_to_ints(data): + if isinstance(data, dict): + for key, value in data.items(): + data[key] = convert_floats_with_no_fraction_to_ints(value) + elif isinstance(data, list): + for i, item in enumerate(data): + data[i] = convert_floats_with_no_fraction_to_ints(item) + elif isinstance(data, float) and data.is_integer(): + return int(data) + return data + + +class BotoAgent(SyncAgentBase): + """A general purpose boto3 agent that can be used to call any boto3 method.""" + + name = "Boto Agent" + + def __init__(self): + super().__init__(task_type_name="boto") + + async def do( + self, + task_template: TaskTemplate, + output_prefix: str, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> Resource: + custom = task_template.custom + + service = custom.get("service") + raw_config = custom.get("config") + convert_floats_with_no_fraction_to_ints(raw_config) + config = raw_config + region = custom.get("region") + method = custom.get("method") + images = custom.get("images") + + boto3_object = Boto3AgentMixin(service=service, region=region) + + result = None + try: + result, idempotence_token = await boto3_object._call( + method=method, + config=config, + images=images, + inputs=inputs, + ) + except CustomException as e: + original_exception = e.original_exception + error_code = original_exception.response["Error"]["Code"] + error_message = original_exception.response["Error"]["Message"] + + if error_code == "ValidationException" and "Cannot create already existing" in error_message: + arn = re.search( + r"arn:aws:[a-zA-Z0-9\-]+:[a-zA-Z0-9\-]+:\d+:[a-zA-Z0-9\-\/]+", + error_message, + ).group(0) + if arn: + arn_result = None + if method == "create_model": + arn_result = {"ModelArn": arn} + elif method == "create_endpoint_config": + arn_result = {"EndpointConfigArn": arn} + + return Resource( + phase=TaskExecution.SUCCEEDED, + outputs={ + "result": arn_result if arn_result else {"result": f"Entity already exists {arn}."}, + "idempotence_token": e.idempotence_token, + }, + ) + else: + return Resource( + phase=TaskExecution.SUCCEEDED, + outputs={ + "result": {"result": "Entity already exists."}, + "idempotence_token": e.idempotence_token, + }, + ) + else: + # Re-raise the exception if it's not the specific error we're handling + raise e + except Exception as e: + raise e + + outputs = {"result": {"result": None}} + if result: + truncated_result = None + if method == "create_model": + truncated_result = {"ModelArn": result.get("ModelArn")} + elif method == "create_endpoint_config": + truncated_result = {"EndpointConfigArn": result.get("EndpointConfigArn")} + + ctx = FlyteContextManager.current_context() + builder = ctx.with_file_access( + FileAccessProvider( + local_sandbox_dir=ctx.file_access.local_sandbox_dir, + raw_output_prefix=output_prefix, + data_config=ctx.file_access.data_config, + ) + ) + with context_manager.FlyteContextManager.with_context(builder) as new_ctx: + outputs = { + "result": TypeEngine.to_literal( + new_ctx, + truncated_result if truncated_result else result, + Annotated[dict, kwtypes(allow_pickle=True)], + TypeEngine.to_literal_type(dict), + ), + "idempotence_token": TypeEngine.to_literal( + new_ctx, + idempotence_token, + str, + TypeEngine.to_literal_type(str), + ), + } + + return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) + + +AgentRegistry.register(BotoAgent()) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent 2.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent 2.py new file mode 100644 index 0000000000..ff34f7a580 --- /dev/null +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent 2.py @@ -0,0 +1,97 @@ +import datetime +from dataclasses import dataclass +from typing import Dict, Optional + +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from google.cloud import bigquery + +from flytekit import FlyteContextManager, StructuredDataset, logger +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +pythonTypeToBigQueryType: Dict[type, str] = { + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes + list: "ARRAY", + bool: "BOOL", + bytes: "BYTES", + datetime.datetime: "DATETIME", + float: "FLOAT64", + int: "INT64", + str: "STRING", +} + + +@dataclass +class BigQueryMetadata(ResourceMeta): + job_id: str + project: str + location: str + + +class BigQueryAgent(AsyncAgentBase): + name = "Bigquery Agent" + + def __init__(self): + super().__init__(task_type_name="bigquery_query_job_task", metadata_type=BigQueryMetadata) + + def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> BigQueryMetadata: + job_config = None + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + logger.info(f"Create BigQuery job config with inputs: {native_inputs}") + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ScalarQueryParameter(name, pythonTypeToBigQueryType[python_interface_inputs[name]], val) + for name, val in native_inputs.items() + ] + ) + + custom = task_template.custom + project = custom["ProjectID"] + location = custom["Location"] + client = bigquery.Client(project=project, location=location) + query_job = client.query(task_template.sql.statement, job_config=job_config) + + return BigQueryMetadata(job_id=str(query_job.job_id), location=location, project=project) + + def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource: + client = bigquery.Client() + log_link = TaskLog( + uri=f"https://console.cloud.google.com/bigquery?project={resource_meta.project}&j=bq:{resource_meta.location}:{resource_meta.job_id}&page=queryresults", + name="BigQuery Console", + ) + + job = client.get_job(resource_meta.job_id, resource_meta.project, resource_meta.location) + if job.errors: + logger.error("failed to run BigQuery job with error:", job.errors.__str__()) + return Resource(phase=TaskExecution.FAILED, message=job.errors.__str__(), log_links=[log_link]) + + cur_phase = convert_to_flyte_phase(str(job.state)) + res = None + + if cur_phase == TaskExecution.SUCCEEDED: + dst = job.destination + if dst: + output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}" + res = {"results": StructuredDataset(uri=output_location)} + + return Resource(phase=cur_phase, message=str(job.state), log_links=[log_link], outputs=res) + + def delete(self, resource_meta: BigQueryMetadata, **kwargs): + client = bigquery.Client() + client.cancel_job(resource_meta.job_id, resource_meta.project, resource_meta.location) + + +AgentRegistry.register(BigQueryAgent()) diff --git a/plugins/flytekit-dgxc-lepton/.dockerignore 2 b/plugins/flytekit-dgxc-lepton/.dockerignore 2 new file mode 100644 index 0000000000..a9f064d303 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/.dockerignore 2 @@ -0,0 +1,14 @@ +# Ignore deployment files from Docker build +deployment/ +examples/ +tests/ +*.md +.git +.gitignore +__pycache__/ +*.pyc +*.pyo +*.egg-info/ +build/ +dist/ +.pytest_cache/ diff --git a/plugins/flytekit-dgxc-lepton/Dockerfile 2.connector b/plugins/flytekit-dgxc-lepton/Dockerfile 2.connector new file mode 100644 index 0000000000..7a22f947f5 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/Dockerfile 2.connector @@ -0,0 +1,30 @@ +# Dockerfile for Lepton Agent +# This creates a standalone agent service for handling Lepton endpoint operations + +FROM ghcr.io/flyteorg/flytekit:py3.12-latest + +# Switch to root to handle file permissions +USER root + +# Install git for leptonai installation +RUN apt-get update && apt-get install -y git && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Copy and install our plugin +COPY --chown=flytekit:flytekit . /home/flytekit/dgxc-lepton-plugin +WORKDIR /home/flytekit/dgxc-lepton-plugin + +# Clean any existing build artifacts +RUN rm -rf *.egg-info build dist + +# Install the plugin +RUN pip install . + +# Switch back to flytekit user +USER flytekit + +# Expose agent service port +EXPOSE 8000 + +# Start the connector service +# The connector service will handle Lepton endpoint lifecycle operations via gRPC +CMD ["pyflyte", "serve", "connector", "--port", "8000", "--modules", "flytekitplugins.dgxc_lepton"] diff --git a/plugins/flytekit-dgxc-lepton/README 2.md b/plugins/flytekit-dgxc-lepton/README 2.md new file mode 100644 index 0000000000..b3fd08a5bb --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/README 2.md @@ -0,0 +1,375 @@ +# Flytekit DGXC Lepton Plugin + +A professional Flytekit plugin that enables seamless deployment and management of AI inference endpoints using Lepton AI infrastructure within Flyte workflows. + +## Overview + +This plugin provides: +- **Unified Task API** for deployment and management of Lepton AI endpoints +- **Type-safe configuration** with consolidated dataclasses and IDE support +- **Multiple endpoint engines**: VLLM, SGLang, NIM, and custom containers +- **Unified configuration classes** for scaling, environment, and mounts + +## Installation + +```bash +pip install flytekitplugins-dgxc-lepton +``` + +## Quick Start + +```python +from flytekit import workflow +from flytekitplugins.dgxc_lepton import ( + lepton_endpoint_deployment_task, lepton_endpoint_deletion_task, LeptonEndpointConfig, + EndpointEngineConfig, EnvironmentConfig, ScalingConfig +) + +@workflow +def inference_workflow() -> str: + """Deploy Llama model using VLLM and return endpoint URL.""" + + # Complete configuration in one place + config = LeptonEndpointConfig( + endpoint_name="my-llama-endpoint", + resource_shape="gpu.1xh200", + node_group="your-node-group", + endpoint_config=EndpointEngineConfig.vllm( + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + served_model_name="llama-3.1-8b-instruct", + ), + environment=EnvironmentConfig.create( + LOG_LEVEL="INFO", + secrets={"HF_TOKEN": "hf-secret"} + ), + scaling=ScalingConfig.traffic(min_replicas=1, max_replicas=2), + ) + + # Deploy endpoint and return URL + return lepton_endpoint_deployment_task(config=config) +``` + +## API Reference + +### Core Components + +#### `lepton_endpoint_deployment_task(config: LeptonEndpointConfig) -> str` +Main function for Lepton AI endpoint deployment. + +**Parameters:** +- `config`: Complete endpoint configuration +- `task_name`: Optional custom task name + +**Returns:** +- Endpoint URL for successful deployment + +#### `lepton_endpoint_deletion_task(endpoint_name: str, ...) -> str` +Function for Lepton AI endpoint deletion. + +**Parameters:** +- `endpoint_name`: Name of the endpoint to delete +- `task_name`: Optional custom task name + +**Returns:** +- Success message confirming deletion + +#### `LeptonEndpointConfig` +Unified configuration for all Lepton endpoint operations. + +**Required Fields:** +- `endpoint_name`: Name of the endpoint +- `resource_shape`: Hardware resource specification (e.g., "gpu.1xh200") +- `node_group`: Target node group for deployment +- `endpoint_config`: Engine-specific configuration + +**Optional Fields:** +- `scaling`: Auto-scaling configuration +- `environment`: Environment variables and secrets +- `mounts`: Storage mount configurations +- `api_token`/`api_token_secret`: Authentication +- `image_pull_secrets`: Container registry secrets +- `endpoint_readiness_timeout`: Deployment timeout + +### Endpoint Engine Configuration + +#### `EndpointEngineConfig` +Unified configuration for different inference engines. + + +##### VLLM Engine +```python +EndpointEngineConfig.vllm( + image="vllm/vllm-openai:latest", + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + served_model_name="default-model", + tensor_parallel_size=1, + pipeline_parallel_size=1, + data_parallel_size=1, + extra_args="--max-model-len 4096", + port=8000 +) +``` + +##### SGLang Engine +```python +EndpointEngineConfig.sglang( + image="lmsysorg/sglang:latest", + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + tensor_parallel_size=1, + data_parallel_size=1, + extra_args="--context-length 4096", + port=30000 +) +``` + +##### NVIDIA NIM +```python +EndpointEngineConfig.nim( + image="nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest", + port=8000 +) +``` + +##### Custom Container +```python +EndpointEngineConfig.custom( + image="python:3.11-slim", + command=["/bin/bash", "-c", "python3 -m http.server 8080"], + port=8080 +) +``` + +### Scaling Configuration + +#### `ScalingConfig` +Unified auto-scaling configuration with enforced single strategy. + + +##### Traffic-based Scaling +```python +ScalingConfig.traffic( + min_replicas=1, + max_replicas=5, + timeout=1800 # Scale down after 30 min of no traffic +) +``` + +##### GPU Utilization Scaling +```python +ScalingConfig.gpu( + target_utilization=80, # Target 80% GPU utilization + min_replicas=1, + max_replicas=10 +) +``` + +##### QPM (Queries Per Minute) Scaling +```python +ScalingConfig.qpm( + target_qpm=100.5, # Target queries per minute + min_replicas=2, + max_replicas=8 +) +``` + +### Environment Configuration + +#### `EnvironmentConfig` +Unified configuration for environment variables and secrets. + +**Factory Methods:** + +##### Environment Variables Only +```python +EnvironmentConfig.from_env( + LOG_LEVEL="DEBUG", + MODEL_PATH="/models", + CUDA_VISIBLE_DEVICES="0,1" +) +``` + +##### Secrets Only +```python +EnvironmentConfig.from_secrets( + HF_TOKEN="hf-secret", + NGC_API_KEY="ngc-secret" +) +``` + +##### Mixed Configuration +```python +EnvironmentConfig.create( + LOG_LEVEL="INFO", + MODEL_PATH="/models", + secrets={ + "HF_TOKEN": "hf-secret", + "NGC_API_KEY": "ngc-secret" + } +) +``` + +### Mount Configuration + +#### `MountReader` +Simplified NFS mount configuration. + +```python +MountReader.node_nfs( + ("/shared-storage/models", "/opt/models"), + ("/shared-storage/data", "/opt/data"), + ("/shared-storage/logs", "/opt/logs", False), # Disabled mount + storage_name="production-nfs" # Custom storage name +) +``` + +## Complete Examples + +### VLLM Deployment with Auto-scaling + +```python +from flytekit import workflow +from flytekitplugins.dgxc_lepton import ( + lepton_endpoint_deployment_task, LeptonEndpointConfig, + EndpointEngineConfig, EnvironmentConfig, ScalingConfig, MountReader +) + +@workflow +def deploy_vllm_with_scaling() -> str: + """Deploy VLLM with traffic-based auto-scaling.""" + + config = LeptonEndpointConfig( + endpoint_name="vllm-llama-3.1-8b", + resource_shape="gpu.1xh200", + node_group="inference-nodes", + endpoint_config=EndpointEngineConfig.vllm( + checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + served_model_name="llama-3.1-8b-instruct", + tensor_parallel_size=1, + extra_args="--max-model-len 8192 --enable-chunked-prefill" + ), + environment=EnvironmentConfig.create( + LOG_LEVEL="INFO", + CUDA_VISIBLE_DEVICES="0", + secrets={"HF_TOKEN": "hf-secret"} + ), + scaling=ScalingConfig.traffic( + min_replicas=1, + max_replicas=3, + timeout=1800 + ), + mounts=MountReader.node_nfs( + ("/shared-storage/models", "/opt/models"), + ("/shared-storage/cache", "/root/.cache") + ), + api_token_secret="lepton-api-token", + image_pull_secrets=["hf-secret"], + endpoint_readiness_timeout=600 + ) + + return lepton_endpoint_deployment_task(config=config) +``` + +### NIM Deployment with QPM Scaling + +```python +@workflow +def deploy_nim_with_qpm_scaling() -> str: + """Deploy NVIDIA NIM with QPM-based scaling.""" + + config = LeptonEndpointConfig( + endpoint_name="nemotron-super-reasoning", + resource_shape="gpu.1xh200", + node_group="nim-nodes", + endpoint_config=EndpointEngineConfig.nim( + image="nvcr.io/nim/nvidia/llama-3_3-nemotron-super-49b-v1_5:latest" + ), + environment=EnvironmentConfig.create( + OMPI_ALLOW_RUN_AS_ROOT="1", + secrets={"NGC_API_KEY": "ngc-secret"} + ), + scaling=ScalingConfig.qpm( + target_qpm=2.5, + min_replicas=1, + max_replicas=3 + ), + image_pull_secrets=["ngc-secret"], + api_token="UNIQUE_ENDPOINT_TOKEN" + ) + + return lepton_endpoint_deployment_task(config=config) +``` + +### Custom Container Deployment + +```python +@workflow +def deploy_custom_service() -> str: + """Deploy custom inference service.""" + + config = LeptonEndpointConfig( + endpoint_name="custom-inference-api", + resource_shape="cpu.large", + node_group="cpu-nodes", + endpoint_config=EndpointEngineConfig.custom( + image="my-registry/inference-api:v1.0", + command=["python", "app.py"], + port=8080 + ), + environment=EnvironmentConfig.from_env( + LOG_LEVEL="DEBUG", + API_VERSION="v1", + WORKERS="4" + ), + scaling=ScalingConfig.gpu( + target_utilization=70, + min_replicas=2, + max_replicas=6 + ) + ) + + return lepton_endpoint_deployment_task(config=config) +``` +## Configuration Requirements + +Replace these placeholders with your actual values: +- ``: Your Kubernetes node group for GPU workloads +- ``: Your NGC registry pull secret name +- `/shared-storage/model-cache/*`: Your shared storage paths for model caching +- `NGC_API_KEY`: Your NGC API key secret name +- `HUGGING_FACE_HUB_TOKEN_read`: Your HuggingFace token secret name + +## Monitoring & Debugging + +```bash +# Monitor connector logs +kubectl logs -n flyte deployment/lepton-connector --follow + +# Check Lepton console (URLs auto-generated in Flyte execution view) + +# List recent executions +pyflyte get executions -p flytesnacks -d development --limit 5 + +## Development + +### Running Tests + +```bash +pytest tests/test_lepton.py -v +``` + +### Plugin Registration + +The plugin automatically registers with Flytekit's dynamic plugin loading system: + +```python +# Automatic registration enables this usage pattern +task = LeptonEndpointDeploymentTask(config=config) +``` + +## Support + +For issues, questions, or contributions, please refer to the Flytekit documentation and Lepton AI platform documentation. + +## License + +This plugin follows the same license as Flytekit. diff --git a/plugins/flytekit-dgxc-lepton/setup 2.py b/plugins/flytekit-dgxc-lepton/setup 2.py new file mode 100644 index 0000000000..8a507b1ca3 --- /dev/null +++ b/plugins/flytekit-dgxc-lepton/setup 2.py @@ -0,0 +1,35 @@ +from setuptools import setup + +PLUGIN_NAME = "dgxc-lepton" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.9.1,<2.0.0", "leptonai"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="Anshul Jindal", + author_email="ansjindal@nvidia.com", + description="DGXC Lepton Flytekit plugin for inference endpoints", + long_description="Flytekit DGXC Lepton Plugin - AI inference endpoints using Lepton AI infrastructure", + long_description_content_type="text/markdown", + packages=["flytekitplugins.dgxc_lepton"], + install_requires=plugin_requires, + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": ["dgxc_lepton=flytekitplugins.dgxc_lepton"]}, +) diff --git a/plugins/flytekit-k8sdataservice/flytekitplugins/k8sdataservice/agent 2.py b/plugins/flytekit-k8sdataservice/flytekitplugins/k8sdataservice/agent 2.py new file mode 100644 index 0000000000..199db535bd --- /dev/null +++ b/plugins/flytekit-k8sdataservice/flytekitplugins/k8sdataservice/agent 2.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass +from typing import Optional + +from flyteidl.core.execution_pb2 import TaskExecution +from flytekitplugins.k8sdataservice.k8s.manager import K8sManager +from flytekitplugins.k8sdataservice.task import DataServiceConfig + +from flytekit import logger +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +@dataclass +class DataServiceMetadata(ResourceMeta): + dataservice_config: DataServiceConfig + name: str + + +class DataServiceAgent(AsyncAgentBase): + name = "K8s DataService Async Agent" + + def __init__(self): + self.k8s_manager = K8sManager() + super().__init__(task_type_name="dataservicetask", metadata_type=DataServiceMetadata) + self.config = None + + def create( + self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs + ) -> DataServiceMetadata: + graph_engine_config = task_template.custom + self.k8s_manager.set_configs(graph_engine_config) + logger.info(f"Loaded agent config file {self.config}") + existing_release_name = graph_engine_config.get("ExistingReleaseName", None) + logger.info(f"The existing data service release name is {existing_release_name}") + + name = "" + if existing_release_name is None or existing_release_name == "": + logger.info("Creating K8s data service resources...") + name = self.k8s_manager.create_data_service() + logger.info(f'Data service {name} with image {graph_engine_config["Image"]} completed') + else: + name = existing_release_name + logger.info(f"User configs to use the existing data service release name: {name}.") + + dataservice_config = DataServiceConfig( + Name=graph_engine_config.get("Name", None), + Image=graph_engine_config["Image"], + Command=graph_engine_config["Command"], + Cluster=graph_engine_config["Cluster"], + ExistingReleaseName=graph_engine_config.get("ExistingReleaseName", None), + ) + metadata = DataServiceMetadata( + dataservice_config=dataservice_config, + name=name, + ) + logger.info(f"Created DataService metadata {metadata}") + return metadata + + def get(self, resource_meta: DataServiceMetadata) -> Resource: + logger.info("K8s Data Service get is called") + data = resource_meta.dataservice_config + data_dict = data.__dict__ if isinstance(data, DataServiceConfig) else data + logger.info(f"The data_dict is {data_dict}") + self.k8s_manager.set_configs(data_dict) + name = data.Name + logger.info(f"Get the stateful set name {name}") + + k8s_status = self.k8s_manager.check_stateful_set_status(name) + flyte_state = None + if k8s_status in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]: + flyte_state = TaskExecution.FAILED + elif k8s_status in ["done", "succeeded", "success"]: + flyte_state = TaskExecution.SUCCEEDED + elif k8s_status in ["running", "terminating", "pending"]: + flyte_state = TaskExecution.RUNNING + else: + logger.error(f"Unrecognized state: {k8s_status}") + outputs = { + "data_service_name": name, + } + # TODO: Add logs for StatefulSet. + return Resource(phase=flyte_state, outputs=outputs) + + def delete(self, resource_meta: DataServiceMetadata): + logger.info("DataService delete is called") + data = resource_meta.dataservice_config + + data_dict = data.__dict__ if isinstance(data, DataServiceConfig) else data + self.k8s_manager.set_configs(data_dict) + + name = resource_meta.name + logger.info(f"To delete the DataService (e.g., StatefulSet and Service) with name {name}") + self.k8s_manager.delete_stateful_set(name) + self.k8s_manager.delete_service(name) + + +AgentRegistry.register(DataServiceAgent()) diff --git a/plugins/flytekit-k8sdataservice/tests/k8sdataservice/test_agent 2.py b/plugins/flytekit-k8sdataservice/tests/k8sdataservice/test_agent 2.py new file mode 100644 index 0000000000..0db9877c26 --- /dev/null +++ b/plugins/flytekit-k8sdataservice/tests/k8sdataservice/test_agent 2.py @@ -0,0 +1,319 @@ +import json +from dataclasses import asdict +from datetime import timedelta +from unittest.mock import patch, MagicMock +import grpc +from google.protobuf import json_format +from flytekitplugins.k8sdataservice.task import DataServiceConfig +from flytekitplugins.k8sdataservice.agent import DataServiceMetadata +from google.protobuf.struct_pb2 import Struct +from flyteidl.core.execution_pb2 import TaskExecution +import flytekit.models.interface as interface_models +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import TaskTemplate + + +cmd = ["command", "args"] + + +def create_test_task_metadata() -> task.TaskMetadata: + return task.TaskMetadata( + discoverable= True, + runtime=task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timeout=timedelta(days=1), + retries=literals.RetryStrategy(3), + interruptible=True, + discovery_version="0.1.1b0", + deprecated_error_message="This is deprecated!", + cache_serializable=True, + pod_template_name="A", + cache_ignore_input_vars=(), + ) + + +def create_test_setup(original_name: str = "gnn-1234", existing_release_name: str = "gnn-2345"): + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = create_test_task_metadata() + s = Struct() + if existing_release_name != "": + s.update({ + "Name": original_name, + "Image": "image", + "Command": cmd, + "Cluster": "ei-dev2", + "ExistingReleaseName": existing_release_name, + }) + else: + s.update({ + "Name": original_name, + "Image": "image", + "Command": cmd, + "Cluster": "ei-dev2", + }) + task_config = json_format.MessageToDict(s) + return task_id, task_metadata, task_config + + +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.create_data_service", return_value="gnn-1234") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.check_stateful_set_status", return_value="succeeded") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_stateful_set") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_service") +def test_gnn_agent(mock_delete_service, mock_delete_stateful_set, mock_check_status, mock_create_data_service): + agent = AgentRegistry.get_agent("dataservicetask") + task_id, task_metadata, task_config = create_test_setup(existing_release_name="") + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=interfaces, + type="dataservicetask", + ) + + expected_resource_metadata = DataServiceMetadata( + dataservice_config=DataServiceConfig(Name="gnn-1234", Image="image", Command=cmd, Cluster="ei-dev2"), + name="gnn-1234") + # Test create method + res_resource_metadata = agent.create(dummy_template, task_inputs) + assert res_resource_metadata == expected_resource_metadata + mock_create_data_service.assert_called_once() + + # Test get method + res = agent.get(res_resource_metadata) + assert res.phase == TaskExecution.SUCCEEDED + assert res.outputs.get("data_service_name") == 'gnn-1234' + mock_check_status.assert_called_once_with("gnn-1234") + + # # Test delete method + agent.delete(res_resource_metadata) + mock_delete_stateful_set.assert_called_once_with("gnn-1234") + mock_delete_service.assert_called_once_with("gnn-1234") + + +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.create_data_service", return_value="gnn-1234") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.check_stateful_set_status", return_value="succeeded") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_stateful_set") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_service") +def test_gnn_agent_reuse_data_service(mock_delete_service, mock_delete_stateful_set, mock_check_status, mock_create_data_service): + agent = AgentRegistry.get_agent("dataservicetask") + task_id, task_metadata, task_config = create_test_setup(original_name="gnn-2345", existing_release_name="gnn-2345") + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=interfaces, + type="dataservicetask", + ) + + expected_resource_metadata = DataServiceMetadata( + dataservice_config=DataServiceConfig( + Name="gnn-2345", Image="image", Command=cmd, Cluster="ei-dev2", ExistingReleaseName="gnn-2345"), + name="gnn-2345") + + # Test create method, and create_data_service should have not been called + res_resource_metadata = agent.create(dummy_template, task_inputs) + assert res_resource_metadata == expected_resource_metadata + mock_create_data_service.assert_not_called() + + # Test get method + res = agent.get(res_resource_metadata) + assert res.phase == TaskExecution.SUCCEEDED + mock_check_status.assert_called_once_with("gnn-2345") + + # # Test delete method + agent.delete(res_resource_metadata) + mock_delete_stateful_set.assert_called_once_with("gnn-2345") + mock_delete_service.assert_called_once_with("gnn-2345") + + +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.create_data_service", return_value="gnn-1234") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.check_stateful_set_status", return_value="running") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_stateful_set") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_service") +def test_gnn_agent_status(mock_delete_service, mock_delete_stateful_set, mock_check_status, mock_create_data_service): + agent = AgentRegistry.get_agent("dataservicetask") + task_id, task_metadata, task_config = create_test_setup(original_name="gnn-2345", existing_release_name="gnn-2345") + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=interfaces, + type="dataservicetask", + ) + + expected_resource_metadata = DataServiceMetadata( + dataservice_config=DataServiceConfig( + Name="gnn-2345", Image="image", Command=cmd, Cluster="ei-dev2", ExistingReleaseName="gnn-2345"), + name="gnn-2345") + # Test create method, and create_data_service should have not been called + res_resource_metadata = agent.create(dummy_template, task_inputs) + assert res_resource_metadata == expected_resource_metadata + mock_create_data_service.assert_not_called() + + # Test get method + res = agent.get(res_resource_metadata) + assert res.phase == TaskExecution.RUNNING + mock_check_status.assert_called_once_with("gnn-2345") + + # # Test delete methods are not called + agent.delete(res_resource_metadata) + mock_delete_stateful_set.assert_called_once_with("gnn-2345") + mock_delete_service.assert_called_once_with("gnn-2345") + + +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.create_data_service", return_value="gnn-1234") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.check_stateful_set_status", return_value="succeeded") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_stateful_set") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_service") +def test_gnn_agent_no_configmap(mock_delete_service, mock_delete_stateful_set, mock_check_status, mock_create_data_service): + agent = AgentRegistry.get_agent("dataservicetask") + task_id, task_metadata, task_config = create_test_setup(original_name="gnn-2345", existing_release_name="gnn-2345") + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=interfaces, + type="dataservicetask", + ) + + expected_resource_metadata = DataServiceMetadata( + dataservice_config=DataServiceConfig( + Name="gnn-2345", Image="image", Command=cmd, Cluster="ei-dev2", ExistingReleaseName="gnn-2345"), + name="gnn-2345") + + # Test create method, and create_data_service should have not been called + res_resource_metadata = agent.create(dummy_template, task_inputs) + assert res_resource_metadata == expected_resource_metadata + mock_create_data_service.assert_not_called() + + # Test get method + res = agent.get(res_resource_metadata) + assert res.phase == TaskExecution.SUCCEEDED + mock_check_status.assert_called_once_with("gnn-2345") + + # # Test delete methods are not called + agent.delete(res_resource_metadata) + mock_delete_stateful_set.assert_called_once_with("gnn-2345") + mock_delete_service.assert_called_once_with("gnn-2345") + + +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.create_data_service", return_value="gnn-1234") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.check_stateful_set_status", return_value="pending") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_stateful_set") +@patch("flytekitplugins.k8sdataservice.agent.K8sManager.delete_service") +def test_gnn_agent_status_failed(mock_delete_service, mock_delete_stateful_set, mock_check_status, mock_create_data_service): + agent = AgentRegistry.get_agent("dataservicetask") + task_id, task_metadata, task_config = create_test_setup(original_name="gnn-2345", existing_release_name="gnn-2345") + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=interfaces, + type="dataservicetask", + ) + + expected_resource_metadata = DataServiceMetadata( + dataservice_config=DataServiceConfig( + Name="gnn-2345", Image="image", Command=cmd, Cluster="ei-dev2", ExistingReleaseName="gnn-2345"), + name="gnn-2345") + + # Test create method, and create_data_service should have not been called + res_resource_metadata = agent.create(dummy_template, task_inputs) + assert res_resource_metadata == expected_resource_metadata + mock_create_data_service.assert_not_called() + + # Test get method + res = agent.get(res_resource_metadata) + assert res.phase == TaskExecution.RUNNING + mock_check_status.assert_called_once_with("gnn-2345") + + mock_check_status.return_value = "failed" + res.phase == TaskExecution.FAILED + + # # Test delete methods are not called + agent.delete(res_resource_metadata) + mock_delete_stateful_set.assert_called_once_with("gnn-2345") + mock_delete_service.assert_called_once_with("gnn-2345") diff --git a/plugins/flytekit-neptune/dev-requirements 2.in b/plugins/flytekit-neptune/dev-requirements 2.in new file mode 100644 index 0000000000..ac57c9d501 --- /dev/null +++ b/plugins/flytekit-neptune/dev-requirements 2.in @@ -0,0 +1,2 @@ +neptune-scale>=0.13.0 +neptune>=1.10.4 diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent 2.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent 2.py new file mode 100644 index 0000000000..e4f24baa5a --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent 2.py @@ -0,0 +1,53 @@ +import asyncio +import logging +from typing import Optional + +from flyteidl.core.execution_pb2 import TaskExecution + +from flytekit import FlyteContextManager, lazy_module +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase +from flytekit.extend.backend.utils import get_agent_secret +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +openai = lazy_module("openai") + +TIMEOUT_SECONDS = 10 +OPENAI_API_KEY = "FLYTE_OPENAI_API_KEY" + + +class ChatGPTAgent(SyncAgentBase): + name = "ChatGPT Agent" + + def __init__(self): + super().__init__(task_type_name="chatgpt") + + async def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> Resource: + ctx = FlyteContextManager.current_context() + input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) + message = input_python_value["message"] + + custom = task_template.custom + custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}] + client = openai.AsyncOpenAI( + organization=custom["openai_organization"], + api_key=get_agent_secret(secret_key=OPENAI_API_KEY), + ) + + logger = logging.getLogger("httpx") + logger.setLevel(logging.WARNING) + + completion = await asyncio.wait_for(client.chat.completions.create(**custom["chatgpt_config"]), TIMEOUT_SECONDS) + message = completion.choices[0].message.content + outputs = {"o0": message} + + return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) + + +AgentRegistry.register(ChatGPTAgent()) diff --git a/plugins/flytekit-perian/flytekitplugins/perian_job/agent 2.py b/plugins/flytekit-perian/flytekitplugins/perian_job/agent 2.py new file mode 100644 index 0000000000..5c7927deca --- /dev/null +++ b/plugins/flytekit-perian/flytekitplugins/perian_job/agent 2.py @@ -0,0 +1,261 @@ +import base64 +import shlex +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from flyteidl.core.execution_pb2 import TaskExecution +from perian import ( + AcceleratorQueryInput, + ApiClient, + Configuration, + CpuQueryInput, + CreateJobRequest, + DockerRegistryCredentials, + DockerRunParameters, + InstanceTypeQueryInput, + JobApi, + JobStatus, + MemoryQueryInput, + Name, + OSStorageConfig, + ProviderQueryInput, + RegionQueryInput, + Size, +) + +from flytekit import current_context +from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions.base import FlyteException +from flytekit.exceptions.user import FlyteUserException +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.loggers import logger +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +PERIAN_API_URL = "https://api.perian.cloud" + + +@dataclass +class PerianMetadata(ResourceMeta): + """Metadata for PERIAN jobs""" + + job_id: str + + +class PerianAgent(AsyncAgentBase): + """Flyte Agent for executing tasks on PERIAN Job Platform""" + + name = "Perian Agent" + + def __init__(self): + logger.info("Initializing Perian agent") + super().__init__(task_type_name="perian_task", metadata_type=PerianMetadata) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap], + output_prefix: Optional[str], + **kwargs, + ) -> PerianMetadata: + logger.info("Creating new Perian job") + ctx = current_context() + literal_types = task_template.interface.inputs + input_kwargs = ( + TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs.literals else None + ) + config = Configuration(host=PERIAN_API_URL) + job_request = self._build_create_job_request(task_template, input_kwargs) + with ApiClient(config) as api_client: + api_instance = JobApi(api_client) + response = api_instance.create_job( + create_job_request=job_request, + _headers=self._build_headers(), + ) + if response.status_code != 200: + raise FlyteException(f"Failed to create Perian job: {response.text}") + + return PerianMetadata(job_id=response.id) + + def get(self, resource_meta: PerianMetadata, **kwargs) -> Resource: + job_id = resource_meta.job_id + logger.info("Getting Perian job status: %s", job_id) + config = Configuration(host=PERIAN_API_URL) + with ApiClient(config) as api_client: + api_instance = JobApi(api_client) + response = api_instance.get_job_by_id( + job_id=str(job_id), + _headers=self._build_headers(), + ) + if response.status_code != 200: + raise FlyteException(f"Failed to get Perian job status: {response.text}") + if not response.jobs: + raise FlyteException(f"Perian job not found: {job_id}") + job = response.jobs[0] + + return Resource( + phase=self._perian_job_status_to_flyte_phase(job.status), + message=job.logs, + ) + + def delete(self, resource_meta: PerianMetadata, **kwargs): + job_id = resource_meta.job_id + logger.info("Cancelling Perian job: %s", job_id) + config = Configuration(host=PERIAN_API_URL) + with ApiClient(config) as api_client: + api_instance = JobApi(api_client) + response = api_instance.cancel_job( + job_id=str(job_id), + _headers=self._build_headers(), + ) + if response.status_code != 200: + raise FlyteException(f"Failed to cancel Perian job: {response.text}") + + def _build_create_job_request( + self, task_template: TaskTemplate, inputs: Optional[Dict[str, Any]] + ) -> CreateJobRequest: + params = task_template.custom + secrets = current_context().secrets + + # Build instance type requirements + reqs = InstanceTypeQueryInput() + if params.get("cores"): + reqs.cpu = CpuQueryInput(cores=int(params["cores"])) + if params.get("memory"): + reqs.ram = MemoryQueryInput(size=Size(params["memory"])) + if any([params.get("accelerators"), params.get("accelerator_type")]): + reqs.accelerator = AcceleratorQueryInput() + if params.get("accelerators"): + reqs.accelerator.no = int(params["accelerators"]) + if params.get("accelerator_type"): + reqs.accelerator.name = Name(params["accelerator_type"]) + if params.get("country_code"): + reqs.region = RegionQueryInput(location=params["country_code"]) + if params.get("provider"): + reqs.provider = ProviderQueryInput(name_short=params["provider"]) + + docker_run = self._read_storage_credentials() + + docker_registry = None + try: + dr_url = secrets.get(key="docker_registry_url") + dr_username = secrets.get(key="docker_registry_username") + dr_password = secrets.get(key="docker_registry_password") + if any([dr_url, dr_username, dr_password]): + docker_registry = DockerRegistryCredentials( + url=dr_url, + username=dr_username, + password=dr_password, + ) + except ValueError: + pass + + container = task_template.container + if container: + image = container.image + else: + image = params["image"] + if ":" in image: + docker_run.image_name, docker_run.image_tag = image.rsplit(":", 1) + else: + docker_run.image_name = image + + if container: + command = container.args + else: + command = self._render_command_template(params["command"], inputs) + if command: + docker_run.command = shlex.join(command) + + if params.get("environment"): + if docker_run.env_variables: + docker_run.env_variables.update(params["environment"]) + else: + docker_run.env_variables = params["environment"] + + storage_config = None + if params.get("os_storage_size"): + storage_config = OSStorageConfig(size=int(params["os_storage_size"])) + + return CreateJobRequest( + auto_failover_instance_type=True, + requirements=reqs, + docker_run_parameters=docker_run, + docker_registry_credentials=docker_registry, + os_storage_config=storage_config, + ) + + def _render_command_template(self, command: List[str], inputs: Optional[Dict[str, Any]]) -> List[str]: + if not inputs: + return command + rendered_command = [] + for c in command: + for key, val in inputs.items(): + c = c.replace("{{.inputs." + key + "}}", str(val)) + rendered_command.append(c) + return rendered_command + + def _read_storage_credentials(self) -> DockerRunParameters: + secrets = current_context().secrets + docker_run = DockerRunParameters() + # AWS + try: + aws_access_key_id = secrets.get(key="aws_access_key_id") + aws_secret_access_key = secrets.get(key="aws_secret_access_key") + docker_run.secrets = { + "AWS_ACCESS_KEY_ID": aws_access_key_id, + "AWS_SECRET_ACCESS_KEY": aws_secret_access_key, + } + return docker_run + except ValueError: + pass + # GCP + try: + creds_file = "/data/gcp-credentials.json" # to be mounted in the container + google_application_credentials = secrets.get(key="google_application_credentials") + docker_run.secrets = { + "GOOGLE_APPLICATION_CREDENTIALS": creds_file, + } + docker_run.container_files = [ + { + "path": creds_file, + "base64_content": base64.b64encode(google_application_credentials.encode()).decode(), + } + ] + return docker_run + except ValueError: + pass + + raise FlyteUserException( + "To access the Flyte storage bucket, `aws_access_key_id` and `aws_secret_access_key` for AWS " + "or `google_application_credentials` for GCP must be provided in the secrets" + ) + + def _build_headers(self) -> dict: + secrets = current_context().secrets + org = secrets.get(key="perian_organization") + token = secrets.get(key="perian_token") + if not org or not token: + raise FlyteUserException("perian_organization and perian_token must be provided in the secrets") + return { + "X-PERIAN-AUTH-ORG": org, + "Authorization": "Bearer " + token, + } + + def _perian_job_status_to_flyte_phase(self, status: JobStatus) -> TaskExecution.Phase: + status_map = { + JobStatus.QUEUED: TaskExecution.QUEUED, + JobStatus.INITIALIZING: TaskExecution.INITIALIZING, + JobStatus.RUNNING: TaskExecution.RUNNING, + JobStatus.DONE: TaskExecution.SUCCEEDED, + JobStatus.SERVERERROR: TaskExecution.FAILED, + JobStatus.USERERROR: TaskExecution.FAILED, + JobStatus.CANCELLED: TaskExecution.ABORTED, + } + if status == JobStatus.UNDEFINED: + raise FlyteException("Undefined Perian job status") + return status_map[status] + + +# To register the Perian agent +AgentRegistry.register(PerianAgent()) diff --git a/tests/flytekit/integration/remote/workflows/basic/dataclass_with_optional_wf 2.py b/tests/flytekit/integration/remote/workflows/basic/dataclass_with_optional_wf 2.py new file mode 100644 index 0000000000..2448f2a237 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/dataclass_with_optional_wf 2.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Optional + +from flytekit import task, workflow + + +@dataclass +class MyDataClassWithOptional: + foo: dict[str, float] + bar: Optional[dict] = None + baz: Optional[dict[str, float]] = None + qux: Optional[dict[str, int]] = None + + +@task +def t1(in_dataclass: MyDataClassWithOptional) -> MyDataClassWithOptional: + return in_dataclass + + +@workflow +def wf(in_dataclass: MyDataClassWithOptional) -> MyDataClassWithOptional: + return t1(in_dataclass=in_dataclass) # type: ignore + + +@dataclass +class MyParentDataClass: + child: MyDataClassWithOptional + a: Optional[dict[str, float]] = None + b: Optional[MyDataClassWithOptional] = None + + +@task +def t2(in_dataclass: MyParentDataClass) -> MyParentDataClass: + return in_dataclass + + +@workflow +def wf_nested_dc(in_dataclass: MyParentDataClass) -> MyParentDataClass: + return t2(in_dataclass=in_dataclass) # type: ignore diff --git a/tests/flytekit/unit/cli/pyflyte/test_grpc_verbosity 2.py b/tests/flytekit/unit/cli/pyflyte/test_grpc_verbosity 2.py new file mode 100644 index 0000000000..449b8da14d --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_grpc_verbosity 2.py @@ -0,0 +1,48 @@ +import os + + +def test_grpc_verbosity_set_on_import(): + """ + Test that GRPC_VERBOSITY is set to NONE if not already present in the environment + when the SDK container module is imported. + """ + original_value = os.environ.get("GRPC_VERBOSITY", None) + + try: + if "GRPC_VERBOSITY" in os.environ: + del os.environ["GRPC_VERBOSITY"] + + import importlib + import flytekit.clis.sdk_in_container + importlib.reload(flytekit.clis.sdk_in_container) + + assert "GRPC_VERBOSITY" in os.environ + assert os.environ["GRPC_VERBOSITY"] == "NONE" + + finally: + if original_value is not None: + os.environ["GRPC_VERBOSITY"] = original_value + elif "GRPC_VERBOSITY" in os.environ: + del os.environ["GRPC_VERBOSITY"] + + +def test_grpc_verbosity_not_overridden(): + """ + Test that GRPC_VERBOSITY is not overridden if already set in the environment. + """ + original_value = os.environ.get("GRPC_VERBOSITY", None) + + try: + os.environ["GRPC_VERBOSITY"] = "INFO" + + import importlib + import flytekit.clis.sdk_in_container + importlib.reload(flytekit.clis.sdk_in_container) + + assert os.environ["GRPC_VERBOSITY"] == "INFO" + + finally: + if original_value is not None: + os.environ["GRPC_VERBOSITY"] = original_value + elif "GRPC_VERBOSITY" in os.environ: + del os.environ["GRPC_VERBOSITY"] diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index 72dca288eb..3915453e16 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -1,6 +1,7 @@ import os import shutil +import rich_click as click from click.testing import CliRunner from flyteidl.admin import task_pb2 @@ -177,7 +178,11 @@ def test_package(): def test_pkgs(): - pp = flytekit.clis.sdk_in_container.utils.validate_package(None, None, ["a.b", "a.c,b.a", "cc.a"]) + ctx = click.Context(click.Command('test')) + ctx.obj = dict() + ctx.obj["verbose"] = 0 + + pp = flytekit.clis.sdk_in_container.utils.validate_package(ctx, None, ["a.b", "a.c,b.a", "cc.a"]) assert pp == ["a.b", "a.c", "b.a", "cc.a"] diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index ec14aa8227..35e8604d15 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -1,6 +1,8 @@ import os import shutil import subprocess +import json +import yaml import mock import pytest @@ -15,6 +17,7 @@ from flytekit.core import context_manager from flytekit.core.context_manager import FlyteContextManager from flytekit.remote.remote import FlyteRemote +from flytekit.loggers import logging sample_file_contents = """ from flytekit import task, workflow @@ -163,3 +166,108 @@ def test_non_fast_register_require_version(mock_client, mock_remote): result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core3"]) assert result.exit_code == 1 shutil.rmtree("core3") + + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_registrated_summary_json(mock_client, mock_remote): + ctx = FlyteContextManager.current_context() + mock_remote._client = mock_client + mock_remote.return_value.context = ctx + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core5", exist_ok=True) + with open(os.path.join("core5", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + + result = runner.invoke( + pyflyte.main, + ["register", "--summary-format", "json", "core5"] + ) + assert result.exit_code == 0 + try: + summary_data = json.loads(result.output) + except json.JSONDecodeError as e: + pytest.fail(f"Failed to parse registration summary JSON: {e}") + except Exception as e: + pytest.fail(f"Unexpected error while parsing registration summary: {e}") + assert isinstance(summary_data, list) + assert len(summary_data) > 0 + for entry in summary_data: + assert "id" in entry + assert "type" in entry + assert "version" in entry + assert "status" in entry + shutil.rmtree("core5") + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_registrated_summary_yaml(mock_client, mock_remote): + ctx = FlyteContextManager.current_context() + mock_remote._client = mock_client + mock_remote.return_value.context = ctx + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core6", exist_ok=True) + with open(os.path.join("core6", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + + result = runner.invoke( + pyflyte.main, + ["register", "--summary-format", "yaml", "core6"] + ) + assert result.exit_code == 0 + try: + summary_data = yaml.safe_load(result.output) + except yaml.YAMLError as e: + pytest.fail(f"Failed to parse YAML output: {e}") + assert isinstance(summary_data, list) + assert len(summary_data) > 0 + for entry in summary_data: + assert "id" in entry + assert "type" in entry + assert "version" in entry + assert "status" in entry + + shutil.rmtree("core6") + + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_quiet(mock_client, mock_remote): + ctx = FlyteContextManager.current_context() + mock_remote._client = mock_client + mock_remote.return_value.context = ctx + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core7", exist_ok=True) + with open(os.path.join("core7", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke( + pyflyte.main, + ["register", "--quiet", "core7"] + ) + assert result.exit_code == 0 + assert result.output == "" + + shutil.rmtree("core7") diff --git a/tests/flytekit/unit/extend/test_agent 2.py b/tests/flytekit/unit/extend/test_agent 2.py new file mode 100644 index 0000000000..040396db98 --- /dev/null +++ b/tests/flytekit/unit/extend/test_agent 2.py @@ -0,0 +1,523 @@ +import typing +from collections import OrderedDict +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +import grpc +import pytest +from flyteidl.admin.agent_pb2 import ( + Agent, + CreateRequestHeader, + CreateTaskRequest, + DeleteTaskRequest, + ExecuteTaskSyncRequest, + GetAgentRequest, + GetTaskRequest, + ListAgentsRequest, + ListAgentsResponse, + TaskCategory, DeleteTaskResponse, +) +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from flyteidl.core.identifier_pb2 import ResourceType + +from flytekit import PythonFunctionTask, task +from flytekit.configuration import ( + FastSerializationSettings, + Image, + ImageConfig, + SerializationSettings, +) +from flytekit.core.base_task import PythonTask, kwtypes +from flytekit.core.interface import Interface +from flytekit.exceptions.system import FlyteAgentNotFound +from flytekit.extend.backend.agent_service import ( + AgentMetadataService, + AsyncAgentService, + SyncAgentService, +) +from flytekit.extend.backend.base_agent import ( + AgentRegistry, + AsyncAgentBase, + AsyncAgentExecutorMixin, + Resource, + ResourceMeta, + SyncAgentBase, + SyncAgentExecutorMixin, + is_terminal_phase, + render_task_template, +) +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret +from flytekit.models import literals +from flytekit.models.core.identifier import ( + Identifier, + NodeExecutionIdentifier, + TaskExecutionIdentifier, + WorkflowExecutionIdentifier, +) +from flytekit.models.literals import LiteralMap +from flytekit.models.security import Identity +from flytekit.models.task import TaskExecutionMetadata, TaskTemplate +from flytekit.tools.translator import get_serializable +from flytekit.utils.asyn import loop_manager + +dummy_id = "dummy_id" + + +@dataclass +class DummyMetadata(ResourceMeta): + job_id: str + output_path: typing.Optional[str] = None + task_name: typing.Optional[str] = None + + +class DummyAgent(AsyncAgentBase): + name = "Dummy Agent" + + def __init__(self): + super().__init__(task_type_name="dummy", metadata_type=DummyMetadata) + + def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap], **kwargs) -> DummyMetadata: + return DummyMetadata(job_id=dummy_id) + + def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: + return Resource( + phase=TaskExecution.SUCCEEDED, + log_links=[TaskLog(name="console", uri="localhost:3000")], + custom_info={"custom": "info", "num": 1}, + ) + + def delete(self, resource_meta: DummyMetadata, **kwargs): + ... + + +class AsyncDummyAgent(AsyncAgentBase): + name = "Async Dummy Agent" + + def __init__(self): + super().__init__(task_type_name="async_dummy", metadata_type=DummyMetadata) + + async def create( + self, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + output_prefix: typing.Optional[str] = None, + task_execution_metadata: typing.Optional[TaskExecutionMetadata] = None, + **kwargs, + ) -> DummyMetadata: + output_path = f"{output_prefix}/{dummy_id}" if output_prefix else None + task_name = task_execution_metadata.task_execution_id.task_id.name if task_execution_metadata else "default" + return DummyMetadata(job_id=dummy_id, output_path=output_path, task_name=task_name) + + async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: + return Resource( + phase=TaskExecution.SUCCEEDED, + log_links=[TaskLog(name="console", uri="localhost:3000")], + custom_info={"custom": "info", "num": 1}, + ) + + async def delete(self, resource_meta: DummyMetadata, **kwargs): + ... + + +class MockOpenAIAgent(SyncAgentBase): + name = "mock openAI Agent" + + def __init__(self): + super().__init__(task_type_name="openai") + + def do( + self, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + **kwargs, + ) -> Resource: + assert inputs.literals["a"].scalar.primitive.integer == 1 + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) + + +class MockAsyncOpenAIAgent(SyncAgentBase): + name = "mock async openAI Agent" + + def __init__(self): + super().__init__(task_type_name="async_openai") + + async def do(self, task_template: TaskTemplate, inputs: LiteralMap = None, **kwargs) -> Resource: + assert inputs.literals["a"].scalar.primitive.integer == 1 + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) + + +def get_task_template(task_type: str) -> TaskTemplate: + @task + def simple_task(i: int): + print(i) + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + fast_serialization_settings=FastSerializationSettings(enabled=True), + ) + serialized = get_serializable(OrderedDict(), serialization_settings, simple_task) + serialized.template._type = task_type + return serialized.template + + +task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, +) + +task_execution_metadata = TaskExecutionMetadata( + task_execution_id=TaskExecutionIdentifier( + task_id=Identifier(ResourceType.TASK, "project", "domain", "name", "version"), + node_execution_id=NodeExecutionIdentifier("node_id", WorkflowExecutionIdentifier("project", "domain", "name")), + retry_attempt=1, + ), + namespace="namespace", + labels={"label_key": "label_val"}, + annotations={"annotation_key": "annotation_val"}, + k8s_service_account="k8s service account", + environment_variables={"env_var_key": "env_var_val"}, + identity=Identity(execution_identity="task executor"), +) + + +def test_dummy_agent(): + AgentRegistry.register(DummyAgent(), override=True) + agent = AgentRegistry.get_agent("dummy") + template = get_task_template("dummy") + metadata = DummyMetadata(job_id=dummy_id) + assert agent.create(template, task_inputs) == DummyMetadata(job_id=dummy_id) + resource = agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.log_links[0].name == "console" + assert resource.log_links[0].uri == "localhost:3000" + assert resource.custom_info["custom"] == "info" + assert resource.custom_info["num"] == 1 + assert agent.delete(metadata) is None + + class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): + def __init__(self, **kwargs): + super().__init__(task_type="dummy", **kwargs) + + t = DummyTask(task_config={}, task_function=lambda: None, container_image="dummy") + t.execute() + + t._task_type = "non-exist-type" + with pytest.raises(Exception, match="Cannot find agent for task category: non-exist-type."): + t.execute() + + +@pytest.mark.parametrize( + "agent,consume_metadata", + [(DummyAgent(), False), (AsyncDummyAgent(), True)], + ids=["sync", "async"], +) +@pytest.mark.asyncio +async def test_async_agent_service(agent, consume_metadata): + AgentRegistry.register(agent, override=True) + service = AsyncAgentService() + ctx = MagicMock(spec=grpc.ServicerContext) + + inputs_proto = task_inputs.to_flyte_idl() + output_prefix = "/tmp" + metadata_bytes = ( + DummyMetadata( + job_id=dummy_id, + output_path=f"{output_prefix}/{dummy_id}", + task_name=task_execution_metadata.task_execution_id.task_id.name, + ).encode() + if consume_metadata + else DummyMetadata(job_id=dummy_id).encode() + ) + + tmp = get_task_template(agent.task_category.name).to_flyte_idl() + task_category = TaskCategory(name=agent.task_category.name, version=0) + req = CreateTaskRequest( + inputs=inputs_proto, + template=tmp, + output_prefix=output_prefix, + task_execution_metadata=task_execution_metadata.to_flyte_idl(), + ) + + res = await service.CreateTask(req, ctx) + assert res.resource_meta == metadata_bytes + res = await service.GetTask(GetTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) + assert res.resource.phase == TaskExecution.SUCCEEDED + res = await service.DeleteTask( + DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes), + ctx, + ) + assert res == DeleteTaskResponse() + + agent_metadata = AgentRegistry.get_agent_metadata(agent.name) + assert agent_metadata.supported_task_types[0] == agent.task_category.name + assert agent_metadata.supported_task_categories[0].name == agent.task_category.name + + with pytest.raises(FlyteAgentNotFound): + AgentRegistry.get_agent_metadata("non-exist-namr") + + +def test_register_agent(): + agent = DummyAgent() + AgentRegistry.register(agent, override=True) + assert AgentRegistry.get_agent("dummy").name == agent.name + + with pytest.raises(ValueError, match="Duplicate agent for task type: dummy_v0"): + AgentRegistry.register(agent) + + with pytest.raises(FlyteAgentNotFound): + AgentRegistry.get_agent("non-exist-type") + + agents = AgentRegistry.list_agents() + assert len(agents) >= 1 + + +@pytest.mark.asyncio +async def test_agent_metadata_service(): + agent = DummyAgent() + AgentRegistry.register(agent, override=True) + + ctx = MagicMock(spec=grpc.ServicerContext) + metadata_service = AgentMetadataService() + res = await metadata_service.ListAgents(ListAgentsRequest(), ctx) + assert isinstance(res, ListAgentsResponse) + res = await metadata_service.GetAgent(GetAgentRequest(name="Dummy Agent"), ctx) + assert res.agent.name == agent.name + assert res.agent.supported_task_types[0] == agent.task_category.name + assert res.agent.supported_task_categories[0].name == agent.task_category.name + + +def test_openai_agent(): + AgentRegistry.register(MockOpenAIAgent(), override=True) + + class OpenAITask(SyncAgentExecutorMixin, PythonTask): + def __init__(self, **kwargs): + super().__init__( + task_type="openai", + interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), + **kwargs, + ) + + t = OpenAITask(task_config={}, name="openai task") + res = t(a=1) + assert res == 1 + + +def test_async_openai_agent(): + AgentRegistry.register(MockAsyncOpenAIAgent(), override=True) + + class OpenAITask(SyncAgentExecutorMixin, PythonTask): + def __init__(self, **kwargs): + super().__init__( + task_type="async_openai", + interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), + **kwargs, + ) + + t = OpenAITask(task_config={}, name="openai task") + res = t(a=1) + assert res == 1 + + +async def get_request_iterator(task_type: str): + inputs_proto = task_inputs.to_flyte_idl() + template = get_task_template(task_type).to_flyte_idl() + header = CreateRequestHeader(template=template, output_prefix="/tmp") + yield ExecuteTaskSyncRequest(header=header) + yield ExecuteTaskSyncRequest(inputs=inputs_proto) + + +@pytest.mark.asyncio +async def test_sync_agent_service(): + AgentRegistry.register(MockOpenAIAgent(), override=True) + ctx = MagicMock(spec=grpc.ServicerContext) + + service = SyncAgentService() + res = await service.ExecuteTaskSync(get_request_iterator("openai"), ctx).__anext__() + assert res.header.resource.phase == TaskExecution.SUCCEEDED + assert res.header.resource.outputs.literals["o0"].scalar.primitive.integer == 1 + + +@pytest.mark.asyncio +async def test_sync_agent_service_with_asyncio(): + AgentRegistry.register(MockAsyncOpenAIAgent(), override=True) + AgentRegistry.register(DummyAgent(), override=True) + ctx = MagicMock(spec=grpc.ServicerContext) + + service = SyncAgentService() + res = await service.ExecuteTaskSync(get_request_iterator("async_openai"), ctx).__anext__() + assert res.header.resource.phase == TaskExecution.SUCCEEDED + assert res.header.resource.outputs.literals["o0"].scalar.primitive.integer == 1 + + +def test_is_terminal_phase(): + assert is_terminal_phase(TaskExecution.SUCCEEDED) + assert is_terminal_phase(TaskExecution.ABORTED) + assert is_terminal_phase(TaskExecution.FAILED) + assert not is_terminal_phase(TaskExecution.RUNNING) + + +def test_convert_to_flyte_phase(): + assert convert_to_flyte_phase("FAILED") == TaskExecution.FAILED + assert convert_to_flyte_phase("TIMEOUT") == TaskExecution.FAILED + assert convert_to_flyte_phase("TIMEDOUT") == TaskExecution.FAILED + assert convert_to_flyte_phase("CANCELED") == TaskExecution.FAILED + assert convert_to_flyte_phase("SKIPPED") == TaskExecution.FAILED + assert convert_to_flyte_phase("INTERNAL_ERROR") == TaskExecution.FAILED + + assert convert_to_flyte_phase("DONE") == TaskExecution.SUCCEEDED + assert convert_to_flyte_phase("SUCCEEDED") == TaskExecution.SUCCEEDED + assert convert_to_flyte_phase("SUCCESS") == TaskExecution.SUCCEEDED + + assert convert_to_flyte_phase("RUNNING") == TaskExecution.RUNNING + assert convert_to_flyte_phase("TERMINATING") == TaskExecution.RUNNING + + assert convert_to_flyte_phase("PENDING") == TaskExecution.INITIALIZING + + invalid_state = "INVALID_STATE" + with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"): + convert_to_flyte_phase(invalid_state) + + +@patch("flytekit.current_context") +def test_get_agent_secret(mocked_context): + mocked_context.return_value.secrets.get.return_value = "mocked token" + assert get_agent_secret("mocked key") == "mocked token" + + +def test_render_task_template(): + template = get_task_template("dummy") + tt = render_task_template(template, "s3://becket") + assert tt.container.args == [ + "pyflyte-fast-execute", + "--additional-distribution", + "{{ .remote_package_path }}", + "--dest-dir", + "{{ .dest_dir }}", + "--", + "pyflyte-execute", + "--inputs", + "s3://becket/inputs.pb", + "--output-prefix", + "s3://becket", + "--raw-output-data-prefix", + "s3://becket/raw_output", + "--checkpoint-path", + "s3://becket/checkpoint_output", + "--prev-checkpoint", + "s3://becket/prev_checkpoint", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "test_agent", + "task-name", + "simple_task", + ] + + +@pytest.fixture +def sample_agents(): + async_agent = Agent( + name="Sensor", + is_sync=False, + supported_task_categories=[TaskCategory(name="sensor", version=0)], + ) + sync_agent = Agent( + name="ChatGPT Agent", + is_sync=True, + supported_task_categories=[TaskCategory(name="chatgpt", version=0)], + ) + return [async_agent, sync_agent] + + +def test_resource_type(): + o = Resource( + phase=TaskExecution.SUCCEEDED, + ) + v = loop_manager.run_sync(o.to_flyte_idl) + assert v + assert v.phase == TaskExecution.SUCCEEDED + assert len(v.log_links) == 0 + assert v.message == "" + assert len(v.outputs.literals) == 0 + assert len(v.custom_info) == 0 + + o2 = Resource.from_flyte_idl(v) + assert o2 + + o = Resource( + phase=TaskExecution.SUCCEEDED, + log_links=[TaskLog(name="console", uri="localhost:3000")], + message="foo", + outputs={"o0": 1}, + custom_info={"custom": "info", "num": 1}, + ) + v = loop_manager.run_sync(o.to_flyte_idl) + assert v + assert v.phase == TaskExecution.SUCCEEDED + assert v.log_links[0].name == "console" + assert v.log_links[0].uri == "localhost:3000" + assert v.message == "foo" + assert v.outputs.literals["o0"].scalar.primitive.integer == 1 + assert v.custom_info["custom"] == "info" + assert v.custom_info["num"] == 1 + + o2 = Resource.from_flyte_idl(v) + assert o2.phase == o.phase + assert list(o2.log_links) == list(o.log_links) + assert o2.message == o.message + # round-tripping creates a literal map out of outputs + assert o2.outputs.literals["o0"].scalar.primitive.integer == 1 + assert o2.custom_info == o.custom_info + + +def test_agent_complex_type(): + @dataclass + class Foo: + val: str + + class FooAgent(SyncAgentBase): + def __init__(self) -> None: + super().__init__(task_type_name="foo") + + def do( + self, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + **kwargs: typing.Any, + ) -> Resource: + return Resource( + phase=TaskExecution.SUCCEEDED, outputs={"foos": [Foo(val="a"), Foo(val="b")], "has_foos": True} + ) + + AgentRegistry.register(FooAgent(), override=True) + + class FooTask(SyncAgentExecutorMixin, PythonTask): # type: ignore + _TASK_TYPE = "foo" + + def __init__(self, name: str, **kwargs: typing.Any) -> None: + task_config: dict[str, typing.Any] = {} + + outputs = {"has_foos": bool, "foos": typing.Optional[typing.List[Foo]]} + + super().__init__( + task_type=self._TASK_TYPE, + name=name, + task_config=task_config, + interface=Interface(outputs=outputs), + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> typing.Dict[str, typing.Any]: + return {} + + foo_task = FooTask(name="foo_task") + res = foo_task() + assert res.has_foos + assert res.foos[1].val == "b" diff --git a/tests/flytekit/unit/models/test_concurrency 2.py b/tests/flytekit/unit/models/test_concurrency 2.py new file mode 100644 index 0000000000..1b1a2f82b0 --- /dev/null +++ b/tests/flytekit/unit/models/test_concurrency 2.py @@ -0,0 +1,41 @@ +from flyteidl.admin import launch_plan_pb2 as _launch_plan_idl + +from flytekit.models.concurrency import ConcurrencyLimitBehavior, ConcurrencyPolicy + + +def test_concurrency_limit_behavior(): + assert ConcurrencyLimitBehavior.SKIP == _launch_plan_idl.CONCURRENCY_LIMIT_BEHAVIOR_SKIP + + # Test enum to string conversion + assert ConcurrencyLimitBehavior.enum_to_string(ConcurrencyLimitBehavior.SKIP) == "SKIP" + assert ConcurrencyLimitBehavior.enum_to_string(999) == "" + + +def test_concurrency_policy_serialization(): + policy = ConcurrencyPolicy(max_concurrency=1, behavior=ConcurrencyLimitBehavior.SKIP) + + assert policy.max_concurrency == 1 + assert policy.behavior == ConcurrencyLimitBehavior.SKIP + + # Test serialization to protobuf + pb = policy.to_flyte_idl() + assert isinstance(pb, _launch_plan_idl.ConcurrencyPolicy) + assert pb.max == 1 + assert pb.behavior == _launch_plan_idl.CONCURRENCY_LIMIT_BEHAVIOR_SKIP + + # Test deserialization from protobuf + policy2 = ConcurrencyPolicy.from_flyte_idl(pb) + assert policy2.max_concurrency == 1 + assert policy2.behavior == ConcurrencyLimitBehavior.SKIP + + +def test_concurrency_policy_with_different_max(): + # Test with a higher max value + policy = ConcurrencyPolicy(max_concurrency=5, behavior=ConcurrencyLimitBehavior.SKIP) + assert policy.max_concurrency == 5 + + pb = policy.to_flyte_idl() + assert pb.max == 5 + + policy2 = ConcurrencyPolicy.from_flyte_idl(pb) + assert policy2.max_concurrency == 5