Skip to content

Commit b337274

Browse files
committed
Add remote workflow execution support to pyflyte run
This commit adds support for executing remote workflows via the `pyflyte run remote-workflow` command, similar to the existing `remote-task` and `remote-launchplan` commands. Changes include: - Add WORKFLOW_LAUNCHER constant to DynamicEntityLaunchCommand - Update _fetch_entity() to handle workflow launcher type - Add workflow version support (name:version syntax) - Update get_command() in RemoteEntityGroup to create workflow launcher - Add comprehensive unit tests for remote workflow execution
1 parent ff4c79c commit b337274

File tree

2 files changed

+76
-2
lines changed

2 files changed

+76
-2
lines changed

flytekit/clis/sdk_in_container/run.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,14 +810,15 @@ class DynamicEntityLaunchCommand(click.RichCommand):
810810

811811
LP_LAUNCHER = "lp"
812812
TASK_LAUNCHER = "task"
813+
WORKFLOW_LAUNCHER = "workflow"
813814

814815
def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs):
815816
super().__init__(name=name, help=h, **kwargs)
816817
self._entity_name = entity_name
817818
self._launcher = launcher
818819
self._entity = None
819820

820-
def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]:
821+
def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask, FlyteWorkflow]:
821822
if self._entity:
822823
return self._entity
823824
run_level_params: RunLevelParams = ctx.obj
@@ -837,6 +838,12 @@ def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, Fly
837838
)
838839
)
839840
entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name)
841+
elif self._launcher == self.WORKFLOW_LAUNCHER:
842+
parts = self._entity_name.split(":")
843+
if len(parts) == 2:
844+
entity = r.fetch_workflow(run_level_params.project, run_level_params.domain, parts[0], parts[1])
845+
else:
846+
entity = r.fetch_workflow(run_level_params.project, run_level_params.domain, self._entity_name)
840847
else:
841848
parts = self._entity_name.split(":")
842849
if len(parts) == 2:
@@ -973,13 +980,20 @@ def list_commands(self, ctx):
973980
return []
974981

975982
def get_command(self, ctx, name):
976-
if self._command_name in [self.LAUNCHPLAN_COMMAND, self.WORKFLOW_COMMAND]:
983+
if self._command_name == self.LAUNCHPLAN_COMMAND:
977984
return DynamicEntityLaunchCommand(
978985
name=name,
979986
h=f"Execute a {self._command_name}.",
980987
entity_name=name,
981988
launcher=DynamicEntityLaunchCommand.LP_LAUNCHER,
982989
)
990+
elif self._command_name == self.WORKFLOW_COMMAND:
991+
return DynamicEntityLaunchCommand(
992+
name=name,
993+
h=f"Execute a {self._command_name}.",
994+
entity_name=name,
995+
launcher=DynamicEntityLaunchCommand.WORKFLOW_LAUNCHER,
996+
)
983997
return DynamicEntityLaunchCommand(
984998
name=name,
985999
h=f"Execute a {self._command_name}.",

tests/flytekit/unit/cli/pyflyte/test_run.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,3 +962,63 @@ def example_task(flag: bool) -> bool:
962962
args, _ = mock_run_remote.call_args
963963
inputs = args[4]['flag']
964964
assert inputs == False
965+
966+
967+
@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
968+
@mock.patch("flytekit.clis.sdk_in_container.run.run_remote")
969+
def test_remote_workflow(mock_run_remote, mock_remote):
970+
@task()
971+
def example_task(x: int, y: str) -> str:
972+
return f"{x},{y}"
973+
974+
@workflow
975+
def example_workflow(x: int, y: str) -> str:
976+
return example_task(x=x, y=y)
977+
978+
mock_remote_instance = mock.MagicMock()
979+
mock_remote.return_value = mock_remote_instance
980+
mock_remote_instance.fetch_workflow.return_value = example_workflow
981+
982+
runner = CliRunner()
983+
result = runner.invoke(
984+
pyflyte.main,
985+
["run", "remote-workflow", "some_module.example_workflow", "--x", "42", "--y", "hello"],
986+
catch_exceptions=False,
987+
)
988+
989+
assert result.exit_code == 0
990+
mock_remote_instance.fetch_workflow.assert_called_once()
991+
mock_run_remote.assert_called_once()
992+
993+
994+
@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
995+
@mock.patch("flytekit.clis.sdk_in_container.run.run_remote")
996+
def test_remote_workflow_with_version(mock_run_remote, mock_remote):
997+
@task()
998+
def example_task(x: int) -> int:
999+
return x * 2
1000+
1001+
@workflow
1002+
def example_workflow(x: int) -> int:
1003+
return example_task(x=x)
1004+
1005+
mock_remote_instance = mock.MagicMock()
1006+
mock_remote.return_value = mock_remote_instance
1007+
mock_remote_instance.fetch_workflow.return_value = example_workflow
1008+
1009+
runner = CliRunner()
1010+
result = runner.invoke(
1011+
pyflyte.main,
1012+
["run", "remote-workflow", "some_module.example_workflow:v1", "--x", "10"],
1013+
catch_exceptions=False,
1014+
)
1015+
1016+
assert result.exit_code == 0
1017+
# Verify fetch_workflow was called with the correct arguments (project, domain, name, version)
1018+
mock_remote_instance.fetch_workflow.assert_called_once()
1019+
call_args = mock_remote_instance.fetch_workflow.call_args[0]
1020+
# Should be called with 4 args when version is specified
1021+
assert len(call_args) == 4
1022+
assert call_args[2] == "some_module.example_workflow"
1023+
assert call_args[3] == "v1"
1024+
mock_run_remote.assert_called_once()

0 commit comments

Comments
 (0)