Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/advanced/custom_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import flyte

env = flyte.TaskEnvironment("custom_context")


@env.task
async def downstream_task(x: int) -> int:
custom_ctx = flyte.ctx().custom_context
if "increment" not in custom_ctx:
raise ValueError("Expected 'increment' in custom context")
return x + int(custom_ctx["increment"])


@env.task
async def main(x: int) -> int:
vals = []
for i in range(3):
with flyte.custom_context(increment=str(i)):
vals.append(await downstream_task(x))
return sum(vals)


if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(main, x=10).url)
3 changes: 3 additions & 0 deletions src/flyte/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._build import build
from ._cache import Cache, CachePolicy, CacheRequest
from ._context import ctx
from ._custom_context import custom_context, get_custom_context
from ._deploy import build_images, deploy
from ._environment import Environment
from ._excepthook import custom_excepthook
Expand Down Expand Up @@ -90,7 +91,9 @@ def version() -> str:
"build_images",
"ctx",
"current_domain",
"custom_context",
"deploy",
"get_custom_context",
"group",
"init",
"init_from_config",
Expand Down
73 changes: 73 additions & 0 deletions src/flyte/_custom_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from contextlib import contextmanager

from flyte._context import ctx

from ._context import internal_ctx


def get_custom_context() -> dict[str, str]:
"""
Get the current input context. This can be used within a task to retrieve
context metadata that was passed to the action.

Context will automatically propagate to sub-actions.

Example:
```python
import flyte

env = flyte.TaskEnvironment(name="...")

@env.task
def t1():
# context can be retrieved with `get_custom_context`
ctx = flyte.get_custom_context()
print(ctx) # {'project': '...', 'entity': '...'}
```

:return: Dictionary of context key-value pairs
"""
tctx = ctx()
if tctx is None or tctx.custom_context is None:
return {}
return tctx.custom_context


@contextmanager
def custom_context(**context: str):
"""
Synchronous context manager to set input context for tasks spawned within this block.

Example:
```python
import flyte

env = flyte.TaskEnvironment(name="...")

@env.task
def t1():
ctx = flyte.get_custom_context()
print(ctx)

@env.task
def main():
# context can be passed via a context manager
with flyte.custom_context(project="my-project"):
t1() # will have {'project': 'my-project'} as context
```

:param context: Key-value pairs to set as input context
"""
ctx = internal_ctx()
if ctx.data.task_context is None:
yield
return

tctx = ctx.data.task_context
new_tctx = tctx.replace(custom_context={**tctx.custom_context, **context})

with ctx.replace_task_context(new_tctx):
yield
# Exit the context and restore the previous context
1 change: 1 addition & 0 deletions src/flyte/_internal/controllers/_local_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ async def get_action_outputs(
tctx = ctx.data.task_context
if not tctx:
raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")

converted_inputs = convert.Inputs.empty()
if _interface.inputs:
converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/flyte/_internal/controllers/remote/_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ async def get_action_outputs(

func_name = _func.__name__
invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)

inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
Expand Down
29 changes: 26 additions & 3 deletions src/flyte/_internal/runtime/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import flyte.errors
import flyte.storage as storage
from flyte._context import ctx
from flyte.models import ActionID, NativeInterface, TaskContext
from flyte.types import TypeEngine, TypeTransformerFailedError

Expand All @@ -25,6 +26,11 @@ class Inputs:
def empty(cls) -> "Inputs":
return cls(proto_inputs=common_pb2.Inputs())

@property
def context(self) -> Dict[str, str]:
"""Get the context as a dictionary."""
return {kv.key: kv.value for kv in self.proto_inputs.context}


@dataclass(frozen=True)
class Outputs:
Expand Down Expand Up @@ -102,15 +108,30 @@ def is_optional_type(tp) -> bool:
return NoneType in get_args(tp) # fastest check


async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
async def convert_from_native_to_inputs(
interface: NativeInterface, *args, custom_context: Dict[str, str] | None = None, **kwargs
) -> Inputs:
kwargs = interface.convert_to_kwargs(*args, **kwargs)

missing = [key for key in interface.required_inputs() if key not in kwargs]
if missing:
raise ValueError(f"Missing required inputs: {', '.join(missing)}")

# Read custom_context from TaskContext if available (inside task execution)
# Otherwise use the passed parameter (for remote run initiation)
context_kvs = None
tctx = ctx()
if tctx and tctx.custom_context:
# Inside a task - read from TaskContext
context_to_use = tctx.custom_context
context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in context_to_use.items()]
elif custom_context:
# Remote run initiation
context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in custom_context.items()]

if len(interface.inputs) == 0:
return Inputs.empty()
# Handle context even for empty inputs
return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))

# fill in defaults if missing
type_hints: Dict[str, type] = {}
Expand Down Expand Up @@ -144,10 +165,12 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
for k, v in already_converted_kwargs.items():
copied_literals[k] = v
literal_map = literals_pb2.LiteralMap(literals=copied_literals)

