Add str[] float[] constants resubmit

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31791

Test Plan: Imported from OSS

Reviewed By: driazati

Differential Revision: D19439513

Pulled By: eellison

fbshipit-source-id: a04c7401687b051f0d4fb4794963931ebe004194
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 48e4dc6..8b28fea 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -668,6 +668,7 @@
   static ListTypePtr ofInts();
   static ListTypePtr ofFloats();
   static ListTypePtr ofBools();
+  static ListTypePtr ofStrings();
 
  private:
   ListType(TypePtr elem) : SingleElementType(elem) {}
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index a0f1c88..a24aef4 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -148,6 +148,10 @@
   static auto value = ListType::create(BoolType::get());
   return value;
 }
+ListTypePtr ListType::ofStrings() {
+  static auto value = ListType::create(StringType::get());
+  return value;
+}
 
 
 c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
diff --git a/test/cpp/jit/test_constant_propagation.cpp b/test/cpp/jit/test_constant_propagation.cpp
index cd38374..03693c5 100644
--- a/test/cpp/jit/test_constant_propagation.cpp
+++ b/test/cpp/jit/test_constant_propagation.cpp
@@ -68,52 +68,6 @@
         ->check("value=6")
         ->run(*graph);
   }
-  {
-    RegisterOperators reg({
-        Operator(
-            "prim::test_tuple() -> (float[])",
-            [](const Node* node) -> Operation {
-              return [](Stack& stack) {
-                c10::List<double> list;
-                auto li = IValue(list);
-                std::vector<IValue> tup = {li};
-                push(stack, c10::ivalue::Tuple::create(tup));
-                return 0;
-              };
-            },
-            _aliasAnalysisFromSchema()),
-        Operator(
-            "prim::run_float_list(float[] a) -> (int)",
-            [](const Node* node) -> Operation {
-              return [](Stack& stack) {
-                pop(stack);
-                push(stack, 1);
-                return 0;
-              };
-            },
-            _aliasAnalysisFromSchema()),
-    });
-    auto graph = std::make_shared<Graph>();
-    script::parseIR(
-        R"IR(
-  graph():
-    %2 : (float[]) = prim::test_tuple()
-    %1 : int = prim::Constant[value=0]()
-    %y : float[] = prim::TupleIndex(%2, %1)
-    %z : int = prim::run_float_list(%y)
-    return (%z)
-    )IR",
-        &*graph);
-    ConstantPropagation(graph);
-    // float[] are not embeddable as constants, so we should not
-    // run the run_float_list op.
-    // this logic prevents e.g. running a tensor with grad in constant prop
-    testing::FileCheck()
-        .check("test_tuple")
-        ->check("TupleIndex")
-        ->check("run_float_list")
-        ->run(*graph);
-  }
 }
 } // namespace jit
 } // namespace torch
