Skip to content

Commit a6303a4

Browse files
Fix: secure tar extraction issue (aws#5587)
* Fix: secure tar extraction issue * Update unit test
1 parent 1210ac1 commit a6303a4

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def _is_bad_path(path, base):
5757
bool: True if the path is not rooted under the base directory, False otherwise.
5858
"""
5959
# joinpath will ignore base if path is absolute
60-
return not _get_resolved_path(joinpath(base, path)).startswith(base)
60+
resolved = _get_resolved_path(joinpath(base, path))
61+
return os.path.commonpath([resolved, base]) != base
6162

6263

6364
def _is_bad_link(info, base):
@@ -77,19 +78,18 @@ def _is_bad_link(info, base):
7778
return _is_bad_path(info.linkname, base=tip)
7879

7980

80-
def _get_safe_members(members):
81+
def _get_safe_members(members, base):
8182
"""A generator that yields members that are safe to extract.
8283
8384
It filters out bad paths and bad links.
8485
8586
Args:
8687
members (list): A list of members to check.
88+
base (str): The base directory for extraction.
8789
8890
Yields:
8991
tarfile.TarInfo: The tar file info.
9092
"""
91-
base = _get_resolved_path("")
92-
9393
for file_info in members:
9494
if _is_bad_path(file_info.name, base):
9595
logger.error("%s is blocked (illegal path)", file_info.name)
@@ -120,7 +120,8 @@ def custom_extractall_tarfile(tar, extract_path):
120120
if hasattr(tarfile, "data_filter"):
121121
tar.extractall(path=extract_path, filter="data")
122122
else:
123-
tar.extractall(path=extract_path, members=_get_safe_members(tar))
123+
base = _get_resolved_path(extract_path)
124+
tar.extractall(path=extract_path, members=_get_safe_members(tar.getmembers(), base))
124125

125126

126127
def repack(inference_script, model_archive, source_dir=None): # pragma: no cover

sagemaker-mlops/tests/unit/workflow/test_repack_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_get_safe_members_all_safe():
105105
mock_member2.islnk = Mock(return_value=False)
106106

107107
members = [mock_member1, mock_member2]
108-
safe_members = list(_get_safe_members(members))
108+
safe_members = list(_get_safe_members(members, "/tmp/extract"))
109109

110110
assert len(safe_members) == 2
111111
assert mock_member1 in safe_members
@@ -128,7 +128,7 @@ def test_get_safe_members_filters_bad_path():
128128
mock_is_bad.side_effect = lambda name, base: name == "/etc/passwd"
129129

130130
members = [mock_member_safe, mock_member_bad]
131-
safe_members = list(_get_safe_members(members))
131+
safe_members = list(_get_safe_members(members, "/tmp/extract"))
132132

133133
assert len(safe_members) == 1
134134
assert mock_member_safe in safe_members
@@ -152,7 +152,7 @@ def test_get_safe_members_filters_bad_symlink():
152152
mock_is_bad_link.return_value = True
153153

154154
members = [mock_member_safe, mock_member_symlink]
155-
safe_members = list(_get_safe_members(members))
155+
safe_members = list(_get_safe_members(members, "/tmp/extract"))
156156

157157
assert len(safe_members) == 1
158158
assert mock_member_safe in safe_members
@@ -176,7 +176,7 @@ def test_get_safe_members_filters_bad_hardlink():
176176
mock_is_bad_link.return_value = True
177177

178178
members = [mock_member_safe, mock_member_hardlink]
179-
safe_members = list(_get_safe_members(members))
179+
safe_members = list(_get_safe_members(members, "/tmp/extract"))
180180

181181
assert len(safe_members) == 1
182182
assert mock_member_safe in safe_members

0 commit comments

Comments
 (0)