Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Commit

Permalink
Merge branch 'deepti/restructure_downloads' into 'master'
Browse files Browse the repository at this point in the history
Add model file caching; consolidate S3 URLs.

- Create a single mechanism for getting the path of a model.
- Consolidate all model S3 URLs.
  • Loading branch information
Izsak, Peter committed Aug 28, 2019
2 parents 30591cc + c9cf928 commit 83f9a90
Show file tree
Hide file tree
Showing 3 changed files with 378 additions and 6 deletions.
172 changes: 172 additions & 0 deletions nlp_architect/models/pretrained_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

from nlp_architect.utils.io import uncompress_file, zipfile_list
from nlp_architect.utils.file_cache import cached_path

from nlp_architect import LIBRARY_OUT

S3_PREFIX = "https://s3-us-west-2.amazonaws.com/nlp-architect-data/"


class PretrainedModel:

""" Generic class to download the pre-trained models
Usage Example:
chunker = ChunkerModel.get_instance()
chunker2 = ChunkerModel.get_instance()
print(chunker, chunker2)
print("Local File path = ", chunker.get_file_path())
files_models = chunker2.get_model_files()
for idx, file_name in enumerate(files_models):
print(str(idx) + ": " + file_name)
"""

def __init__(self, model_name, sub_path, files):
if isinstance(self, (BistModel, ChunkerModel, MrcModel, IntentModel, AbsaModel, NerModel)):
if self._instance is not None: # pylint: disable=no-member
raise Exception("This class is a singleton!")
self.model_name = model_name
self.base_path = S3_PREFIX + sub_path
self.files = files
self.download_path = LIBRARY_OUT / 'pretrained_models' / self.model_name
self.model_files = []

@classmethod
# pylint: disable=no-member
def get_instance(cls):
"""
Static instance access method
Args:
cls (Class name): Calling class
"""
if cls._instance is None:
cls() # pylint: disable=no-value-for-parameter
return cls._instance

def get_file_path(self):
"""
Return local file path of downloaded model files
"""
for filename in self.files:
cached_file_path, need_downloading = cached_path(
self.base_path + filename, self.download_path)
if filename.endswith('zip'):
if need_downloading:
print('Unzipping...')
uncompress_file(cached_file_path, outpath=self.download_path)
print('Done.')
return self.download_path

def get_model_files(self):
"""
Return individual file names of downloaded models
"""
for fileName in self.files:
cached_file_path, need_downloading = cached_path(
self.base_path + fileName, self.download_path)
if fileName.endswith('zip'):
if need_downloading:
print('Unzipping...')
uncompress_file(cached_file_path, outpath=self.download_path)
print('Done.')
self.model_files.extend(zipfile_list(cached_file_path))
else:
self.model_files.extend([fileName])
return self.model_files


# Model-specific classes developers instantiate where model has to be used

class BistModel(PretrainedModel):
"""
Download and process (unzip) pre-trained BIST model
"""
_instance = None
sub_path = 'models/dep_parse/'
files = ['bist-pretrained.zip']

def __init__(self):
super().__init__('bist', self.sub_path, self.files)
BistModel._instance = self


class IntentModel(PretrainedModel):
"""
Download and process (unzip) pre-trained Intent model
"""
_instance = None
sub_path = 'models/intent/'
files = ['model_info.dat', 'model.h5']

def __init__(self):
super().__init__('intent', self.sub_path, self.files)
IntentModel._instance = self


class MrcModel(PretrainedModel):
"""
Download and process (unzip) pre-trained MRC model
"""
_instance = None
sub_path = 'models/mrc/'
files = ['mrc_data.zip', 'mrc_model.zip']

def __init__(self):
super().__init__('mrc', self.sub_path, self.files)
MrcModel._instance = self


class NerModel(PretrainedModel):
"""
Download and process (unzip) pre-trained NER model
"""
_instance = None
sub_path = 'models/ner/'
files = ['model_v4.h5', 'model_info_v4.dat']

def __init__(self):
super().__init__('ner', self.sub_path, self.files)
NerModel._instance = self


class AbsaModel(PretrainedModel):
"""
Download and process (unzip) pre-trained ABSA model
"""
_instance = None
sub_path = 'models/absa/'
files = ['rerank_model.h5']

