Skip to content
81 changes: 81 additions & 0 deletions examples/advanced/artifact_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Tuple

import flyte
import flyte.artifacts as artifacts

env = flyte.TaskEnvironment("artifact_example")


@env.task
def create_artifact() -> str:
result = "This is my artifact content"
metadata = artifacts.Metadata(
name="my_artifact", version="1.0", description="An example artifact created in create_artifact task"
)
return artifacts.new(result, metadata)


@env.task
def model_artifact() -> str:
result = "This is my model artifact content"
card = artifacts.Card.create_from(
content="<h1>Model Card</h1><p>This is a sample model card.</p>",
format="html",
card_type="model",
)

metadata = artifacts.Metadata.create_model_metadata(
name="my_model_artifact",
version="1.0",
description="An example model artifact created in model_artifact task",
framework="PyTorch",
model_type="Neural Network",
architecture="ResNet50",
task="Image Classification",
modality=("image",),
serial_format="pt",
short_description="A ResNet50 model for image classification tasks.",
card=card,
)
return artifacts.new(result, metadata)


@env.task
def call_artifact() -> Tuple[str, str]:
x = create_artifact()
print(x)
y = model_artifact()
print(y)
return x, y


@env.task
async def use_artifact(v: str) -> str:
print(f"Using artifact with content: {v}")
return f"Artifact used with content: {v}"


@env.task
async def use_multiple_artifacts(v: list[str]) -> str:
print(f"Using multiple artifacts with contents: {v}")
return f"Multiple artifacts used with contents: {v}"


if __name__ == "__main__":
flyte.init()
v = flyte.run(call_artifact)
print(v.outputs())

from flyte.remote import Artifact

artifact_instance = Artifact.get("my_artifact", version="1.0")
v2 = flyte.run(use_artifact, v=artifact_instance)
print(v2.outputs())

artifact_list = [Artifact.get("my_artifact", version="1.0"), Artifact.get("my_artifact", version="1.0")]
v3 = flyte.run(use_multiple_artifacts, v=artifact_list)
print(v3.outputs())

artifact_list_via_prefix = list(Artifact.listall("my_artifact", version="1.0"))
v4 = flyte.run(use_multiple_artifacts, v=artifact_list_via_prefix)
print(v4.outputs())
1 change: 0 additions & 1 deletion src/flyte/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import rich.repr
from packaging.version import Version


if TYPE_CHECKING:
from flyte import Secret, SecretRequest

Expand Down
16 changes: 14 additions & 2 deletions src/flyte/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,25 @@ async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.k
mode="local",
)
with ctx.replace_task_context(tctx):
new_kwargs = {}
if kwargs:
from flyte.remote import Artifact

for k, v in kwargs.items():
if isinstance(v, Artifact):
new_kwargs[k] = v.pb2["data"]
elif isinstance(v, list) and len(v) > 0:
new_kwargs[k] = [item.pb2["data"] if isinstance(item, Artifact) else item for item in v]
else:
new_kwargs[k] = v
print(new_kwargs)
# make the local version always runs on a different thread, returns a wrapped future.
if obj._call_as_synchronous:
fut = controller.submit_sync(obj, *args, **kwargs)
fut = controller.submit_sync(obj, *args, **new_kwargs)
awaitable = asyncio.wrap_future(fut)
outputs = await awaitable
else:
outputs = await controller.submit(obj, *args, **kwargs)
outputs = await controller.submit(obj, *args, **new_kwargs)

