[Static Runtime] Clean up op implementations (#56841)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56841
- Move arg checks to outside the lambda so we can perform these checks at Static Runtime initialization time
- use `optional` where possible
- support `to.other` overload, the 5-arg input load of `torch.to`.
Test Plan:
```
buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/opt-clang //caffe2/caffe2/fb/predictor:ptvsc2_predictor_bench_test -- --run-disabled
```
Reviewed By: edvgha
Differential Revision: D27933176
fbshipit-source-id: 49d6249c8784c44146461e286e7a301596172d7c
diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h
index 6d424b7..4967d88 100644
--- a/benchmarks/static_runtime/test_scripts.h
+++ b/benchmarks/static_runtime/test_scripts.h
@@ -195,6 +195,11 @@
return torch.to(input, dtype, non_blocking, copy)
)JIT";
+const auto to_script_2 = R"JIT(
+ def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
+ return torch.to(input, other, non_blocking, copy, memory_format)
+)JIT";
+
const std::string embedding_bag_default = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
return torch.embedding_bag(a, b, c)
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index f6203ca..66b67a8 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -240,10 +240,13 @@
TEST(StaticRuntime, IndividualOps_to) {
auto test_to = [](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) {
auto a = at::randn({2, 3});
+ auto other = at::randn({2, 3}, b);
std::vector<IValue> args0{a, b, c, d, e};
std::vector<IValue> args1{a, b, c, d};
+ std::vector<IValue> args2{a, other, c, d, e};
testStaticRuntime(to_script_0, args0);
testStaticRuntime(to_script_1, args1);
+ testStaticRuntime(to_script_2, args2);
};
test_to(at::ScalarType::Float, true, true, c10::MemoryFormat::Contiguous);
diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp
index a9ece64..5540e7b 100644
--- a/torch/csrc/jit/runtime/static/ops.cpp
+++ b/torch/csrc/jit/runtime/static/ops.cpp
@@ -255,8 +255,8 @@
REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
- const auto in1_s = p_node->Input(1).toScalar();
- const auto in2_s = p_node->Input(2).toScalar();
+ const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
+ const auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
@@ -284,15 +284,10 @@
aten_nan_to_num,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
- auto input_size = p_node->inputs().size();
const auto& in0_t = p_node->Input(0).toTensor();
- const double in1_d = input_size > 1 ? p_node->Input(1).toDouble() : 0;
- const double in2_d = input_size > 2
- ? p_node->Input(2).toDouble()
- : std::numeric_limits<double>::infinity();
- const double in3_d = input_size > 3
- ? p_node->Input(3).toDouble()
- : -std::numeric_limits<double>::infinity();
+ const auto in1_d = p_node->Input(1).toOptional<double>();
+ const auto in2_d = p_node->Input(2).toOptional<double>();
+ const auto in3_d = p_node->Input(3).toOptional<double>();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
@@ -334,28 +329,15 @@
aten::leaky_relu,
aten_leaky_relu,
[](Node* n) -> SROperator {
- const auto in1 = toIValue(n->inputs()[1]);
- if (in1) {
- const auto in1_s = in1->toScalar();
- return [=](ProcessedNode* p_node) {
- const auto& in0_t = p_node->Input(0).toTensor();
- if (p_node->Output(0).isNone()) {
- p_node->Output(0) = create_empty_from(in0_t);
- }
- auto& out_t = p_node->Output(0).toTensor();
- at::native::leaky_relu_out(in0_t, in1_s, out_t);
- };
- } else {
- return [](ProcessedNode* p_node) {
- const auto& in0_t = p_node->Input(0).toTensor();
- const auto in1_s = p_node->Input(1).toScalar();
- if (p_node->Output(0).isNone()) {
- p_node->Output(0) = create_empty_from(in0_t);
- }
- auto& out_t = p_node->Output(0).toTensor();
- at::native::leaky_relu_out(in0_t, in1_s, out_t);
- };
- }
+ return [](ProcessedNode* p_node) {
+ const auto& in0_t = p_node->Input(0).toTensor();
+ const auto in1_s = p_node->Input(1).toScalar();
+ if (p_node->Output(0).isNone()) {
+ p_node->Output(0) = create_empty_from(in0_t);
+ }
+ auto& out_t = p_node->Output(0).toTensor();
+ at::native::leaky_relu_out(in0_t, in1_s, out_t);
+ };
});
namespace {
@@ -580,9 +562,8 @@
}
auto& out_t = p_node->Output(0).toTensor();
if (!te->supports(in0_t)) {
- const auto in0_t = p_node->Input(0).toTensor();
- const double in1_d =
- p_node->inputs().size() > 1 ? p_node->Input(1).toDouble() : -1.0;
+ const auto& in0_t = p_node->Input(0).toTensor();
+ const auto in1_d = p_node->Input(1).toOptional<double>();
fastResizeToZero(out_t);
at::native::logit_out(in0_t, in1_d, out_t);
} else {
@@ -758,33 +739,47 @@
static_runtime::to_copy,
aten_to_copy,
[](Node* n) -> SROperator {
+ // support 4- or 5-arg for adindexer/adfinder models
+ TORCH_CHECK(n->inputs().size() == 4 || n->inputs().size() == 5);
+
return [](ProcessedNode* p_node) {
- // support 4- or 5-arg for adindexer/adfinder models
- DCHECK(p_node->inputs().size() >= 4);
- const auto& in0_t = p_node->Input(0).toTensor();
- auto in2_i = p_node->Input(2).toBool(); // non_blocking
- // ignore input 3 (copy)
+ const auto& self = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
- auto in1_i = p_node->Input(1).toScalarType();
- c10::optional<c10::MemoryFormat> in4_o = c10::nullopt;
- if (p_node->inputs().size() > 4 && p_node->Input(4).isInt()) {
- in4_o = p_node->Input(4).toOptional<c10::MemoryFormat>();
+ // handle dtype, layout, and device
+ at::ScalarType dtype;
+ c10::Layout layout = self.layout();
+ c10::Device device = self.device();
+ if (p_node->Input(1).isTensor()) {
+ const auto& other = p_node->Input(1).toTensor();
+ dtype = other.scalar_type();
+ layout = other.layout();
+ device = other.device();
+ } else {
+ dtype = p_node->Input(1).toScalarType();
}
- if (in4_o.value_or(c10::MemoryFormat::Preserve) ==
+ // handle memory format
+ c10::optional<c10::MemoryFormat> memory_format = c10::nullopt;
+ if (p_node->inputs().size() == 5) {
+ memory_format = p_node->Input(4).toOptional<c10::MemoryFormat>();
+ }
+ if (memory_format.value_or(c10::MemoryFormat::Preserve) ==
c10::MemoryFormat::Preserve) {
- if (in0_t.is_non_overlapping_and_dense()) {
- in4_o = c10::nullopt;
+ if (self.is_non_overlapping_and_dense()) {
+ memory_format = c10::nullopt;
} else {
- in4_o = in0_t.suggest_memory_format();
+ memory_format = self.suggest_memory_format();
}
}
// See Note [Explicit nullopt MemoryFormat argument]
p_node->Output(0) = at::detail::empty_cpu(
- {0}, in1_i, in0_t.layout(), in0_t.device(), c10::nullopt, in4_o);
+ {0}, dtype, layout, self.device(), c10::nullopt, memory_format);
}
+
+ // ignore input 3 (copy)
+ auto non_blocking = p_node->Input(2).toBool(); // non_blocking
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
- at::native::to_copy_out(out_t, in0_t, in2_i);
+ at::native::to_copy_out(out_t, self, non_blocking);
};
});
@@ -810,8 +805,8 @@
static_runtime::flatten_copy,
aten_flatten,
[](Node* n) -> SROperator {
+ TORCH_CHECK(n->inputs().size() == 3);
return [](ProcessedNode* p_node) {
- DCHECK(p_node->inputs().size() == 3);
const auto& self = p_node->Input(0).toTensor();
const auto start_dim = p_node->Input(1).toInt();
const auto end_dim = p_node->Input(2).toInt();
@@ -827,25 +822,29 @@
REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const at::Tensor& self = p_node->Input(0).toTensor();
- std::vector<int64_t> dim = {};
- if ((p_node->inputs().size() > 1) && (!p_node->Input(1).isNone())) {
- dim = p_node->Input(1).toIntList().vec();
+
+ c10::optional<at::ScalarType> dtype = c10::nullopt;
+ if (p_node->inputs().size() == 2) {
+ // sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
+ dtype = p_node->Input(1).toOptional<at::ScalarType>();
}
+
+ std::vector<int64_t> dim = {};
+ bool keepdim = false;
+ if (p_node->inputs().size() == 4) {
+ // sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *,
+ // ScalarType? dtype=None) -> Tensor
+ dim = p_node->Input(1).toIntList().vec();
+ keepdim = p_node->Input(2).toBool();
+ dtype = p_node->Input(3).toOptional<at::ScalarType>();
+ }
+
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(self);
}
auto& output = p_node->Output(0).toTensor();
fastResizeToZero(output);
- if (p_node->inputs().size() > 2) {
- at::native::sum_out(
- self,
- dim,
- p_node->Input(2).toBool(),
- p_node->Input(3).toOptional<at::ScalarType>(),
- output);
- return;
- }
- at::native::sum_out(self, dim, false /* keep_dim */, c10::nullopt, output);
+ at::native::sum_out(self, dim, keepdim, dtype, output);
};
});
@@ -1011,14 +1010,18 @@
return [](ProcessedNode* p_node) {
DCHECK(p_node->inputs().size() == 5);
const auto& in0_t = p_node->Input(0).toTensor();
- const auto in1_i = p_node->Input(1).toScalarType();
const auto in2_i = p_node->Input(2).toBool();
const auto in3_i = p_node->Input(3).toBool();
- if (p_node->Input(4).isNone()) {
- p_node->Output(0) =
- at::native::to(in0_t, in1_i, in2_i, in3_i, c10::nullopt);
+ const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
+ if (p_node->Input(1).isTensor()) {
+ // to.other(Tensor self, Tensor other, bool non_blocking=False, bool
+ // copy=False, MemoryFormat? memory_format=None) -> Tensor
+ const auto in1_t = p_node->Input(1).toTensor();
+ p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
} else {
- const auto in4_o = p_node->Input(4).toMemoryFormat();
+ // to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool
+ // copy=False, MemoryFormat? memory_format=None) -> Tensor
+ const auto in1_i = p_node->Input(1).toScalarType();
p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
}
};