Document Any (#11580)
Summary:
Documents the `AnyModule` class in the C++ API.
Also changed the API to be friendlier by default. Calling `AnyModule::forward` used to return an `AnyModule::Value` which you had to call `.get<T>()` on to cast to a concrete type. I changed the name of that `forward` method to `any_forward` and instead made `forward` templated on a `ReturnType` template parameter which you can supply to do the `.get<T>` cast for you automatically. I default this parameter to `torch::Tensor` so that it can often be omitted. So where you used to have to write
```cpp
any_module.forward(...).get<int>();
any_module.forward(...).get<torch::Tensor>();
```
you now write
```cpp
any_module.forward<int>(...);
any_module.forward(...);
```
ebetica ezyang soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11580
Differential Revision: D9798626
Pulled By: goldsborough
fbshipit-source-id: 060b4ea28facaffc417f53b80b846a9dff9acb73
diff --git a/test/cpp/api/any.cpp b/test/cpp/api/any.cpp
index ab044b8..9368d4d 100644
--- a/test/cpp/api/any.cpp
+++ b/test/cpp/api/any.cpp
@@ -22,8 +22,9 @@
}
};
AnyModule any(M{});
- REQUIRE(any.forward().get<int>() == 123);
+ REQUIRE(any.forward<int>() == 123);
}
+
SECTION("int(int)") {
struct M : torch::nn::Module {
int forward(int x) {
@@ -31,8 +32,9 @@
}
};
AnyModule any(M{});
- REQUIRE(any.forward(5).get<int>() == 5);
+ REQUIRE(any.forward<int>(5) == 5);
}
+
SECTION("const char*(const char*)") {
struct M : torch::nn::Module {
const char* forward(const char* x) {
@@ -40,7 +42,7 @@
}
};
AnyModule any(M{});
- REQUIRE(any.forward("hello").get<const char*>() == std::string("hello"));
+ REQUIRE(any.forward<const char*>("hello") == std::string("hello"));
}
SECTION("string(int, const double)") {
@@ -51,7 +53,7 @@
};
AnyModule any(M{});
int x = 4;
- REQUIRE(any.forward(x, 3.14).get<std::string>() == std::string("7"));
+ REQUIRE(any.forward<std::string>(x, 3.14) == std::string("7"));
}
SECTION("Tensor(string, const string&, string&&)") {
@@ -66,8 +68,8 @@
};
AnyModule any(M{});
REQUIRE(
- any.forward(std::string("a"), std::string("ab"), std::string("abc"))
- .get<torch::Tensor>()
+ any.forward(
+ std::string("a"), std::string("ab"), std::string("abc"))
.sum()
.toCInt() == 6);
}
@@ -181,7 +183,7 @@
any.forward<int>(5),
StartsWith("Cannot call forward() on an empty AnyModule"));
}
- SECTION("can move assign differentm modules") {
+ SECTION("can move assign different modules") {
struct M : torch::nn::Module {
std::string forward(int x) {
return std::to_string(x);
@@ -196,10 +198,10 @@
REQUIRE(any.is_empty());
any = std::make_shared<M>();
REQUIRE(!any.is_empty());
- REQUIRE(any.forward(5).get<std::string>() == "5");
+ REQUIRE(any.forward<std::string>(5) == "5");
any = std::make_shared<N>();
REQUIRE(!any.is_empty());
- REQUIRE(any.forward(5.0f).get<int>() == 8);
+ REQUIRE(any.forward<int>(5.0f) == 8);
}
SECTION("constructs from ModuleHolder") {
struct MImpl : torch::nn::Module {
@@ -218,6 +220,10 @@
AnyModule any(M{5});
REQUIRE(any.get<MImpl>().value == 5);
REQUIRE(any.get<M>()->value == 5);
+
+ AnyModule module(Linear(3, 4));
+ std::shared_ptr<Module> ptr = module.ptr();
+ Linear linear(module.get<Linear>());
}
SECTION("converts autograd::Variable to torch::Tensor correctly") {
struct M : torch::nn::Module {
@@ -232,12 +238,10 @@
AnyModule any(M{});
REQUIRE(
any.forward(torch::autograd::Variable(torch::ones(5)))
- .get<torch::Tensor>()
.sum()
.toCFloat() == 5);
// at::Tensors that are not variables work too.
- REQUIRE(
- any.forward(at::ones(5)).get<torch::Tensor>().sum().toCFloat() == 5);
+ REQUIRE(any.forward(at::ones(5)).sum().toCFloat() == 5);
}
}
}
diff --git a/torch/csrc/api/include/torch/nn/modules/any.h b/torch/csrc/api/include/torch/nn/modules/any.h
index 920aea6..6261494 100644
--- a/torch/csrc/api/include/torch/nn/modules/any.h
+++ b/torch/csrc/api/include/torch/nn/modules/any.h
@@ -21,10 +21,87 @@
namespace torch {
namespace nn {
-/// A class to store a type erased module, whose `forward()` method can be
-/// invoked, with dynamic type checking. An `AnyModule` has an empty state, into
-/// which it is default constructed. `is_empty()` can be used to query whether
-/// the `AnyModule` is empty.
+/// Stores a type erased `Module`.
+///
+/// The PyTorch C++ API does not impose an interface on the signature of
+/// `forward()` in `Module` subclasses. This gives you complete freedom to
+/// design your `forward()` methods to your liking. However, this also means
+/// there is no unified base type you could store in order to call `forward()`
+/// polymorphically for any module. This is where the `AnyModule` comes in.
+/// Instead of inheritance, it relies on type erasure for polymorphism.
+///
+/// An `AnyModule` can store any `nn::Module` subclass that provides a
+/// `forward()` method. This `forward()` may accept any types and return any
+/// type. Once stored in an `AnyModule`, you can invoke the underlying module's
+/// `forward()` by calling `AnyModule::forward()` with the arguments you would
+/// supply to the stored module (though see one important limitation below).
+/// Example:
+///
+/// \rst
+/// .. code-block::
+/// struct GenericTrainer {
+/// torch::nn::AnyModule module;
+///
+/// void train(torch::Tensor input) {
+/// module.forward(input);
+/// }
+/// };
+///
+/// GenericTrainer trainer1{torch::nn::Linear(3, 4)};
+/// GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
+/// \endrst
+///
+/// As `AnyModule` erases the static type of the stored module (and its
+/// `forward()` method) to achieve polymorphism, type checking of arguments is
+/// moved to runtime. That is, passing an argument with an incorrect type to an
+/// `AnyModule` will compile, but throw an exception at runtime:
+///
+/// \rst
+/// .. code-block::
+/// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
+/// // Linear takes a tensor as input, but we are passing an integer.
+/// // This will compile, but throw a `torch::Error` exception at runtime.
+/// module.forward(123);
+/// \endrst
+///
+/// \rst
+/// .. attention::
+/// One noteworthy limitation of `AnyModule` is that its `forward()` method
+/// does not support implicit conversion of argument types. For example, if
+/// the stored module's `forward()` method accepts a `float` and you call
+/// `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw
+/// an exception.
+/// \endrst
+///
+/// The return type of the `AnyModule`'s `forward()` method is controlled via
+/// the first template argument to `AnyModule::forward()`. It defaults to
+/// `torch::Tensor`. To change it, you can write `any_module.forward<int>()`,
+/// for example.
+///
+/// \rst
+/// .. code-block::
+/// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
+/// auto output = module.forward(torch::ones({2, 3}));
+///
+/// struct IntModule {
+/// int forward(int x) { return x; }
+/// };
+/// torch::nn::AnyModule module(IntModule{});
+/// int output = module.forward<int>(5);
+/// \endrst
+///
+/// The only other method an `AnyModule` provides access to on the stored
+/// module is `clone()`. However, you may acquire a handle on the module via
+/// `.ptr()`, which returns a `shared_ptr<nn::Module>`. Further, if you know
+/// the concrete type of the stored module, you can get a concrete handle to it
+/// using `.get<T>()` where `T` is the concrete module type.
+///
+/// \rst
+/// .. code-block::
+/// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
+/// std::shared_ptr<nn::Module> ptr = module.ptr();
+/// torch::nn::Linear linear(module.get<torch::nn::Linear>());
+/// \endrst
class AnyModule {
public:
/// A type-erased value.
@@ -69,7 +146,13 @@
/// returns the return value as an `Value`. Use this method when chaining
/// `AnyModule`s in a loop.
template <typename... ArgumentTypes>
- Value forward(ArgumentTypes&&... arguments);
+ Value any_forward(ArgumentTypes&&... arguments);
+
+ /// Invokes `forward()` on the contained module with the given arguments, and
+ /// casts the returned `Value` to the supplied `ReturnType` (which defaults to
+ /// `torch::Tensor`).
+ template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
+ ReturnType forward(ArgumentTypes&&... arguments);
/// Attempts to cast the underlying module to the given module type. Throws an
/// exception if the types do not match.
@@ -358,7 +441,7 @@
}
template <typename... ArgumentTypes>
-AnyModule::Value AnyModule::forward(ArgumentTypes&&... arguments) {
+AnyModule::Value AnyModule::any_forward(ArgumentTypes&&... arguments) {
AT_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
std::vector<Value> values;
values.reserve(sizeof...(ArgumentTypes));
@@ -368,6 +451,12 @@
return content_->forward(std::move(values));
}
+template <typename ReturnType, typename... ArgumentTypes>
+ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
+ return any_forward(std::forward<ArgumentTypes>(arguments)...)
+ .template get<ReturnType>();
+}
+
template <typename T, typename>
T& AnyModule::get() {
AT_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
@@ -393,9 +482,9 @@
template <typename T, typename>
std::shared_ptr<T> AnyModule::ptr() const {
AT_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
- /// Call get() but discard the value, just to do the type checking.
+ // Call get() but discard the value, just to do the type checking.
get_<T>();
- return std::static_pointer_cast<T>(ptr());
+ return std::dynamic_pointer_cast<T>(ptr());
}
inline const std::type_info& AnyModule::type_info() const {
diff --git a/torch/csrc/api/include/torch/nn/modules/sequential.h b/torch/csrc/api/include/torch/nn/modules/sequential.h
index 9f3f7a0..384afd1 100644
--- a/torch/csrc/api/include/torch/nn/modules/sequential.h
+++ b/torch/csrc/api/include/torch/nn/modules/sequential.h
@@ -1,6 +1,7 @@
#pragma once
#include <torch/detail/static.h>
+#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/any.h>
#include <torch/nn/pimpl.h>
@@ -57,10 +58,11 @@
AT_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
auto iterator = modules_.begin();
- auto input = iterator->forward(std::forward<ArgumentTypes>(arguments)...);
+ auto input =
+ iterator->any_forward(std::forward<ArgumentTypes>(arguments)...);
for (++iterator; iterator != modules_.end(); ++iterator) {
- input = iterator->forward(std::move(input));
+ input = iterator->any_forward(std::move(input));
}
// Check the return value and give a nice error message if the requsted