[PyTorch][codemod] Replace immediately-dereferenced expect calls w/expectRef (#50228)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50228
`fastmod -m 'expect(<((at|c10)::)?\w+Type>\(\)\s*)->'
'expectRef${1}.'`
Presuming it builds, this is a safe change: the result of `expect()`
wasn't being saved anywhere, so we didn't need it, so we can take a
reference instead of a new `shared_ptr`.
ghstack-source-id: 119782961
Test Plan: CI
Reviewed By: SplitInfinity
Differential Revision: D25837374
fbshipit-source-id: 86757b70b1520e3dbaa141001e7976400cdd3b08
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 1223577..6ff7a52 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -417,7 +417,7 @@
std::ostream& out,
const IValue& the_list,
IValueFormatter formatter) {
- auto list_elem_type = the_list.type()->expect<ListType>()->getElementType();
+ auto list_elem_type = the_list.type()->expectRef<ListType>().getElementType();
if (the_list.toListRef().size() == 0 ||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
out << "annotate(" << the_list.type()->annotation_str() << ", ";
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index ec1ec97..878d08c 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -252,7 +252,7 @@
// Handle non-container types which do not subtype each other and unify
if (t1->kind() == TensorType::Kind && t2->kind() == TensorType::Kind) {
- return t1->expect<TensorType>()->merge(*t2->expect<TensorType>());
+ return t1->expectRef<TensorType>().merge(*t2->expect<TensorType>());
}
if (t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) {
@@ -1317,7 +1317,7 @@
TORCH_CHECK(
(type->kind() == TensorType::Kind) ||
(type->kind() == OptionalType::Kind &&
- type->expect<OptionalType>()->getElementType()->kind() ==
+ type->expectRef<OptionalType>().getElementType()->kind() ==
TensorType::Kind) ||
(type->kind() == NoneType::Kind),
"Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ",
diff --git a/test/cpp/jit/test_jit_type.cpp b/test/cpp/jit/test_jit_type.cpp
index 9462a57..8fd14d2 100644
--- a/test/cpp/jit/test_jit_type.cpp
+++ b/test/cpp/jit/test_jit_type.cpp
@@ -18,7 +18,7 @@
TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(opt_bool_tensor));
auto unified = unifyTypes(opt_bool_tensor, tensor);
TORCH_INTERNAL_ASSERT(unified);
- auto elem = (*unified)->expect<OptionalType>()->getElementType();
+ auto elem = (*unified)->expectRef<OptionalType>().getElementType();
TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(TensorType::get()));
auto opt_tuple_none_int = OptionalType::create(
diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp
index eaec2fe..3256c89 100644
--- a/test/cpp/jit/test_misc.cpp
+++ b/test/cpp/jit/test_misc.cpp
@@ -504,18 +504,18 @@
ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments()
.at(0)
.type()
- ->expect<ListType>()
- ->getElementType()
- ->expect<ListType>()
- ->getElementType()));
+ ->expectRef<ListType>()
+ .getElementType()
+ ->expectRef<ListType>()
+ .getElementType()));
auto s2 = parseSchema("at::what(int[][] foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments()
.at(0)
.type()
- ->expect<ListType>()
- ->getElementType()
- ->expect<ListType>()
- ->getElementType()));
+ ->expectRef<ListType>()
+ .getElementType()
+ ->expectRef<ListType>()
+ .getElementType()));
}
TEST(SchemaParserTest, NamedReturns) {
@@ -531,7 +531,7 @@
// futures
auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(
- s4.arguments().at(0).type()->expect<FutureType>()->getElementType()));
+ s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
}
TEST(SchemaParserTest, AnnotatedAliasSets) {
@@ -1751,7 +1751,7 @@
});
ASSERT_NE(guard, nodes.end());
ASSERT_EQ(
- guard->input()->type()->expect<TensorType>()->sizes().size(),
+ guard->input()->type()->expectRef<TensorType>().sizes().size(),
c10::nullopt);
checkShape(*guard, {2, 3}, false);
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp
index 88aaa3e..63ec915 100644
--- a/torch/csrc/jit/codegen/cuda/partition.cpp
+++ b/torch/csrc/jit/codegen/cuda/partition.cpp
@@ -20,7 +20,7 @@
// not tensor type, return false as the op is not outputing scalar.
return c10::nullopt;
}
- return value->type()->expect<TensorType>()->device();
+ return value->type()->expectRef<TensorType>().device();
}
static c10::optional<c10::Device> getDevice(const Node* node) {
diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp
index eadc52d..522d819 100644
--- a/torch/csrc/jit/codegen/fuser/codegen.cpp
+++ b/torch/csrc/jit/codegen/fuser/codegen.cpp
@@ -91,7 +91,7 @@
return "double";
} else if (t->kind() == TypeKind::BoolType) {
return "bool";
- } else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
+ } else if (auto scalar_type = t->expectRef<TensorType>().scalarType()) {
return calcScalarTypeName(*scalar_type);
}
// something went wrong with the type analysis during shape propagation
@@ -118,7 +118,7 @@
} else if (t->kind() == TypeKind::NoneType) {
// Support None value for optional arguments like memory format
return vn;
- } else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
+ } else if (auto scalar_type = t->expectRef<TensorType>().scalarType()) {
if (*scalar_type != outtype) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
@@ -261,7 +261,7 @@
} else {
size_t i = 0;
- auto outtype = n->output()->type()->expect<TensorType>()->scalarType();
+ auto outtype = n->output()->type()->expectRef<TensorType>().scalarType();
TORCH_INTERNAL_ASSERT(outtype);
for (auto in : n->inputs()) {
diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp
index e49a6a6..e3e2382 100644
--- a/torch/csrc/jit/codegen/fuser/compiler.cpp
+++ b/torch/csrc/jit/codegen/fuser/compiler.cpp
@@ -260,7 +260,7 @@
sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
}
- auto scalar_type = o->type()->expect<TensorType>()->scalarType();
+ auto scalar_type = o->type()->expectRef<TensorType>().scalarType();
TORCH_INTERNAL_ASSERT(scalar_type);
auto type = TensorType::createContiguous(*scalar_type, device, sizes);
output_desc.emplace_back(type);
diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp
index 0dd84e4..954648f 100644
--- a/torch/csrc/jit/frontend/ir_emitter.cpp
+++ b/torch/csrc/jit/frontend/ir_emitter.cpp
@@ -2976,7 +2976,7 @@
return std::make_shared<SimpleValue>(
graph
->insertNode(graph->createList(
- type->expect<ListType>()->getElementType(), {}))
+ type->expectRef<ListType>().getElementType(), {}))
->output());
}
// list(iter) desugars to [_elem for _elem in iter]
@@ -3376,7 +3376,7 @@
TypePtr elem_type = TensorType::get();
if (type_hint) {
if (type_hint->kind() == TypeKind::ListType) {
- elem_type = type_hint->expect<ListType>()->getElementType();
+ elem_type = type_hint->expectRef<ListType>().getElementType();
} else {
// If the type hint was not a List[T] throw an error
throw ErrorReport(tree)
diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp
index 86c1c2f..977cc3d 100644
--- a/torch/csrc/jit/frontend/schema_matching.cpp
+++ b/torch/csrc/jit/frontend/schema_matching.cpp
@@ -72,7 +72,7 @@
if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
auto unpacked = createTupleUnpack(value);
auto elem_type =
- unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
+ unwrapOptional(concrete_type)->expectRef<ListType>().getElementType();
value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
}
@@ -340,8 +340,9 @@
// The actual cannot already be a list
if (actual_type->kind() != TypeKind::ListType &&
!convertibleToList(actual_type, unwrapOptional(arg.type()))) {
- auto formal_type =
- unwrapOptional(arg.type())->expect<ListType>()->getElementType();
+ auto formal_type = unwrapOptional(arg.type())
+ ->expectRef<ListType>()
+ .getElementType();
Value* list = tryCreateList(
formal_type,
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index 1ca0f48..45d4ece 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -63,7 +63,7 @@
}
case TypeKind::TupleType: {
std::vector<TypePtr> mutable_types;
- for (const auto& elem : type->expect<TupleType>()->elements()) {
+ for (const auto& elem : type->expectRef<TupleType>().elements()) {
if (auto mut_elem = getMutableType(elem)) {
mutable_types.push_back(*mut_elem);
}
@@ -511,7 +511,7 @@
case prim::GetAttr:
if (isFrozen_ && node->kind() == prim::GetAttr) {
auto& ty = node->input()->type();
- if (ty->expect<ClassType>()->is_module()) {
+ if (ty->expectRef<ClassType>().is_module()) {
return analyzeCreator(node);
}
}
diff --git a/torch/csrc/jit/passes/clear_undefinedness.cpp b/torch/csrc/jit/passes/clear_undefinedness.cpp
index b235cbb..568441d 100644
--- a/torch/csrc/jit/passes/clear_undefinedness.cpp
+++ b/torch/csrc/jit/passes/clear_undefinedness.cpp
@@ -10,7 +10,7 @@
o->setType(TensorType::get());
} else if (
o->type()->kind() == ListType::Kind &&
- o->type()->expect<ListType>()->getElementType()->kind() ==
+ o->type()->expectRef<ListType>().getElementType()->kind() ==
TensorType::Kind) {
o->setType(ListType::create(TensorType::get()));
}
diff --git a/torch/csrc/jit/passes/decompose_ops.cpp b/torch/csrc/jit/passes/decompose_ops.cpp
index ad8dcf3..d7ca569 100644
--- a/torch/csrc/jit/passes/decompose_ops.cpp
+++ b/torch/csrc/jit/passes/decompose_ops.cpp
@@ -39,7 +39,7 @@
if (!input->type()->isSubtypeOf(TensorType::get())) {
return false;
}
- auto device = input->type()->expect<TensorType>()->device();
+ auto device = input->type()->expectRef<TensorType>().device();
// As of now, we do the decomposition for batchnorm/layernorm on GPU device
// only
if (!device || (*device).is_cpu()) {
diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp
index 38684d2..18e90a2 100644
--- a/torch/csrc/jit/passes/freeze_module.cpp
+++ b/torch/csrc/jit/passes/freeze_module.cpp
@@ -156,7 +156,7 @@
Module& attrModule,
std::shared_ptr<Graph>& graph) {
if (!input->type()->cast<InterfaceType>() &&
- !input->type()->expect<ClassType>()->is_module()) {
+ !input->type()->expectRef<ClassType>().is_module()) {
return false;
}
@@ -425,7 +425,7 @@
if (!findConstantAttr(input, name, attrModule, graph)) {
GRAPH_DEBUG(
input->type()->cast<InterfaceType>() ||
- input->type()->expect<ClassType>()->is_module()
+ input->type()->expectRef<ClassType>().is_module()
? "attribute: " + name + " is mutable."
: "");
continue;
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index f09f945..8243a2c 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -177,7 +177,7 @@
if (!v->type()->isSubtypeOf(TensorType::get())) {
return true;
}
- auto device = v->type()->expect<TensorType>()->device();
+ auto device = v->type()->expectRef<TensorType>().device();
if (!device) {
return !strict_fuser_check;
}
diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp
index 8072174..34d7fd6 100644
--- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp
+++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp
@@ -9,7 +9,7 @@
namespace graph_rewrite_helper {
std::string getFuncName(Value* func_value) {
- auto func = func_value->type()->expect<FunctionType>()->function();
+ auto func = func_value->type()->expectRef<FunctionType>().function();
const auto& qname = func->qualname();
const auto& name = qname.qualifiedName();
auto rdot_idx = name.rfind('.');
diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp
index c26404d..4f3a96b 100644
--- a/torch/csrc/jit/passes/guard_elimination.cpp
+++ b/torch/csrc/jit/passes/guard_elimination.cpp
@@ -242,7 +242,7 @@
size_t i = 0;
for (auto input : n->inputs()) {
if ((input->node()->kind() == prim::Guard &&
- !input->type()->expect<TensorType>()->isSummarized()) ||
+ !input->type()->expectRef<TensorType>().isSummarized()) ||
input->node()->kind() == prim::Constant ||
(allow_numbers && input->type()->isSubtypeOf(NumberType::get())) ||
except.count(i) != 0) {
@@ -377,7 +377,7 @@
case aten::conv3d:
return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
case aten::slice:
- return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
+ return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
// check that the dimension argument is constant
n->input(1)->node()->kind() == prim::Constant &&
// the start offset is constant
@@ -389,7 +389,7 @@
case aten::max_pool1d:
case aten::max_pool2d:
case aten::max_pool3d:
- return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
+ return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
// check that the kernel size is constant
n->input(1)->node()->kind() == prim::Constant &&
// check that the stride is constant
@@ -402,7 +402,7 @@
n->input(5)->node()->kind() == prim::Constant;
case aten::unsqueeze:
// check that the dimension argument is constant
- return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
+ return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
n->input(1)->node()->kind() == prim::Constant;
case aten::cat:
// check that the dimension argument is constant
@@ -427,8 +427,8 @@
// aten::size is effectively a constant
if (asize->input()
->type()
- ->expect<TensorType>()
- ->sizes()
+ ->expectRef<TensorType>()
+ .sizes()
.concrete_sizes()) {
return true;
}
diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp
index 8e23024..d488201 100644
--- a/torch/csrc/jit/passes/onnx/peephole.cpp
+++ b/torch/csrc/jit/passes/onnx/peephole.cpp
@@ -138,14 +138,14 @@
// Not all broadcasts are supported by ONNX broadcast.
c10::optional<size_t> axis = fusibleExpandTo(
unexpanded_input->type()
- ->expect<TensorType>()
- ->sizes()
+ ->expectRef<TensorType>()
+ .sizes()
.concrete_sizes()
.value(), // from
n->output()
->type()
- ->expect<TensorType>()
- ->sizes()
+ ->expectRef<TensorType>()
+ .sizes()
.concrete_sizes()
.value()); // to
if (axis == c10::nullopt)
diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp
index fda3230..f3786c5 100644
--- a/torch/csrc/jit/passes/peephole.cpp
+++ b/torch/csrc/jit/passes/peephole.cpp
@@ -425,11 +425,11 @@
// Attempts to find a matrix with a defined scalar type to type as
auto* type_as_mat = mat1;
- if (!type_as_mat->type()->expect<TensorType>()->scalarType()) {
+ if (!type_as_mat->type()->expectRef<TensorType>().scalarType()) {
type_as_mat = mat2;
}
auto mat_scalar_type =
- type_as_mat->type()->expect<TensorType>()->scalarType();
+ type_as_mat->type()->expectRef<TensorType>().scalarType();
// we can't use type_as if we don't know the target type (mm), the
// bias needs to be coerced to
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index 1549b4f..0692d4c 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -532,7 +532,7 @@
std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
auto* func_node = n->input(0)->node();
- auto func = func_node->output()->type()->expect<FunctionType>()->function();
+ auto func = func_node->output()->type()->expectRef<FunctionType>().function();
TORCH_CHECK(
func->isGraphFunction(), "Quantization only works for graph function");
return func->graph();
diff --git a/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp b/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp
index 5013dcf..c9aca34 100644
--- a/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp
+++ b/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp
@@ -20,8 +20,8 @@
n->ty_(
attr::profiled_type,
n->ty(attr::profiled_type)
- ->expect<TensorType>()
- ->withRequiresGrad(new_requires_grad));
+ ->expectRef<TensorType>()
+ .withRequiresGrad(new_requires_grad));
}
for (Block* b : n->blocks()) {
UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad);
diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp
index 19de83d..da16a67 100644
--- a/torch/csrc/jit/python/pybind_utils.cpp
+++ b/torch/csrc/jit/python/pybind_utils.cpp
@@ -78,7 +78,7 @@
return static_cast<int64_t>(stream->cdata);
}
case TypeKind::ListType: {
- const auto& elem_type = type->expect<ListType>()->getElementType();
+ const auto& elem_type = type->expectRef<ListType>().getElementType();
switch (elem_type->kind()) {
// allows single int/float to be broadcasted to a fixed size list
case TypeKind::IntType:
@@ -127,7 +127,7 @@
// return an IValue() to denote a NoneType
return {};
}
- return toIValue(obj, type->expect<OptionalType>()->getElementType());
+ return toIValue(obj, type->expectRef<OptionalType>().getElementType());
}
case TypeKind::ClassType: {
auto classType = type->expect<ClassType>();
diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp
index a4120d4..521b93f 100644
--- a/torch/csrc/jit/python/python_custom_class.cpp
+++ b/torch/csrc/jit/python/python_custom_class.cpp
@@ -31,7 +31,7 @@
py::class_<ScriptClass>(m, "ScriptClass")
.def("__call__", &ScriptClass::__call__)
.def_property_readonly("__doc__", [](const ScriptClass& self) {
- return self.class_type_.type_->expect<ClassType>()->doc_string();
+ return self.class_type_.type_->expectRef<ClassType>().doc_string();
});
// This function returns a ScriptClass that wraps the constructor
diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp
index 5542f53..ab8545d 100644
--- a/torch/csrc/jit/python/python_ir.cpp
+++ b/torch/csrc/jit/python/python_ir.cpp
@@ -434,7 +434,7 @@
.VS(requires_grad)
.def(
"requiresGrad",
- [](Value& n) { n.type()->expect<TensorType>()->requiresGrad(); })
+ [](Value& n) { n.type()->expectRef<TensorType>().requiresGrad(); })
.def("toIValue", [](Value& n) { return toIValue(&n); })
.def("type", [](Value& v) { return v.type(); });
#undef VS
@@ -686,7 +686,7 @@
.def(
"dim",
[](Type& t) {
- auto vshape = t.shared_from_this()->expect<TensorType>()->sizes();
+ auto vshape = t.shared_from_this()->expectRef<TensorType>().sizes();
return vshape.size() ? py::cast(*vshape.size())
: py::cast<py::none>(Py_None);
})
@@ -694,7 +694,7 @@
"undefined",
[](Type& t) {
auto undef =
- t.shared_from_this()->expect<TensorType>()->undefined();
+ t.shared_from_this()->expectRef<TensorType>().undefined();
return undef.has_value() ? py::cast(*undef)
: py::cast<py::none>(Py_None);
})
@@ -732,13 +732,13 @@
"contiguous",
[](Type& t) {
return std::static_pointer_cast<Type>(
- t.expect<TensorType>()->contiguous());
+ t.expectRef<TensorType>().contiguous());
})
.def(
"scalarType",
[](Type& t) {
auto scalar_type =
- t.shared_from_this()->expect<TensorType>()->scalarType();
+ t.shared_from_this()->expectRef<TensorType>().scalarType();
return (scalar_type) ? toString(*scalar_type) : nullptr;
})
.def(
diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp
index f6f6fbf..92a0611 100644
--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -631,7 +631,7 @@
// Check if it's a property.
auto prop =
- concreteType_->getJitType()->expect<ClassType>()->getProperty(field);
+ concreteType_->getJitType()->expectRef<ClassType>().getProperty(field);
if (prop) {
return MethodValue(self_, prop->getter->name())
.call(loc, m, {}, {}, /*n_binders=*/1);
@@ -647,7 +647,8 @@
throw ErrorReport(loc)
<< "Module '"
- << concreteType_->getJitType()->expect<ClassType>()->name()->name() << "'"
+ << concreteType_->getJitType()->expectRef<ClassType>().name()->name()
+ << "'"
<< " has no attribute '" << field << "' " << hint;
}
diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp
index 2d18d6f..5e6c9a9 100644
--- a/torch/csrc/jit/runtime/interpreter.cpp
+++ b/torch/csrc/jit/runtime/interpreter.cpp
@@ -847,7 +847,7 @@
void emitTupleConstruct(Node* node) {
bool named =
- node->output()->type()->expect<TupleType>()->name().has_value();
+ node->output()->type()->expectRef<TupleType>().name().has_value();
if (named) {
emitContainerConstruct(NAMED_TUPLE_CONSTRUCT, node);
} else {
@@ -938,7 +938,7 @@
break;
case prim::CallFunction:
emitCall(
- node->inputs().at(0)->type()->expect<FunctionType>()->function(),
+ node->inputs().at(0)->type()->expectRef<FunctionType>().function(),
node->inputs().slice(1));
break;
case prim::CallMethod:
diff --git a/torch/csrc/jit/runtime/register_c10_ops.cpp b/torch/csrc/jit/runtime/register_c10_ops.cpp
index 4e1a4fb..e31c13a 100644
--- a/torch/csrc/jit/runtime/register_c10_ops.cpp
+++ b/torch/csrc/jit/runtime/register_c10_ops.cpp
@@ -46,7 +46,7 @@
node->addInput(none);
continue;
} else {
- type = type->expect<OptionalType>()->getElementType();
+ type = type->expectRef<OptionalType>().getElementType();
}
}
if (type->isSubtypeOf(TensorType::get())) {
@@ -67,7 +67,7 @@
} else if (type->kind() == TypeKind::NumberType) {
tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
} else if (type->kind() == TypeKind::ListType) {
- const auto& elem_type = type->expect<ListType>()->getElementType();
+ const auto& elem_type = type->expectRef<ListType>().getElementType();
if (elem_type->isSubtypeOf(TensorType::get())) {
AT_ASSERT(iter->isTensorList());
auto list = iter->toTensorVector();
@@ -134,7 +134,7 @@
AT_ASSERT(iter->isTensor());
tracer::addOutput(node, iter->toTensor());
} else if (type->kind() == TypeKind::ListType) {
- const auto& elem_type = type->expect<ListType>()->getElementType();
+ const auto& elem_type = type->expectRef<ListType>().getElementType();
if (elem_type->isSubtypeOf(TensorType::get())) {
AT_ASSERT(iter->isTensorList());
tracer::addOutput(node, iter->toTensorList());
diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp
index 8b348de..6b61b7c 100644
--- a/torch/csrc/jit/serialization/python_print.cpp
+++ b/torch/csrc/jit/serialization/python_print.cpp
@@ -880,10 +880,10 @@
return true;
}
- if (v.isTuple() && v.type()->expect<TupleType>()->schema()) {
+ if (v.isTuple() && v.type()->expectRef<TupleType>().schema()) {
// print the namedtuple constructor and let rest of tuple printing
// continue
- ss << v.type()->expect<TupleType>()->annotation_str(type_printer_);
+ ss << v.type()->expectRef<TupleType>().annotation_str(type_printer_);
}
return false;
};
@@ -981,7 +981,7 @@
} break;
case prim::TupleConstruct: {
if (auto qualname =
- node->output()->type()->expect<TupleType>()->name()) {
+ node->output()->type()->expectRef<TupleType>().name()) {
stmt << node->output()->type()->annotation_str(type_printer_);
}
printValueList(