Skip to content

Commit

Permalink
feat: add middleauth+https paths indicate CAVE interface (#106)
Browse files Browse the repository at this point in the history
* feat: add middleauth+https paths indicate CAVE interface

* fix: missing parameter

* docs: describe what the heck CAVE is

* fix: don't call replace when proto is None
  • Loading branch information
william-silversmith authored Jul 27, 2024
1 parent 3ae7c76 commit e4b04bf
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 13 deletions.
9 changes: 9 additions & 0 deletions automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,15 @@ def okgoogle(url):
'a/username2/b/c/d', None, None
))

def test_middleauth_path_extraction():
import cloudfiles.paths
path = cloudfiles.paths.extract('middleauth+https://example.com/wow/cool/')
assert path.format == 'precomputed'
assert path.protocol == 'middleauth+https'
assert path.bucket is None
assert path.path == 'wow/cool/'
assert path.host == "https://example.com"

@pytest.mark.parametrize("protocol", ('mem', 'file', 's3'))
def test_access_non_cannonical_minimal_path(s3, protocol):
from cloudfiles import CloudFiles, exceptions
Expand Down
3 changes: 2 additions & 1 deletion cloudfiles/cloudfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .interfaces import (
FileInterface, HttpInterface,
S3Interface, GoogleCloudStorageInterface,
MemoryInterface
MemoryInterface, CaveInterface,
)

INTERFACES = {
Expand All @@ -54,6 +54,7 @@
'http': HttpInterface,
'https': HttpInterface,
'mem': MemoryInterface,
'middleauth+https': CaveInterface,
}
for alias in ALIASES:
INTERFACES[alias] = S3Interface
Expand Down
45 changes: 37 additions & 8 deletions cloudfiles/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from .connectionpools import S3ConnectionPool, GCloudBucketPool, MemoryPool, MEMORY_DATA
from .exceptions import MD5IntegrityError, CompressionError
from .lib import mkdir, sip, md5, validate_s3_multipart_etag
from .secrets import http_credentials, CLOUD_FILES_DIR, CLOUD_FILES_LOCK_DIR
from .secrets import (
http_credentials,
cave_credentials,
CLOUD_FILES_DIR,
CLOUD_FILES_LOCK_DIR,
)

COMPRESSION_EXTENSIONS = ('.gz', '.br', '.zstd','.bz2','.xz')
GZIP_TYPES = (True, 'gzip', 1)
Expand Down Expand Up @@ -731,6 +736,9 @@ def __init__(self, path, secrets=None, request_payer=None, **kwargs):
if secrets and 'user' in secrets and 'password' in secrets:
self.session.auth = (secrets['user'], secrets['password'])

def default_headers(self):
return {}

def get_path_to_file(self, file_path):
return posixpath.join(self._path.host, self._path.path, file_path)

