Document Any (#11580)

Summary:
Documents the `AnyModule` class in the C++ API.

Also changed the API to be friendlier by default. Calling `AnyModule::forward` used to return an `AnyModule::Value` which you had to call `.get<T>()` on to cast to a concrete type. I changed the name of that `forward` method to `any_forward` and instead made `forward` templated on a `ReturnType` template parameter which you can supply to do the `.get<T>` cast for you automatically. I default this parameter to `torch::Tensor` so that it can often be omitted. So where you used to have to write

```cpp
any_module.forward(...).get<int>();
any_module.forward(...).get<torch::Tensor>();
```

you now write

```cpp
any_module.forward<int>(...);
any_module.forward(...);
```

ebetica ezyang soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11580

Differential Revision: D9798626

Pulled By: goldsborough

fbshipit-source-id: 060b4ea28facaffc417f53b80b846a9dff9acb73
diff --git a/test/cpp/api/any.cpp b/test/cpp/api/any.cpp
index ab044b8..9368d4d 100644
--- a/test/cpp/api/any.cpp
+++ b/test/cpp/api/any.cpp
@@ -22,8 +22,9 @@
       }
     };
     AnyModule any(M{});
-    REQUIRE(any.forward().get<int>() == 123);
+    REQUIRE(any.forward<int>() == 123);
   }
+
   SECTION("int(int)") {
     struct M : torch::nn::Module {
       int forward(int x) {
@@ -31,8 +32,9 @@
       }
     };
     AnyModule any(M{});
-    REQUIRE(any.forward(5).get<int>() == 5);
+    REQUIRE(any.forward<int>(5) == 5);
   }
+
   SECTION("const char*(const char*)") {
     struct M : torch::nn::Module {
       const char* forward(const char* x) {
@@ -40,7 +42,7 @@
       }
     };
     AnyModule any(M{});
-    REQUIRE(any.forward("hello").get<const char*>() == std::string("hello"));
+    REQUIRE(any.forward<const char*>("hello") == std::string("hello"));
   }
 
   SECTION("string(int, const double)") {
@@ -51,7 +53,7 @@
     };
     AnyModule any(M{});
     int x = 4;
-    REQUIRE(any.forward(x, 3.14).get<std::string>() == std::string("7"));
+    REQUIRE(any.forward<std::string>(x, 3.14) == std::string("7"));
   }
 
   SECTION("Tensor(string, const string&, string&&)") {
@@ -66,8 +68,8 @@
     };
     AnyModule any(M{});
     REQUIRE(
-        any.forward(std::string("a"), std::string("ab"), std::string("abc"))
-            .get<torch::Tensor>()
+        any.forward(
+               std::string("a"), std::string("ab"), std::string("abc"))
             .sum()
             .toCInt() == 6);
   }
@@ -181,7 +183,7 @@
         any.forward<int>(5),
         StartsWith("Cannot call forward() on an empty AnyModule"));
   }
