[PyTorch][codemod] Replace immediately-dereferenced expect calls w/expectRef (#50228)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50228

`fastmod -m 'expect(<((at|c10)::)?\w+Type>\(\)\s*)->'
'expectRef${1}.'`
Presuming it builds, this is a safe change: the result of `expect()`
wasn't being saved anywhere, so we didn't need it, so we can take a
reference instead of a new `shared_ptr`.
ghstack-source-id: 119782961

Test Plan: CI

Reviewed By: SplitInfinity

Differential Revision: D25837374

fbshipit-source-id: 86757b70b1520e3dbaa141001e7976400cdd3b08
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 1223577..6ff7a52 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -417,7 +417,7 @@
     std::ostream& out,
     const IValue& the_list,
     IValueFormatter formatter) {
-  auto list_elem_type = the_list.type()->expect<ListType>()->getElementType();
+  auto list_elem_type = the_list.type()->expectRef<ListType>().getElementType();
   if (the_list.toListRef().size() == 0 ||
       !elementTypeCanBeInferredFromMembers(list_elem_type)) {
     out << "annotate(" << the_list.type()->annotation_str() << ", ";
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index ec1ec97..878d08c 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -252,7 +252,7 @@
 
   // Handle non-container types which do not subtype each other and unify
   if (t1->kind() == TensorType::Kind && t2->kind() == TensorType::Kind) {
-    return t1->expect<TensorType>()->merge(*t2->expect<TensorType>());
+    return t1->expectRef<TensorType>().merge(*t2->expect<TensorType>());
   }
 
   if (t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) {
@@ -1317,7 +1317,7 @@
     TORCH_CHECK(
         (type->kind() == TensorType::Kind) ||
             (type->kind() == OptionalType::Kind &&
-            type->expect<OptionalType>()->getElementType()->kind() ==
+            type->expectRef<OptionalType>().getElementType()->kind() ==
                 TensorType::Kind) ||
             (type->kind() == NoneType::Kind),
         "Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ",
diff --git a/test/cpp/jit/test_jit_type.cpp b/test/cpp/jit/test_jit_type.cpp
index 9462a57..8fd14d2 100644
--- a/test/cpp/jit/test_jit_type.cpp
+++ b/test/cpp/jit/test_jit_type.cpp
@@ -18,7 +18,7 @@
   TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(opt_bool_tensor));
   auto unified = unifyTypes(opt_bool_tensor, tensor);
   TORCH_INTERNAL_ASSERT(unified);
-  auto elem = (*unified)->expect<OptionalType>()->getElementType();
+  auto elem = (*unified)->expectRef<OptionalType>().getElementType();
   TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(TensorType::get()));
 
   auto opt_tuple_none_int = OptionalType::create(
diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp
index eaec2fe..3256c89 100644
--- a/test/cpp/jit/test_misc.cpp
+++ b/test/cpp/jit/test_misc.cpp
@@ -504,18 +504,18 @@
   ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments()
                                               .at(0)
                                               .type()
-                                              ->expect<ListType>()
-                                              ->getElementType()
-                                              ->expect<ListType>()
-                                              ->getElementType()));
+                                              ->expectRef<ListType>()
+                                              .getElementType()
+                                              ->expectRef<ListType>()
+                                              .getElementType()));
   auto s2 = parseSchema("at::what(int[][] foo) -> ()");
   ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments()
                                               .at(0)
                                               .type()
-                                              ->expect<ListType>()
-                                              ->getElementType()
-                                              ->expect<ListType>()
-                                              ->getElementType()));
+                                              ->expectRef<ListType>()
+                                              .getElementType()
+                                              ->expectRef<ListType>()
+                                              .getElementType()));
 }
 
 TEST(SchemaParserTest, NamedReturns) {
@@ -531,7 +531,7 @@
   // futures
   auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
   ASSERT_TRUE(IntType::get()->isSubtypeOf(
-      s4.arguments().at(0).type()->expect<FutureType>()->getElementType()));
+      s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
 }
 
 TEST(SchemaParserTest, AnnotatedAliasSets) {
@@ -1751,7 +1751,7 @@
   });
   ASSERT_NE(guard, nodes.end());
   ASSERT_EQ(
-      guard->input()->type()->expect<TensorType>()->sizes().size(),
+      guard->input()->type()->expectRef<TensorType>().sizes().size(),
       c10::nullopt);
   checkShape(*guard, {2, 3}, false);
   auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp
index 88aaa3e..63ec915 100644
--- a/torch/csrc/jit/codegen/cuda/partition.cpp
+++ b/torch/csrc/jit/codegen/cuda/partition.cpp
@@ -20,7 +20,7 @@
     // not tensor type, return false as the op is not outputing scalar.
     return c10::nullopt;
   }
-  return value->type()->expect<TensorType>()->device();
+  return value->type()->expectRef<TensorType>().device();
 }
 
 static c10::optional<c10::Device> getDevice(const Node* node) {
diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp
index eadc52d..522d819 100644
--- a/torch/csrc/jit/codegen/fuser/codegen.cpp
+++ b/torch/csrc/jit/codegen/fuser/codegen.cpp
@@ -91,7 +91,7 @@
     return "double";
   } else if (t->kind() == TypeKind::BoolType) {
     return "bool";
-  } else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
+  } else if (auto scalar_type = t->expectRef<TensorType>().scalarType()) {
     return calcScalarTypeName(*scalar_type);
   }
   // something went wrong with the type analysis during shape propagation
@@ -118,7 +118,7 @@
   } else if (t->kind() == TypeKind::NoneType) {
     // Support None value for optional arguments like memory format
     return vn;
-  } else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
+  } else if (auto scalar_type = t->expectRef<TensorType>().scalarType()) {
     if (*scalar_type != outtype) {
       return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
     }
@@ -261,7 +261,7 @@
   } else {
     size_t i = 0;
 
-    auto outtype = n->output()->type()->expect<TensorType>()->scalarType();
+    auto outtype = n->output()->type()->expectRef<TensorType>().scalarType();
     TORCH_INTERNAL_ASSERT(outtype);
 
     for (auto in : n->inputs()) {
diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp
index e49a6a6..e3e2382 100644
--- a/torch/csrc/jit/codegen/fuser/compiler.cpp
+++ b/torch/csrc/jit/codegen/fuser/compiler.cpp
@@ -260,7 +260,7 @@
       sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
     }
 
-    auto scalar_type = o->type()->expect<TensorType>()->scalarType();
+    auto scalar_type = o->type()->expectRef<TensorType>().scalarType();
     TORCH_INTERNAL_ASSERT(scalar_type);
     auto type = TensorType::createContiguous(*scalar_type, device, sizes);
     output_desc.emplace_back(type);
diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp
index 0dd84e4..954648f 100644
--- a/torch/csrc/jit/frontend/ir_emitter.cpp
+++ b/torch/csrc/jit/frontend/ir_emitter.cpp
@@ -2976,7 +2976,7 @@
           return std::make_shared<SimpleValue>(
               graph
                   ->insertNode(graph->createList(
-                      type->expect<ListType>()->getElementType(), {}))
+                      type->expectRef<ListType>().getElementType(), {}))
                   ->output());
         }
         // list(iter) desugars to [_elem for _elem in iter]
