Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,14 +810,15 @@ class DynamicEntityLaunchCommand(click.RichCommand):

LP_LAUNCHER = "lp"
TASK_LAUNCHER = "task"
WORKFLOW_LAUNCHER = "workflow"

def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs):
super().__init__(name=name, help=h, **kwargs)
self._entity_name = entity_name
self._launcher = launcher
self._entity = None

def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]:
def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask, FlyteWorkflow]:
if self._entity:
return self._entity
run_level_params: RunLevelParams = ctx.obj
Expand All @@ -837,6 +838,12 @@ def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, Fly
)
)
entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name)
elif self._launcher == self.WORKFLOW_LAUNCHER:
parts = self._entity_name.split(":")
if len(parts) == 2:
entity = r.fetch_workflow(run_level_params.project, run_level_params.domain, parts[0], parts[1])
else:
entity = r.fetch_workflow(run_level_params.project, run_level_params.domain, self._entity_name)
else:
parts = self._entity_name.split(":")
if len(parts) == 2:
Expand Down Expand Up @@ -973,13 +980,20 @@ def list_commands(self, ctx):
return []

def get_command(self, ctx, name):
if self._command_name in [self.LAUNCHPLAN_COMMAND, self.WORKFLOW_COMMAND]:
if self._command_name == self.LAUNCHPLAN_COMMAND:
return DynamicEntityLaunchCommand(
name=name,
h=f"Execute a {self._command_name}.",
entity_name=name,
launcher=DynamicEntityLaunchCommand.LP_LAUNCHER,
)
elif self._command_name == self.WORKFLOW_COMMAND:
return DynamicEntityLaunchCommand(
name=name,
h=f"Execute a {self._command_name}.",
entity_name=name,
launcher=DynamicEntityLaunchCommand.WORKFLOW_LAUNCHER,
)
return DynamicEntityLaunchCommand(
name=name,
h=f"Execute a {self._command_name}.",
Expand Down
60 changes: 60 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,63 @@ def example_task(flag: bool) -> bool:
args, _ = mock_run_remote.call_args
inputs = args[4]['flag']
assert inputs == False


@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clis.sdk_in_container.run.run_remote")
def test_remote_workflow(mock_run_remote, mock_remote):
@task()
def example_task(x: int, y: str) -> str:
return f"{x},{y}"

@workflow
def example_workflow(x: int, y: str) -> str:
return example_task(x=x, y=y)

mock_remote_instance = mock.MagicMock()
mock_remote.return_value = mock_remote_instance
mock_remote_instance.fetch_workflow.return_value = example_workflow

runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", "remote-workflow", "some_module.example_workflow", "--x", "42", "--y", "hello"],
catch_exceptions=False,
)

assert result.exit_code == 0
mock_remote_instance.fetch_workflow.assert_called_once()
mock_run_remote.assert_called_once()


@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clis.sdk_in_container.run.run_remote")
def test_remote_workflow_with_version(mock_run_remote, mock_remote):
@task()
def example_task(x: int) -> int:
return x * 2

@workflow
def example_workflow(x: int) -> int:
return example_task(x=x)

mock_remote_instance = mock.MagicMock()
mock_remote.return_value = mock_remote_instance
mock_remote_instance.fetch_workflow.return_value = example_workflow

runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", "remote-workflow", "some_module.example_workflow:v1", "--x", "10"],
catch_exceptions=False,
)

assert result.exit_code == 0
# Verify fetch_workflow was called with the correct arguments (project, domain, name, version)
mock_remote_instance.fetch_workflow.assert_called_once()
call_args = mock_remote_instance.fetch_workflow.call_args[0]
# Should be called with 4 args when version is specified
assert len(call_args) == 4
assert call_args[2] == "some_module.example_workflow"
assert call_args[3] == "v1"
mock_run_remote.assert_called_once()
Loading