Revert D30731191: [pytorch][PR] Torchhub: rewrite commit hash check to avoid using unnecessary GitHub API credits
Test Plan: revert-hammer
Differential Revision:
D30731191 (https://github.com/pytorch/pytorch/commit/f9bf144a0c5e3627f5fafb256cebf1f02152ab0c)
Original commit changeset: d1ee7c2ef259
fbshipit-source-id: 5c7207f66c5354ce7b9ac2594e4f5b8307619b0c
diff --git a/test/test_utils.py b/test/test_utils.py
index 2aa72cb..34f4406 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -702,8 +702,8 @@
def test_load_commit_from_forked_repo(self):
with self.assertRaisesRegex(
ValueError,
- 'Torchhub tried to look for 4e2c216'):
- torch.hub.load('pytorch/vision:4e2c216', 'resnet18', force_reload=True)
+ 'If it\'s a commit from a forked repo'):
+ model = torch.hub.load('pytorch/vision:4e2c216', 'resnet18', force_reload=True)
class TestHipify(TestCase):
def test_import_hipify(self):
diff --git a/torch/hub.py b/torch/hub.py
index cc91ae1..82287d8 100644
--- a/torch/hub.py
+++ b/torch/hub.py
@@ -10,7 +10,6 @@
import warnings
import zipfile
-from urllib.error import HTTPError
from urllib.request import urlopen, Request
from urllib.parse import urlparse # noqa: F401
@@ -119,83 +118,30 @@
with urlopen(url) as r:
return r.read().decode(r.headers.get_content_charset('utf-8'))
-def _branch_belongs_to_repo(repo_owner, repo_name, branch):
- # Return True if either:
- # - branch corresponds to a branch name in the repo
- # - branch corresponds to a tag name in the repo
- # - branch corresponds to a commit that has an associated tag in the repo. Technically,
- # this is an undocumented feature: we explicitly tell users that they should not pass
- # commit hashes.
- def find_in_refs(ref_kind):
- # We limit the search to 5k branches / tags, which should be more than enough
- for page in range(1, 50):
- url = f'https://api.github.com/repos/{repo_owner}/{repo_name}/{ref_kind}?per_page=100&page={page}'
- response = json.loads(_read_url(Request(url, headers=headers)))
- # Empty response means no more data to process
- if not response:
- return False
- for br in response:
- if br['name'] == branch or br['commit']['sha'].startswith(branch):
- return True
-
- # Note: this should never be executed, unless a repo really has 5k+ branches/tags
- warnings.warn(
- f"Torchhub tried to validate {branch} and looked at 5000 {ref_kind}, "
- "but couldn't find a match."
- )
- return False
-
+def _validate_not_a_forked_repo(repo_owner, repo_name, branch):
+ # Use urlopen to avoid depending on local git.
headers = {'Accept': 'application/vnd.github.v3+json'}
token = os.environ.get(ENV_GITHUB_TOKEN)
if token is not None:
headers['Authorization'] = f'token {token}'
- return any(find_in_refs(ref_kind) for ref_kind in ('branches', 'tags'))
+ for url_prefix in (
+ f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
+ f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
+ page = 0
+ while True:
+ page += 1
+ url = f'{url_prefix}?per_page=100&page={page}'
+ response = json.loads(_read_url(Request(url, headers=headers)))
+ # Empty response means no more data to process
+ if not response:
+ break
+ for br in response:
+ if br['name'] == branch or br['commit']['sha'].startswith(branch):
+ return
-
-def _validate_branch(repo_owner, repo_name, branch):
- # Here we try to make sure that the branch isn't a potentially malicious commit.
- # This is important because in GitHub the download URL
- # f'https://github.com/{repo_owner}/{repo_name}/archive/{commit_hash}.zip'
- # may actually exist *even if* commit_hash is from a different repo_owner / fork.
- # For example this URL exists:
- # https://github.com/pytorch/vision/archive/8949c7011facf7801fdf077cc3e4ecd8f0940c7e.zip
- # even though the commit hash doesn't come from pytorch/vision. It actually comes from a fork
- # https://github.com/NicolasHug/vision/commit/8949c7011facf7801fdf077cc3e4ecd8f0940c7e
- # So we want avoid downloading code that can come from a potentionally malicious fork.
-
- try:
- # The above issue only exists for commit hashes. So if we know that ``branch``
- # isn't a commit hash, we can return early. This should avoid many GitHub API calls.
- # If we can't convert branch to a hex int we know for sure it's not a commit hash
- int(branch, 16)
- except ValueError:
- return
- # Note: here it's still possible that the branch param corresponds to a branch name
- # or a tag name, so we need to check for those as well.
-
- try:
- branch_found = _branch_belongs_to_repo(repo_owner, repo_name, branch)
- except HTTPError as e:
- if e.code == 403 and "rate limit exceeded" in str(e):
- raise ValueError(
- f"Torchhub was unable to verify that {branch} is indeed part of the {repo_owner}/{repo_name} repo, "
- f"as it received a GitHub API rate limit error. Consider setting the {ENV_GITHUB_TOKEN} env variable "
- "to a GitHub token that has more API credits. "
- "If you're absolutely sure about what you are doing, you can bypass this security check by setting "
- "skip_validation=True."
- ) from e
- else:
- raise
-
- if not branch_found:
- raise ValueError(
- f"Torchhub tried to look for {branch} in the {repo_owner}/{repo_name} repo, but couldn't find it. "
- "We error now to avoid downloading and executing code that is potentially malicious. "
- "Perhaps you specified the wrong repo owner, or a wrong branch / tag? If you're absolutely "
- "sure about what you are doing, you can bypass this security check by setting "
- "skip_validation=True."
- )
+ raise ValueError(f'Cannot find {branch} in https://github.com/{repo_owner}/{repo_name}. '
+ 'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
def _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=False):
@@ -224,7 +170,7 @@
else:
# Validate the tag/branch is from the original repo instead of a forked repo
if not skip_validation:
- _validate_branch(repo_owner, repo_name, branch)
+ _validate_not_a_forked_repo(repo_owner, repo_name, branch)
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
_remove_if_exists(cached_file)