Support named tuple return from operators on JIT (#16253)

Summary:
Fixes: https://github.com/pytorch/pytorch/issues/16233

The following changes are made:
- Modify `TupleType` to store optional field names
- Modify schema matching to return fill in those field names when creating  `TupleType` as return type.
- Modify codegen of JIT to copy field names to schema string
- Modify `SchemaParser` to set field names of returned schema.
- Modify `SimpleValue::attr` to emit tuple indexing for named tuple.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16253

Reviewed By: ezyang

Differential Revision: D13954298

Pulled By: zdevito

fbshipit-source-id: 247d483d78a0c9c12d1ba36e1f1ec6c3f1a3007b
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index abaa41d..f23bae4 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -8,6 +8,8 @@
 #include <c10/util/TypeList.h>
 #include <caffe2/core/common.h>
 
+#include <c10/util/Optional.h>
+
 #include <memory>
 #include <iostream>
 #include <type_traits>
@@ -629,10 +631,11 @@
 
 struct TupleType;
 using TupleTypePtr = std::shared_ptr<TupleType>;
+using OptNameList = c10::optional<std::vector<std::string>>;
 // This type represents a Tuple
 struct CAFFE2_API TupleType : public Type {
-  static TupleTypePtr create(std::vector<TypePtr> types) {
-    return TupleTypePtr(new TupleType( std::move(types) )); // NOLINT(modernize-make-shared)
+  static TupleTypePtr create(std::vector<TypePtr> types, OptNameList names=c10::nullopt) {
+    return TupleTypePtr(new TupleType(std::move(types), std::move(names))); // NOLINT(modernize-make-shared)
   }
   DEFINE_IS_SUBCLASS(TupleType);
   at::ArrayRef<TypePtr> elements() const {
@@ -641,13 +644,25 @@
   bool operator==(const Type& rhs) const override {
     return compare(rhs, [](const TypePtr a, const TypePtr b) {
       return *a == *b;
-    });
+    }) && names_ == rhs.expect<TupleType>()->names_;
+    // `compare` guarantees that rhs is always a TupleType, so the
+    // dynamic_cast above always success.
   }
-  bool isSubtypeOf(const TypePtr rhs) const override {
+  bool isSubtypeOf(const TypePtr rhs_) const override {
+    if (Type::isSubtypeOf(rhs_))
+      return true;
+    auto rhs = rhs_->cast<TupleType>();
+    if (!rhs)
+      return false;
+    // unnamed tuple is not a subtype of nametuple
+    if (!hasNames() && rhs->hasNames())
+      return false;
+    // namedtuple may be a subtype of unnamed tuple
+    bool names_match = !rhs->hasNames() || names() == rhs->names();
     // co-variant rules for tuples
-    return compare(*rhs, [](const TypePtr a, const TypePtr b) {
+    return names_match && compare(*rhs, [](const TypePtr a, const TypePtr b) {
       return a->isSubtypeOf(b);
-    }) || Type::isSubtypeOf(rhs);
+    });
   }
   bool requires_grad() const override {
     return std::any_of(elements_.begin(), elements_.end(),
@@ -678,6 +693,12 @@
   bool hasFreeVariables() const override {
     return has_free_variables_;
   }
+  bool hasNames() const {
+    return names_.has_value();
+  }
+  const std::vector<std::string> &names() const {
+    return names_.value();
+  }
 
   at::ArrayRef<TypePtr> containedTypes() const override {
     return elements_;
@@ -688,9 +709,10 @@
 
   static const TypeKind Kind = TypeKind::TupleType;
 private:
-  TupleType(std::vector<TypePtr> elements_)
+  TupleType(std::vector<TypePtr> elements_, OptNameList names)
   : Type(TypeKind::TupleType)
-  , elements_(std::move(elements_)) {
+  , elements_(std::move(elements_))
+  , names_(std::move(names)) {
     has_free_variables_ =
         std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
           return v->hasFreeVariables();
@@ -710,8 +732,10 @@
     }
     return true;
   }
+
   std::vector<TypePtr> elements_;
   bool has_free_variables_;
+  OptNameList names_;
 };
 
 struct NumberType;
diff --git a/test/test_jit.py b/test/test_jit.py
index 2caf025..051b4ea 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8377,6 +8377,23 @@
         self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
                                     "out of range")
 
+    def test_namedtuple_attr(self):
+        def f(x):
+            return x.max(dim=1).indices + torch.max(x, dim=1).indices
+
+        self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True)
+
+        with self.assertRaisesRegex(RuntimeError, "Unknown attribute to named tuple"):
+            @torch.jit.script
+            def g1(x):
+                return x.max(dim=1).unknown_symbol
+
+        with self.assertRaisesRegex(RuntimeError, "Getting attributes of tuples is not supported"):
+            @torch.jit.script
+            def g2(x):
+                print((x, x, x).__doc__)
+                return x
+
     def test_tuple_slicing(self):
         def tuple_slice(a):
             if bool(a):
diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py
index 97b2ccc..cfc979b 100644
--- a/tools/jit/gen_jit_dispatch.py
+++ b/tools/jit/gen_jit_dispatch.py
@@ -482,7 +482,9 @@
     if len(decl['returns']) == 1:
         ret_list = jit_type_of(decl['returns'][0])
     else:
-        ret_list = '({})'.format(', '.join(jit_type_of(r) for r in decl['returns']))
+        def type_maybe_field(r):
+            return '{} {}'.format(jit_type_of(r), r['field_name']) if 'field_name' in r else jit_type_of(r)
+        ret_list = '({})'.format(', '.join(type_maybe_field(r) for r in decl['returns']))
     name = decl['name'] if not is_out_variant(decl) else decl['name'][:-4]
     constructed_string = 'aten::{}({}) -> {}'.format(name, arg_list, ret_list)
     return match_signature(decl, constructed_string, should_match_schema)
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index f189a77..40d16fd 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -1195,9 +1195,9 @@
   return n;
 }
 
-Node* Graph::createTuple(at::ArrayRef<Value*> values) {
+Node* Graph::createTuple(at::ArrayRef<Value*> values, c10::OptNameList field_names) {
   auto types = fmap(values, [](Value* v) { return v->type(); });
-  auto tt = TupleType::create(std::move(types));
+  auto tt = TupleType::create(std::move(types), std::move(field_names));
   auto n = create(prim::TupleConstruct, values);
   n->output()->setType(tt);
   return n;
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 29d70e5..430d385 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1042,7 +1042,7 @@
   TORCH_API Node* createUndefined();
   TORCH_API Node* createFusionGroup();
   TORCH_API Node* createDifferentiableSubgraph();
-  TORCH_API Node* createTuple(at::ArrayRef<Value*> values);
+  TORCH_API Node* createTuple(at::ArrayRef<Value*> values, c10::OptNameList field_names=c10::nullopt);
   TORCH_API Node* createTupleUnpack(Value* v);
   TORCH_API Node* createTupleIndex(Value* tup, int64_t index);
   TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end);
diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp
index 8fa2841..c3fc314 100644
--- a/torch/csrc/jit/operator.cpp
+++ b/torch/csrc/jit/operator.cpp
@@ -214,11 +214,11 @@
       alias_info = std::move(container);
     }
     if (is_return) {
-      // optionally named return values
+      // optionally field names in return values
       if (L.cur().kind == TK_IDENT) {
         name = L.next().text();
       } else {
-        name = "ret" + std::to_string(idx);
+        name = "";
       }
     } else {
       name = L.expect(TK_IDENT).text();
diff --git a/torch/csrc/jit/script/schema_matching.cpp b/torch/csrc/jit/script/schema_matching.cpp
index 97404bc..6eefb74 100644
--- a/torch/csrc/jit/script/schema_matching.cpp
+++ b/torch/csrc/jit/script/schema_matching.cpp
@@ -299,19 +299,31 @@
       return c10::nullopt;
     }
   }
-  auto return_types = fmap(schema.returns(), [&](const Argument& r) {
+
+  const auto &returns = schema.returns();
+  auto return_types = fmap(returns, [&](const Argument& r) {
     return evalTypeVariables(r.type(), type_env);
   });
-  return MatchedSchema{std::move(positional_inputs), std::move(return_types)};
+  // Codegen does not support return of namedtuples with undefined field names.
+  // Therefore, either all or none returns has field names.
+  bool return_has_field_names = std::all_of(returns.begin(), returns.end(),
+    [&](const Argument& r) { return r.name().length() > 0; });
+  c10::OptNameList return_field_names = c10::nullopt;
+  if (return_has_field_names) {
+    return_field_names = fmap(returns, [&](const Argument& r) {
+      return r.name();
+    });
+  }
+  return MatchedSchema{std::move(positional_inputs), std::move(return_types), std::move(return_field_names)};
 }
 
 // pack outputs of a function following python rules. If there is a single value
 // return a SimpleValue, otherwise pack all the values into a Tuple.
-Value* packOutputs(Graph& g, at::ArrayRef<Value*> values) {
+Value* packOutputs(Graph& g, at::ArrayRef<Value*> values, c10::OptNameList field_names) {
   if (values.size() == 1) {
     return values[0];
   }
-  return g.insertNode(g.createTuple(values))->output();
+  return g.insertNode(g.createTuple(values, std::move(field_names)))->output();
 }
 
 // Given a successful match between operator schema and symbol, emit a node
@@ -332,7 +344,7 @@
   // otherwise schema and dispatch are not in sync
   getOperation(n);
 
-  return packOutputs(graph, n->outputs());
+  return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
 }
 
 static std::string prefixLine(
diff --git a/torch/csrc/jit/script/schema_matching.h b/torch/csrc/jit/script/schema_matching.h
index d93c832..4d11f6e 100644
--- a/torch/csrc/jit/script/schema_matching.h
+++ b/torch/csrc/jit/script/schema_matching.h
@@ -18,6 +18,7 @@
 struct MatchedSchema {
   std::vector<Value*> inputs;
   std::vector<TypePtr> return_types;
+  c10::OptNameList return_field_names;
 };
 
 TORCH_API c10::optional<MatchedSchema> tryMatchSchema(
diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp
index fe021e9..b103459 100644
--- a/torch/csrc/jit/script/sugared_value.cpp
+++ b/torch/csrc/jit/script/sugared_value.cpp
@@ -87,6 +87,20 @@
   if (getValue()->type()->isSubtypeOf(NumberType::get())) {
     throw ErrorReport(loc) << "Cannot call methods on numbers";
   }
+  if (getValue()->type()->kind() == TypeKind::TupleType) {
+    auto tuple_type = getValue()->type()->expect<TupleType>();
+    if (!tuple_type->hasNames()) {
+      throw ErrorReport(loc) << "Getting attributes of tuples is not supported";
+    }
+    auto names = tuple_type->names();
+    for (int i = 0; i < names.size(); i++) {
+      if (names[i] == field) {
+        auto r = m.graph()->insertNode(m.graph()->createTupleIndex(getValue(), i))->output();
+        return std::make_shared<SimpleValue>(r);
+      }
+    }
+    throw ErrorReport(loc) << "Unknown attribute to named tuple";
+  }
   return std::make_shared<BuiltinFunction>(
       Symbol::aten(field), NamedValue(loc, "self", value));
 }