Use non-Variable type for callsites that check type equality (#16325)
Summary:
When Variable and Tensor are merged, the dynamic type of the tensors passed to certain functions will become variables, and expecting `type()` on those variables to still return non-Variable types will cause type mismatch error.
One way to fix this problem is to use the thread-local guard `at::AutoNonVariableTypeMode` to force `type()` to return non-Variable type, but ideally we want to limit the use of `at::AutoNonVariableTypeMode` to be only in VariableType.cpp. Another way to fix the problem is to use `at::globalContext().getNonVariableType()` instead to get the non-Variable type of the tensor, which is what this PR is trying to achieve.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16325
Differential Revision: D14012022
Pulled By: yf225
fbshipit-source-id: 77ef1d2a02f78bff0063bdd72596e34046f1e00d
diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp
index 29257aa..f89f047 100644
--- a/aten/src/ATen/native/TensorIterator.cpp
+++ b/aten/src/ATen/native/TensorIterator.cpp
@@ -95,20 +95,21 @@
if (missing_dtypes || compute_common_dtype_) {
auto& type = compute_common_type();
for (auto& op : operands_) {
+ auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.type().scalarType());
if (!op.type) {
op.type = &type;
} else if (compute_common_dtype_ && op.type != &type) {
if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
- type.device_type() == kCUDA && op.tensor.type().device_type() == kCPU) {
+ type.device_type() == kCUDA && op_tensor_type.device_type() == kCPU) {
// don't cast CPU scalars in CUDA ops that directly support them
- op.type = &op.tensor.type();
+ op.type = &op_tensor_type;
} else if (promote_gpu_output_dtypes_ && op.tensor.defined() &&
- !op.is_output && op.tensor.type().scalarType() == kHalf &&
+ !op.is_output && op_tensor_type.scalarType() == kHalf &&
type.scalarType() == kFloat && type.device_type() == kCUDA &&
- op.tensor.type().device_type() == kCUDA) {
+ op_tensor_type.device_type() == kCUDA) {
// allow input tensor type upcasting for fp16 to fp32 in fused kernel
// on GPU
- op.type = &op.tensor.type();
+ op.type = &op_tensor_type;
} else {
op.type = &type;
}
@@ -117,15 +118,16 @@
}
for (auto& op : operands_) {
- if (op.tensor.defined() && op.tensor.type() != *op.type) {
+ auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.type().scalarType());
+ if (op.tensor.defined() && op_tensor_type != *op.type) {
if (op.is_output) {
- AT_ERROR("output with type ", op.tensor.type().toString(),
+ AT_ERROR("output with type ", op_tensor_type.toString(),
" doesn't match the desired type ", op.type->toString());
} else if (op.tensor.dim() == 0) {
op.tensor = op.tensor.to(*op.type);
} else {
AT_ERROR("expected type ", op.type->toString(), " but got ",
- op.tensor.type().toString());
+ op_tensor_type.toString());
}
}
}
diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp
index 1b0be10..d13032d 100644
--- a/torch/csrc/Generator.cpp
+++ b/torch/csrc/Generator.cpp
@@ -82,8 +82,9 @@
throw TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name);
}
auto& tensor = ((THPVariable*)_new_state)->cdata.data();
- if (tensor.type() != CPU(kByte)) {
- auto type_name = torch::utils::type_to_string(tensor.type());
+ auto& tensor_type = at::globalContext().getNonVariableType(tensor.type().backend(), tensor.type().scalarType());
+ if (tensor_type != CPU(kByte)) {
+ auto type_name = torch::utils::type_to_string(tensor_type);
throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str());
}
THGenerator *generator = THPGenerator_TH_CData(self);
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index f747c6f..ab2037a 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -251,7 +251,8 @@
auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.type().scalarType());
if (typeOpt) {
auto& sparseType = at::globalContext().getNonVariableType(backend, var.type().scalarType());
- gradIsSparse = grad.type() == sparseType;
+ auto& gradType = at::globalContext().getNonVariableType(grad.type().backend(), grad.type().scalarType());
+ gradIsSparse = gradType == sparseType;
}
THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse,
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index e5577c0..8d1571a 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -149,7 +149,8 @@
PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *obj)
{
HANDLE_TH_ERRORS
- if (!THPVariable_Check(obj) || THPVariable_UnpackData(obj).type().ID() != at::TypeID::CPUByte) {
+ auto& data_type = THPVariable_Unpack(obj).type();
+ if (!THPVariable_Check(obj) || at::globalContext().getNonVariableType(data_type.backend(), data_type.scalarType()).ID() != at::TypeID::CPUByte) {
throw TypeError("set_rng_state expects a torch.ByteTensor, but got %s",
Py_TYPE(obj)->tp_name);
}
diff --git a/torch/csrc/nn/type_checks.h b/torch/csrc/nn/type_checks.h
index 1f3140c..966e32f 100644
--- a/torch/csrc/nn/type_checks.h
+++ b/torch/csrc/nn/type_checks.h
@@ -11,7 +11,8 @@
inline bool check_type(PyObject* obj, at::TypeID typeID) {
if (THPVariable_Check(obj)) {
- return ((THPVariable*)obj)->cdata.data().type().ID() == typeID;
+ auto& data_type = ((THPVariable*)obj)->cdata.type();
+ return at::globalContext().getNonVariableType(data_type.backend(), data_type.scalarType()).ID() == typeID;
}
return false;
}