blob: bcf6689d9ef68c29718fd1a693a7bf1fd7cfbe81 [file] [log] [blame]
torch.hub
===================================
Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
Publishing models
-----------------
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
to a github repository by adding a simple ``hubconf.py`` file;
``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function
(example: a pre-trained model you want to publish).
::
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
How to implement an entrypoint?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Here is a code snippet from pytorch/vision repository, which specifies an entrypoint
for ``resnet18`` model. You can see a full script in
`pytorch/vision repo <https://github.com/pytorch/vision/blob/master/hubconf.py>`_
::
dependencies = ['torch']
def resnet18(pretrained=False, **kwargs):
"""
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model in the repo
from torchvision.models.resnet import resnet18 as _resnet18
model = _resnet18(pretrained=pretrained, **kwargs)
return model
- ``dependencies`` variable is a **list** of package names required to to run the model.
- ``pretrained`` controls whether to load the pre-trained weights provided by repo owners.
- ``args`` and ``kwargs`` are passed along to the real callable function.
- Docstring of the function works as a help message. It explains what does the model do and what
are the allowed positional/keyword arguments. It's highly recommended to add a few examples here.
- Entrypoint function should **ALWAYS** return a model(nn.module).
- Pretrained weights can either be stored local in the github repo, or loadable by
``torch.hub.load_state_dict_from_url()``. In the example above ``torchvision.models.resnet.resnet18``
handles ``pretrained``, alternatively you can put the following logic in the entrypoint.
::
if kwargs.get('pretrained', False):
# For checkpoint saved in local repo
model.load_state_dict(<path_to_saved_checkpoint>)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
Important Notice
^^^^^^^^^^^^^^^^
- The published models should be at least in a branch/tag. It can't be a random commit.
Loading models from Hub
-----------------------
Pytorch Hub provides convenient APIs to explore all available models in hub through ``torch.hub.list()``,
show docstring and examples through ``torch.hub.help()`` and load the pre-trained models using ``torch.hub.load()``
.. automodule:: torch.hub
.. autofunction:: list
.. autofunction:: help
.. autofunction:: load
Where are my downloaded models saved?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The locations are used in the order of
- hub_dir: user specified path. It can be set in the following ways:
- Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)``
- ``$TORCH_HOME/hub``, if environment variable ``TORCH_HOME`` is set.
- ``$XDG_CACHE_HOME/torch/hub``, if environment variable ``XDG_CACHE_HOME` is set.
- ``~/.cache/torch/hub``
.. autofunction:: set_dir
Caching logic
^^^^^^^^^^^^^
By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in ``hub_dir``.
Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete
the existing github folder and downloaded weights, reinitialize a fresh download. This is useful
when updates are published to the same branch, users can keep up with the latest release.