[Static Runtime] Clean up op implementations (#56841)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56841

- Move arg checks to outside the lambda so we can perform these checks at Static Runtime initialization time
- use `optional` where possible
- support `to.other` overload, the 5-arg input load of `torch.to`.

Test Plan:
```
buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/opt-clang //caffe2/caffe2/fb/predictor:ptvsc2_predictor_bench_test -- --run-disabled
```

Reviewed By: edvgha

Differential Revision: D27933176

fbshipit-source-id: 49d6249c8784c44146461e286e7a301596172d7c
diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h
index 6d424b7..4967d88 100644
--- a/benchmarks/static_runtime/test_scripts.h
+++ b/benchmarks/static_runtime/test_scripts.h
@@ -195,6 +195,11 @@
       return torch.to(input, dtype, non_blocking, copy)
 )JIT";
 
+const auto to_script_2 = R"JIT(
+  def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
+      return torch.to(input, other, non_blocking, copy, memory_format)
+)JIT";
+
 const std::string embedding_bag_default = R"JIT(
   def forward(self, a: Tensor, b: Tensor, c: Tensor):
       return torch.embedding_bag(a, b, c)
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index f6203ca..66b67a8 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -240,10 +240,13 @@
 TEST(StaticRuntime, IndividualOps_to) {
   auto test_to = [](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) {
     auto a = at::randn({2, 3});
+    auto other = at::randn({2, 3}, b);
     std::vector<IValue> args0{a, b, c, d, e};
     std::vector<IValue> args1{a, b, c, d};
+    std::vector<IValue> args2{a, other, c, d, e};
     testStaticRuntime(to_script_0, args0);
     testStaticRuntime(to_script_1, args1);
+    testStaticRuntime(to_script_2, args2);
   };
 
   test_to(at::ScalarType::Float, true, true, c10::MemoryFormat::Contiguous);
diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp
index a9ece64..5540e7b 100644
--- a/torch/csrc/jit/runtime/static/ops.cpp
+++ b/torch/csrc/jit/runtime/static/ops.cpp
@@ -255,8 +255,8 @@
 REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
   return [](ProcessedNode* p_node) {
     const auto& in0_t = p_node->Input(0).toTensor();
-    const auto in1_s = p_node->Input(1).toScalar();
-    const auto in2_s = p_node->Input(2).toScalar();
+    const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
+    const auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
     if (p_node->Output(0).isNone()) {
       p_node->Output(0) = create_empty_from(in0_t);
     }
@@ -284,15 +284,10 @@
     aten_nan_to_num,
     [](Node* n) -> SROperator {
       return [](ProcessedNode* p_node) {
-        auto input_size = p_node->inputs().size();
         const auto& in0_t = p_node->Input(0).toTensor();
-        const double in1_d = input_size > 1 ? p_node->Input(1).toDouble() : 0;
-        const double in2_d = input_size > 2
-            ? p_node->Input(2).toDouble()
-            : std::numeric_limits<double>::infinity();
-        const double in3_d = input_size > 3
-            ? p_node->Input(3).toDouble()
-            : -std::numeric_limits<double>::infinity();
+        const auto in1_d = p_node->Input(1).toOptional<double>();
+        const auto in2_d = p_node->Input(2).toOptional<double>();
+        const auto in3_d = p_node->Input(3).toOptional<double>();
         if (p_node->Output(0).isNone()) {
           p_node->Output(0) = create_empty_from(in0_t);
         }
@@ -334,28 +329,15 @@
     aten::leaky_relu,
     aten_leaky_relu,
     [](Node* n) -> SROperator {
-      const auto in1 = toIValue(n->inputs()[1]);
-      if (in1) {
-        const auto in1_s = in1->toScalar();
-        return [=](ProcessedNode* p_node) {
-          const auto& in0_t = p_node->Input(0).toTensor();
-          if (p_node->Output(0).isNone()) {
-            p_node->Output(0) = create_empty_from(in0_t);
-          }
-          auto& out_t = p_node->Output(0).toTensor();
-          at::native::leaky_relu_out(in0_t, in1_s, out_t);
-        };
-      } else {
-        return [](ProcessedNode* p_node) {
-          const auto& in0_t = p_node->Input(0).toTensor();
-          const auto in1_s = p_node->Input(1).toScalar();
-          if (p_node->Output(0).isNone()) {
-            p_node->Output(0) = create_empty_from(in0_t);
-          }
-          auto& out_t = p_node->Output(0).toTensor();
-          at::native::leaky_relu_out(in0_t, in1_s, out_t);
-        };
-      }
+      return [](ProcessedNode* p_node) {
+        const auto& in0_t = p_node->Input(0).toTensor();
+        const auto in1_s = p_node->Input(1).toScalar();
+        if (p_node->Output(0).isNone()) {
+          p_node->Output(0) = create_empty_from(in0_t);
+        }
+        auto& out_t = p_node->Output(0).toTensor();
+        at::native::leaky_relu_out(in0_t, in1_s, out_t);
+      };
     });
 
 namespace {
@@ -580,9 +562,8 @@
     }
     auto& out_t = p_node->Output(0).toTensor();
     if (!te->supports(in0_t)) {
-      const auto in0_t = p_node->Input(0).toTensor();
-      const double in1_d =
-          p_node->inputs().size() > 1 ? p_node->Input(1).toDouble() : -1.0;
+      const auto& in0_t = p_node->Input(0).toTensor();
+      const auto in1_d = p_node->Input(1).toOptional<double>();
       fastResizeToZero(out_t);
       at::native::logit_out(in0_t, in1_d, out_t);
     } else {
@@ -758,33 +739,47 @@
     static_runtime::to_copy,
     aten_to_copy,
     [](Node* n) -> SROperator {
+      // support 4- or 5-arg for adindexer/adfinder models
+      TORCH_CHECK(n->inputs().size() == 4 || n->inputs().size() == 5);
+
       return [](ProcessedNode* p_node) {
-        // support 4- or 5-arg for adindexer/adfinder models
-        DCHECK(p_node->inputs().size() >= 4);
-        const auto& in0_t = p_node->Input(0).toTensor();
-        auto in2_i = p_node->Input(2).toBool(); // non_blocking
-        // ignore input 3 (copy)
+        const auto& self = p_node->Input(0).toTensor();
         if (p_node->Output(0).isNone()) {
-          auto in1_i = p_node->Input(1).toScalarType();
-          c10::optional<c10::MemoryFormat> in4_o = c10::nullopt;
-          if (p_node->inputs().size() > 4 && p_node->Input(4).isInt()) {
-            in4_o = p_node->Input(4).toOptional<c10::MemoryFormat>();
+          // handle dtype, layout, and device
+          at::ScalarType dtype;
+          c10::Layout layout = self.layout();
+          c10::Device device = self.device();
+          if (p_node->Input(1).isTensor()) {
+            const auto& other = p_node->Input(1).toTensor();
+            dtype = other.scalar_type();
+            layout = other.layout();
+            device = other.device();
+          } else {
+            dtype = p_node->Input(1).toScalarType();
           }
-          if (in4_o.value_or(c10::MemoryFormat::Preserve) ==
+          // handle memory format
+          c10::optional<c10::MemoryFormat> memory_format = c10::nullopt;
+          if (p_node->inputs().size() == 5) {
+            memory_format = p_node->Input(4).toOptional<c10::MemoryFormat>();
+          }
+          if (memory_format.value_or(c10::MemoryFormat::Preserve) ==
               c10::MemoryFormat::Preserve) {
-            if (in0_t.is_non_overlapping_and_dense()) {
-              in4_o = c10::nullopt;
+            if (self.is_non_overlapping_and_dense()) {
+              memory_format = c10::nullopt;
             } else {
-              in4_o = in0_t.suggest_memory_format();
+              memory_format = self.suggest_memory_format();
             }
           }
           // See Note [Explicit nullopt MemoryFormat argument]
           p_node->Output(0) = at::detail::empty_cpu(
-              {0}, in1_i, in0_t.layout(), in0_t.device(), c10::nullopt, in4_o);
+              {0}, dtype, layout, self.device(), c10::nullopt, memory_format);
         }
+
+        // ignore input 3 (copy)
+        auto non_blocking = p_node->Input(2).toBool(); // non_blocking
         auto& out_t = p_node->Output(0).toTensor();
         fastResizeToZero(out_t);
-        at::native::to_copy_out(out_t, in0_t, in2_i);
+        at::native::to_copy_out(out_t, self, non_blocking);
       };
     });
 
@@ -810,8 +805,8 @@
     static_runtime::flatten_copy,
     aten_flatten,
     [](Node* n) -> SROperator {
+      TORCH_CHECK(n->inputs().size() == 3);
       return [](ProcessedNode* p_node) {
-        DCHECK(p_node->inputs().size() == 3);
         const auto& self = p_node->Input(0).toTensor();
         const auto start_dim = p_node->Input(1).toInt();
         const auto end_dim = p_node->Input(2).toInt();
@@ -827,25 +822,29 @@
 REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
   return [](ProcessedNode* p_node) {
     const at::Tensor& self = p_node->Input(0).toTensor();
-    std::vector<int64_t> dim = {};
-    if ((p_node->inputs().size() > 1) && (!p_node->Input(1).isNone())) {
-      dim = p_node->Input(1).toIntList().vec();
+
+    c10::optional<at::ScalarType> dtype = c10::nullopt;
+    if (p_node->inputs().size() == 2) {
+      // sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
+      dtype = p_node->Input(1).toOptional<at::ScalarType>();
     }
+
+    std::vector<int64_t> dim = {};
+    bool keepdim = false;
+    if (p_node->inputs().size() == 4) {
+      // sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *,
+      // ScalarType? dtype=None) -> Tensor
+      dim = p_node->Input(1).toIntList().vec();
+      keepdim = p_node->Input(2).toBool();
+      dtype = p_node->Input(3).toOptional<at::ScalarType>();
+    }
+
     if (p_node->Output(0).isNone()) {
       p_node->Output(0) = create_empty_from(self);
     }
     auto& output = p_node->Output(0).toTensor();
     fastResizeToZero(output);
-    if (p_node->inputs().size() > 2) {
-      at::native::sum_out(
-          self,
-          dim,
-          p_node->Input(2).toBool(),
-          p_node->Input(3).toOptional<at::ScalarType>(),
-          output);
-      return;
-    }
-    at::native::sum_out(self, dim, false /* keep_dim */, c10::nullopt, output);
+    at::native::sum_out(self, dim, keepdim, dtype, output);
   };
 });
 
@@ -1011,14 +1010,18 @@
     return [](ProcessedNode* p_node) {
       DCHECK(p_node->inputs().size() == 5);
       const auto& in0_t = p_node->Input(0).toTensor();
-      const auto in1_i = p_node->Input(1).toScalarType();
       const auto in2_i = p_node->Input(2).toBool();
       const auto in3_i = p_node->Input(3).toBool();
-      if (p_node->Input(4).isNone()) {
-        p_node->Output(0) =
-            at::native::to(in0_t, in1_i, in2_i, in3_i, c10::nullopt);
+      const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
+      if (p_node->Input(1).isTensor()) {
+        // to.other(Tensor self, Tensor other, bool non_blocking=False, bool
+        // copy=False, MemoryFormat? memory_format=None) -> Tensor
+        const auto in1_t = p_node->Input(1).toTensor();
+        p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
       } else {
-        const auto in4_o = p_node->Input(4).toMemoryFormat();
+        // to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool
+        // copy=False, MemoryFormat? memory_format=None) -> Tensor
+        const auto in1_i = p_node->Input(1).toScalarType();
         p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
       }
     };