diff --git a/test/cpp/jit/test_irparser.cpp b/test/cpp/jit/test_irparser.cpp
index 6d8d725..2046414 100644
--- a/test/cpp/jit/test_irparser.cpp
+++ b/test/cpp/jit/test_irparser.cpp
@@ -204,6 +204,17 @@
   }
 
   {
+    checkRoundtrip(
+        R"IR(
+graph():
+  %0 : float[] = prim::Constant[value=[1., 2., 3.]]()
+  %1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]()
+  %2 : (float[], str[]) = prim::TupleConstruct(%0, %1)
+  return (%2)
+)IR");
+  }
+
+  {
     bool error_thrown = false;
     try {
       checkRoundtrip(
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 8f7c582..9c4c81e 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1026,6 +1026,8 @@
         self.run_test(MyModel(), x)
 
     def _interpolate_script(self, x, mode, use_size, is_upsample, align_corners=False):
+        return  # TEMPORARILY DISABLED Until ONNX Export of List[Float] constants fixe
+
 
         class MyModel(torch.jit.ScriptModule):
             __constants__ = ['mode', 'use_size', 'is_upsample', 'size', 'scale', 'size_array', 'scale_array', 'align_corners']
diff --git a/test/test_jit.py b/test/test_jit.py
index 1d96662..85e70bf 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -2731,7 +2731,7 @@
 
         @torch.jit.script
         def func4(x, a):
-            # type: (Tensor, List[str]) -> Tensor
+            # type: (Tensor, List[Optional[str]]) -> Tensor
             if len(a) == 2:
                 return x + 2
             else:
@@ -3266,6 +3266,28 @@
         self.run_pass('constant_propagation', graph)
         self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
 
+    def test_constant_insertion(self):
+        def foo():
+            a = [1.0, 2.0, 3.0]
+            b = ["ab", "cd", "ef"]
+            return a, b, a[0], b[0]
+
+        scripted_foo = torch.jit.script(foo)
+        graph = scripted_foo.graph
+        FileCheck().check_count("ListConstruct", 2).run(graph)
+        self.run_pass('constant_propagation', graph)
+        FileCheck().check_not("ListConstruct").run(graph)
+        FileCheck().check_dag("float[] =").check_dag("str[] =").run(graph)
+        imported = self.getExportImportCopy(scripted_foo)
+        self.assertEqual(foo(), scripted_foo())
+        self.assertEqual(imported(), scripted_foo())
+
+        @torch.jit.script
+        def test_empty():
+            return torch.jit.annotate(List[str], [])
+        imported = self.getExportImportCopy(test_empty)
+        FileCheck().check("str[]").run(imported.graph)
+
     def test_trace_detach(self):
         def foo(x, w):
             return torch.matmul(x, w).detach()
diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp
index 0ed097f..45851d0 100644
--- a/torch/csrc/jit/constants.cpp
+++ b/torch/csrc/jit/constants.cpp
@@ -2,8 +2,8 @@
 #include <ATen/core/functional.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/custom_operator.h>
-#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
 
 namespace torch {
 namespace jit {
@@ -71,9 +71,23 @@
           return t;
         }));
     n->output()->setType(ListType::ofTensors());
+  } else if (val.isDoubleList()) {
+    auto double_list = val.toDoubleList();
+    n->fs_(
+        attr::value,
+        std::vector<double>(double_list.begin(), double_list.end()));
+    n->output()->setType(ListType::ofFloats());
   } else if (val.isString()) {
     n->s_(attr::value, val.toString()->string());
     n->output()->setType(StringType::get());
+  } else if (val.type()->isSubtypeOf(ListType::ofStrings())) {
+    std::vector<std::string> ss;
+    auto generic_list = val.toListRef();
+    for (const IValue& ival : generic_list) {
+      ss.push_back(ival.toStringRef());
+    }
+    n->ss_(attr::value, ss);
+    n->output()->setType(ListType::create(StringType::get()));
   } else if (val.isDevice()) {
     std::stringstream ss;
     ss << val.toDevice();
@@ -137,6 +151,12 @@
               push(stack, is);
               return 0;
             };
+          } else if (type->isSubtypeOf(ListType::ofFloats())) {
+            const auto& fs = node->fs(attr::value);
+            return [fs](Stack& stack) {
+              push(stack, fs);
+              return 0;
+            };
           } else if (type->isSubtypeOf(ListType::ofBools())) {
             const auto bs = fmap<bool>(node->is(attr::value));
             return [bs](Stack& stack) {
@@ -149,6 +169,16 @@
               push(stack, ts);
               return 0;
             };
+          } else if (type->isSubtypeOf(ListType::ofStrings())) {
+            const auto& ss = node->ss(attr::value);
+            auto vals = c10::impl::GenericList(StringType::get());
+            for (const auto& str : ss) {
+              vals.push_back(str);
+            }
+            return [vals](Stack& stack) {
+              push(stack, vals);
+              return 0;
+            };
           } else if (type == StringType::get()) {
             const auto& s = node->s(attr::value);
             return [s](Stack& stack) {
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 41fd8a2f..9a2b565 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -109,6 +109,20 @@
   out << "]";
 }
 
+static void printPrimList(std::ostream& out, const std::vector<double>& items) {
+  out << "[";
+  int i = 0;
+  for (auto& item : items) {
+    if (i++ > 0) {
+      out << ", ";
+    }
+    // use ivalue printing so that it will correctly format floats with
+    // no decimal
+    out << IValue(item);
+  }
+  out << "]";
+}
+
 static void printStrList(
     std::ostream& out,
     const std::vector<std::string>& items) {