| from __future__ import absolute_import, division, print_function, unicode_literals |
| import errno |
| import hashlib |
| import os |
| import re |
| import shutil |
| import sys |
| import tempfile |
| import torch |
| import warnings |
| import zipfile |
| |
| from urllib.request import urlopen, Request |
| from urllib.parse import urlparse # noqa: F401 |
| |
| try: |
| from tqdm.auto import tqdm # automatically select proper tqdm submodule if available |
| except ImportError: |
| try: |
| from tqdm import tqdm |
| except ImportError: |
| # fake tqdm if it's not installed |
| class tqdm(object): # type: ignore |
| |
| def __init__(self, total=None, disable=False, |
| unit=None, unit_scale=None, unit_divisor=None): |
| self.total = total |
| self.disable = disable |
| self.n = 0 |
| # ignore unit, unit_scale, unit_divisor; they're just for real tqdm |
| |
| def update(self, n): |
| if self.disable: |
| return |
| |
| self.n += n |
| if self.total is None: |
| sys.stderr.write("\r{0:.1f} bytes".format(self.n)) |
| else: |
| sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) |
| sys.stderr.flush() |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.disable: |
| return |
| |
| sys.stderr.write('\n') |
| |
| # matches bfd8deac from resnet18-bfd8deac.pth |
| HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') |
| |
| MASTER_BRANCH = 'master' |
| ENV_TORCH_HOME = 'TORCH_HOME' |
| ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' |
| DEFAULT_CACHE_DIR = '~/.cache' |
| VAR_DEPENDENCY = 'dependencies' |
| MODULE_HUBCONF = 'hubconf.py' |
| READ_DATA_CHUNK = 8192 |
| _hub_dir = None |
| |
| |
| # Copied from tools/shared/module_loader to be included in torch package |
| def import_module(name, path): |
| if sys.version_info >= (3, 5): |
| import importlib.util |
| spec = importlib.util.spec_from_file_location(name, path) |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| return module |
| else: |
| from importlib.machinery import SourceFileLoader |
| return SourceFileLoader(name, path).load_module() |
| |
| |
| def _remove_if_exists(path): |
| if os.path.exists(path): |
| if os.path.isfile(path): |
| os.remove(path) |
| else: |
| shutil.rmtree(path) |
| |
| |
| def _git_archive_link(repo_owner, repo_name, branch): |
| return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch) |
| |
| |
| def _load_attr_from_module(module, func_name): |
| # Check if callable is defined in the module |
| if func_name not in dir(module): |
| return None |
| return getattr(module, func_name) |
| |
| |
| def _get_torch_home(): |
| torch_home = os.path.expanduser( |
| os.getenv(ENV_TORCH_HOME, |
| os.path.join(os.getenv(ENV_XDG_CACHE_HOME, |
| DEFAULT_CACHE_DIR), 'torch'))) |
| return torch_home |
| |
| |
| def _parse_repo_info(github): |
| branch = MASTER_BRANCH |
| if ':' in github: |
| repo_info, branch = github.split(':') |
| else: |
| repo_info = github |
| repo_owner, repo_name = repo_info.split('/') |
| return repo_owner, repo_name, branch |
| |
| |
| def _get_cache_or_reload(github, force_reload, verbose=True): |
| # Setup hub_dir to save downloaded files |
| hub_dir = get_dir() |
| if not os.path.exists(hub_dir): |
| os.makedirs(hub_dir) |
| # Parse github repo information |
| repo_owner, repo_name, branch = _parse_repo_info(github) |
| # Github allows branch name with slash '/', |
| # this causes confusion with path on both Linux and Windows. |
| # Backslash is not allowed in Github branch name so no need to |
| # to worry about it. |
| normalized_br = branch.replace('/', '_') |
| # Github renames folder repo-v1.x.x to repo-1.x.x |
| # We don't know the repo name before downloading the zip file |
| # and inspect name from it. |
| # To check if cached repo exists, we need to normalize folder names. |
| repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br])) |
| |
| use_cache = (not force_reload) and os.path.exists(repo_dir) |
| |
| if use_cache: |
| if verbose: |
| sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) |
| else: |
| cached_file = os.path.join(hub_dir, normalized_br + '.zip') |
| _remove_if_exists(cached_file) |
| |
| url = _git_archive_link(repo_owner, repo_name, branch) |
| sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file)) |
| download_url_to_file(url, cached_file, progress=False) |
| |
| with zipfile.ZipFile(cached_file) as cached_zipfile: |
| extraced_repo_name = cached_zipfile.infolist()[0].filename |
| extracted_repo = os.path.join(hub_dir, extraced_repo_name) |
| _remove_if_exists(extracted_repo) |
| # Unzip the code and rename the base folder |
| cached_zipfile.extractall(hub_dir) |
| |
| _remove_if_exists(cached_file) |
| _remove_if_exists(repo_dir) |
| shutil.move(extracted_repo, repo_dir) # rename the repo |
| |
| return repo_dir |
| |
| |
| def _check_module_exists(name): |
| if sys.version_info >= (3, 4): |
| import importlib.util |
| return importlib.util.find_spec(name) is not None |
| elif sys.version_info >= (3, 3): |
| # Special case for python3.3 |
| import importlib.find_loader |
| return importlib.find_loader(name) is not None |
| else: |
| # NB: Python2.7 imp.find_module() doesn't respect PEP 302, |
| # it cannot find a package installed as .egg(zip) file. |
| # Here we use workaround from: |
| # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1 |
| # Also imp doesn't handle hierarchical module names (names contains dots). |
| try: |
| # 1. Try imp.find_module(), which searches sys.path, but does |
| # not respect PEP 302 import hooks. |
| import imp |
| result = imp.find_module(name) |
| if result: |
| return True |
| except ImportError: |
| pass |
| path = sys.path |
| for item in path: |
| # 2. Scan path for import hooks. sys.path_importer_cache maps |
| # path items to optional "importer" objects, that implement |
| # find_module() etc. Note that path must be a subset of |
| # sys.path for this to work. |
| importer = sys.path_importer_cache.get(item) |
| if importer: |
| try: |
| result = importer.find_module(name, [item]) |
| if result: |
| return True |
| except ImportError: |
| pass |
| return False |
| |
| def _check_dependencies(m): |
| dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) |
| |
| if dependencies is not None: |
| missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] |
| if len(missing_deps): |
| raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) |
| |
| |
| def _load_entry_from_hubconf(m, model): |
| if not isinstance(model, str): |
| raise ValueError('Invalid input: model should be a string of function name') |
| |
| # Note that if a missing dependency is imported at top level of hubconf, it will |
| # throw before this function. It's a chicken and egg situation where we have to |
| # load hubconf to know what're the dependencies, but to import hubconf it requires |
| # a missing package. This is fine, Python will throw proper error message for users. |
| _check_dependencies(m) |
| |
| func = _load_attr_from_module(m, model) |
| |
| if func is None or not callable(func): |
| raise RuntimeError('Cannot find callable {} in hubconf'.format(model)) |
| |
| return func |
| |
| |
| def get_dir(): |
| r""" |
| Get the Torch Hub cache directory used for storing downloaded models & weights. |
| |
| If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where |
| environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. |
| ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux |
| filesytem layout, with a default value ``~/.cache`` if the environment |
| variable is not set. |
| """ |
| # Issue warning to move data if old env is set |
| if os.getenv('TORCH_HUB'): |
| warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') |
| |
| if _hub_dir is not None: |
| return _hub_dir |
| return os.path.join(_get_torch_home(), 'hub') |
| |
| |
| def set_dir(d): |
| r""" |
| Optionally set the Torch Hub directory used to save downloaded models & weights. |
| |
| Args: |
| d (string): path to a local folder to save downloaded models & weights. |
| """ |
| global _hub_dir |
| _hub_dir = d |
| |
| |
| def list(github, force_reload=False): |
| r""" |
| List all entrypoints available in `github` hubconf. |
| |
| Args: |
| github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional |
| tag/branch. The default branch is `master` if not specified. |
| Example: 'pytorch/vision[:hub]' |
| force_reload (bool, optional): whether to discard the existing cache and force a fresh download. |
| Default is `False`. |
| Returns: |
| entrypoints: a list of available entrypoint names |
| |
| Example: |
| >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) |
| """ |
| repo_dir = _get_cache_or_reload(github, force_reload, True) |
| |
| sys.path.insert(0, repo_dir) |
| |
| hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) |
| |
| sys.path.remove(repo_dir) |
| |
| # We take functions starts with '_' as internal helper functions |
| entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] |
| |
| return entrypoints |
| |
| |
| def help(github, model, force_reload=False): |
| r""" |
| Show the docstring of entrypoint `model`. |
| |
| Args: |
| github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional |
| tag/branch. The default branch is `master` if not specified. |
| Example: 'pytorch/vision[:hub]' |
| model (string): a string of entrypoint name defined in repo's hubconf.py |
| force_reload (bool, optional): whether to discard the existing cache and force a fresh download. |
| Default is `False`. |
| Example: |
| >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) |
| """ |
| repo_dir = _get_cache_or_reload(github, force_reload, True) |
| |
| sys.path.insert(0, repo_dir) |
| |
| hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) |
| |
| sys.path.remove(repo_dir) |
| |
| entry = _load_entry_from_hubconf(hub_module, model) |
| |
| return entry.__doc__ |
| |
| |
| # Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`, |
| # but Python2 complains syntax error for it. We have to skip force_reload in function |
| # signature here but detect it in kwargs instead. |
| # TODO: fix it after Python2 EOL |
| def load(github, model, *args, **kwargs): |
| r""" |
| Load a model from a github repo, with pretrained weights. |
| |
| Args: |
| github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional |
| tag/branch. The default branch is `master` if not specified. |
| Example: 'pytorch/vision[:hub]' |
| model (string): a string of entrypoint name defined in repo's hubconf.py |
| *args (optional): the corresponding args for callable `model`. |
| force_reload (bool, optional): whether to force a fresh download of github repo unconditionally. |
| Default is `False`. |
| verbose (bool, optional): If False, mute messages about hitting local caches. Note that the message |
| about first download is cannot be muted. |
| Default is `True`. |
| **kwargs (optional): the corresponding kwargs for callable `model`. |
| |
| Returns: |
| a single model with corresponding pretrained weights. |
| |
| Example: |
| >>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True) |
| """ |
| force_reload = kwargs.get('force_reload', False) |
| kwargs.pop('force_reload', None) |
| verbose = kwargs.get('verbose', True) |
| kwargs.pop('verbose', None) |
| |
| repo_dir = _get_cache_or_reload(github, force_reload, verbose) |
| |
| sys.path.insert(0, repo_dir) |
| |
| hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) |
| |
| entry = _load_entry_from_hubconf(hub_module, model) |
| |
| model = entry(*args, **kwargs) |
| |
| sys.path.remove(repo_dir) |
| |
| return model |
| |
| |
| def download_url_to_file(url, dst, hash_prefix=None, progress=True): |
| r"""Download object at the given URL to a local path. |
| |
| Args: |
| url (string): URL of the object to download |
| dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` |
| hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. |
| Default: None |
| progress (bool, optional): whether or not to display a progress bar to stderr |
| Default: True |
| |
| Example: |
| >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') |
| |
| """ |
| file_size = None |
| # We use a different API for python2 since urllib(2) doesn't recognize the CA |
| # certificates in older Python |
| req = Request(url, headers={"User-Agent": "torch.hub"}) |
| u = urlopen(req) |
| meta = u.info() |
| if hasattr(meta, 'getheaders'): |
| content_length = meta.getheaders("Content-Length") |
| else: |
| content_length = meta.get_all("Content-Length") |
| if content_length is not None and len(content_length) > 0: |
| file_size = int(content_length[0]) |
| |
| # We deliberately save it in a temp file and move it after |
| # download is complete. This prevents a local working checkpoint |
| # being overridden by a broken download. |
| dst = os.path.expanduser(dst) |
| dst_dir = os.path.dirname(dst) |
| f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) |
| |
| try: |
| if hash_prefix is not None: |
| sha256 = hashlib.sha256() |
| with tqdm(total=file_size, disable=not progress, |
| unit='B', unit_scale=True, unit_divisor=1024) as pbar: |
| while True: |
| buffer = u.read(8192) |
| if len(buffer) == 0: |
| break |
| f.write(buffer) |
| if hash_prefix is not None: |
| sha256.update(buffer) |
| pbar.update(len(buffer)) |
| |
| f.close() |
| if hash_prefix is not None: |
| digest = sha256.hexdigest() |
| if digest[:len(hash_prefix)] != hash_prefix: |
| raise RuntimeError('invalid hash value (expected "{}", got "{}")' |
| .format(hash_prefix, digest)) |
| shutil.move(f.name, dst) |
| finally: |
| f.close() |
| if os.path.exists(f.name): |
| os.remove(f.name) |
| |
| def _download_url_to_file(url, dst, hash_prefix=None, progress=True): |
| warnings.warn('torch.hub._download_url_to_file has been renamed to\ |
| torch.hub.download_url_to_file to be a public API,\ |
| _download_url_to_file will be removed in after 1.3 release') |
| download_url_to_file(url, dst, hash_prefix, progress) |
| |
| def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): |
| r"""Loads the Torch serialized object at the given URL. |
| |
| If downloaded file is a zip file, it will be automatically |
| decompressed. |
| |
| If the object is already present in `model_dir`, it's deserialized and |
| returned. |
| The default value of `model_dir` is ``<hub_dir>/checkpoints`` where |
| `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. |
| |
| Args: |
| url (string): URL of the object to download |
| model_dir (string, optional): directory in which to save the object |
| map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) |
| progress (bool, optional): whether or not to display a progress bar to stderr. |
| Default: True |
| check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention |
| ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more |
| digits of the SHA256 hash of the contents of the file. The hash is used to |
| ensure unique names and to verify the contents of the file. |
| Default: False |
| file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. |
| |
| Example: |
| >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') |
| |
| """ |
| # Issue warning to move data if old env is set |
| if os.getenv('TORCH_MODEL_ZOO'): |
| warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') |
| |
| if model_dir is None: |
| hub_dir = get_dir() |
| model_dir = os.path.join(hub_dir, 'checkpoints') |
| |
| try: |
| os.makedirs(model_dir) |
| except OSError as e: |
| if e.errno == errno.EEXIST: |
| # Directory already exists, ignore. |
| pass |
| else: |
| # Unexpected OSError, re-raise. |
| raise |
| |
| parts = urlparse(url) |
| filename = os.path.basename(parts.path) |
| if file_name is not None: |
| filename = file_name |
| cached_file = os.path.join(model_dir, filename) |
| if not os.path.exists(cached_file): |
| sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) |
| hash_prefix = HASH_REGEX.search(filename).group(1) if check_hash else None |
| download_url_to_file(url, cached_file, hash_prefix, progress=progress) |
| |
| # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. |
| # We deliberately don't handle tarfile here since our legacy serialization format was in tar. |
| # E.g. resnet18-5c106cde.pth which is widely used. |
| if zipfile.is_zipfile(cached_file): |
| with zipfile.ZipFile(cached_file) as cached_zipfile: |
| members = cached_zipfile.infolist() |
| if len(members) != 1: |
| raise RuntimeError('Only one file(not dir) is allowed in the zipfile') |
| cached_zipfile.extractall(model_dir) |
| extraced_name = members[0].filename |
| cached_file = os.path.join(model_dir, extraced_name) |
| |
| return torch.load(cached_file, map_location=map_location) |