class _LocalRun(Run):
def __init__(self, outputs: Tuple[Any, ...] | Any):
Expand Down
13 changes: 7 additions & 6 deletions src/flyte/_utils/module_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def load_python_modules(path: Path, recursive: bool = False) -> Tuple[List[str],
:return: List of loaded module names, and list of file paths that failed to load
"""
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn

loaded_modules = []
failed_paths = []

Expand All @@ -32,12 +33,12 @@ def load_python_modules(path: Path, recursive: bool = False) -> Tuple[List[str],
python_files = glob.glob(str(path / pattern), recursive=recursive)

with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeElapsedColumn(),
TimeRemainingColumn(),
TextColumn("• {task.fields[current_file]}"),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeElapsedColumn(),
TimeRemainingColumn(),
TextColumn("• {task.fields[current_file]}"),
) as progress:
task = progress.add_task(f"Loading {len(python_files)} files", total=len(python_files), current_file="")
for file_path in python_files:
Expand Down
37 changes: 37 additions & 0 deletions src/flyte/artifacts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Artifacts module

This module provides a wrapper method to mark certain outputs as artifacts with associated metadata.

Usage example:
```python
import flyte.artifacts as artifacts

@env.task
def my_task() -> MyType:
result = MyType(...)
metadata = artifacts.Metadata(name="my_artifact", version="1.0", description="An example artifact")
return artifacts.new(result, metadata)
```

Launching with known artifacts:
```python
flyte.run(main, x=flyte.remote.Artifact.get("name", version="1.0"))
```

Retireve a set of artifacts and pass them as a list
```python
from flyte.remote import Artifact
flyte.run(main, x=[Artifact.get("name1", version="1.0"), Artifact.get("name2", version="2.0")])
```
OR
```python
flyte.run(main, x=flyte.remote.Artifact.list("name_prefix", partition_match="x"))
```
"""

from ._card import Card, CardFormat, CardType
from ._metadata import Metadata
from ._wrapper import Artifact, new

__all__ = ["Artifact", "Card", "CardFormat", "CardType", "Metadata", "new"]
63 changes: 63 additions & 0 deletions src/flyte/artifacts/_card.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import pathlib
import tempfile
from dataclasses import dataclass
from typing import Literal

import flyte
from flyte import storage, syncify

CardType = Literal["model", "data", "generic"]
CardFormat = Literal["html", "md", "json", "yaml", "csv", "tsv", "png", "jpg", "jpeg"]


@dataclass(frozen=True, kw_only=True)
class Card(object):
uri: str
format: CardFormat = "html"
card_type: CardType = "generic"

@syncify.syncify
@classmethod
async def create_from(
cls,
*,
content: str | None = None,
local_path: pathlib.Path | None = None,
format: CardFormat = "html",
card_type: CardType = "generic",
) -> Card:
"""
Upload a card either from raw content or from a local file path.

:param content: Raw content of the card to be uploaded.
:param local_path: Local file path of the card to be uploaded.
:param format: Format of the card (e.g., 'html', 'md',
'json', 'yaml', 'csv', 'tsv', 'png', 'jpg', 'jpeg').
:param card_type: Type of the card (e.g., 'model', 'data', 'generic').
"""
if content:
with tempfile.NamedTemporaryFile(mode="w", suffix=f".{format}", delete=False) as temp_file:
temp_file.write(content)
temp_path = pathlib.Path(temp_file.name)
return await _upload_card_from_local(temp_path, format=format, card_type=card_type)
if local_path:
return await _upload_card_from_local(local_path, format=format, card_type=card_type)
raise ValueError("Either content or local_path must be provided to upload a card.")


async def _upload_card_from_local(
local_path: pathlib.Path, format: CardFormat = "html", card_type: CardType = "generic"
) -> Card:
# Implement upload. If in task context, upload to current metadata location, if not, upload using control plane.
uri = ""
ctx = flyte.ctx()
if ctx:
output_path = ctx.output_path + "/" + f"{card_type}.{format}"
uri = await storage.put(str(local_path), output_path)
else:
import flyte.remote as remote

_, uri = await remote.upload_file.aio(local_path)
return Card(uri=uri, format=format, card_type=card_type)
54 changes: 54 additions & 0 deletions src/flyte/artifacts/_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from typing import Optional, Tuple

from ._card import Card


@dataclass(frozen=True, kw_only=True)
class Metadata:
"""Structured metadata for Flyte artifacts."""

# Core tracking fields
name: str
version: Optional[str] = None
description: Optional[str] = None
data: Optional[typing.Mapping[str, str]] = None
card: Optional[Card] = None

@classmethod
def create_model_metadata(
cls,
*,
name: str,
version: Optional[str] = None,
description: Optional[str] = None,
card: Optional[Card] = None,
framework: Optional[str] = None,
model_type: Optional[str] = None,
architecture: Optional[str] = None,
task: Optional[str] = None,
modality: Tuple[str, ...] = ("text",),
serial_format: str = "safetensors",
short_description: Optional[str] = None,
) -> Metadata:
"""
Helper method to create ModelMetadata. This method sets the data keys specific to models.
"""
return cls(
name=name,
version=version,
description=description,
data={
"framework": framework or "",
"model_type": model_type or "",
"architecture": architecture or "",
"task": task or "",
"modality": ",".join(modality) if modality else "",
"serial_format": serial_format or "",
"short_description": short_description or "",
},
card=card,
)
Loading
Loading