From d4fa04aebb99d90d911d71072c1cbc100b3427cf Mon Sep 17 00:00:00 2001 From: Atharva Kulkarni Date: Sun, 13 Oct 2024 17:33:49 +0530 Subject: [PATCH 1/2] added node index to current_context Signed-off-by: Atharva --- flytekit/bin/entrypoint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 0f9f3407fc..5ff290834b 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -470,7 +470,9 @@ def setup_execution( if checkpoint_path is not None: checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") - + + node_index = _compute_array_job_index() + execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=exe_project, @@ -498,6 +500,7 @@ def setup_execution( output_metadata_prefix=output_metadata_prefix, checkpoint=checkpointer, task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), + node_index=node_index, ) metadata = { From 6f06ad879580832d3cc7029b51db79c70c54c5a1 Mon Sep 17 00:00:00 2001 From: Atharva Date: Sun, 20 Oct 2024 00:44:52 +0530 Subject: [PATCH 2/2] Fixed formatting issues and added unit test Signed-off-by: Atharva --- flytekit/bin/entrypoint.py | 4 ++-- tests/flytekit/unit/bin/test_python_entrypoint.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 5ff290834b..5a8f94ef83 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -470,9 +470,9 @@ def setup_execution( if checkpoint_path is not None: checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") - + node_index = _compute_array_job_index() - + execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=exe_project, diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 70e8ea5b4e..a902920dee 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -37,6 +37,7 @@ from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models, execution from flytekit.models.core import execution as execution_models +from flytekit.bin.entrypoint import _compute_array_job_index from flytekit.core.utils import write_proto_to_file from flytekit.models.types import LiteralType, SimpleType @@ -406,6 +407,16 @@ def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock assert "some system exception" in ed.error.message assert ed.error.origin == execution_models.ExecutionError.ErrorKind.SYSTEM +@pytest.fixture +def flyte_context(): + """Fixture to set up a mock Flyte context.""" + with mock.patch.object(context_manager.FlyteContext, 'current_context', return_value=mock.Mock()): + yield + +def test_compute_array_job_index(flyte_context): + assert _compute_array_job_index() == 0 + assert _compute_array_job_index(index=1) == 1 + assert _compute_array_job_index(index=2) == 2 def test_setup_disk_prefix(): with setup_execution("qwerty") as ctx: