Add IMethod interface

Summary:
Expose IMethod interface, which provides a unified interface to either script or python methods backed by torchscript or torchdeploy.

IMethod provides a way to depend on a torch method without depending on a particular runtime implementation such as torchscript or python/deploy.

Test Plan: add unit tests.

Reviewed By: suo

Differential Revision: D29463455

fbshipit-source-id: 903391d9af9fbdd8fcdb096c1a136ec6ac153b7c
diff --git a/test/cpp/api/imethod.cpp b/test/cpp/api/imethod.cpp
new file mode 100644
index 0000000..6ed54c9
--- /dev/null
+++ b/test/cpp/api/imethod.cpp
@@ -0,0 +1,44 @@
+// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
+
+#include <gtest/gtest.h>
+#include <torch/csrc/deploy/deploy.h>
+#include <torch/script.h>
+#include <torch/torch.h>
+
+using namespace ::testing;
+using namespace caffe2;
+
+TEST(IMethodTest, CallMethod) {
+  auto script_model = torch::jit::load(getenv("SIMPLE_JIT"));
+  auto script_method = script_model.get_method("forward");
+
+  torch::deploy::InterpreterManager manager(3);
+  torch::deploy::Package p = manager.load_package(getenv("SIMPLE"));
+  auto py_model = p.load_pickle("model", "model.pkl");
+  torch::deploy::PythonMethodWrapper py_method(py_model, "forward");
+
+  auto input = torch::ones({10, 20});
+  auto output_py = py_method({input});
+  auto output_script = script_method({input});
+  EXPECT_TRUE(output_py.isTensor());
+  EXPECT_TRUE(output_script.isTensor());
+  auto output_py_tensor = output_py.toTensor();
+  auto output_script_tensor = output_script.toTensor();
+
+  EXPECT_TRUE(output_py_tensor.equal(output_script_tensor));
+  EXPECT_EQ(output_py_tensor.numel(), 200);
+}
+
+TEST(IMethodTest, GetArgumentNames) {
+  auto script_model = torch::jit::load(getenv("SIMPLE_JIT"));
+  auto script_method = script_model.get_method("forward");
+
+  torch::deploy::InterpreterManager manager(3);
+  torch::deploy::Package p = manager.load_package(getenv("SIMPLE"));
+  auto py_model = p.load_pickle("model", "model.pkl");
+  torch::deploy::PythonMethodWrapper py_method(py_model, "forward");
+
+  // TODO(whc) implement and test these
+  EXPECT_THROW(script_method.getArgumentNames(), std::runtime_error);
+  EXPECT_THROW(py_method.getArgumentNames(), std::runtime_error);
+}
diff --git a/torch/csrc/api/include/torch/imethod.h b/torch/csrc/api/include/torch/imethod.h
new file mode 100644
index 0000000..68581a3
--- /dev/null
+++ b/torch/csrc/api/include/torch/imethod.h
@@ -0,0 +1,37 @@
+#pragma once
+#include <ATen/core/ivalue.h>
+
+namespace torch {
+
+class IMethod {
+  /*
+  IMethod provides a portable interface for torch methods, whether
+  they are backed by torchscript or python/deploy.
+
+  This is helpful since torchscript methods provide additional information
+  (e.g. FunctionSchema, Graph) which aren't available in pure python methods.
+
+  Higher level APIs should prefer depending on this interface rather
+  than a specific implementation of it, to promote portability and reuse, and
+  avoid unintentional dependencies on e.g. script methods.
+
+  Note: This API is experimental, and may evolve.
+  */
+ public:
+  using IValueList = std::vector<c10::IValue>;
+  using IValueMap = std::unordered_map<std::string, at::IValue>;
+
+  virtual ~IMethod() = default;
+
+  virtual c10::IValue operator()(
+      std::vector<c10::IValue> args,
+      const IValueMap& kwargs = IValueMap()) = 0;
+
+  // Returns an ordered list of argument names, possible in both
+  // script and python methods.  This is a more portable dependency
+  // than a ScriptMethod FunctionSchema, which has more information
+  // than can be generally expected from a python method.
+  virtual std::vector<std::string> getArgumentNames() = 0;
+};
+
+} // namespace torch
diff --git a/torch/csrc/deploy/deploy.h b/torch/csrc/deploy/deploy.h
index 640f3da..48bff73 100644
--- a/torch/csrc/deploy/deploy.h
+++ b/torch/csrc/deploy/deploy.h
@@ -2,6 +2,7 @@
 // NOLINTNEXTLINE(modernize-deprecated-headers)
 #include <assert.h>
 #include <c10/util/irange.h>
+#include <torch/csrc/api/include/torch/imethod.h>
 #include <torch/csrc/deploy/interpreter/interpreter_impl.h>
 #include <torch/csrc/jit/serialization/import.h>
 #include <fstream>
@@ -231,6 +232,35 @@
   friend struct InterpreterManager;
 };
 
+class PythonMethodWrapper : public torch::IMethod {
+  // PythonMethodWrapper is a more specific instance of a
+  // ReplicatedObj which represents a python method, and
+  // is therefore callable and has argument names accessible.
+ public:
+  PythonMethodWrapper(
+      torch::deploy::ReplicatedObj& model,
+      std::string method_name)
+      : model_(std::move(model)), method_name_(std::move(method_name)) {}
+
+  c10::IValue operator()(
+      std::vector<c10::IValue> args,
+      const IValueMap& kwargs = IValueMap()) override {
+    // TODO(whc) ideally, pickle the method itself as replicatedobj, to skip
+    // this lookup each time
+    auto model_session = model_.acquire_session();
+    auto method = model_session.self.attr(method_name_.c_str());
+    return method.call_kwargs(args, kwargs).toIValue();
+  }
+
+  std::vector<std::string> getArgumentNames() override {
+    throw std::runtime_error("getArgumentNames not yet implemented");
+  }
+
+ private:
+  torch::deploy::ReplicatedObj model_;
+  std::string method_name_;
+};
+
 struct TORCH_API Package {
   // shorthand for getting the object as a pickle resource in the package
   ReplicatedObj load_pickle(
diff --git a/torch/csrc/jit/api/method.h b/torch/csrc/jit/api/method.h
index 692c08f..fe12635 100644
--- a/torch/csrc/jit/api/method.h
+++ b/torch/csrc/jit/api/method.h
@@ -3,6 +3,7 @@
 #include <ATen/core/function.h>
 #include <ATen/core/ivalue.h>
 #include <ATen/core/stack.h>
+#include <torch/csrc/api/include/torch/imethod.h>
 #include <torch/csrc/jit/api/function_impl.h>
 
 namespace torch {
@@ -18,7 +19,7 @@
 //     ...
 // Note: because Method/Module are exposed to python these
 // classes use python method naming conventions
-struct TORCH_API Method {
+struct TORCH_API Method : public torch::IMethod {
   Method(ObjectPtr owner, Function* function);
 
   // the module that contains this method.
@@ -30,7 +31,7 @@
 
   c10::IValue operator()(
       std::vector<c10::IValue> stack,
-      const Kwargs& kwargs = Kwargs());
+      const Kwargs& kwargs = Kwargs()) override;
 
   // Run method async. Invocation on this function would invokes a JIT
   // interpreter that executes ops inline, one by one, on caller's thread. A
@@ -57,6 +58,10 @@
     return function_->get_executor();
   }
 
+  std::vector<std::string> getArgumentNames() override {
+    throw std::runtime_error("getArgumentNames not yet implemented");
+  }
+
   Function& function() const {
     return *function_;
   }