Skip to content

Commit 27edbcd

Browse files
Use filter arg to safe extract archives (#1862)
* use filter to safe extract archives * update release notes * update actions * add tests * fix action * final test
1 parent 762e08f commit 27edbcd

File tree

7 files changed

+97
-8
lines changed

7 files changed

+97
-8
lines changed

.github/workflows/pull_request_check.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
name: pull request check
88
runs-on: ubuntu-latest
99
steps:
10-
- uses: nearform/github-action-check-linked-issues@v1
10+
- uses: nearform-actions/github-action-check-linked-issues@v1
1111
id: check-linked-issues
1212
with:
1313
exclude-branches: "release_v**, backport_v**, main, latest-dep-update-**, min-dep-update-**, dependabot/**"

.github/workflows/release_notes_updated.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ jobs:
1212
- name: Check for development branch
1313
id: branch
1414
shell: python
15+
env:
16+
REF: ${{ github.event.pull_request.head.ref }}
1517
run: |
1618
from re import compile
1719
main = '^main$'
@@ -21,7 +23,7 @@ jobs:
2123
min_dep_update = '^min-dep-update-[a-f0-9]{7}$'
2224
regex = main, release, backport, dep_update, min_dep_update
2325
patterns = list(map(compile, regex))
24-
ref = "${{ github.event.pull_request.head.ref }}"
26+
ref = "$REF"
2527
is_dev = not any(pattern.match(ref) for pattern in patterns)
2628
print('::set-output name=is_dev::' + str(is_dev))
2729
- if: ${{ steps.branch.outputs.is_dev == 'True' }}

docs/source/release_notes.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ Release Notes
66
Future Release
77
==============
88
* Enhancements
9+
* Add support for Python 3.12 :pr:`1855`
910
* Fixes
1011
* Changes
11-
* Add support for Python 3.12 :pr:`1855`
12-
* Drop support for using Woodwork with Dask or Pyspark dataframes (:pr:`1857`)
12+
* Drop support for using Woodwork with Dask or Pyspark dataframes :pr:`1857`
13+
* Use ``filter`` arg in call to ``tarfile.extractall`` to safely deserialize DataFrames :pr:`1862`
1314
* Documentation Changes
1415
* Testing Changes
1516

woodwork/deserializers/deserializer_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import tarfile
33
import tempfile
44
import warnings
5+
from inspect import getfullargspec
56
from itertools import zip_longest
67
from pathlib import Path
78

@@ -125,7 +126,12 @@ def read_from_s3(self, profile_name):
125126

126127
use_smartopen(tar_filepath, self.path, transport_params)
127128
with tarfile.open(str(tar_filepath)) as tar:
128-
tar.extractall(path=tmpdir)
129+
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
130+
tar.extractall(path=tmpdir, filter="data")
131+
else:
132+
raise RuntimeError(
133+
"Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.",
134+
)
129135
self.read_path = os.path.join(
130136
tmpdir,
131137
self.typing_info["loading_info"]["location"],

woodwork/deserializers/parquet_deserializer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import tarfile
44
import tempfile
5+
from inspect import getfullargspec
56
from pathlib import Path
67

78
import pandas as pd
@@ -61,7 +62,12 @@ def read_from_s3(self, profile_name):
6162

6263
use_smartopen(tar_filepath, self.path, transport_params)
6364
with tarfile.open(str(tar_filepath)) as tar:
64-
tar.extractall(path=tmpdir)
65+
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
66+
tar.extractall(path=tmpdir, filter="data")
67+
else:
68+
raise RuntimeError(
69+
"Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.",
70+
)
6571

6672
self.read_path = os.path.join(tmpdir, self.data_subdirectory, self.filename)
6773

woodwork/deserializers/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import tarfile
44
import tempfile
5+
from inspect import getfullargspec
56
from pathlib import Path
67

78
from woodwork.deserializers import (
@@ -99,7 +100,12 @@ def read_table_typing_information(path, typing_info_filename, profile_name):
99100

100101
use_smartopen(file_path, path, transport_params)
101102
with tarfile.open(str(file_path)) as tar:
102-
tar.extractall(path=tmpdir)
103+
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
104+
tar.extractall(path=tmpdir, filter="data")
105+
else:
106+
raise RuntimeError(
107+
"Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.",
108+
)
103109

104110
file = os.path.join(tmpdir, typing_info_filename)
105111
with open(file, "r") as file:

woodwork/tests/accessor/test_serialization.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import shutil
44
import warnings
5-
from unittest.mock import patch
5+
from unittest.mock import MagicMock, patch
66

77
import boto3
88
import pandas as pd
@@ -662,6 +662,35 @@ def test_to_csv_S3(sample_df, s3_client, s3_bucket, profile_name):
662662
assert sample_df.ww.schema == deserialized_df.ww.schema
663663

664664

665+
@patch("woodwork.deserializers.utils.getfullargspec")
666+
def test_to_csv_S3_errors_if_python_version_unsafe(
667+
mock_inspect,
668+
sample_df,
669+
s3_client,
670+
s3_bucket,
671+
):
672+
mock_response = MagicMock()
673+
mock_response.kwonlyargs = []
674+
mock_inspect.return_value = mock_response
675+
sample_df.ww.init(
676+
name="test_data",
677+
index="id",
678+
semantic_tags={"id": "tag1"},
679+
logical_types={"age": Ordinal(order=[25, 33, 57])},
680+
)
681+
sample_df.ww.to_disk(
682+
TEST_S3_URL,
683+
format="csv",
684+
encoding="utf-8",
685+
engine="python",
686+
profile_name=None,
687+
)
688+
make_public(s3_client, s3_bucket)
689+
690+
with pytest.raises(RuntimeError, match="Please upgrade your Python version"):
691+
read_woodwork_table(TEST_S3_URL, profile_name=None)
692+
693+
665694
@pytest.mark.parametrize("profile_name", [None, False])
666695
def test_serialize_s3_pickle(sample_df, s3_client, s3_bucket, profile_name):
667696
sample_df.ww.init()
@@ -673,6 +702,23 @@ def test_serialize_s3_pickle(sample_df, s3_client, s3_bucket, profile_name):
673702
assert sample_df.ww.schema == deserialized_df.ww.schema
674703

675704

705+
@patch("woodwork.deserializers.deserializer_base.getfullargspec")
706+
def test_serialize_s3_pickle_errors_if_python_version_unsafe(
707+
mock_inspect,
708+
sample_df,
709+
s3_client,
710+
s3_bucket,
711+
):
712+
mock_response = MagicMock()
713+
mock_response.kwonlyargs = []
714+
mock_inspect.return_value = mock_response
715+
sample_df.ww.init()
716+
sample_df.ww.to_disk(TEST_S3_URL, format="pickle", profile_name=None)
717+
make_public(s3_client, s3_bucket)
718+
with pytest.raises(RuntimeError, match="Please upgrade your Python version"):
719+
read_woodwork_table(TEST_S3_URL, profile_name=None)
720+
721+
676722
@pytest.mark.parametrize("profile_name", [None, False])
677723
def test_serialize_s3_parquet(sample_df, s3_client, s3_bucket, profile_name):
678724
sample_df.ww.init()
@@ -688,6 +734,28 @@ def test_serialize_s3_parquet(sample_df, s3_client, s3_bucket, profile_name):
688734
assert sample_df.ww.schema == deserialized_df.ww.schema
689735

690736

737+
@patch("woodwork.deserializers.parquet_deserializer.getfullargspec")
738+
def test_serialize_s3_parquet_errors_if_python_version_unsafe(
739+
mock_inspect,
740+
sample_df,
741+
s3_client,
742+
s3_bucket,
743+
):
744+
mock_response = MagicMock()
745+
mock_response.kwonlyargs = []
746+
mock_inspect.return_value = mock_response
747+
sample_df.ww.init()
748+
sample_df.ww.to_disk(TEST_S3_URL, format="parquet", profile_name=None)
749+
make_public(s3_client, s3_bucket)
750+
751+
with pytest.raises(RuntimeError, match="Please upgrade your Python version"):
752+
read_woodwork_table(
753+
TEST_S3_URL,
754+
filename="data.parquet",
755+
profile_name=None,
756+
)
757+
758+
691759
def create_test_credentials(test_path):
692760
with open(test_path, "w+") as f:
693761
f.write("[test]\n")

0 commit comments

Comments
 (0)