Support named tuple return from operators on JIT (#16253)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/16233
The following changes are made:
- Modify `TupleType` to store optional field names
- Modify schema matching to return fill in those field names when creating `TupleType` as return type.
- Modify codegen of JIT to copy field names to schema string
- Modify `SchemaParser` to set field names of returned schema.
- Modify `SimpleValue::attr` to emit tuple indexing for named tuple.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16253
Reviewed By: ezyang
Differential Revision: D13954298
Pulled By: zdevito
fbshipit-source-id: 247d483d78a0c9c12d1ba36e1f1ec6c3f1a3007b
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index abaa41d..f23bae4 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -8,6 +8,8 @@
#include <c10/util/TypeList.h>
#include <caffe2/core/common.h>
+#include <c10/util/Optional.h>
+
#include <memory>
#include <iostream>
#include <type_traits>
@@ -629,10 +631,11 @@
struct TupleType;
using TupleTypePtr = std::shared_ptr<TupleType>;
+using OptNameList = c10::optional<std::vector<std::string>>;
// This type represents a Tuple
struct CAFFE2_API TupleType : public Type {
- static TupleTypePtr create(std::vector<TypePtr> types) {
- return TupleTypePtr(new TupleType( std::move(types) )); // NOLINT(modernize-make-shared)
+ static TupleTypePtr create(std::vector<TypePtr> types, OptNameList names=c10::nullopt) {
+ return TupleTypePtr(new TupleType(std::move(types), std::move(names))); // NOLINT(modernize-make-shared)
}
DEFINE_IS_SUBCLASS(TupleType);
at::ArrayRef<TypePtr> elements() const {
@@ -641,13 +644,25 @@
bool operator==(const Type& rhs) const override {
return compare(rhs, [](const TypePtr a, const TypePtr b) {
return *a == *b;
- });
+ }) && names_ == rhs.expect<TupleType>()->names_;
+ // `compare` guarantees that rhs is always a TupleType, so the
+ // dynamic_cast above always success.
}
- bool isSubtypeOf(const TypePtr rhs) const override {
+ bool isSubtypeOf(const TypePtr rhs_) const override {
+ if (Type::isSubtypeOf(rhs_))
+ return true;
+ auto rhs = rhs_->cast<TupleType>();
+ if (!rhs)
+ return false;
+ // unnamed tuple is not a subtype of nametuple
+ if (!hasNames() && rhs->hasNames())
+ return false;
+ // namedtuple may be a subtype of unnamed tuple
+ bool names_match = !rhs->hasNames() || names() == rhs->names();
// co-variant rules for tuples
- return compare(*rhs, [](const TypePtr a, const TypePtr b) {
+ return names_match && compare(*rhs, [](const TypePtr a, const TypePtr b) {
return a->isSubtypeOf(b);
- }) || Type::isSubtypeOf(rhs);
+ });
}
bool requires_grad() const override {
return std::any_of(elements_.begin(), elements_.end(),
@@ -678,6 +693,12 @@
bool hasFreeVariables() const override {
return has_free_variables_;
}
+ bool hasNames() const {
+ return names_.has_value();
+ }
+ const std::vector<std::string> &names() const {
+ return names_.value();
+ }
at::ArrayRef<TypePtr> containedTypes() const override {
return elements_;
@@ -688,9 +709,10 @@
static const TypeKind Kind = TypeKind::TupleType;
private:
- TupleType(std::vector<TypePtr> elements_)
+ TupleType(std::vector<TypePtr> elements_, OptNameList names)
: Type(TypeKind::TupleType)
- , elements_(std::move(elements_)) {
+ , elements_(std::move(elements_))
+ , names_(std::move(names)) {
has_free_variables_ =
std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
return v->hasFreeVariables();
@@ -710,8 +732,10 @@
}
return true;
}
+
std::vector<TypePtr> elements_;
bool has_free_variables_;
+ OptNameList names_;
};
struct NumberType;
diff --git a/test/test_jit.py b/test/test_jit.py
index 2caf025..051b4ea 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8377,6 +8377,23 @@
self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
"out of range")
+ def test_namedtuple_attr(self):
+ def f(x):
+ return x.max(dim=1).indices + torch.max(x, dim=1).indices
+
+ self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True)
+
+ with self.assertRaisesRegex(RuntimeError, "Unknown attribute to named tuple"):
+ @torch.jit.script
+ def g1(x):
+ return x.max(dim=1).unknown_symbol
+
+ with self.assertRaisesRegex(RuntimeError, "Getting attributes of tuples is not supported"):
+ @torch.jit.script
+ def g2(x):
+ print((x, x, x).__doc__)
+ return x
+
def test_tuple_slicing(self):
def tuple_slice(a):
if bool(a):
diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py
index 97b2ccc..cfc979b 100644
--- a/tools/jit/gen_jit_dispatch.py
+++ b/tools/jit/gen_jit_dispatch.py
@@ -482,7 +482,9 @@
if len(decl['returns']) == 1:
ret_list = jit_type_of(decl['returns'][0])
else:
- ret_list = '({})'.format(', '.join(jit_type_of(r) for r in decl['returns']))
+ def type_maybe_field(r):
+ return '{} {}'.format(jit_type_of(r), r['field_name']) if 'field_name' in r else jit_type_of(r)
+ ret_list = '({})'.format(', '.join(type_maybe_field(r) for r in decl['returns']))
name = decl['name'] if not is_out_variant(decl) else decl['name'][:-4]
constructed_string = 'aten::{}({}) -> {}'.format(name, arg_list, ret_list)
return match_signature(decl, constructed_string, should_match_schema)
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index f189a77..40d16fd 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -1195,9 +1195,9 @@
return n;
}
-Node* Graph::createTuple(at::ArrayRef<Value*> values) {
+Node* Graph::createTuple(at::ArrayRef<Value*> values, c10::OptNameList field_names) {
auto types = fmap(values, [](Value* v) { return v->type(); });
- auto tt = TupleType::create(std::move(types));
+ auto tt = TupleType::create(std::move(types), std::move(field_names));
auto n = create(prim::TupleConstruct, values);
n->output()->setType(tt);
return n;
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 29d70e5..430d385 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1042,7 +1042,7 @@
TORCH_API Node* createUndefined();
TORCH_API Node* createFusionGroup();
TORCH_API Node* createDifferentiableSubgraph();
- TORCH_API Node* createTuple(at::ArrayRef<Value*> values);
+ TORCH_API Node* createTuple(at::ArrayRef<Value*> values, c10::OptNameList field_names=c10::nullopt);
TORCH_API Node* createTupleUnpack(Value* v);
TORCH_API Node* createTupleIndex(Value* tup, int64_t index);
TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end);
diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp
index 8fa2841..c3fc314 100644
--- a/torch/csrc/jit/operator.cpp
+++ b/torch/csrc/jit/operator.cpp
@@ -214,11 +214,11 @@
alias_info = std::move(container);
}
if (is_return) {
- // optionally named return values
+ // optionally field names in return values
if (L.cur().kind == TK_IDENT) {
name = L.next().text();
} else {
- name = "ret" + std::to_string(idx);
+ name = "";
}
} else {
name = L.expect(TK_IDENT).text();
diff --git a/torch/csrc/jit/script/schema_matching.cpp b/torch/csrc/jit/script/schema_matching.cpp
index 97404bc..6eefb74 100644
--- a/torch/csrc/jit/script/schema_matching.cpp
+++ b/torch/csrc/jit/script/schema_matching.cpp
@@ -299,19 +299,31 @@
return c10::nullopt;
}
}
- auto return_types = fmap(schema.returns(), [&](const Argument& r) {
+
+ const auto &returns = schema.returns();
+ auto return_types = fmap(returns, [&](const Argument& r) {
return evalTypeVariables(r.type(), type_env);
});
- return MatchedSchema{std::move(positional_inputs), std::move(return_types)};
+ // Codegen does not support return of namedtuples with undefined field names.
+ // Therefore, either all or none returns has field names.
+ bool return_has_field_names = std::all_of(returns.begin(), returns.end(),
+ [&](const Argument& r) { return r.name().length() > 0; });
+ c10::OptNameList return_field_names = c10::nullopt;
+ if (return_has_field_names) {
+ return_field_names = fmap(returns, [&](const Argument& r) {
+ return r.name();
+ });
+ }
+ return MatchedSchema{std::move(positional_inputs), std::move(return_types), std::move(return_field_names)};
}
// pack outputs of a function following python rules. If there is a single value
// return a SimpleValue, otherwise pack all the values into a Tuple.
-Value* packOutputs(Graph& g, at::ArrayRef<Value*> values) {
+Value* packOutputs(Graph& g, at::ArrayRef<Value*> values, c10::OptNameList field_names) {
if (values.size() == 1) {
return values[0];
}
- return g.insertNode(g.createTuple(values))->output();
+ return g.insertNode(g.createTuple(values, std::move(field_names)))->output();
}
// Given a successful match between operator schema and symbol, emit a node
@@ -332,7 +344,7 @@
// otherwise schema and dispatch are not in sync
getOperation(n);
- return packOutputs(graph, n->outputs());
+ return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
}
static std::string prefixLine(
diff --git a/torch/csrc/jit/script/schema_matching.h b/torch/csrc/jit/script/schema_matching.h
index d93c832..4d11f6e 100644
--- a/torch/csrc/jit/script/schema_matching.h
+++ b/torch/csrc/jit/script/schema_matching.h
@@ -18,6 +18,7 @@
struct MatchedSchema {
std::vector<Value*> inputs;
std::vector<TypePtr> return_types;
+ c10::OptNameList return_field_names;
};
TORCH_API c10::optional<MatchedSchema> tryMatchSchema(
diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp
index fe021e9..b103459 100644
--- a/torch/csrc/jit/script/sugared_value.cpp
+++ b/torch/csrc/jit/script/sugared_value.cpp
@@ -87,6 +87,20 @@
if (getValue()->type()->isSubtypeOf(NumberType::get())) {
throw ErrorReport(loc) << "Cannot call methods on numbers";
}
+ if (getValue()->type()->kind() == TypeKind::TupleType) {
+ auto tuple_type = getValue()->type()->expect<TupleType>();
+ if (!tuple_type->hasNames()) {
+ throw ErrorReport(loc) << "Getting attributes of tuples is not supported";
+ }
+ auto names = tuple_type->names();
+ for (int i = 0; i < names.size(); i++) {
+ if (names[i] == field) {
+ auto r = m.graph()->insertNode(m.graph()->createTupleIndex(getValue(), i))->output();
+ return std::make_shared<SimpleValue>(r);
+ }
+ }
+ throw ErrorReport(loc) << "Unknown attribute to named tuple";
+ }
return std::make_shared<BuiltinFunction>(
Symbol::aten(field), NamedValue(loc, "self", value));
}