Skip to content
Closed
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
__TEST_DATA_FILENAME = "test_data.tar.gz"
__TEST_DATA_URL = "https://github.com/NVIDIA/NeMo/releases/download/v1.0.0rc1/"
__TEST_DATA_SUBDIR = ".data"
__TEST_DATA_CACHE = "/home/TestData/ci/v1.0.0rc1/test_data.tar.gz"


def pytest_addoption(parser):
Expand All @@ -49,6 +50,11 @@
action='store_true',
help="pass that argument to use local test data/skip downloading from URL/GitHub (DEFAULT: False)",
)
parser.addoption(
"--no-test-data-cache",
action='store_true',
help='pass this if you want to avoid using the cached test_data file',
)
parser.addoption(
'--with_downloads',
action='store_true',
Expand Down Expand Up @@ -205,6 +211,15 @@
return False, "k2 needs CUDA to be available in torch."


def get_size_with_fallback(path):
# Get size of local test_data archive.
try:
return getsize(path)
except:
# File does not exist.
return -1


def pytest_configure(config):
"""
Initial configuration of conftest.
Expand All @@ -224,18 +239,23 @@
"markers",
"nightly: runs the nightly test for QA.",
)
skip_cache = config.getoption("--no-test-data-cache", default=False)

# Test dir and archive filepath.
test_dir = join(dirname(__file__), __TEST_DATA_SUBDIR)
test_data_archive = join(dirname(__file__), __TEST_DATA_SUBDIR, __TEST_DATA_FILENAME)

# Get size of local test_data archive.
try:
test_data_local_size = getsize(test_data_archive)
except:
# File does not exist.
test_data_local_size = -1
# if the user does not have the /home/TestData, don't fail and download the file as usual.
if not skip_cache and not os.path.exists(__TEST_DATA_CACHE):
skip_cache = True

if config.option.use_local_test_data:
if not skip_cache:
rmtree(test_dir)
mkdir(test_dir)
with tarfile.open(__TEST_DATA_CACHE) as tar:
tar.extractall(path=test_dir)
elif config.option.use_local_test_data:
test_data_local_size = get_size_with_fallback(test_data_archive)
if test_data_local_size == -1:
pytest.exit("Test data `{}` is not present in the system".format(test_data_archive))
else:
Expand All @@ -244,10 +264,13 @@
__TEST_DATA_FILENAME, test_data_local_size, test_dir
)
)
# untar local test data
extract_data_from_tar(test_dir, test_data_archive, local_data=True)

# Get size of remote test_data archive.
url = None
if not config.option.use_local_test_data:
elif not config.option.use_local_test_data:
test_data_local_size = get_size_with_fallback(test_data_archive)
url = None
try:
url = __TEST_DATA_URL + __TEST_DATA_FILENAME
u = urllib.request.urlopen(url)
Expand Down Expand Up @@ -283,10 +306,8 @@
__TEST_DATA_FILENAME, test_data_local_size, test_dir
)
)

else:
# untar local test data
extract_data_from_tar(test_dir, test_data_archive, local_data=config.option.use_local_test_data)
raise RuntimeError()

if config.option.relax_numba_compat is not None:
from nemo.core.utils import numba_utils
Expand Down
Loading