@@ -3376,7 +3376,7 @@
         TypePtr elem_type = TensorType::get();
         if (type_hint) {
           if (type_hint->kind() == TypeKind::ListType) {
-            elem_type = type_hint->expect<ListType>()->getElementType();
+            elem_type = type_hint->expectRef<ListType>().getElementType();
           } else {
             // If the type hint was not a List[T] throw an error
             throw ErrorReport(tree)
diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp
index 86c1c2f..977cc3d 100644
--- a/torch/csrc/jit/frontend/schema_matching.cpp
+++ b/torch/csrc/jit/frontend/schema_matching.cpp
@@ -72,7 +72,7 @@
     if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
       auto unpacked = createTupleUnpack(value);
       auto elem_type =
-          unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
+          unwrapOptional(concrete_type)->expectRef<ListType>().getElementType();
       value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
     }
 
@@ -340,8 +340,9 @@
         // The actual cannot already be a list
         if (actual_type->kind() != TypeKind::ListType &&
             !convertibleToList(actual_type, unwrapOptional(arg.type()))) {
-          auto formal_type =
-              unwrapOptional(arg.type())->expect<ListType>()->getElementType();
+          auto formal_type = unwrapOptional(arg.type())
+                                 ->expectRef<ListType>()
+                                 .getElementType();
 
           Value* list = tryCreateList(
               formal_type,
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index 1ca0f48..45d4ece 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -63,7 +63,7 @@
       }
       case TypeKind::TupleType: {
         std::vector<TypePtr> mutable_types;
-        for (const auto& elem : type->expect<TupleType>()->elements()) {
+        for (const auto& elem : type->expectRef<TupleType>().elements()) {
           if (auto mut_elem = getMutableType(elem)) {
             mutable_types.push_back(*mut_elem);
           }
@@ -511,7 +511,7 @@
     case prim::GetAttr:
       if (isFrozen_ && node->kind() == prim::GetAttr) {
         auto& ty = node->input()->type();
-        if (ty->expect<ClassType>()->is_module()) {
+        if (ty->expectRef<ClassType>().is_module()) {
           return analyzeCreator(node);
         }
       }
diff --git a/torch/csrc/jit/passes/clear_undefinedness.cpp b/torch/csrc/jit/passes/clear_undefinedness.cpp
index b235cbb..568441d 100644
--- a/torch/csrc/jit/passes/clear_undefinedness.cpp
+++ b/torch/csrc/jit/passes/clear_undefinedness.cpp
@@ -10,7 +10,7 @@
     o->setType(TensorType::get());
   } else if (
       o->type()->kind() == ListType::Kind &&
-      o->type()->expect<ListType>()->getElementType()->kind() ==
+      o->type()->expectRef<ListType>().getElementType()->kind() ==
           TensorType::Kind) {
     o->setType(ListType::create(TensorType::get()));
   }
diff --git a/torch/csrc/jit/passes/decompose_ops.cpp b/torch/csrc/jit/passes/decompose_ops.cpp
index ad8dcf3..d7ca569 100644
--- a/torch/csrc/jit/passes/decompose_ops.cpp
+++ b/torch/csrc/jit/passes/decompose_ops.cpp
@@ -39,7 +39,7 @@
   if (!input->type()->isSubtypeOf(TensorType::get())) {
     return false;
   }
-  auto device = input->type()->expect<TensorType>()->device();
+  auto device = input->type()->expectRef<TensorType>().device();
   // As of now, we do the decomposition for batchnorm/layernorm on GPU device
   // only
   if (!device || (*device).is_cpu()) {
diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp
index 38684d2..18e90a2 100644
--- a/torch/csrc/jit/passes/freeze_module.cpp
+++ b/torch/csrc/jit/passes/freeze_module.cpp
@@ -156,7 +156,7 @@
       Module& attrModule,
       std::shared_ptr<Graph>& graph) {
     if (!input->type()->cast<InterfaceType>() &&
-        !input->type()->expect<ClassType>()->is_module()) {
+        !input->type()->expectRef<ClassType>().is_module()) {
       return false;
     }
 
@@ -425,7 +425,7 @@
           if (!findConstantAttr(input, name, attrModule, graph)) {
             GRAPH_DEBUG(
                 input->type()->cast<InterfaceType>() ||
-                        input->type()->expect<ClassType>()->is_module()
+                        input->type()->expectRef<ClassType>().is_module()
                     ? "attribute: " + name + " is mutable."
                     : "");
             continue;
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index f09f945..8243a2c 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -177,7 +177,7 @@
     if (!v->type()->isSubtypeOf(TensorType::get())) {
       return true;
     }
-    auto device = v->type()->expect<TensorType>()->device();
+    auto device = v->type()->expectRef<TensorType>().device();
     if (!device) {
       return !strict_fuser_check;
     }
diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp
index 8072174..34d7fd6 100644
--- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp
+++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp
@@ -9,7 +9,7 @@
 namespace graph_rewrite_helper {
 
 std::string getFuncName(Value* func_value) {
-  auto func = func_value->type()->expect<FunctionType>()->function();
+  auto func = func_value->type()->expectRef<FunctionType>().function();
   const auto& qname = func->qualname();
   const auto& name = qname.qualifiedName();
   auto rdot_idx = name.rfind('.');
diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp
index c26404d..4f3a96b 100644
--- a/torch/csrc/jit/passes/guard_elimination.cpp
+++ b/torch/csrc/jit/passes/guard_elimination.cpp
@@ -242,7 +242,7 @@
     size_t i = 0;
     for (auto input : n->inputs()) {
       if ((input->node()->kind() == prim::Guard &&
-           !input->type()->expect<TensorType>()->isSummarized()) ||
+           !input->type()->expectRef<TensorType>().isSummarized()) ||
           input->node()->kind() == prim::Constant ||
           (allow_numbers && input->type()->isSubtypeOf(NumberType::get())) ||
           except.count(i) != 0) {
@@ -377,7 +377,7 @@
       case aten::conv3d:
         return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
       case aten::slice:
-        return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
+        return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
             // check that the dimension argument is constant
             n->input(1)->node()->kind() == prim::Constant &&
             // the start offset is constant
@@ -389,7 +389,7 @@
       case aten::max_pool1d:
       case aten::max_pool2d:
       case aten::max_pool3d:
-        return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
+        return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
             // check that the kernel size is constant
             n->input(1)->node()->kind() == prim::Constant &&
             // check that the stride is constant
@@ -402,7 +402,7 @@
             n->input(5)->node()->kind() == prim::Constant;
       case aten::unsqueeze:
         // check that the dimension argument is constant
-        return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
+        return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
             n->input(1)->node()->kind() == prim::Constant;
       case aten::cat:
         // check that the dimension argument is constant
@@ -427,8 +427,8 @@
             // aten::size is effectively a constant
             if (asize->input()
                     ->type()
-                    ->expect<TensorType>()
-                    ->sizes()
+                    ->expectRef<TensorType>()
+                    .sizes()
                     .concrete_sizes()) {
               return true;
             }
diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp
index 8e23024..d488201 100644
--- a/torch/csrc/jit/passes/onnx/peephole.cpp
+++ b/torch/csrc/jit/passes/onnx/peephole.cpp
@@ -138,14 +138,14 @@
       // Not all broadcasts are supported by ONNX broadcast.
       c10::optional<size_t> axis = fusibleExpandTo(
           unexpanded_input->type()
-              ->expect<TensorType>()
-              ->sizes()
+              ->expectRef<TensorType>()
+              .sizes()
               .concrete_sizes()
               .value(), // from
           n->output()
               ->type()
-              ->expect<TensorType>()
-              ->sizes()
+              ->expectRef<TensorType>()
+              .sizes()
               .concrete_sizes()
               .value()); // to
       if (axis == c10::nullopt)
diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp
index fda3230..f3786c5 100644
--- a/torch/csrc/jit/passes/peephole.cpp
+++ b/torch/csrc/jit/passes/peephole.cpp
@@ -425,11 +425,11 @@
 
             // Attempts to find a matrix with a defined scalar type to type as
             auto* type_as_mat = mat1;
-            if (!type_as_mat->type()->expect<TensorType>()->scalarType()) {
+            if (!type_as_mat->type()->expectRef<TensorType>().scalarType()) {
               type_as_mat = mat2;
             }
             auto mat_scalar_type =
-                type_as_mat->type()->expect<TensorType>()->scalarType();
+                type_as_mat->type()->expectRef<TensorType>().scalarType();
 
             // we can't use type_as if we don't know the target type (mm), the
             // bias needs to be coerced to
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index 1549b4f..0692d4c 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -532,7 +532,7 @@
 
 std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
   auto* func_node = n->input(0)->node();
-  auto func = func_node->output()->type()->expect<FunctionType>()->function();
+  auto func = func_node->output()->type()->expectRef<FunctionType>().function();
   TORCH_CHECK(
       func->isGraphFunction(), "Quantization only works for graph function");
   return func->graph();
diff --git a/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp b/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp
index 5013dcf..c9aca34 100644
--- a/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp
+++ b/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp
@@ -20,8 +20,8 @@
       n->ty_(
           attr::profiled_type,
           n->ty(attr::profiled_type)
-              ->expect<TensorType>()
-              ->withRequiresGrad(new_requires_grad));
+              ->expectRef<TensorType>()
+              .withRequiresGrad(new_requires_grad));
     }
     for (Block* b : n->blocks()) {
       UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad);
diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp
index 19de83d..da16a67 100644
--- a/torch/csrc/jit/python/pybind_utils.cpp
+++ b/torch/csrc/jit/python/pybind_utils.cpp
@@ -78,7 +78,7 @@
       return static_cast<int64_t>(stream->cdata);
     }
     case TypeKind::ListType: {
-      const auto& elem_type = type->expect<ListType>()->getElementType();
+      const auto& elem_type = type->expectRef<ListType>().getElementType();
       switch (elem_type->kind()) {
         // allows single int/float to be broadcasted to a fixed size list
         case TypeKind::IntType:
@@ -127,7 +127,7 @@
         // return an IValue() to denote a NoneType
         return {};
       }
-      return toIValue(obj, type->expect<OptionalType>()->getElementType());
+      return toIValue(obj, type->expectRef<OptionalType>().getElementType());
     }
     case TypeKind::ClassType: {
       auto classType = type->expect<ClassType>();
diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp
index a4120d4..521b93f 100644
--- a/torch/csrc/jit/python/python_custom_class.cpp
+++ b/torch/csrc/jit/python/python_custom_class.cpp
@@ -31,7 +31,7 @@
   py::class_<ScriptClass>(m, "ScriptClass")
       .def("__call__", &ScriptClass::__call__)
       .def_property_readonly("__doc__", [](const ScriptClass& self) {
-        return self.class_type_.type_->expect<ClassType>()->doc_string();
+        return self.class_type_.type_->expectRef<ClassType>().doc_string();
       });
 
   // This function returns a ScriptClass that wraps the constructor
diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp
index 5542f53..ab8545d 100644
--- a/torch/csrc/jit/python/python_ir.cpp
+++ b/torch/csrc/jit/python/python_ir.cpp
@@ -434,7 +434,7 @@
       .VS(requires_grad)
       .def(
           "requiresGrad",
-          [](Value& n) { n.type()->expect<TensorType>()->requiresGrad(); })
+          [](Value& n) { n.type()->expectRef<TensorType>().requiresGrad(); })
       .def("toIValue", [](Value& n) { return toIValue(&n); })
       .def("type", [](Value& v) { return v.type(); });
 #undef VS
@@ -686,7 +686,7 @@
       .def(
           "dim",
           [](Type& t) {
-            auto vshape = t.shared_from_this()->expect<TensorType>()->sizes();
+            auto vshape = t.shared_from_this()->expectRef<TensorType>().sizes();
             return vshape.size() ? py::cast(*vshape.size())
                                  : py::cast<py::none>(Py_None);
           })
@@ -694,7 +694,7 @@
           "undefined",
           [](Type& t) {
             auto undef =
-                t.shared_from_this()->expect<TensorType>()->undefined();
+                t.shared_from_this()->expectRef<TensorType>().undefined();
             return undef.has_value() ? py::cast(*undef)
                                      : py::cast<py::none>(Py_None);
           })
@@ -732,13 +732,13 @@
           "contiguous",
           [](Type& t) {
             return std::static_pointer_cast<Type>(
-                t.expect<TensorType>()->contiguous());
+                t.expectRef<TensorType>().contiguous());
           })
       .def(
           "scalarType",
           [](Type& t) {
             auto scalar_type =
-                t.shared_from_this()->expect<TensorType>()->scalarType();
+                t.shared_from_this()->expectRef<TensorType>().scalarType();
             return (scalar_type) ? toString(*scalar_type) : nullptr;
           })
       .def(
diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp
index f6f6fbf..92a0611 100644
--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -631,7 +631,7 @@
 
   // Check if it's a property.
   auto prop =
-      concreteType_->getJitType()->expect<ClassType>()->getProperty(field);
+      concreteType_->getJitType()->expectRef<ClassType>().getProperty(field);
   if (prop) {
     return MethodValue(self_, prop->getter->name())
         .call(loc, m, {}, {}, /*n_binders=*/1);
@@ -647,7 +647,8 @@
 
   throw ErrorReport(loc)
       << "Module '"
-      << concreteType_->getJitType()->expect<ClassType>()->name()->name() << "'"
+      << concreteType_->getJitType()->expectRef<ClassType>().name()->name()
+      << "'"
       << " has no attribute '" << field << "' " << hint;
 }
 
diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp
index 2d18d6f..5e6c9a9 100644
--- a/torch/csrc/jit/runtime/interpreter.cpp
+++ b/torch/csrc/jit/runtime/interpreter.cpp
@@ -847,7 +847,7 @@
 
   void emitTupleConstruct(Node* node) {
     bool named =
-        node->output()->type()->expect<TupleType>()->name().has_value();
+        node->output()->type()->expectRef<TupleType>().name().has_value();
     if (named) {
       emitContainerConstruct(NAMED_TUPLE_CONSTRUCT, node);
     } else {
@@ -938,7 +938,7 @@
         break;
       case prim::CallFunction:
         emitCall(
-            node->inputs().at(0)->type()->expect<FunctionType>()->function(),
+            node->inputs().at(0)->type()->expectRef<FunctionType>().function(),
             node->inputs().slice(1));
         break;
       case prim::CallMethod:
diff --git a/torch/csrc/jit/runtime/register_c10_ops.cpp b/torch/csrc/jit/runtime/register_c10_ops.cpp
index 4e1a4fb..e31c13a 100644
--- a/torch/csrc/jit/runtime/register_c10_ops.cpp
+++ b/torch/csrc/jit/runtime/register_c10_ops.cpp
@@ -46,7 +46,7 @@
             node->addInput(none);
             continue;
           } else {
-            type = type->expect<OptionalType>()->getElementType();
+            type = type->expectRef<OptionalType>().getElementType();
           }
         }
         if (type->isSubtypeOf(TensorType::get())) {
@@ -67,7 +67,7 @@
         } else if (type->kind() == TypeKind::NumberType) {
           tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
         } else if (type->kind() == TypeKind::ListType) {
-          const auto& elem_type = type->expect<ListType>()->getElementType();
+          const auto& elem_type = type->expectRef<ListType>().getElementType();
           if (elem_type->isSubtypeOf(TensorType::get())) {
             AT_ASSERT(iter->isTensorList());
             auto list = iter->toTensorVector();
@@ -134,7 +134,7 @@
           AT_ASSERT(iter->isTensor());
           tracer::addOutput(node, iter->toTensor());
         } else if (type->kind() == TypeKind::ListType) {
-          const auto& elem_type = type->expect<ListType>()->getElementType();
+          const auto& elem_type = type->expectRef<ListType>().getElementType();
           if (elem_type->isSubtypeOf(TensorType::get())) {
             AT_ASSERT(iter->isTensorList());
             tracer::addOutput(node, iter->toTensorList());
diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp
index 8b348de..6b61b7c 100644
--- a/torch/csrc/jit/serialization/python_print.cpp
+++ b/torch/csrc/jit/serialization/python_print.cpp
@@ -880,10 +880,10 @@
         return true;
       }
 
-      if (v.isTuple() && v.type()->expect<TupleType>()->schema()) {
+      if (v.isTuple() && v.type()->expectRef<TupleType>().schema()) {
         // print the namedtuple constructor and let rest of tuple printing
         // continue
-        ss << v.type()->expect<TupleType>()->annotation_str(type_printer_);
+        ss << v.type()->expectRef<TupleType>().annotation_str(type_printer_);
       }
       return false;
     };
@@ -981,7 +981,7 @@
       } break;
       case prim::TupleConstruct: {
         if (auto qualname =
-                node->output()->type()->expect<TupleType>()->name()) {
+                node->output()->type()->expectRef<TupleType>().name()) {
           stmt << node->output()->type()->annotation_str(type_printer_);
         }
         printValueList(