Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 43d4a5f

Browse files
committedNov 21, 2024·
fix: adhere to wikipedia bot policy
1 parent 50e2665 commit 43d4a5f

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed
 

‎trapdata/ml/utils.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
import os
99
import pathlib
1010
import re
11-
import requests
1211
import tempfile
1312
import time
1413
import urllib.error
15-
from urllib.parse import urlparse
1614
from dataclasses import dataclass
1715
from typing import TYPE_CHECKING, Optional
16+
from urllib.parse import urlparse
1817

1918
import pandas as pd
2019
import PIL.Image
2120
import PIL.ImageFile
21+
import requests
2222
import torch
2323
import torchvision
2424

@@ -29,6 +29,10 @@
2929

3030
PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True
3131

32+
# This is polite and required by some hosts
33+
# see: https://foundation.wikimedia.org/wiki/Policy:User-Agent_policy
34+
USER_AGENT = "AntennaInsectDataPlatform/1.0 (https://insectai.org)"
35+
3236

3337
def get_device(device_str=None) -> torch.device:
3438
"""
@@ -51,7 +55,8 @@ def get_or_download_file(
5155
Fetch a file from a URL or local path. If the path is a URL, download the file.
5256
If the URL has already been downloaded, return the existing local path.
5357
If the path is a local path, return the path.
54-
>>> filepath = get_or_download_file("https://example.uk/images/31-20230919033000-snapshot.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=451d406b7eb1113e1bb05c083ce51481%2F20240429%2F")
58+
>>> filepath = get_or_download_file(
59+
"https://example.uk/images/31-20230919033000-snapshot.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=451d406b7eb1113e1bb05c083ce51481%2F20240429%2F")
5560
>>> filepath.name
5661
'31-20230919033000-snapshot.jpg'
5762
>>> filepath = get_or_download_file("/home/user/images/31-20230919033000-snapshot.jpg")
@@ -60,10 +65,10 @@ def get_or_download_file(
6065
"""
6166
if not path_or_url:
6267
raise Exception("Specify a URL or path to fetch file from.")
63-
68+
6469
destination_dir = destination_dir or os.environ.get("LOCAL_WEIGHTS_PATH")
6570
fname = pathlib.Path(urlparse(path_or_url).path).name
66-
71+
6772
if destination_dir:
6873
destination_dir = pathlib.Path(destination_dir)
6974
if prefix:
@@ -78,22 +83,23 @@ def get_or_download_file(
7883
raise Exception(
7984
"No destination directory specified by LOCAL_WEIGHTS_PATH or app settings."
8085
)
81-
86+
8287
if local_filepath and local_filepath.exists():
8388
logger.info(f"Using existing {local_filepath}")
8489
return local_filepath
8590
else:
8691
logger.info(f"Downloading {path_or_url} to {local_filepath}")
87-
92+
8893
# Check if the path is a URL
89-
if path_or_url.startswith(('http://', 'https://')):
90-
response = requests.get(path_or_url, stream=True)
94+
if path_or_url.startswith(("http://", "https://")):
95+
headers = {"User-Agent": USER_AGENT}
96+
response = requests.get(path_or_url, stream=True, headers=headers)
9197
response.raise_for_status() # Raise an exception for HTTP errors
92-
93-
with open(local_filepath, 'wb') as f:
98+
99+
with open(local_filepath, "wb") as f:
94100
for chunk in response.iter_content(chunk_size=8192):
95101
f.write(chunk)
96-
102+
97103
logger.info(f"Downloaded to {local_filepath}")
98104
return local_filepath
99105
else:

0 commit comments

Comments
 (0)
Please sign in to comment.