| import torch |
| |
| import hashlib |
| import os |
| import re |
| import shutil |
| import sys |
| import tempfile |
| |
| try: |
| from requests.utils import urlparse |
| from requests import get as urlopen |
| requests_available = True |
| except ImportError: |
| requests_available = False |
| if sys.version_info[0] == 2: |
| from urlparse import urlparse # noqa f811 |
| from urllib2 import urlopen # noqa f811 |
| else: |
| from urllib.request import urlopen |
| from urllib.parse import urlparse |
| try: |
| from tqdm import tqdm |
| except ImportError: |
| tqdm = None # defined below |
| |
| # matches bfd8deac from resnet18-bfd8deac.pth |
| HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') |
| |
| |
| def load_url(url, model_dir=None, map_location=None, progress=True): |
| r"""Loads the Torch serialized object at the given URL. |
| |
| If the object is already present in `model_dir`, it's deserialized and |
| returned. The filename part of the URL should follow the naming convention |
| ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more |
| digits of the SHA256 hash of the contents of the file. The hash is used to |
| ensure unique names and to verify the contents of the file. |
| |
| The default value of `model_dir` is ``$TORCH_HOME/models`` where |
| ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be |
| overridden with the ``$TORCH_MODEL_ZOO`` environment variable. |
| |
| Args: |
| url (string): URL of the object to download |
| model_dir (string, optional): directory in which to save the object |
| map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) |
| progress (bool, optional): whether or not to display a progress bar to stderr |
| |
| Example: |
| >>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') |
| |
| """ |
| if model_dir is None: |
| torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) |
| model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) |
| if not os.path.exists(model_dir): |
| os.makedirs(model_dir) |
| parts = urlparse(url) |
| filename = os.path.basename(parts.path) |
| cached_file = os.path.join(model_dir, filename) |
| if not os.path.exists(cached_file): |
| sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) |
| hash_prefix = HASH_REGEX.search(filename).group(1) |
| _download_url_to_file(url, cached_file, hash_prefix, progress=progress) |
| return torch.load(cached_file, map_location=map_location) |
| |
| |
| def _download_url_to_file(url, dst, hash_prefix, progress): |
| if requests_available: |
| u = urlopen(url, stream=True) |
| file_size = int(u.headers["Content-Length"]) |
| u = u.raw |
| else: |
| u = urlopen(url) |
| meta = u.info() |
| if hasattr(meta, 'getheaders'): |
| file_size = int(meta.getheaders("Content-Length")[0]) |
| else: |
| file_size = int(meta.get_all("Content-Length")[0]) |
| |
| f = tempfile.NamedTemporaryFile(delete=False) |
| try: |
| if hash_prefix is not None: |
| sha256 = hashlib.sha256() |
| with tqdm(total=file_size, disable=not progress) as pbar: |
| while True: |
| buffer = u.read(8192) |
| if len(buffer) == 0: |
| break |
| f.write(buffer) |
| if hash_prefix is not None: |
| sha256.update(buffer) |
| pbar.update(len(buffer)) |
| |
| f.close() |
| if hash_prefix is not None: |
| digest = sha256.hexdigest() |
| if digest[:len(hash_prefix)] != hash_prefix: |
| raise RuntimeError('invalid hash value (expected "{}", got "{}")' |
| .format(hash_prefix, digest)) |
| shutil.move(f.name, dst) |
| finally: |
| f.close() |
| if os.path.exists(f.name): |
| os.remove(f.name) |
| |
| |
| if tqdm is None: |
| # fake tqdm if it's not installed |
| class tqdm(object): |
| |
| def __init__(self, total, disable=False): |
| self.total = total |
| self.disable = disable |
| self.n = 0 |
| |
| def update(self, n): |
| if self.disable: |
| return |
| |
| self.n += n |
| sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) |
| sys.stderr.flush() |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.disable: |
| return |
| |
| sys.stderr.write('\n') |