Skip to content

Commit

Permalink
Merge branch 'master' into rest-api-agent
Browse files Browse the repository at this point in the history
  • Loading branch information
kumare3 committed Jan 31, 2025
2 parents 4feef3f + 88ac611 commit 3c9a800
Show file tree
Hide file tree
Showing 86 changed files with 4,520 additions and 496 deletions.
13 changes: 11 additions & 2 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os:
- ubuntu-24.04-arm
- ubuntu-latest
- windows-latest
- macos-latest
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -78,7 +82,11 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os:
- ubuntu-24.04-arm
- ubuntu-latest
- windows-latest
- macos-latest
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -303,6 +311,7 @@ jobs:
- flytekit-huggingface
- flytekit-identity-aware-proxy
- flytekit-inference
- flytekit-k8sdataservice
- flytekit-k8s-pod
- flytekit-kf-mpi
- flytekit-kf-pytorch
Expand Down
10 changes: 9 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.9
rev: v0.8.3
hooks:
# Run the linter.
- id: ruff
Expand All @@ -28,3 +28,11 @@ repos:
- id: codespell
additional_dependencies:
- tomli
- repo: https://github.com/jsh9/pydoclint
rev: 0.6.0
hooks:
- id: pydoclint
args:
- --style=google
- --exclude='.git|tests/flytekit/*|tests/'
- --baseline=pydoclint-errors-baseline.txt
1 change: 1 addition & 0 deletions Dockerfile.agent
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ RUN apt-get update && apt-get install build-essential -y \
RUN uv pip install --system --no-cache-dir -U flytekit==$VERSION \
flytekitplugins-airflow==$VERSION \
flytekitplugins-bigquery==$VERSION \
flytekitplugins-k8sdataservice==$VERSION \
flytekitplugins-openai==$VERSION \
flytekitplugins-snowflake==$VERSION \
flytekitplugins-awssagemaker==$VERSION \
Expand Down
12 changes: 12 additions & 0 deletions docs/source/plugins/k8sstatefuldataservice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. k8sstatefuldataservice:
###################################################
Kubernetes StatefulSet Data Service API reference
###################################################

.. tags:: Integration, DeepLearning, MachineLearning, Kubernetes, GNN

