Skip to content

Commit

Permalink
Sagemaker Operator Character limit fix (apache#45551)
Browse files Browse the repository at this point in the history
Co-authored-by: Dirk Kotze <[email protected]>
Co-authored-by: Rudolf Luttich <[email protected]>
Co-authored-by: Rudolf07688 <[email protected]>
  • Loading branch information
4 people authored and HariGS-DB committed Jan 16, 2025
1 parent 8f93b41 commit 0708d81
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def _get_unique_name(
if fail_if_exists:
raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.")
else:
name = f"{proposed_name}-{time.time_ns()//1000000}"
max_name_len = 63
timestamp = str(
time.time_ns() // 1000000000
) # only keep the relevant datetime (first 10 digits)
name = f"{proposed_name[:max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
return name

Expand Down
31 changes: 29 additions & 2 deletions providers/tests/amazon/aws/operators/test_sagemaker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@

EXPECTED_INTEGER_FIELDS: list[list[Any]] = []

MOCK_UNIX_TIME = 1234567890123456789 # reproducible time for testing time.time_ns()


class TestSageMakerBaseOperator:
ERROR_WHEN_RESOURCE_NOT_FOUND = ClientError({"Error": {"Code": "ValidationException"}}, "op")
Expand Down Expand Up @@ -92,6 +94,31 @@ def test_job_renamed(self):
assert describe_mock.call_count == 3
assert re.match("test-[0-9]+$", name)

@patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=MOCK_UNIX_TIME)
def test_job_name_length(self, _):
describe_mock = MagicMock()
# scenario: The name is longer than 63 characters so we need the function to truncate the name and add a timestamp
describe_mock.side_effect = [None, None, self.ERROR_WHEN_RESOURCE_NOT_FOUND]
name = self.sagemaker._get_unique_job_name(
"ThisNameIsLongerThan64CharactersSoItShouldBeTruncatedWithATimestamp", False, describe_mock
)
assert len(name) <= 63

@patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=MOCK_UNIX_TIME)
def test_truncated_job_name(self, _):
describe_mock = MagicMock()

describe_mock.side_effect = [None, None, self.ERROR_WHEN_RESOURCE_NOT_FOUND]

# scenario: The name is longer than 63 characters so we need the function to truncate the name and add a timestamp
full_name = "ThisNameIsLongerThan64CharactersSoItShouldBeTruncatedWithATimestamp"

name = self.sagemaker._get_unique_job_name(full_name, False, describe_mock)

base_name, timestamp = name.split("-")
assert base_name == full_name[: len(base_name)]
assert timestamp == str(MOCK_UNIX_TIME)[:10]

def test_job_not_unique_with_fail(self):
with pytest.raises(AirflowException):
self.sagemaker._get_unique_job_name("test", True, lambda _: None)
Expand All @@ -116,7 +143,7 @@ def test_get_unique_name_raises_exception_if_name_exists_when_fail_is_true(self)

assert str(context.value) == "A SageMaker model with name existing_name already exists."

@patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=3000000)
@patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=MOCK_UNIX_TIME)
def test_get_unique_name_avoids_name_collision(self, time_mock):
new_name = self.sagemaker._get_unique_name(
"existing_name",
Expand All @@ -126,7 +153,7 @@ def test_get_unique_name_avoids_name_collision(self, time_mock):
resource_type="model",
)

assert new_name == "existing_name-3"
assert new_name == "existing_name-1234567890"

def test_get_unique_name_checks_only_once_when_resource_does_not_exist(self):
describe_func = MagicMock(side_effect=ClientError({"Error": {"Code": "ValidationException"}}, "op"))
Expand Down
16 changes: 10 additions & 6 deletions providers/tests/amazon/aws/operators/test_sagemaker_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@

CONFIG: dict = {"Model": CREATE_MODEL_PARAMS, "Transform": CREATE_TRANSFORM_PARAMS}

MOCK_UNIX_TIME: int = 1234567890123456789 # reproducible time for testing time.time_ns()


class TestSageMakerTransformOperator:
def setup_method(self):
Expand Down Expand Up @@ -182,10 +184,12 @@ def test_execute_without_check_if_job_exists(self, _, __, ___, mock_transform, _
max_ingestion_time=None,
)

@mock.patch( # since it is divided by 1000000, the added timestamp should be 2.
"airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=2000000
@mock.patch( # since it is divided by 1000000000, the added timestamp should be 1234567890.
"airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=MOCK_UNIX_TIME
)
@mock.patch.object(
SageMakerHook, "describe_transform_job", return_value={"ModelName": "model_name-1234567890"}
)
@mock.patch.object(SageMakerHook, "describe_transform_job", return_value={"ModelName": "model_name-2"})
@mock.patch.object(
SageMakerHook,
"create_transform_job",
Expand All @@ -200,7 +204,7 @@ def test_execute_without_check_if_job_exists(self, _, __, ___, mock_transform, _
side_effect=[
None,
ClientError({"Error": {"Code": "ValidationException"}}, "op"),
"model_name-2",
"model_name-1234567890",
],
)
@mock.patch.object(sagemaker, "serialize", return_value="")
Expand All @@ -215,9 +219,9 @@ def test_when_model_already_exists_it_should_add_timestamp_to_model_name(
self.sagemaker.execute(None)

mock_describe_model.assert_has_calls(
[mock.call("model_name"), mock.call("model_name-2"), mock.call("model_name-2")]
[mock.call("model_name"), mock.call("model_name-1234567890"), mock.call("model_name-1234567890")]
)
mock_create_model.assert_called_once_with({"ModelName": "model_name-2"})
mock_create_model.assert_called_once_with({"ModelName": "model_name-1234567890"})

@mock.patch.object(SageMakerHook, "describe_transform_job")
@mock.patch.object(SageMakerHook, "create_transform_job")
Expand Down

0 comments on commit 0708d81

Please sign in to comment.