From 0708d8158275d2d008adae953aa46ae40f23ba9c Mon Sep 17 00:00:00 2001 From: dirkrkotzeml Date: Wed, 15 Jan 2025 19:47:27 +0200 Subject: [PATCH] Sagemaker Operator Character limit fix (#45551) Co-authored-by: Dirk Kotze Co-authored-by: Rudolf Luttich Co-authored-by: Rudolf07688 <43000341+Rudolf07688@users.noreply.github.com> --- .../amazon/aws/operators/sagemaker.py | 6 +++- .../aws/operators/test_sagemaker_base.py | 31 +++++++++++++++++-- .../aws/operators/test_sagemaker_transform.py | 16 ++++++---- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py index 3559f9fdf13aa..76432ae7f3bda 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -165,7 +165,11 @@ def _get_unique_name( if fail_if_exists: raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.") else: - name = f"{proposed_name}-{time.time_ns()//1000000}" + max_name_len = 63 + timestamp = str( + time.time_ns() // 1000000000 + ) # only keep the relevant datetime (first 10 digits) + name = f"{proposed_name[:max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name) return name diff --git a/providers/tests/amazon/aws/operators/test_sagemaker_base.py b/providers/tests/amazon/aws/operators/test_sagemaker_base.py index f44e49ffc290f..3ae0c2d95ed9e 100644 --- a/providers/tests/amazon/aws/operators/test_sagemaker_base.py +++ b/providers/tests/amazon/aws/operators/test_sagemaker_base.py @@ -49,6 +49,8 @@ EXPECTED_INTEGER_FIELDS: list[list[Any]] = [] +MOCK_UNIX_TIME = 1234567890123456789 # reproducible time for testing time.time_ns() + class TestSageMakerBaseOperator: ERROR_WHEN_RESOURCE_NOT_FOUND = ClientError({"Error": {"Code": "ValidationException"}}, "op") @@ -92,6 +94,31 @@ def test_job_renamed(self): assert describe_mock.call_count == 3 assert re.match("test-[0-9]+$", name) + @patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=MOCK_UNIX_TIME) + def test_job_name_length(self, _): + describe_mock = MagicMock() + # scenario: The name is longer than 63 characters so we need the function to truncate the name and add a timestamp + describe_mock.side_effect = [None, None, self.ERROR_WHEN_RESOURCE_NOT_FOUND] + name = self.sagemaker._get_unique_job_name( + "ThisNameIsLongerThan64CharactersSoItShouldBeTruncatedWithATimestamp", False, describe_mock + ) + assert len(name) <= 63 + + @patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=MOCK_UNIX_TIME) + def test_truncated_job_name(self, _): + describe_mock = MagicMock() + + describe_mock.side_effect = [None, None, self.ERROR_WHEN_RESOURCE_NOT_FOUND] + + # scenario: The name is longer than 63 characters so we need the function to truncate the name and add a timestamp + full_name = "ThisNameIsLongerThan64CharactersSoItShouldBeTruncatedWithATimestamp" + + name = self.sagemaker._get_unique_job_name(full_name, False, describe_mock) + + base_name, timestamp = name.split("-") + assert base_name == full_name[: len(base_name)] + assert timestamp == str(MOCK_UNIX_TIME)[:10] + def test_job_not_unique_with_fail(self): with pytest.raises(AirflowException): self.sagemaker._get_unique_job_name("test", True, lambda _: None) @@ -116,7 +143,7 @@ def test_get_unique_name_raises_exception_if_name_exists_when_fail_is_true(self) assert str(context.value) == "A SageMaker model with name existing_name already exists." - @patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=3000000) + @patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=MOCK_UNIX_TIME) def test_get_unique_name_avoids_name_collision(self, time_mock): new_name = self.sagemaker._get_unique_name( "existing_name", @@ -126,7 +153,7 @@ def test_get_unique_name_avoids_name_collision(self, time_mock): resource_type="model", ) - assert new_name == "existing_name-3" + assert new_name == "existing_name-1234567890" def test_get_unique_name_checks_only_once_when_resource_does_not_exist(self): describe_func = MagicMock(side_effect=ClientError({"Error": {"Code": "ValidationException"}}, "op")) diff --git a/providers/tests/amazon/aws/operators/test_sagemaker_transform.py b/providers/tests/amazon/aws/operators/test_sagemaker_transform.py index 2452558ec4228..9f3bac20fb479 100644 --- a/providers/tests/amazon/aws/operators/test_sagemaker_transform.py +++ b/providers/tests/amazon/aws/operators/test_sagemaker_transform.py @@ -68,6 +68,8 @@ CONFIG: dict = {"Model": CREATE_MODEL_PARAMS, "Transform": CREATE_TRANSFORM_PARAMS} +MOCK_UNIX_TIME: int = 1234567890123456789 # reproducible time for testing time.time_ns() + class TestSageMakerTransformOperator: def setup_method(self): @@ -182,10 +184,12 @@ def test_execute_without_check_if_job_exists(self, _, __, ___, mock_transform, _ max_ingestion_time=None, ) - @mock.patch( # since it is divided by 1000000, the added timestamp should be 2. - "airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=2000000 + @mock.patch( # since it is divided by 1000000000, the added timestamp should be 1234567890. + "airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=MOCK_UNIX_TIME + ) + @mock.patch.object( + SageMakerHook, "describe_transform_job", return_value={"ModelName": "model_name-1234567890"} ) - @mock.patch.object(SageMakerHook, "describe_transform_job", return_value={"ModelName": "model_name-2"}) @mock.patch.object( SageMakerHook, "create_transform_job", @@ -200,7 +204,7 @@ def test_execute_without_check_if_job_exists(self, _, __, ___, mock_transform, _ side_effect=[ None, ClientError({"Error": {"Code": "ValidationException"}}, "op"), - "model_name-2", + "model_name-1234567890", ], ) @mock.patch.object(sagemaker, "serialize", return_value="") @@ -215,9 +219,9 @@ def test_when_model_already_exists_it_should_add_timestamp_to_model_name( self.sagemaker.execute(None) mock_describe_model.assert_has_calls( - [mock.call("model_name"), mock.call("model_name-2"), mock.call("model_name-2")] + [mock.call("model_name"), mock.call("model_name-1234567890"), mock.call("model_name-1234567890")] ) - mock_create_model.assert_called_once_with({"ModelName": "model_name-2"}) + mock_create_model.assert_called_once_with({"ModelName": "model_name-1234567890"}) @mock.patch.object(SageMakerHook, "describe_transform_job") @mock.patch.object(SageMakerHook, "create_transform_job")