Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/flyte/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._group import group
from ._image import Image
from ._initialize import current_domain, init, init_from_config
from ._input_context import get_input_context, input_context
from ._map import map
from ._pod import PodTemplate
from ._resources import AMD_GPU, GPU, HABANA_GAUDI, TPU, Device, DeviceClass, Neuron, Resources
Expand Down Expand Up @@ -91,9 +92,11 @@ def version() -> str:
"ctx",
"current_domain",
"deploy",
"get_input_context",
"group",
"init",
"init_from_config",
"input_context",
"map",
"run",
"trace",
Expand Down
120 changes: 120 additions & 0 deletions src/flyte/_input_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

import contextvars
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncIterator, Dict, Iterator

from flyte._context import ctx

# Context variable to store the input context
# This stores both global context (from with_runcontext) and local context (from context managers)
_input_context_var: contextvars.ContextVar[Dict[str, str]] = contextvars.ContextVar("input_context", default={})


def get_input_context() -> Dict[str, str]:
"""
Get the current input context. This can be used within a task to retrieve
context metadata that was passed to the action.

Context will automatically propagate to sub-actions.

Example:
```python
import flyte

env = flyte.TaskEnvironment(name="...")

@env.task
def t1():
# context can be retrieved with `get_input_context`
ctx = flyte.get_input_context()
print(ctx) # {'project': '...', 'entity': '...'}
```

:return: Dictionary of context key-value pairs
"""
# First check if we're in a task context and have input_context set there
task_ctx = ctx()
if task_ctx and task_ctx.input_context:
return task_ctx.input_context.copy()

# Otherwise, check the context variable (for context manager usage)
return _input_context_var.get().copy()


@asynccontextmanager
async def input_context(**context: str) -> AsyncIterator[None]:
"""
Async context manager to set input context for tasks spawned within this block.

Example:
```python
import flyte

env = flyte.TaskEnvironment(name="...")

@env.task
async def t1():
ctx = flyte.get_input_context()
print(ctx)

@env.task
async def main():
# context can be passed via a context manager
async with flyte.input_context(project="my-project"):
await t1() # will have {'project': 'my-project'} as context
```

:param context: Key-value pairs to set as input context
"""
# Start with current context (includes global context if set)
current = _input_context_var.get().copy()

# Merge with code-provided context (code values override existing values)
current.update(context)

# Set the new context
token = _input_context_var.set(current)
try:
yield
finally:
_input_context_var.reset(token)


@contextmanager
def input_context_sync(**context: str) -> Iterator[None]:
"""
Synchronous context manager to set input context for tasks spawned within this block.

Example:
```python
import flyte

env = flyte.TaskEnvironment(name="...")

@env.task
def t1():
ctx = flyte.get_input_context()
print(ctx)

@env.task
def main():
# context can be passed via a context manager
with flyte.input_context_sync(project="my-project"):
t1() # will have {'project': 'my-project'} as context
```

:param context: Key-value pairs to set as input context
"""
# Start with current context (includes global context if set)
current = _input_context_var.get().copy()

# Merge with code-provided context (code values override existing values)
current.update(context)

# Set the new context
token = _input_context_var.set(current)
try:
yield
finally:
_input_context_var.reset(token)
14 changes: 12 additions & 2 deletions src/flyte/_internal/controllers/_local_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flyte._cache.cache import VersionParameters, cache_from_request
from flyte._cache.local_cache import LocalTaskCache
from flyte._context import internal_ctx
from flyte._input_context import _input_context_var
from flyte._internal.controllers import TraceInfo
from flyte._internal.runtime import convert
from flyte._internal.runtime.entrypoints import direct_dispatch
Expand Down Expand Up @@ -84,7 +85,10 @@ async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
if not tctx:
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")

inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
# Use context from context manager
current_context = _input_context_var.get()

inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, input_context=current_context, **kwargs)
inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
task_interface = transform_native_to_typed_interface(_task.interface)

Expand Down Expand Up @@ -186,9 +190,15 @@ async def get_action_outputs(
tctx = ctx.data.task_context
if not tctx:
raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")

# Propagate context from current task to sub-tasks, merging with context manager context
current_context = tctx.input_context.copy()
# Merge with context from context manager (if any)
current_context.update(_input_context_var.get())

converted_inputs = convert.Inputs.empty()
if _interface.inputs:
converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, input_context=current_context, **kwargs)
assert converted_inputs

inputs_hash = convert.generate_inputs_hash_from_proto(converted_inputs.proto_inputs)
Expand Down
19 changes: 16 additions & 3 deletions src/flyte/_internal/controllers/remote/_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import flyte.storage as storage
from flyte._code_bundle import build_pkl_bundle
from flyte._context import internal_ctx
from flyte._input_context import _input_context_var
from flyte._internal.controllers import TraceInfo
from flyte._internal.controllers.remote._action import Action
from flyte._internal.controllers.remote._core import Controller
Expand Down Expand Up @@ -176,7 +177,12 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg
upload_from_dataplane_base_path=tctx.run_base_dir,
)

inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
# Propagate context from current task to sub-tasks, merging with context manager context
current_context = _input_context_var.get()

inputs = await convert.convert_from_native_to_inputs(
_task.native_interface, *args, input_context=current_context, **kwargs
)

root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
# Don't set output path in sec context because node executor will set it
Expand Down Expand Up @@ -377,7 +383,11 @@ async def get_action_outputs(

func_name = _func.__name__
invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)

# Propagate context from current task to traced functions
current_context = _input_context_var.get()

inputs = await convert.convert_from_native_to_inputs(_interface, *args, input_context=current_context, **kwargs)
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)

Expand Down Expand Up @@ -496,7 +506,10 @@ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args,
native_interface = _task.interface
pb_interface = _task.pb2.spec.task_template.interface

inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
# Propagate context from current task to task references
current_context = _input_context_var.get()

inputs = await convert.convert_from_native_to_inputs(native_interface, *args, input_context=current_context, **kwargs)
inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
tctx, task_name, inputs_hash, invoke_seq_num
Expand Down
20 changes: 17 additions & 3 deletions src/flyte/_internal/runtime/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class Inputs:
def empty(cls) -> "Inputs":
return cls(proto_inputs=common_pb2.Inputs())

@property
def context(self) -> Dict[str, str]:
"""Get the context as a dictionary."""
return {kv.key: kv.value for kv in self.proto_inputs.context}


@dataclass(frozen=True)
class Outputs:
Expand Down Expand Up @@ -102,15 +107,22 @@ def is_optional_type(tp) -> bool:
return NoneType in get_args(tp) # fastest check


async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
async def convert_from_native_to_inputs(
interface: NativeInterface, *args, input_context: Dict[str, str] | None = None, **kwargs
) -> Inputs:
kwargs = interface.convert_to_kwargs(*args, **kwargs)

missing = [key for key in interface.required_inputs() if key not in kwargs]
if missing:
raise ValueError(f"Missing required inputs: {', '.join(missing)}")

context_kvs = None
if input_context:
context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in input_context.items()]

if len(interface.inputs) == 0:
return Inputs.empty()
# Handle context even for empty inputs
return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))

# fill in defaults if missing
type_hints: Dict[str, type] = {}
Expand Down Expand Up @@ -144,10 +156,12 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
for k, v in already_converted_kwargs.items():
copied_literals[k] = v
literal_map = literals_pb2.LiteralMap(literals=copied_literals)

# Make sure we the interface, not literal_map or kwargs, because those may have a different order
return Inputs(
proto_inputs=common_pb2.Inputs(
literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()]
literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
context=context_kvs,
)
)

Expand Down
38 changes: 29 additions & 9 deletions src/flyte/_internal/runtime/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import flyte.report
from flyte._context import internal_ctx
from flyte._input_context import _input_context_var
from flyte._internal.imagebuild.image_builder import ImageCache
from flyte._logging import log, logger
from flyte._task import TaskTemplate
Expand Down Expand Up @@ -129,6 +130,14 @@ async def convert_and_run(
in a context tree.
"""
ctx = internal_ctx()

# Load inputs first to get context
if input_path:
inputs = await load_inputs(input_path, path_rewrite_config=raw_data_path.path_rewrite)

# Extract context from inputs
input_context = inputs.context if inputs else {}

tctx = TaskContext(
action=action,
checkpoints=checkpoints,
Expand All @@ -142,16 +151,27 @@ async def convert_and_run(
report=flyte.report.Report(name=action.name),
mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
interactive_mode=interactive_mode,
input_context=input_context,
)
with ctx.replace_task_context(tctx):
inputs = await load_inputs(input_path, path_rewrite_config=raw_data_path.path_rewrite) if input_path else inputs
inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
if err is not None:
return None, convert_from_native_to_error(err)
if task.report:
await flyte.report.flush.aio()
return await convert_from_native_to_outputs(out, task.native_interface, task.name), None

# Set input context so child tasks can inherit it
# This initializes the context var with global context from with_runcontext
context_token = None
if input_context:
context_token = _input_context_var.set(input_context.copy())

try:
with ctx.replace_task_context(tctx):
inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
if err is not None:
return None, convert_from_native_to_error(err)
if task.report:
await flyte.report.flush.aio()
return await convert_from_native_to_outputs(out, task.native_interface, task.name), None
finally:
if context_token is not None:
_input_context_var.reset(context_token)


async def extract_download_run_upload(
Expand Down
Loading
Loading