plug caffe2 into jit (#16331)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16331
Temporary measure to enable caffe2 ops in pytorch
Reviewed By: smessmer
Differential Revision: D13740752
fbshipit-source-id: 2d9383574d42ce84ee471aba32eeb4f5a0cc7a4c
diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc
index 9df184c..7b7e3dd 100644
--- a/caffe2/core/operator.cc
+++ b/caffe2/core/operator.cc
@@ -332,7 +332,7 @@
CAFFE_ENFORCE(
fn_wrap,
"Operator not registered with FunctionSchema constructor.",
- name);
+ name.toUnqualString());
auto fn = fn_wrap->getSchema();
auto op = caffe2::FunctionSchemaOperatorRegistry()->Create(
name.toUnqualString(), fn, inputs, outputs);
diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h
index 54cae81..5641902 100644
--- a/caffe2/core/operator.h
+++ b/caffe2/core/operator.h
@@ -1098,7 +1098,7 @@
C10_REGISTER_CLASS(FunctionSchemaOperatorRegistry, name, impl) \
struct FunctionSchemaStorageBase##name : public FunctionSchemaStorageBase { \
c10::FunctionSchema getSchema() override { \
- return c10::FunctionSchema(#name, inputs, outputs); \
+ return c10::FunctionSchema("caffe2::" #name, inputs, outputs); \
} \
}; \
C10_REGISTER_CLASS( \
diff --git a/torch/csrc/jit/custom_operator.cpp b/torch/csrc/jit/custom_operator.cpp
new file mode 100644
index 0000000..05c0ef6
--- /dev/null
+++ b/torch/csrc/jit/custom_operator.cpp
@@ -0,0 +1,59 @@
+#include <jit/custom_operator.h>
+
+namespace torch {
+namespace jit {
+
+Operator createOperatorFromC2(
+ const std::string& name
+ ) {
+ auto symbolic_name = c10::Symbol::fromQualString("caffe2::" + name);
+ auto fn_wrap = caffe2::FunctionSchemaRegistry()->Create(symbolic_name.toUnqualString());
+ CAFFE_ENFORCE(
+ fn_wrap,
+ "Operator not registered with FunctionSchema constructor.",
+ name);
+ auto fn = fn_wrap->getSchema();
+
+ return Operator(fn, [symbolic_name, fn](Stack& stack) {
+ const auto input_size = fn.arguments().size();
+ const auto output_size = fn.returns().size();
+ std::vector<c10::IValue> inputs;
+ for (auto i = 0; i < input_size; ++i) {
+ auto input = pop(stack);
+ // Tensors come in as variables but need to be unwrapped
+ if (input.isTensor()) {
+ input = torch::autograd::Variable(input.toTensor()).data();
+ }
+ inputs.emplace(inputs.begin(), std::move(input));
+ }
+
+ // We use a temporary stack for arguments passed into RunOperator
+ std::list<c10::IValue> outputs_real;
+ std::vector<c10::IValue*> outputs;
+ for (auto i = 0; i < output_size; ++i) {
+ if (TensorType::get() == fn.returns()[i].type()) {
+ caffe2::Tensor tensor(caffe2::CPU);
+ auto at_tensor = at::Tensor(c10::C10Tensor(std::move(tensor)));
+ outputs_real.emplace_back(c10::IValue(at_tensor));
+ } else {
+ outputs_real.emplace_back(c10::IValue());
+ }
+ outputs.emplace_back(&outputs_real.back());
+ }
+
+ caffe2::RunOperator(symbolic_name, inputs, outputs);
+
+ // We need to convert tensors back into variables
+ for (auto& t : outputs_real) {
+ if (t.isTensor()) {
+ push(stack, c10::IValue(torch::autograd::make_variable(t.toTensor())));
+ } else {
+ push(stack, std::move(t));
+ }
+ }
+
+ return 0;
+ });
+}
+
+}} // torch::jit
diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h
index 4d82bed..265bf7a 100644
--- a/torch/csrc/jit/custom_operator.h
+++ b/torch/csrc/jit/custom_operator.h
@@ -6,6 +6,8 @@
#include <torch/csrc/utils/variadic.h>
#include <ATen/core/function_schema.h>
+#include <caffe2/core/operator.h>
+
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
@@ -258,6 +260,8 @@
});
}
+Operator createOperatorFromC2(const std::string& name);
+
/// Registration class for new operators. Effectively calls
/// `torch::jit::registerOperator` for every supplied operator, but allows doing
/// so in the global scope when a `RegisterOperators` object is assigned to a
@@ -279,6 +283,14 @@
op(name, std::forward<Implementation>(implementation));
}
+ /// Requires declaration of the FunctionSchema with
+ /// REGISTER_FUNCTION_SCHEMA_OPERATOR(name, ...)
+ static RegisterOperators&& Caffe2Operator(const std::string& name) {
+ auto r = RegisterOperators();
+ registerOperator(createOperatorFromC2(name));
+ return std::move(r);
+ }
+
/// Creates a new operator from a name and implementation function (function
/// pointer or function object/lambda) using `torch::jit::createOperator`, and
/// then registers the operator.
diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp
index 6d1fa9c..36aca85 100644
--- a/torch/csrc/jit/interpreter.cpp
+++ b/torch/csrc/jit/interpreter.cpp
@@ -15,6 +15,8 @@
#include <torch/csrc/jit/script/jit_exception.h>
#include <ATen/core/thread_pool.h>
+#include <caffe2/core/operator.h>
+
#include <exception>
#include <iostream>
#include <memory>