Skip to content

Commit

Permalink
Update: Optimize download
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Nov 27, 2024
1 parent e42797e commit 58be4fd
Showing 1 changed file with 75 additions and 56 deletions.
131 changes: 75 additions & 56 deletions utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,73 @@
import os
import hashlib
import tarfile
import urllib.request
import zipfile

from tqdm import tqdm
from pathlib import Path
from logger import logger
from py7zr import SevenZipFile


class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
import requests
from py7zr import SevenZipFile
from tqdm import tqdm
from config import ABS_PATH


def _download_file(url, dest_path):
def _download_file(url, dest_path, max_retry=1):
logging.info(f"Downloading: {url}")

headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}

if os.path.exists(dest_path):
file_size = os.path.getsize(dest_path)
headers['Range'] = f'bytes={file_size}-'

request = urllib.request.Request(url, headers=headers)

response = urllib.request.urlopen(request)
if response.geturl() != url:
return _download_file(response.geturl(), dest_path)

total_size = int(response.headers['Content-Length'])

with open(dest_path, 'ab') as file, tqdm(total=total_size, unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
desc=url.split('/')[-1]) as t:
chunk_size = 1024 * 1024 # 1MB
while True:
chunk = response.read(chunk_size)
if not chunk:
break
file.write(chunk)
t.update(len(chunk))
try:
response = requests.head(url, headers=headers, allow_redirects=True, timeout=10)
if response.status_code >= 400:
logging.error(f"Failed to connect to {url}, status code: {response.status_code}")
return False, f"Failed to connect, status code: {response.status_code}"
except Exception as e:
logging.error(f"Failed to get file size for {url}: {e}")
return False, f"Request timeout: {e}"

total_size = int(response.headers.get('content-length', 0))
file_size = os.path.getsize(dest_path) if os.path.exists(dest_path) else 0

if file_size == total_size:
logging.info(f"File {dest_path} already downloaded and complete.")
return True, "File already downloaded and complete."
elif file_size > total_size:
logging.warning(f"Local file size {file_size} exceeds server file size {total_size}. Removing local file.")
os.remove(dest_path)
if max_retry <= 0:
return False, "Local file size exceeds server file size."
return _download_file(url, dest_path, max_retry=max_retry - 1)

headers['Range'] = f'bytes={file_size}-' if file_size > 0 else None

os.makedirs(os.path.dirname(dest_path), exist_ok=True)

relative_path = os.path.relpath(dest_path, ABS_PATH)
chunk_size = 1024 * 1024 # 1MB

try:
with requests.get(url, headers=headers, stream=True, timeout=10) as response, open(dest_path, 'ab') as file, tqdm(
total=total_size,
initial=file_size,
unit='B',
unit_scale=True,
unit_divisor=1024,
desc=f"Downloading: {relative_path or url.split('/')[-1]}",
) as progress:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
file.write(chunk)
progress.update(len(chunk))

logging.info(f"Download completed: {dest_path}")
return True, "Download completed."
except Exception as e:
logging.error(f"Error during downloading {url}: {e}")
if max_retry > 0:
logging.info(f"Retrying download ({max_retry} retries left)...")
return _download_file(url, dest_path, max_retry=max_retry - 1)
return False, f"Download failed: {e}"


def verify_md5(file_path, expected_md5):
Expand Down Expand Up @@ -89,47 +113,43 @@ def extract_file(file_path, destination=None):

def download_file(urls, target_path, extract_destination=None, expected_md5=None, expected_sha256=None):
if os.path.exists(target_path):
success_msg = "File already exists and verified successfully!"
if expected_md5 is not None:
success, message = verify_md5(Path(target_path), expected_md5)
if not success:
os.remove(target_path)
return False, message
if success:
return True, success_msg

if expected_sha256 is not None:
success, message = verify_sha256(Path(target_path), expected_sha256)
if not success:
os.remove(target_path)
return False, message
if success:
return True, success_msg

# If it's a compressed file and the target_path already exists, skip the download
if extract_destination and target_path.endswith(('.zip', '.tar.gz', '.tar.bz2', '.7z')):
extract_file(target_path, extract_destination)
os.remove(target_path)

return True, "File already exists and verified successfully!"
return True, success_msg

is_download = False
for url in urls:
try:
_download_file(url, target_path)
is_download = True
break
is_download, _ = _download_file(url, target_path)
if is_download:
break
except Exception as error:
logger.error(f"downloading from URL {url}: {error}")
logging.error(f"downloading from URL {url}: {error}")

if not is_download:
return False, "Error downloading from all provided URLs."

if expected_md5 is not None:
success, message = verify_md5(Path(target_path), expected_md5)
if not success:
os.remove(target_path)
return False, message

if expected_sha256 is not None:
success, message = verify_sha256(Path(target_path), expected_sha256)
if not success:
os.remove(target_path)
return False, message

# If it's a compressed file, extract it
Expand All @@ -141,14 +161,13 @@ def download_file(urls, target_path, extract_destination=None, expected_md5=None


if __name__ == "__main__":
URLS = [
"YOUR_PRIMARY_URL_HERE",
"YOUR_FIRST_BACKUP_URL_HERE",
# ... you can add more backup URLs as needed
import logger

URL = [
"https://hf-mirror.com/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin",
]
TARGET_PATH = ""
EXPECTED_MD5 = ""
EXTRACT_DESTINATION = ""
TARGET_PATH = r"E:\work\vits-simple-api\data\bert\chinese-roberta-wwm-ext-large/pytorch_model1.bin"
EXPECTED_MD5 = None
EXTRACT_DESTINATION = None

success, message = download_file(URLS, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION)
print(message)
print(download_file(URL, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION))

0 comments on commit 58be4fd

Please sign in to comment.