.. automodule:: flytekitplugins.k8sdataservice
:no-members:
:no-inherited-members:
:no-special-members:
2 changes: 1 addition & 1 deletion flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _dispatch_execute(
exc_str = get_traceback_str(e)
output_file_dict[error_file_name] = _error_models.ErrorDocument(
_error_models.ContainerError(
code="USER",
code=e.error_code,
message=exc_str,
kind=kind,
origin=_execution_models.ExecutionError.ErrorKind.USER,
Expand Down
3 changes: 2 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,9 @@ def run_remote(
if run_level_params.wait_execution:
msg += " Waiting to complete..."
p = Progress(TimeElapsedColumn(), TextColumn(msg), transient=True)
t = p.add_task("exec")
t = p.add_task("exec", visible=False)
with p:
p.update(t, visible=True)
p.start_task(t)
execution = remote.execute(
entity,
Expand Down
19 changes: 13 additions & 6 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def serve(ctx: click.Context):
type=int,
help="Grpc port for the agent service",
)
@click.option(
"--prometheus_port",
default="9090",
is_flag=False,
type=int,
help="Prometheus port for the agent service",
)
@click.option(
"--worker",
default="10",
Expand All @@ -45,20 +52,20 @@ def serve(ctx: click.Context):
"for testing.",
)
@click.pass_context
def agent(_: click.Context, port, worker, timeout):
def agent(_: click.Context, port, prometheus_port, worker, timeout):
"""
Start a grpc server for the agent service.
"""
import asyncio

asyncio.run(_start_grpc_server(port, worker, timeout))
asyncio.run(_start_grpc_server(port, prometheus_port, worker, timeout))


async def _start_grpc_server(port: int, worker: int, timeout: int):
async def _start_grpc_server(port: int, prometheus_port: int, worker: int, timeout: int):
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService

click.secho("🚀 Starting the agent service...")
_start_http_server()
_start_http_server(prometheus_port)
print_agents_metadata()

server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=worker))
Expand All @@ -73,12 +80,12 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):
await server.wait_for_termination(timeout)


def _start_http_server():
def _start_http_server(prometheus_port: int):
try:
from prometheus_client import start_http_server

click.secho("Starting up the server to expose the prometheus metrics...")
start_http_server(9090)
start_http_server(prometheus_port)
except ImportError as e:
click.secho(f"Failed to start the prometheus server with error {e}", fg="red")

Expand Down
12 changes: 6 additions & 6 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def __init__(
**kwargs,
)

self.sub_node_metadata: NodeMetadata = super().construct_node_metadata()
self.sub_node_metadata._name = self.name

@property
def name(self) -> str:
return self._name
Expand All @@ -137,16 +140,13 @@ def python_interface(self):
return self._collection_interface

def construct_node_metadata(self) -> NodeMetadata:
# TODO: add support for other Flyte entities
"""
This returns metadata for the parent ArrayNode, not the sub-node getting mapped over
"""
return NodeMetadata(
name=self.name,
)

def construct_sub_node_metadata(self) -> NodeMetadata:
nm = super().construct_node_metadata()
nm._name = self.name
return nm

@property
def min_success_ratio(self) -> Optional[float]:
return self._min_success_ratio
Expand Down
6 changes: 5 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pathlib
import signal
import tempfile
import threading
import traceback
import typing
from contextlib import contextmanager
Expand Down Expand Up @@ -994,7 +995,10 @@ def main_signal_handler(signum: int, frame: FrameType):
handler(signum, frame)
exit(1)

signal.signal(signal.SIGINT, main_signal_handler)
# This initialize function is also called by other threads (since the context manager lives in a ContextVar)
# so we should not run this if we're not the main thread.
if threading.current_thread().name == threading.main_thread().name:
signal.signal(signal.SIGINT, main_signal_handler)

# Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
# are already acquainted with
Expand Down
29 changes: 29 additions & 0 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@

Uploadable = typing.Union[str, os.PathLike, pathlib.Path, bytes, io.BufferedReader, io.BytesIO, io.StringIO]

# This is the default chunk size flytekit will use for writing to S3 and GCS. This is set to 25MB by default and is
# configurable by the user if needed. This is used when put() is called on filesystems.
_WRITE_SIZE_CHUNK_BYTES = int(os.environ.get("_F_P_WRITE_CHUNK_SIZE", "26214400")) # 25 * 2**20


def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {
Expand Down Expand Up @@ -108,6 +112,27 @@ def get_fsspec_storage_options(
return {}


def get_additional_fsspec_call_kwargs(protocol: typing.Union[str, tuple], method_name: str) -> Dict[str, Any]:
"""
These are different from the setup args functions defined above. Those kwargs are applied when asking fsspec
to create the filesystem. These kwargs returned here are for when the filesystem's methods are invoked.
:param protocol: s3, gcs, etc.
:param method_name: Pass in the __name__ of the fsspec.filesystem function. _'s will be ignored.
"""
kwargs = {}
method_name = method_name.replace("_", "")
if isinstance(protocol, tuple):
protocol = protocol[0]

# For s3fs and gcsfs, we feel the default chunksize of 50MB is too big.
# Re-evaluate these kwargs when we move off of s3fs to obstore.
if method_name == "put" and protocol in ["s3", "gs"]:
kwargs["chunksize"] = _WRITE_SIZE_CHUNK_BYTES

return kwargs


@decorator
def retry_request(func, *args, **kwargs):
# TODO: Remove this method once s3fs has a new release. https://github.com/fsspec/s3fs/pull/865
Expand Down Expand Up @@ -353,6 +378,10 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw
if "metadata" not in kwargs:
kwargs["metadata"] = {}
kwargs["metadata"].update(self._execution_metadata)

additional_kwargs = get_additional_fsspec_call_kwargs(file_system.protocol, file_system.put.__name__)
kwargs.update(additional_kwargs)

if isinstance(file_system, AsyncFileSystem):
dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
Expand Down
84 changes: 52 additions & 32 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,57 @@ def run_entity(self) -> Any:
def metadata(self) -> _workflow_model.NodeMetadata:
return self._metadata

def _override_node_metadata(
self,
name,
timeout: Optional[Union[int, datetime.timedelta]] = None,
retries: Optional[int] = None,
interruptible: typing.Optional[bool] = None,
cache: typing.Optional[bool] = None,
cache_version: typing.Optional[str] = None,
cache_serialize: typing.Optional[bool] = None,
):
from flytekit.core.array_node_map_task import ArrayNodeMapTask

if isinstance(self.flyte_entity, ArrayNodeMapTask):
# override the sub-node's metadata
node_metadata = self.flyte_entity.sub_node_metadata
else:
node_metadata = self._metadata

if timeout is None:
node_metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
node_metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
node_metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
node_metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
node_metadata._interruptible = interruptible

if name is not None:
node_metadata._name = name

if cache is not None:
assert_not_promise(cache, "cache")
node_metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
node_metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
node_metadata._cache_serializable = cache_serialize

def with_overrides(
self,
node_name: Optional[str] = None,
Expand Down Expand Up @@ -174,27 +225,6 @@ def with_overrides(
assert_no_promises_in_resources(resources)
self._resources = resources

if timeout is None:
self._metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
self._metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
self._metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
self._metadata._interruptible = interruptible

if name is not None:
self._metadata._name = name

if task_config is not None:
logger.warning("This override is beta. We may want to revisit this in the future.")
if not isinstance(task_config, type(self.run_entity._task_config)):
Expand All @@ -209,17 +239,7 @@ def with_overrides(
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if cache is not None:
assert_not_promise(cache, "cache")
self._metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
self._metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize
self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)

return self

Expand Down
5 changes: 1 addition & 4 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import collections
import datetime
import typing
Expand Down Expand Up @@ -1436,9 +1435,7 @@ async def async_flyte_entity_call_handler(
# for both nested eager, async, and sync tasks, submit to the informer.
if not ctx.worker_queue:
raise AssertionError("Worker queue missing, must be set when trying to execute tasks in an eager workflow")
loop = asyncio.get_running_loop()
fut = ctx.worker_queue.add(loop, entity, input_kwargs=kwargs)
result = await fut
result = await ctx.worker_queue.add(entity, input_kwargs=kwargs)
return result

# eager local execution, and all other call patterns are handled by the sync version
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ async def run_with_backend(self, **kwargs):
base_error = ee

html = cast(Controller, ctx.worker_queue).render_html()
Deck("eager workflow", html)
Deck("Eager Executions", html)

if base_error:
# now have to fail this eager task, because we don't want it to show up as succeeded.
Expand Down
Loading

0 comments on commit 3c9a800

Please sign in to comment.