diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 0f9f3407fc..5a8f94ef83 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -471,6 +471,8 @@ def setup_execution( checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") + node_index = _compute_array_job_index() + execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=exe_project, @@ -498,6 +500,7 @@ def setup_execution( output_metadata_prefix=output_metadata_prefix, checkpoint=checkpointer, task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), + node_index=node_index, ) metadata = { diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 70e8ea5b4e..a902920dee 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -37,6 +37,7 @@ from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models, execution from flytekit.models.core import execution as execution_models +from flytekit.bin.entrypoint import _compute_array_job_index from flytekit.core.utils import write_proto_to_file from flytekit.models.types import LiteralType, SimpleType @@ -406,6 +407,16 @@ def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock assert "some system exception" in ed.error.message assert ed.error.origin == execution_models.ExecutionError.ErrorKind.SYSTEM +@pytest.fixture +def flyte_context(): + """Fixture to set up a mock Flyte context.""" + with mock.patch.object(context_manager.FlyteContext, 'current_context', return_value=mock.Mock()): + yield + +def test_compute_array_job_index(flyte_context): + assert _compute_array_job_index() == 0 + assert _compute_array_job_index(index=1) == 1 + assert _compute_array_job_index(index=2) == 2 def test_setup_disk_prefix(): with setup_execution("qwerty") as ctx: