Add missing operators for PyText model.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29664
Test Plan: Imported from OSS
Differential Revision: D18499601
fbshipit-source-id: 8a38d3d809ee5ef5b73b5a5ce1db612aea680e75
diff --git a/torch/csrc/jit/instruction.cpp b/torch/csrc/jit/instruction.cpp
index 4dc0c5e..42f4950 100644
--- a/torch/csrc/jit/instruction.cpp
+++ b/torch/csrc/jit/instruction.cpp
@@ -70,7 +70,7 @@
bool isOpSupportedInMobile(OpCode op) {
static constexpr OpCode supported_ops_in_mobile[] {
- OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, LOOP, RET, GET_ATTR, SET_ATTR
+ OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP, RET, GET_ATTR, SET_ATTR
};
for (auto sop : supported_ops_in_mobile) {
diff --git a/torch/csrc/jit/mobile/register_mobile_ops.cpp b/torch/csrc/jit/mobile/register_mobile_ops.cpp
index f4642d6..bf6eb20 100644
--- a/torch/csrc/jit/mobile/register_mobile_ops.cpp
+++ b/torch/csrc/jit/mobile/register_mobile_ops.cpp
@@ -302,7 +302,7 @@
[]() {
})
).op(
- "_prim::ListConstruct.tensor",
+ "_prim::ListConstruct.Tensor",
torch::RegisterOperators::options().catchAllKernel(
[]() {
})
@@ -311,6 +311,143 @@
torch::RegisterOperators::options().catchAllKernel(
[]() {
})
+// Pytext operators
+).op(
+ "_aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ constexpr int N = 5;
+ auto result_ = at::embedding(
+ (std::move(peek(*stack, 0, N))).toTensor(),
+ (std::move(peek(*stack, 1, N))).toTensor(),
+ (std::move(peek(*stack, 2, N))).toInt(),
+ (std::move(peek(*stack, 3, N))).toBool(),
+ (std::move(peek(*stack, 4, N))).toBool()
+ );
+ drop(*stack, N);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::dropout(Tensor input, float p, bool train) -> Tensor",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = at::dropout(
+ (std::move(peek(*stack, 0, 3))).toTensor(),
+ (std::move(peek(*stack, 1, 3))).toDouble(),
+ (std::move(peek(*stack, 2, 3))).toBool()
+ );
+ drop(*stack, 3);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = ((std::move(peek(*stack, 0, 2))).toTensor()).permute(
+ (std::move(peek(*stack, 1, 2))).toIntListRef()
+ );
+ drop(*stack, 2);
+ pack(*stack, std::move(result_));
+ }).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
+).op(
+ "_aten::matmul(Tensor self, Tensor other) -> Tensor",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = at::matmul(
+ (std::move(peek(*stack, 0, 2))).toTensor(),
+ (std::move(peek(*stack, 1, 2))).toTensor()
+ );
+ drop(*stack, 2);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = at::mul(
+ (std::move(peek(*stack, 0, 2))).toTensor(),
+ (std::move(peek(*stack, 1, 2))).toTensor()
+ );
+ drop(*stack, 2);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::tanh(Tensor self) -> Tensor",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = at::tanh(
+ (std::move(peek(*stack, 0, 1))).toTensor()
+ );
+ drop(*stack, 1);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = at::max(
+ (std::move(peek(*stack, 0, 3))).toTensor(),
+ (std::move(peek(*stack, 1, 3))).toInt(),
+ (std::move(peek(*stack, 2, 3))).toBool()
+ );
+ drop(*stack, 3);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = at::cat(
+ (std::move(peek(*stack, 0, 2))).toTensorListRef(),
+ (std::move(peek(*stack, 1, 2))).toInt()
+ );
+ drop(*stack, 2);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::__is__(t1 self, t2 obj) -> bool",
+ torch::RegisterOperators::options().catchAllKernel(
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ c10::IValue self, obj;
+ pop(*stack, self, obj);
+ push(*stack, self.isSameIdentity(obj));
+ })
+).op(
+ "_aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = at::log_softmax(
+ (std::move(peek(*stack, 0, 3))).toTensor(),
+ (std::move(peek(*stack, 1, 3))).toInt(),
+ (std::move(peek(*stack, 2, 3))).toOptional<c10::ScalarType>()
+ );
+ drop(*stack, 3);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
+ torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ auto result_ = at::softmax(
+ (std::move(peek(*stack, 0, 3))).toTensor(),
+ (std::move(peek(*stack, 1, 3))).toInt(),
+ (std::move(peek(*stack, 2, 3))).toOptional<c10::ScalarType>()
+ );
+ drop(*stack, 3);
+ pack(*stack, std::move(result_));
+ })
+).op(
+ "_aten::warn() -> void",
+ torch::RegisterOperators::options().catchAllKernel(
+ [](c10::OperatorKernel* kernel, Stack* stack) {
+ drop(*stack, 1);
+ pop(*stack);
+ })
+).op(
+ "_prim::unchecked_cast",
+ torch::RegisterOperators::options().catchAllKernel(
+ []() {
+ })
).op(
"_prim::TupleConstruct",
torch::RegisterOperators::options().catchAllKernel(