|  | # mypy: ignore-errors | 
|  |  | 
|  | from typing import Any, Iterable | 
|  | from .version import __version__ as internal_version | 
|  | from ._vendor.packaging.version import Version, InvalidVersion | 
|  |  | 
|  | __all__ = ['TorchVersion'] | 
|  |  | 
|  |  | 
|  | class TorchVersion(str): | 
|  | """A string with magic powers to compare to both Version and iterables! | 
|  | Prior to 1.10.0 torch.__version__ was stored as a str and so many did | 
|  | comparisons against torch.__version__ as if it were a str. In order to not | 
|  | break them we have TorchVersion which masquerades as a str while also | 
|  | having the ability to compare against both packaging.version.Version as | 
|  | well as tuples of values, eg. (1, 2, 1) | 
|  | Examples: | 
|  | Comparing a TorchVersion object to a Version object | 
|  | TorchVersion('1.10.0a') > Version('1.10.0a') | 
|  | Comparing a TorchVersion object to a Tuple object | 
|  | TorchVersion('1.10.0a') > (1, 2)    # 1.2 | 
|  | TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 | 
|  | Comparing a TorchVersion object against a string | 
|  | TorchVersion('1.10.0a') > '1.2' | 
|  | TorchVersion('1.10.0a') > '1.2.1' | 
|  | """ | 
|  | # fully qualified type names here to appease mypy | 
|  | def _convert_to_version(self, inp: Any) -> Any: | 
|  | if isinstance(inp, Version): | 
|  | return inp | 
|  | elif isinstance(inp, str): | 
|  | return Version(inp) | 
|  | elif isinstance(inp, Iterable): | 
|  | # Ideally this should work for most cases by attempting to group | 
|  | # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH) | 
|  | # Examples: | 
|  | #   * (1)         -> Version("1") | 
|  | #   * (1, 20)     -> Version("1.20") | 
|  | #   * (1, 20, 1)  -> Version("1.20.1") | 
|  | return Version('.'.join(str(item) for item in inp)) | 
|  | else: | 
|  | raise InvalidVersion(inp) | 
|  |  | 
|  | def _cmp_wrapper(self, cmp: Any, method: str) -> bool: | 
|  | try: | 
|  | return getattr(Version(self), method)(self._convert_to_version(cmp)) | 
|  | except BaseException as e: | 
|  | if not isinstance(e, InvalidVersion): | 
|  | raise | 
|  | # Fall back to regular string comparison if dealing with an invalid | 
|  | # version like 'parrot' | 
|  | return getattr(super(), method)(cmp) | 
|  |  | 
|  |  | 
|  | for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]: | 
|  | setattr(TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method)) | 
|  |  | 
|  | __version__ = TorchVersion(internal_version) |