Skip to content

Commit 28e4e0a

Browse files
committed
Use more robust checksum based checks to verify test file downloads
1 parent 12f41e0 commit 28e4e0a

File tree

1 file changed

+50
-6
lines changed

1 file changed

+50
-6
lines changed

scripts/reg_test.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import math
1111
import subprocess
1212
import shutil
13+
import hashlib
1314
TATUM_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1415

1516

@@ -165,13 +166,18 @@ def download_extract_test(args, test_name, test_url):
165166
#A tar file of benchmark files
166167
benchmark_tar = os.path.join(os.path.join(TATUM_ROOT, os.path.basename(test_url)))
167168

168-
get_url(args, test_url, benchmark_tar)
169+
new_tar = get_url(args, test_url, benchmark_tar)
169170

170171
test_files_dir = os.path.join(TATUM_ROOT, "test")
171172

172-
print("Extracting test files to {}".format(test_files_dir))
173-
with tarfile.TarFile.open(benchmark_tar, mode="r|*") as tar_file:
174-
tar_file.extractall(path=test_files_dir)
173+
if new_tar or args.force:
174+
175+
print("Extracting test files to {}".format(test_files_dir))
176+
with tarfile.TarFile.open(benchmark_tar, mode="r|*") as tar_file:
177+
tar_file.extractall(path=test_files_dir)
178+
else:
179+
print("Skipping file extraction".format(test_files_dir))
180+
175181

176182
test_files += glob.glob("{}/{}/*.tatum*".format(test_files_dir, test_name))
177183
else:
@@ -183,14 +189,22 @@ def download_extract_test(args, test_name, test_url):
183189

184190
def get_url(args, url, filename):
185191
if not args.force and os.path.exists(filename):
186-
print("Found existing {}, skipping download".format(filename))
187-
return
192+
print("Found existing file {}, checking if hash matches".format(filename))
193+
file_matches = check_hash_match(args, url, filename)
194+
195+
if file_matches:
196+
print("Existing file {} matches, skipping download".format(filename))
197+
return False
198+
else:
199+
print("Existing file {} contents differ, re-downloading".format(filename))
188200

189201
if '://' in url:
190202
download_url(url, filename)
191203
else:
192204
shutl.copytree(url, filename)
193205

206+
return True
207+
194208
def download_url(url, filename):
195209
"""
196210
Downloads the specifed url to filename
@@ -210,5 +224,35 @@ def download_progress_callback(block_num, block_size, expected_size):
210224
if block_num*block_size >= expected_size:
211225
print("")
212226

227+
def check_hash_match(args, url, filename):
228+
checksum_url = url + ".sha256"
229+
try:
230+
web_hash = urllib.request.urlopen(checksum_url).read()
231+
except urllib.error.HTTPError as e:
232+
print("Failed to find expected SHA256 checksum at {} (reason '{}')".format(checksum_url, e))
233+
return False
234+
235+
local_hash = hash_file(filename)
236+
237+
web_digest_bytes = web_hash.split()[0]
238+
local_digest_bytes = str.encode(local_hash)
239+
240+
if web_digest_bytes == local_digest_bytes:
241+
return True
242+
243+
return False
244+
245+
def hash_file(filepath):
246+
BUF_SIZE = 65536
247+
sha256 = hashlib.sha256()
248+
with open(filepath, "rb") as f:
249+
while True:
250+
data = f.read(BUF_SIZE)
251+
if not data:
252+
break
253+
sha256.update(data)
254+
255+
return sha256.hexdigest()
256+
213257
if __name__ == "__main__":
214258
main()

0 commit comments

Comments
 (0)