Replace torch.is_tensor usages with isinstance checks. (#38062)
Summary:
`is_tensor` doesn't really have a reason to exist anymore (other than
backwards compatibility) and is worse for typechecking with mypy (see
gh-32824). Given that it may not be obvious what the fix is once mypy
gives an error, make the change in a number of places at once, and add
a note on this to the `is_tensor` docstring.
Recommending an isinstance check instead has been done for quite a
while, e.g. https://github.com/pytorch/pytorch/pull/7769#discussion_r190458971
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38062
Differential Revision: D21470963
Pulled By: ezyang
fbshipit-source-id: 98dd60d32ca0650abd2de21910b541d32b0eea41
diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py
index a301ca2..7215674 100644
--- a/test/distributed/test_c10d.py
+++ b/test/distributed/test_c10d.py
@@ -2629,7 +2629,7 @@
# Run `forward` function with torch.no_grad()
with torch.no_grad():
output = model(input)
- self.assertTrue(torch.is_tensor(output))
+ self.assertTrue(isinstance(output, torch.Tensor))
# No parameter should have their gradient set.
check_no_grads()
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index b3bbb6b..494ebb0 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -125,7 +125,7 @@
cuda_model = model.cuda()
# input might be nested - we want to move everything to GPU
cuda_input = function._nested_map(
- lambda o: isinstance(o, Variable) or torch.is_tensor(o),
+ lambda o: isinstance(o, Variable) or isinstance(o, torch.Tensor),
lambda o: o.cuda())(input)
return cuda_model, cuda_input
diff --git a/test/test_autograd.py b/test/test_autograd.py
index e9ab495..b7ef750 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4351,15 +4351,15 @@
class TestAutogradFunctional(TestCase):
def _assert_same_struct(self, res, base):
# base and res should be Tensors or tuple of Tensors with the same size
- if torch.is_tensor(base):
- self.assertTrue(torch.is_tensor(res))
+ if isinstance(base, torch.Tensor):
+ self.assertTrue(isinstance(res, torch.Tensor))
self.assertEqual(base.size(), res.size())
elif isinstance(base, tuple):
self.assertTrue(isinstance(res, tuple))
self.assertEqual(len(base), len(res))
for el_base, el_res in zip(base, res):
- self.assertTrue(torch.is_tensor(el_base))
- self.assertTrue(torch.is_tensor(el_res))
+ self.assertTrue(isinstance(el_base, torch.Tensor))
+ self.assertTrue(isinstance(el_res, torch.Tensor))
self.assertEqual(el_base.size(), el_res.size())
else:
# Wrong base
@@ -4374,22 +4374,22 @@
# - tuple, Tensor: res[i][k][l] = (base1[i][k], base2[l])
# - Tensor, tuple: res[i][j][l] = (base1[i], base2[j][l])
# - Tensor, Tensor: res[k][l] = (base1[k], base2[l])
- if torch.is_tensor(base1) and torch.is_tensor(base2):
- self.assertTrue(torch.is_tensor(res))
+ if isinstance(base1, torch.Tensor) and isinstance(base2, torch.Tensor):
+ self.assertTrue(isinstance(res, torch.Tensor))
self.assertEqual(res.size(), base1.size() + base2.size())
- elif isinstance(base1, tuple) and torch.is_tensor(base2):
+ elif isinstance(base1, tuple) and isinstance(base2, torch.Tensor):
self.assertTrue(isinstance(res, tuple))
self.assertEqual(len(res), len(base1))
for el_res, el_base1 in zip(res, base1):
- self.assertTrue(torch.is_tensor(el_res))
- self.assertTrue(torch.is_tensor(el_base1))
+ self.assertTrue(isinstance(el_res, torch.Tensor))
+ self.assertTrue(isinstance(el_base1, torch.Tensor))
self.assertEqual(el_res.size(), el_base1.size() + base2.size())
- elif torch.is_tensor(base1) and isinstance(base2, tuple):
+ elif isinstance(base1, torch.Tensor) and isinstance(base2, tuple):
self.assertTrue(isinstance(res, tuple))
self.assertEqual(len(res), len(base2))
for el_res, el_base2 in zip(res, base2):
- self.assertTrue(torch.is_tensor(el_res))
- self.assertTrue(torch.is_tensor(el_base2))
+ self.assertTrue(isinstance(el_res, torch.Tensor))
+ self.assertTrue(isinstance(el_base2, torch.Tensor))
self.assertEqual(el_res.size(), base1.size() + el_base2.size())
elif isinstance(base1, tuple) and isinstance(base2, tuple):
self.assertTrue(isinstance(res, tuple))
@@ -4398,8 +4398,8 @@
self.assertTrue(isinstance(el_res, tuple))
self.assertEqual(len(res), len(base2))
for el_el_res, el_base2 in zip(el_res, base2):
- self.assertTrue(torch.is_tensor(el_el_res))
- self.assertTrue(torch.is_tensor(el_base2))
+ self.assertTrue(isinstance(el_el_res, torch.Tensor))
+ self.assertTrue(isinstance(el_base2, torch.Tensor))
self.assertEqual(el_el_res.size(), el_base1.size() + el_base2.size())
else:
# Wrong bases
diff --git a/test/test_torch.py b/test/test_torch.py
index 06df202..2f076a3 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -13735,7 +13735,7 @@
def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
def ensure_tuple(x):
- if torch.is_tensor(x):
+ if isinstance(x, torch.Tensor):
return (x,)
return x
@@ -13764,7 +13764,7 @@
return # CPU does not have half support
def ensure_tuple(x):
- if torch.is_tensor(x):
+ if isinstance(x, torch.Tensor):
return (x,)
return x
@@ -17718,13 +17718,13 @@
# Converts CPU tensors to device tensors
device_tensor = cpu_tensor.to(dtype=dtype, device=device)
- device_args = [arg.to(device=device) if torch.is_tensor(arg) else arg for arg in cpu_args]
+ device_args = [arg.to(device=device) if isinstance(arg, torch.Tensor) else arg for arg in cpu_args]
# Converts float device tensors to half/bfloat16 when the dtype is half/bfloat16
# Note: CPU half tensors don't support many operations.
if dtype in {torch.half, torch.bfloat16}:
device_args = [arg.to(dtype=dtype) if
- (torch.is_tensor(arg) and arg.dtype == torch.float) else arg
+ (isinstance(arg, torch.Tensor) and arg.dtype == torch.float) else arg
for arg in device_args]
# Runs the tensor op on CPU and device
diff --git a/torch/__init__.py b/torch/__init__.py
index b5a2c10..8ef0449 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -171,6 +171,11 @@
def is_tensor(obj):
r"""Returns True if `obj` is a PyTorch tensor.
+ Note that this function is simply doing ``isinstance(obj, Tensor)``.
+ Using that ``isinstance`` check is better for typechecking with mypy,
+ and more explicit - so it's recommended to use that instead of
+ ``is_tensor``.
+
Args:
obj (Object): Object to test
"""
diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py
index 484e1e8..34c515f 100644
--- a/torch/autograd/functional.py
+++ b/torch/autograd/functional.py
@@ -11,7 +11,7 @@
is_inp_tuple = False
for i, el in enumerate(inp):
- if not torch.is_tensor(el):
+ if not isinstance(el, torch.Tensor):
if is_inp_tuple:
raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the"
" value at index {} has type {}.".format(arg_name, fn_name, i, type(el)))
@@ -64,7 +64,7 @@
def _grad_postprocess(inputs, create_graph):
# Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
# request it.
- if torch.is_tensor(inputs[0]):
+ if isinstance(inputs[0], torch.Tensor):
if not create_graph:
return tuple(inp.detach() for inp in inputs)
else:
@@ -540,7 +540,7 @@
is_out_tuple, t_out = _as_tuple(out, "outputs of the user-provided function", "hessian")
_check_requires_grad(t_out, "outputs", strict=strict)
- if is_out_tuple or not torch.is_tensor(out):
+ if is_out_tuple or not isinstance(out, torch.Tensor):
raise RuntimeError("The function given to hessian should return a single Tensor")
if out.nelement() != 1:
@@ -621,7 +621,7 @@
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vhp")
_check_requires_grad(outputs, "outputs", strict=strict)
- if is_outputs_tuple or not torch.is_tensor(outputs[0]):
+ if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
raise RuntimeError("The function given to vhp should return a single Tensor")
if outputs[0].nelement() != 1:
@@ -718,7 +718,7 @@
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "hvp")
_check_requires_grad(outputs, "outputs", strict=strict)
- if is_outputs_tuple or not torch.is_tensor(outputs[0]):
+ if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
raise RuntimeError("The function given to hvp should return a single Tensor")
if outputs[0].nelement() != 1:
diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py
index 00527d6..65636ab3 100644
--- a/torch/distributions/utils.py
+++ b/torch/distributions/utils.py
@@ -20,15 +20,15 @@
ValueError: if any of the values is not a `numbers.Number` or
`torch.*Tensor` instance
"""
- if not all(torch.is_tensor(v) or isinstance(v, Number) for v in values):
+ if not all(isinstance(v, torch.Tensor) or isinstance(v, Number) for v in values):
raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
- if not all(map(torch.is_tensor, values)):
+ if not all([isinstance(v, torch.Tensor) for v in values]):
options = dict(dtype=torch.get_default_dtype())
for value in values:
- if torch.is_tensor(value):
+ if isinstance(value, torch.Tensor):
options = dict(dtype=value.dtype, device=value.device)
break
- values = [v if torch.is_tensor(v) else torch.tensor(v, **options)
+ values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options)
for v in values]
return torch.broadcast_tensors(*values)
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 8deb24d..efa50b2 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -109,7 +109,7 @@
return
for w in self._flat_weights:
- if not torch.is_tensor(w):
+ if not isinstance(w, Tensor):
return
# Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
# or the tensors in _flat_weights are of different dtypes
@@ -117,7 +117,7 @@
first_fw = self._flat_weights[0]
dtype = first_fw.dtype
for fw in self._flat_weights:
- if (not torch.is_tensor(fw.data) or not (fw.data.dtype == dtype) or
+ if (not isinstance(fw.data, Tensor) or not (fw.data.dtype == dtype) or
not fw.data.is_cuda or
not torch.backends.cudnn.is_acceptable(fw.data)):
return