# Make sure we the interface, not literal_map or kwargs, because those may have a different order
return Inputs(
proto_inputs=common_pb2.Inputs(
literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()]
literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
context=context_kvs,
)
)

Expand Down
11 changes: 10 additions & 1 deletion src/flyte/_internal/runtime/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ async def convert_and_run(
in a context tree.
"""
ctx = internal_ctx()

# Load inputs first to get context
if input_path:
inputs = await load_inputs(input_path, path_rewrite_config=raw_data_path.path_rewrite)

# Extract context from inputs
custom_context = inputs.context if inputs else {}

tctx = TaskContext(
action=action,
checkpoints=checkpoints,
Expand All @@ -142,9 +150,10 @@ async def convert_and_run(
report=flyte.report.Report(name=action.name),
mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
interactive_mode=interactive_mode,
custom_context=custom_context,
)

with ctx.replace_task_context(tctx):
inputs = await load_inputs(input_path, path_rewrite_config=raw_data_path.path_rewrite) if input_path else inputs
inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
if err is not None:
Expand Down
18 changes: 16 additions & 2 deletions src/flyte/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
log_level: int | None = None,
disable_run_cache: bool = False,
queue: Optional[str] = None,
custom_context: Dict[str, str] | None = None,
):
from flyte._tools import ipython_check

Expand Down Expand Up @@ -124,6 +125,7 @@ def __init__(
self._log_level = log_level
self._disable_run_cache = disable_run_cache
self._queue = queue
self._custom_context = custom_context or {}

@requires_initialization
async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
Expand All @@ -149,7 +151,9 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar
if isinstance(obj, LazyEntity):
task = await obj.fetch.aio()
task_spec = task.pb2.spec
inputs = await convert_from_native_to_inputs(task.interface, *args, **kwargs)
inputs = await convert_from_native_to_inputs(
task.interface, *args, custom_context=self._custom_context, **kwargs
)
version = task.pb2.task_id.version
code_bundle = None
else:
Expand Down Expand Up @@ -205,7 +209,9 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar
root_dir=cfg.root_dir,
)
task_spec = translate_task_to_wire(obj, s_ctx)
inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
inputs = await convert_from_native_to_inputs(
obj.native_interface, *args, custom_context=self._custom_context, **kwargs
)

env = self._env_vars or {}
if env.get("LOG_LEVEL") is None:
Expand Down Expand Up @@ -412,6 +418,7 @@ async def _run_task() -> Tuple[Any, Optional[Exception]]:
compiled_image_cache=image_cache,
run_base_dir=run_base_dir,
report=flyte.report.Report(name=action.name),
custom_context=self._custom_context,
)
async with ctx.replace_task_context(tctx):
return await run_task(tctx=tctx, controller=controller, task=obj, inputs=inputs)
Expand Down Expand Up @@ -463,7 +470,9 @@ async def _run_local(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs:
compiled_image_cache=None,
report=Report(name=action.name),
mode="local",
custom_context=self._custom_context,
)

with ctx.replace_task_context(tctx):
# make the local version always runs on a different thread, returns a wrapped future.
if obj._call_as_synchronous:
Expand Down Expand Up @@ -576,6 +585,7 @@ def with_runcontext(
log_level: int | None = None,
disable_run_cache: bool = False,
queue: Optional[str] = None,
custom_context: Dict[str, str] | None = None,
) -> _Runner:
"""
Launch a new run with the given parameters as the context.
Expand Down Expand Up @@ -620,6 +630,9 @@ async def example_task(x: int, y: str) -> str:
set using `flyte.init()`
:param disable_run_cache: Optional If true, the run cache will be disabled. This is useful for testing purposes.
:param queue: Optional The queue to use for the run. This is used to specify the cluster to use for the run.
:param custom_context: Optional global input context to pass to the task. This will be available via
get_custom_context() within the task and will automatically propagate to sub-tasks.
Acts as base/default values that can be overridden by context managers in the code.

:return: runner
"""
Expand Down Expand Up @@ -648,6 +661,7 @@ async def example_task(x: int, y: str) -> str:
log_level=log_level,
disable_run_cache=disable_run_cache,
queue=queue,
custom_context=custom_context,
)


Expand Down
3 changes: 3 additions & 0 deletions src/flyte/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class TaskContext:
:param action: The action ID of the current execution. This is always set, within a run.
:param version: The version of the executed task. This is set when the task is executed by an action and will be
set on all sub-actions.
:param custom_context: Context metadata for the action. If an action receives context, it'll automatically pass it
to any actions it spawns. Context will not be used for cache key computation.
"""

action: ActionID
Expand All @@ -211,6 +213,7 @@ class TaskContext:
data: Dict[str, Any] = field(default_factory=dict)
mode: Literal["local", "remote", "hybrid"] = "remote"
interactive_mode: bool = False
custom_context: Dict[str, str] = field(default_factory=dict)

def replace(self, **kwargs) -> TaskContext:
if "data" in kwargs:
Expand Down
Loading
Loading