Add IMethod interface
Summary:
Expose IMethod interface, which provides a unified interface to either script or python methods backed by torchscript or torchdeploy.
IMethod provides a way to depend on a torch method without depending on a particular runtime implementation such as torchscript or python/deploy.
Test Plan: add unit tests.
Reviewed By: suo
Differential Revision: D29463455
fbshipit-source-id: 903391d9af9fbdd8fcdb096c1a136ec6ac153b7c
diff --git a/test/cpp/api/imethod.cpp b/test/cpp/api/imethod.cpp
new file mode 100644
index 0000000..6ed54c9
--- /dev/null
+++ b/test/cpp/api/imethod.cpp
@@ -0,0 +1,44 @@
+// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
+
+#include <gtest/gtest.h>
+#include <torch/csrc/deploy/deploy.h>
+#include <torch/script.h>
+#include <torch/torch.h>
+
+using namespace ::testing;
+using namespace caffe2;
+
+TEST(IMethodTest, CallMethod) {
+ auto script_model = torch::jit::load(getenv("SIMPLE_JIT"));
+ auto script_method = script_model.get_method("forward");
+
+ torch::deploy::InterpreterManager manager(3);
+ torch::deploy::Package p = manager.load_package(getenv("SIMPLE"));
+ auto py_model = p.load_pickle("model", "model.pkl");
+ torch::deploy::PythonMethodWrapper py_method(py_model, "forward");
+
+ auto input = torch::ones({10, 20});
+ auto output_py = py_method({input});
+ auto output_script = script_method({input});
+ EXPECT_TRUE(output_py.isTensor());
+ EXPECT_TRUE(output_script.isTensor());
+ auto output_py_tensor = output_py.toTensor();
+ auto output_script_tensor = output_script.toTensor();
+
+ EXPECT_TRUE(output_py_tensor.equal(output_script_tensor));
+ EXPECT_EQ(output_py_tensor.numel(), 200);
+}
+
+TEST(IMethodTest, GetArgumentNames) {
+ auto script_model = torch::jit::load(getenv("SIMPLE_JIT"));
+ auto script_method = script_model.get_method("forward");
+
+ torch::deploy::InterpreterManager manager(3);
+ torch::deploy::Package p = manager.load_package(getenv("SIMPLE"));
+ auto py_model = p.load_pickle("model", "model.pkl");
+ torch::deploy::PythonMethodWrapper py_method(py_model, "forward");
+
+ // TODO(whc) implement and test these
+ EXPECT_THROW(script_method.getArgumentNames(), std::runtime_error);
+ EXPECT_THROW(py_method.getArgumentNames(), std::runtime_error);
+}
diff --git a/torch/csrc/api/include/torch/imethod.h b/torch/csrc/api/include/torch/imethod.h
new file mode 100644
index 0000000..68581a3
--- /dev/null
+++ b/torch/csrc/api/include/torch/imethod.h
@@ -0,0 +1,37 @@
+#pragma once
+#include <ATen/core/ivalue.h>
+
+namespace torch {
+
+class IMethod {
+ /*
+ IMethod provides a portable interface for torch methods, whether
+ they are backed by torchscript or python/deploy.
+
+ This is helpful since torchscript methods provide additional information
+ (e.g. FunctionSchema, Graph) which aren't available in pure python methods.
+
+ Higher level APIs should prefer depending on this interface rather
+ than a specific implementation of it, to promote portability and reuse, and
+ avoid unintentional dependencies on e.g. script methods.
+
+ Note: This API is experimental, and may evolve.
+ */
+ public:
+ using IValueList = std::vector<c10::IValue>;
+ using IValueMap = std::unordered_map<std::string, at::IValue>;
+
+ virtual ~IMethod() = default;
+
+ virtual c10::IValue operator()(
+ std::vector<c10::IValue> args,
+ const IValueMap& kwargs = IValueMap()) = 0;
+
+ // Returns an ordered list of argument names, possible in both
+ // script and python methods. This is a more portable dependency
+ // than a ScriptMethod FunctionSchema, which has more information
+ // than can be generally expected from a python method.
+ virtual std::vector<std::string> getArgumentNames() = 0;
+};
+
+} // namespace torch
diff --git a/torch/csrc/deploy/deploy.h b/torch/csrc/deploy/deploy.h
index 640f3da..48bff73 100644
--- a/torch/csrc/deploy/deploy.h
+++ b/torch/csrc/deploy/deploy.h
@@ -2,6 +2,7 @@
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <assert.h>
#include <c10/util/irange.h>
+#include <torch/csrc/api/include/torch/imethod.h>
#include <torch/csrc/deploy/interpreter/interpreter_impl.h>
#include <torch/csrc/jit/serialization/import.h>
#include <fstream>
@@ -231,6 +232,35 @@
friend struct InterpreterManager;
};
+class PythonMethodWrapper : public torch::IMethod {
+ // PythonMethodWrapper is a more specific instance of a
+ // ReplicatedObj which represents a python method, and
+ // is therefore callable and has argument names accessible.
+ public:
+ PythonMethodWrapper(
+ torch::deploy::ReplicatedObj& model,
+ std::string method_name)
+ : model_(std::move(model)), method_name_(std::move(method_name)) {}
+
+ c10::IValue operator()(
+ std::vector<c10::IValue> args,
+ const IValueMap& kwargs = IValueMap()) override {
+ // TODO(whc) ideally, pickle the method itself as replicatedobj, to skip
+ // this lookup each time
+ auto model_session = model_.acquire_session();
+ auto method = model_session.self.attr(method_name_.c_str());
+ return method.call_kwargs(args, kwargs).toIValue();
+ }
+
+ std::vector<std::string> getArgumentNames() override {
+ throw std::runtime_error("getArgumentNames not yet implemented");
+ }
+
+ private:
+ torch::deploy::ReplicatedObj model_;
+ std::string method_name_;
+};
+
struct TORCH_API Package {
// shorthand for getting the object as a pickle resource in the package
ReplicatedObj load_pickle(
diff --git a/torch/csrc/jit/api/method.h b/torch/csrc/jit/api/method.h
index 692c08f..fe12635 100644
--- a/torch/csrc/jit/api/method.h
+++ b/torch/csrc/jit/api/method.h
@@ -3,6 +3,7 @@
#include <ATen/core/function.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
+#include <torch/csrc/api/include/torch/imethod.h>
#include <torch/csrc/jit/api/function_impl.h>
namespace torch {
@@ -18,7 +19,7 @@
// ...
// Note: because Method/Module are exposed to python these
// classes use python method naming conventions
-struct TORCH_API Method {
+struct TORCH_API Method : public torch::IMethod {
Method(ObjectPtr owner, Function* function);
// the module that contains this method.
@@ -30,7 +31,7 @@
c10::IValue operator()(
std::vector<c10::IValue> stack,
- const Kwargs& kwargs = Kwargs());
+ const Kwargs& kwargs = Kwargs()) override;
// Run method async. Invocation on this function would invokes a JIT
// interpreter that executes ops inline, one by one, on caller's thread. A
@@ -57,6 +58,10 @@
return function_->get_executor();
}
+ std::vector<std::string> getArgumentNames() override {
+ throw std::runtime_error("getArgumentNames not yet implemented");
+ }
+
Function& function() const {
return *function_;
}