def __init__(self):
super().__init__('absa', self.sub_path, self.files)
AbsaModel._instance = self


class ChunkerModel(PretrainedModel):
"""
Download and process (unzip) pre-trained Chunker model
"""
_instance = None
sub_path = 'models/chunker/'
files = ['model.h5', 'model_info.dat.params']

def __init__(self):
super().__init__('chunker', self.sub_path, self.files)
ChunkerModel._instance = self
185 changes: 185 additions & 0 deletions nlp_architect/utils/file_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Utilities for working with the local dataset cache.
"""
import os
import logging
import shutil
import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Tuple, Union, IO
from hashlib import sha256

from nlp_architect import LIBRARY_OUT
from nlp_architect.utils.io import load_json_file

import requests

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

MODEL_CACHE = LIBRARY_OUT / 'pretrained_models'


def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = MODEL_CACHE
else:
cache_dir = cache_dir
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)

parsed = urlparse(url_or_filename)

if parsed.scheme in ('http', 'https'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
if os.path.exists(url_or_filename):
# File, and it exists.
print("File already exists. No further processing needed.")
return url_or_filename
if parsed.scheme == '':
# File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename))

# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))


def url_to_filename(url: str, etag: str = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
if url.split('/')[-1].endswith('zip'):
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
else:
filename = url.split('/')[-1]

return filename


def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = MODEL_CACHE

cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise FileNotFoundError("file {} not found".format(cache_path))

meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise FileNotFoundError("file {} not found".format(meta_path))

with open(meta_path) as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']

return url, etag


def http_get(url: str, temp_file: IO) -> None:
req = requests.get(url, stream=True)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)


def get_from_cache(url: str, cache_dir: str = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = MODEL_CACHE

os.makedirs(cache_dir, exist_ok=True)

response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")

filename = url_to_filename(url, etag)

# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)

need_downloading = True

if os.path.exists(cache_path):
# check if etag has changed comparing with the metadata
if url.split('/')[-1].endswith('zip'):
meta_path = cache_path + '.json'
else:
meta_path = cache_path + '_meta_' + '.json'
meta = load_json_file(meta_path)
if meta['etag'] == etag:
print('file already present')
need_downloading = False

if need_downloading:
print("File not present or etag changed")
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)

# GET file object
http_get(url, temp_file)

# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)

logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)

logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
if url.split('/')[-1].endswith('zip'):
meta_path = cache_path + '.json'
else:
meta_path = cache_path + '_meta_' + '.json'
with open(meta_path, 'w') as meta_file:
json.dump(meta, meta_file)

logger.info("removing temp file %s", temp_file.name)

return cache_path, need_downloading
27 changes: 21 additions & 6 deletions nlp_architect/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,33 @@ def uncompress_file(filepath: str or os.PathLike, outpath='.'):
outpath (str): path to extract to
"""
filepath = str(filepath)
if filepath.endswith('.zip'):
with zipfile.ZipFile(filepath) as z:
z.extractall(outpath)
elif filepath.endswith('.gz'):
if filepath.endswith('.gz'):
if os.path.isdir(outpath):
raise ValueError('output path for gzip must be a file')
with gzip.open(filepath, 'rb') as fp:
file_content = fp.read()
with open(outpath, 'wb') as fp:
fp.write(file_content)
else:
raise ValueError('Unsupported archive provided. Method supports only .zip/.gz files.')
return None
# To unzip zipped model files having SHA-encoded etag and url as filename
# raise ValueError('Unsupported archive provided. Method supports only .zip/.gz files.')
with zipfile.ZipFile(filepath) as z:
z.extractall(outpath)
return [x for x in z.namelist() if not (x.startswith('__MACOSX') or x.endswith('/'))]


def zipfile_list(filepath: str or os.PathLike):
"""
List the files inside a given zip file
Args:
filepath (str): path to file
Returns:
String list of filenames
"""
with zipfile.ZipFile(filepath) as z:
return [x for x in z.namelist() if not (x.startswith('__MACOSX') or x.endswith('/'))]


def gzip_str(g_str):
Expand Down

0 comments on commit 83f9a90

Please sign in to comment.