Use non-Variable type for callsites that check type equality (#16325)

Summary:
When Variable and Tensor are merged, the dynamic type of the tensors passed to certain functions will become variables, and expecting `type()` on those variables to still return non-Variable types will cause type mismatch error.

One way to fix this problem is to use the thread-local guard `at::AutoNonVariableTypeMode` to force `type()` to return non-Variable type, but ideally we want to limit the use of `at::AutoNonVariableTypeMode` to be only in VariableType.cpp. Another way to fix the problem is to use `at::globalContext().getNonVariableType()` instead to get the non-Variable type of the tensor, which is what this PR is trying to achieve.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16325

Differential Revision: D14012022

Pulled By: yf225

fbshipit-source-id: 77ef1d2a02f78bff0063bdd72596e34046f1e00d
diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp
index 29257aa..f89f047 100644
--- a/aten/src/ATen/native/TensorIterator.cpp
+++ b/aten/src/ATen/native/TensorIterator.cpp
@@ -95,20 +95,21 @@
   if (missing_dtypes || compute_common_dtype_) {
     auto& type = compute_common_type();
     for (auto& op : operands_) {
+      auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.type().scalarType());
       if (!op.type) {
         op.type = &type;
       } else if (compute_common_dtype_ && op.type != &type) {
         if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
-            type.device_type() == kCUDA && op.tensor.type().device_type() == kCPU) {
+            type.device_type() == kCUDA && op_tensor_type.device_type() == kCPU) {
           // don't cast CPU scalars in CUDA ops that directly support them
-          op.type = &op.tensor.type();
+          op.type = &op_tensor_type;
         } else if (promote_gpu_output_dtypes_ && op.tensor.defined() &&
-            !op.is_output && op.tensor.type().scalarType() == kHalf &&
+            !op.is_output && op_tensor_type.scalarType() == kHalf &&
             type.scalarType() == kFloat && type.device_type() == kCUDA &&
-            op.tensor.type().device_type() == kCUDA) {
+            op_tensor_type.device_type() == kCUDA) {
           // allow input tensor type upcasting for fp16 to fp32 in fused kernel
           // on GPU
-          op.type = &op.tensor.type();
+          op.type = &op_tensor_type;
         } else {
           op.type = &type;
         }
@@ -117,15 +118,16 @@
   }
 
   for (auto& op : operands_) {
-    if (op.tensor.defined() && op.tensor.type() != *op.type) {
+    auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.type().scalarType());
+    if (op.tensor.defined() && op_tensor_type != *op.type) {
       if (op.is_output) {
-        AT_ERROR("output with type ", op.tensor.type().toString(),
+        AT_ERROR("output with type ", op_tensor_type.toString(),
                  " doesn't match the desired type ", op.type->toString());
       } else if (op.tensor.dim() == 0) {
         op.tensor = op.tensor.to(*op.type);
       } else {
         AT_ERROR("expected type ", op.type->toString(), " but got ",
-            op.tensor.type().toString());
+            op_tensor_type.toString());
       }
     }
   }
diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp
index 1b0be10..d13032d 100644
--- a/torch/csrc/Generator.cpp
+++ b/torch/csrc/Generator.cpp
@@ -82,8 +82,9 @@
     throw TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name);
   }
   auto& tensor = ((THPVariable*)_new_state)->cdata.data();
-  if (tensor.type() != CPU(kByte)) {
-    auto type_name = torch::utils::type_to_string(tensor.type());
+  auto& tensor_type = at::globalContext().getNonVariableType(tensor.type().backend(), tensor.type().scalarType());
+  if (tensor_type != CPU(kByte)) {
+    auto type_name = torch::utils::type_to_string(tensor_type);
     throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str());
   }
   THGenerator *generator = THPGenerator_TH_CData(self);
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index f747c6f..ab2037a 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -251,7 +251,8 @@
   auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.type().scalarType());
   if (typeOpt) {
        auto& sparseType = at::globalContext().getNonVariableType(backend, var.type().scalarType());
-       gradIsSparse = grad.type() == sparseType;
+       auto& gradType = at::globalContext().getNonVariableType(grad.type().backend(), grad.type().scalarType());
+       gradIsSparse = gradType == sparseType;
   }
 
   THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse,
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index e5577c0..8d1571a 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -149,7 +149,8 @@
 PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *obj)
 {
   HANDLE_TH_ERRORS
-  if (!THPVariable_Check(obj) || THPVariable_UnpackData(obj).type().ID() != at::TypeID::CPUByte) {
+  auto& data_type = THPVariable_Unpack(obj).type();
+  if (!THPVariable_Check(obj) || at::globalContext().getNonVariableType(data_type.backend(), data_type.scalarType()).ID() != at::TypeID::CPUByte) {
     throw TypeError("set_rng_state expects a torch.ByteTensor, but got %s",
         Py_TYPE(obj)->tp_name);
   }
diff --git a/torch/csrc/nn/type_checks.h b/torch/csrc/nn/type_checks.h
index 1f3140c..966e32f 100644
--- a/torch/csrc/nn/type_checks.h
+++ b/torch/csrc/nn/type_checks.h
@@ -11,7 +11,8 @@
 
 inline bool check_type(PyObject* obj, at::TypeID typeID) {
   if (THPVariable_Check(obj)) {
-    return ((THPVariable*)obj)->cdata.data().type().ID() == typeID;
+    auto& data_type = ((THPVariable*)obj)->cdata.type();
+    return at::globalContext().getNonVariableType(data_type.backend(), data_type.scalarType()).ID() == typeID;
   }
   return false;
 }