-  SECTION("can move assign differentm modules") {
+  SECTION("can move assign different modules") {
     struct M : torch::nn::Module {
       std::string forward(int x) {
         return std::to_string(x);
@@ -196,10 +198,10 @@
     REQUIRE(any.is_empty());
     any = std::make_shared<M>();
     REQUIRE(!any.is_empty());
-    REQUIRE(any.forward(5).get<std::string>() == "5");
+    REQUIRE(any.forward<std::string>(5) == "5");
     any = std::make_shared<N>();
     REQUIRE(!any.is_empty());
-    REQUIRE(any.forward(5.0f).get<int>() == 8);
+    REQUIRE(any.forward<int>(5.0f) == 8);
   }
   SECTION("constructs from ModuleHolder") {
     struct MImpl : torch::nn::Module {
@@ -218,6 +220,10 @@
     AnyModule any(M{5});
     REQUIRE(any.get<MImpl>().value == 5);
     REQUIRE(any.get<M>()->value == 5);
+
+    AnyModule module(Linear(3, 4));
+    std::shared_ptr<Module> ptr = module.ptr();
+    Linear linear(module.get<Linear>());
   }
   SECTION("converts autograd::Variable to torch::Tensor correctly") {
     struct M : torch::nn::Module {
@@ -232,12 +238,10 @@
       AnyModule any(M{});
       REQUIRE(
           any.forward(torch::autograd::Variable(torch::ones(5)))
-              .get<torch::Tensor>()
               .sum()
               .toCFloat() == 5);
       // at::Tensors that are not variables work too.
-      REQUIRE(
-          any.forward(at::ones(5)).get<torch::Tensor>().sum().toCFloat() == 5);
+      REQUIRE(any.forward(at::ones(5)).sum().toCFloat() == 5);
     }
   }
 }
diff --git a/torch/csrc/api/include/torch/nn/modules/any.h b/torch/csrc/api/include/torch/nn/modules/any.h
index 920aea6..6261494 100644
--- a/torch/csrc/api/include/torch/nn/modules/any.h
+++ b/torch/csrc/api/include/torch/nn/modules/any.h
@@ -21,10 +21,87 @@
 namespace torch {
 namespace nn {
 
-/// A class to store a type erased module, whose `forward()` method can be
-/// invoked, with dynamic type checking. An `AnyModule` has an empty state, into
-/// which it is default constructed. `is_empty()` can be used to query whether
-/// the `AnyModule` is empty.
+/// Stores a type erased `Module`.
+///
+/// The PyTorch C++ API does not impose an interface on the signature of
+/// `forward()` in `Module` subclasses. This gives you complete freedom to
+/// design your `forward()` methods to your liking. However, this also means
+/// there is no unified base type you could store in order to call `forward()`
+/// polymorphically for any module. This is where the `AnyModule` comes in.
+/// Instead of inheritance, it relies on type erasure for polymorphism.
+///
+/// An `AnyModule` can store any `nn::Module` subclass that provides a
+/// `forward()` method. This `forward()` may accept any types and return any
+/// type. Once stored in an `AnyModule`, you can invoke the underlying module's
+/// `forward()` by calling `AnyModule::forward()` with the arguments you would
+/// supply to the stored module (though see one important limitation below).
+/// Example:
+///
+/// \rst
+/// .. code-block::
+///   struct GenericTrainer {
+///     torch::nn::AnyModule module;
+///
+///     void train(torch::Tensor input) {
+///       module.forward(input);
+///     }
+///   };
+///
+///   GenericTrainer trainer1{torch::nn::Linear(3, 4)};
+///   GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
+/// \endrst
+///
+/// As `AnyModule` erases the static type of the stored module (and its
+/// `forward()` method) to achieve polymorphism, type checking of arguments is
+/// moved to runtime. That is, passing an argument with an incorrect type to an
+/// `AnyModule` will compile, but throw an exception at runtime:
+///
+/// \rst
+/// .. code-block::
+///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
+///   // Linear takes a tensor as input, but we are passing an integer.
+///   // This will compile, but throw a `torch::Error` exception at runtime.
+///   module.forward(123);
+/// \endrst
+///
+/// \rst
+/// .. attention::
+///   One noteworthy limitation of `AnyModule` is that its `forward()` method
+///   does not support implicit conversion of argument types. For example, if
+///   the stored module's `forward()` method accepts a `float` and you call
+///   `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw
+///   an exception.
+/// \endrst
+///
+/// The return type of the `AnyModule`'s `forward()` method is controlled via
+/// the first template argument to `AnyModule::forward()`. It defaults to
+/// `torch::Tensor`. To change it, you can write `any_module.forward<int>()`,
+/// for example.
+///
+/// \rst
+/// .. code-block::
+///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
+///   auto output = module.forward(torch::ones({2, 3}));
+///
+///   struct IntModule {
+///     int forward(int x) { return x; }
+///   };
+///   torch::nn::AnyModule module(IntModule{});
+///   int output = module.forward<int>(5);
+/// \endrst
+///
+/// The only other method an `AnyModule` provides access to on the stored
+/// module is `clone()`. However, you may acquire a handle on the module via
+/// `.ptr()`, which returns a `shared_ptr<nn::Module>`. Further, if you know
+/// the concrete type of the stored module, you can get a concrete handle to it
+/// using `.get<T>()` where `T` is the concrete module type.
+///
+/// \rst
+/// .. code-block::
+///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
+///   std::shared_ptr<nn::Module> ptr = module.ptr();
+///   torch::nn::Linear linear(module.get<torch::nn::Linear>());
+/// \endrst
 class AnyModule {
  public:
   /// A type-erased value.
@@ -69,7 +146,13 @@
   /// returns the return value as an `Value`. Use this method when chaining
   /// `AnyModule`s in a loop.
   template <typename... ArgumentTypes>
-  Value forward(ArgumentTypes&&... arguments);
+  Value any_forward(ArgumentTypes&&... arguments);
+
+  /// Invokes `forward()` on the contained module with the given arguments, and
+  /// casts the returned `Value` to the supplied `ReturnType` (which defaults to
+  /// `torch::Tensor`).
+  template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
+  ReturnType forward(ArgumentTypes&&... arguments);
 
   /// Attempts to cast the underlying module to the given module type. Throws an
   /// exception if the types do not match.
@@ -358,7 +441,7 @@
 }
 
 template <typename... ArgumentTypes>
-AnyModule::Value AnyModule::forward(ArgumentTypes&&... arguments) {
+AnyModule::Value AnyModule::any_forward(ArgumentTypes&&... arguments) {
   AT_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
   std::vector<Value> values;
   values.reserve(sizeof...(ArgumentTypes));
@@ -368,6 +451,12 @@
   return content_->forward(std::move(values));
 }
 
+template <typename ReturnType, typename... ArgumentTypes>
+ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
+  return any_forward(std::forward<ArgumentTypes>(arguments)...)
+      .template get<ReturnType>();
+}
+
 template <typename T, typename>
 T& AnyModule::get() {
   AT_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
@@ -393,9 +482,9 @@
 template <typename T, typename>
 std::shared_ptr<T> AnyModule::ptr() const {
   AT_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
-  /// Call get() but discard the value, just to do the type checking.
+  // Call get() but discard the value, just to do the type checking.
   get_<T>();
-  return std::static_pointer_cast<T>(ptr());
+  return std::dynamic_pointer_cast<T>(ptr());
 }
 
 inline const std::type_info& AnyModule::type_info() const {
diff --git a/torch/csrc/api/include/torch/nn/modules/sequential.h b/torch/csrc/api/include/torch/nn/modules/sequential.h
index 9f3f7a0..384afd1 100644
--- a/torch/csrc/api/include/torch/nn/modules/sequential.h
+++ b/torch/csrc/api/include/torch/nn/modules/sequential.h
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <torch/detail/static.h>
+#include <torch/nn/cloneable.h>
 #include <torch/nn/module.h>
 #include <torch/nn/modules/any.h>
 #include <torch/nn/pimpl.h>
@@ -57,10 +58,11 @@
     AT_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
 
     auto iterator = modules_.begin();
-    auto input = iterator->forward(std::forward<ArgumentTypes>(arguments)...);
+    auto input =
+        iterator->any_forward(std::forward<ArgumentTypes>(arguments)...);
 
     for (++iterator; iterator != modules_.end(); ++iterator) {
-      input = iterator->forward(std::move(input));
+      input = iterator->any_forward(std::move(input));
     }
 
     // Check the return value and give a nice error message if the requsted