Expand All @@ -749,7 +757,8 @@ def put_file(self, file_path, content, content_type,
@retry
def head(self, file_path):
key = self.get_path_to_file(file_path)
with self.session.head(key) as resp:
headers = self.default_headers()
with self.session.head(key, headers=headers) as resp:
resp.raise_for_status()
return resp.headers

Expand All @@ -761,13 +770,14 @@ def size(self, file_path):
def get_file(self, file_path, start=None, end=None, part_size=None):
key = self.get_path_to_file(file_path)

headers = self.default_headers()
if start is not None or end is not None:
start = int(start) if start is not None else 0
end = int(end - 1) if end is not None else ''
headers = { "Range": "bytes={}-{}".format(start, end) }
resp = self.session.get(key, headers=headers)
else:
resp = self.session.get(key)
headers["Range"] = f"bytes={start}-{end}"

resp = self.session.get(key, headers=headers)

if resp.status_code in (404, 403):
return (None, None, None, None)
resp.close()
Expand All @@ -788,7 +798,8 @@ def get_file(self, file_path, start=None, end=None, part_size=None):
@retry
def exists(self, file_path):
key = self.get_path_to_file(file_path)
with self.session.get(key, stream=True) as resp:
headers = self.default_headers()
with self.session.get(key, stream=True, headers=headers) as resp:
return resp.ok

def files_exist(self, file_paths):
Expand All @@ -805,11 +816,15 @@ def _list_files_google(self, prefix, flat):
if prefix and prefix[-1] != '/':
prefix += '/'

headers = self.default_headers()

@retry
def request(token):
nonlocal headers
results = self.session.get(
f"https://storage.googleapis.com/storage/v1/b/{bucket}/o",
params={ "prefix": prefix, "pageToken": token },
headers=headers,
)
results.raise_for_status()
results.close()
Expand All @@ -832,12 +847,13 @@ def _list_files_apache(self, prefix, flat):
baseurl = posixpath.join(self._path.host, self._path.path)

directories = ['']
headers = self.default_headers()

while directories:
directory = directories.pop()
url = posixpath.join(baseurl, directory)

resp = requests.get(url)
resp = requests.get(url, headers=headers)
resp.raise_for_status()

if 'text/html' not in resp.headers["Content-Type"]:
Expand Down Expand Up @@ -1200,3 +1216,16 @@ def release_connection(self):
with S3_BUCKET_POOL_LOCK:
pool = S3_POOL[S3ConnectionPoolParams(service, self._path.bucket, self._request_payer)]
pool.release_connection(self._conn)

class CaveInterface(HttpInterface):
"""
CAVE is an internal system that powers proofreading
systems in Seung Lab. If you have no idea what this
is, don't worry about it.
see: https://github.com/CAVEconnectome
"""
def default_headers(self):
cred = cave_credentials()
return {
"Authorization": f"Bearer {cred['token']}",
}
19 changes: 15 additions & 4 deletions cloudfiles/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
ALIASES = {}
BASE_ALLOWED_PROTOCOLS = [
'gs', 'file', 's3',
'http', 'https', 'mem'
'http', 'https', 'mem',
'middleauth+https', 'ngauth+https',
]
ALLOWED_PROTOCOLS = list(BASE_ALLOWED_PROTOCOLS)
ALLOWED_FORMATS = [
Expand Down Expand Up @@ -69,7 +70,13 @@ def cloudpath_error(cloudpath):
def mkregexp():
fmt_capture = r'|'.join(ALLOWED_FORMATS)
fmt_capture = "(?:(?P<fmt>{})://)".format(fmt_capture)
proto_capture = r'|'.join(ALLOWED_PROTOCOLS)

allowed_protos = [
p.replace('+', r'\+')
for p in ALLOWED_PROTOCOLS
]

proto_capture = r'|'.join(allowed_protos)
proto_capture = "(?:(?P<proto>{})://)".format(proto_capture)
regexp = "{}?{}?".format(fmt_capture, proto_capture)
return regexp
Expand Down Expand Up @@ -292,8 +299,12 @@ def extract_format_protocol(cloudpath:str, allow_defaults=True) -> tuple:
proto = m.group('proto')
endpoint = None

if proto in ('http', 'https'):
cloudpath = proto + "://" + cloudpath
tmp_proto = None
if proto is not None:
tmp_proto = proto.replace("middleauth+", "").replace("ngauth+", "")

if tmp_proto in ('http', 'https'):
cloudpath = tmp_proto + "://" + cloudpath
parse = urllib.parse.urlparse(cloudpath)
endpoint = parse.scheme + "://" + parse.netloc
cloudpath = cloudpath.replace(endpoint, '', 1)
Expand Down
18 changes: 18 additions & 0 deletions cloudfiles/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ def aws_credentials(bucket = '', service = 'aws', skip_files=False):
AWS_CREDENTIALS_CACHE[service][bucket] = aws_credentials
return aws_credentials

CAVE_CREDENTIALS = None
def cave_credentials():
global CAVE_CREDENTIALS
default_file_path = 'cave-secret.json'
path = secretpath(default_file_path)

if CAVE_CREDENTIALS:
return CAVE_CREDENTIALS

if os.path.exists(path):
with open(path, 'rt') as f:
CAVE_CREDENTIALS = json.loads(f.read())
else:
CAVE_CREDENTIALS = None

return CAVE_CREDENTIALS


HTTP_CREDENTIALS = None
def http_credentials():
global HTTP_CREDENTIALS
Expand Down

0 comments on commit e4b04bf

Please sign in to comment.