Add str[] float[] constants resubmit
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31791
Test Plan: Imported from OSS
Reviewed By: driazati
Differential Revision: D19439513
Pulled By: eellison
fbshipit-source-id: a04c7401687b051f0d4fb4794963931ebe004194
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 48e4dc6..8b28fea 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -668,6 +668,7 @@
static ListTypePtr ofInts();
static ListTypePtr ofFloats();
static ListTypePtr ofBools();
+ static ListTypePtr ofStrings();
private:
ListType(TypePtr elem) : SingleElementType(elem) {}
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index a0f1c88..a24aef4 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -148,6 +148,10 @@
static auto value = ListType::create(BoolType::get());
return value;
}
+ListTypePtr ListType::ofStrings() {
+ static auto value = ListType::create(StringType::get());
+ return value;
+}
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
diff --git a/test/cpp/jit/test_constant_propagation.cpp b/test/cpp/jit/test_constant_propagation.cpp
index cd38374..03693c5 100644
--- a/test/cpp/jit/test_constant_propagation.cpp
+++ b/test/cpp/jit/test_constant_propagation.cpp
@@ -68,52 +68,6 @@
->check("value=6")
->run(*graph);
}
- {
- RegisterOperators reg({
- Operator(
- "prim::test_tuple() -> (float[])",
- [](const Node* node) -> Operation {
- return [](Stack& stack) {
- c10::List<double> list;
- auto li = IValue(list);
- std::vector<IValue> tup = {li};
- push(stack, c10::ivalue::Tuple::create(tup));
- return 0;
- };
- },
- _aliasAnalysisFromSchema()),
- Operator(
- "prim::run_float_list(float[] a) -> (int)",
- [](const Node* node) -> Operation {
- return [](Stack& stack) {
- pop(stack);
- push(stack, 1);
- return 0;
- };
- },
- _aliasAnalysisFromSchema()),
- });
- auto graph = std::make_shared<Graph>();
- script::parseIR(
- R"IR(
- graph():
- %2 : (float[]) = prim::test_tuple()
- %1 : int = prim::Constant[value=0]()
- %y : float[] = prim::TupleIndex(%2, %1)
- %z : int = prim::run_float_list(%y)
- return (%z)
- )IR",
- &*graph);
- ConstantPropagation(graph);
- // float[] are not embeddable as constants, so we should not
- // run the run_float_list op.
- // this logic prevents e.g. running a tensor with grad in constant prop
- testing::FileCheck()
- .check("test_tuple")
- ->check("TupleIndex")
- ->check("run_float_list")
- ->run(*graph);
- }
}
} // namespace jit
} // namespace torch
diff --git a/test/cpp/jit/test_irparser.cpp b/test/cpp/jit/test_irparser.cpp
index 6d8d725..2046414 100644
--- a/test/cpp/jit/test_irparser.cpp
+++ b/test/cpp/jit/test_irparser.cpp
@@ -204,6 +204,17 @@
}
{
+ checkRoundtrip(
+ R"IR(
+graph():
+ %0 : float[] = prim::Constant[value=[1., 2., 3.]]()
+ %1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]()
+ %2 : (float[], str[]) = prim::TupleConstruct(%0, %1)
+ return (%2)
+)IR");
+ }
+
+ {
bool error_thrown = false;
try {
checkRoundtrip(
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 8f7c582..9c4c81e 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1026,6 +1026,8 @@
self.run_test(MyModel(), x)
def _interpolate_script(self, x, mode, use_size, is_upsample, align_corners=False):
+ return # TEMPORARILY DISABLED Until ONNX Export of List[Float] constants fixe
+
class MyModel(torch.jit.ScriptModule):
__constants__ = ['mode', 'use_size', 'is_upsample', 'size', 'scale', 'size_array', 'scale_array', 'align_corners']
diff --git a/test/test_jit.py b/test/test_jit.py
index 1d96662..85e70bf 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -2731,7 +2731,7 @@
@torch.jit.script
def func4(x, a):
- # type: (Tensor, List[str]) -> Tensor
+ # type: (Tensor, List[Optional[str]]) -> Tensor
if len(a) == 2:
return x + 2
else:
@@ -3266,6 +3266,28 @@
self.run_pass('constant_propagation', graph)
self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
+ def test_constant_insertion(self):
+ def foo():
+ a = [1.0, 2.0, 3.0]
+ b = ["ab", "cd", "ef"]
+ return a, b, a[0], b[0]
+
+ scripted_foo = torch.jit.script(foo)
+ graph = scripted_foo.graph
+ FileCheck().check_count("ListConstruct", 2).run(graph)
+ self.run_pass('constant_propagation', graph)
+ FileCheck().check_not("ListConstruct").run(graph)
+ FileCheck().check_dag("float[] =").check_dag("str[] =").run(graph)
+ imported = self.getExportImportCopy(scripted_foo)
+ self.assertEqual(foo(), scripted_foo())
+ self.assertEqual(imported(), scripted_foo())
+
+ @torch.jit.script
+ def test_empty():
+ return torch.jit.annotate(List[str], [])
+ imported = self.getExportImportCopy(test_empty)
+ FileCheck().check("str[]").run(imported.graph)
+
def test_trace_detach(self):
def foo(x, w):
return torch.matmul(x, w).detach()
diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp
index 0ed097f..45851d0 100644
--- a/torch/csrc/jit/constants.cpp
+++ b/torch/csrc/jit/constants.cpp
@@ -2,8 +2,8 @@
#include <ATen/core/functional.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/custom_operator.h>
-#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
namespace torch {
namespace jit {
@@ -71,9 +71,23 @@
return t;
}));
n->output()->setType(ListType::ofTensors());
+ } else if (val.isDoubleList()) {
+ auto double_list = val.toDoubleList();
+ n->fs_(
+ attr::value,
+ std::vector<double>(double_list.begin(), double_list.end()));
+ n->output()->setType(ListType::ofFloats());
} else if (val.isString()) {
n->s_(attr::value, val.toString()->string());
n->output()->setType(StringType::get());
+ } else if (val.type()->isSubtypeOf(ListType::ofStrings())) {
+ std::vector<std::string> ss;
+ auto generic_list = val.toListRef();
+ for (const IValue& ival : generic_list) {
+ ss.push_back(ival.toStringRef());
+ }
+ n->ss_(attr::value, ss);
+ n->output()->setType(ListType::create(StringType::get()));
} else if (val.isDevice()) {
std::stringstream ss;
ss << val.toDevice();
@@ -137,6 +151,12 @@
push(stack, is);
return 0;
};
+ } else if (type->isSubtypeOf(ListType::ofFloats())) {
+ const auto& fs = node->fs(attr::value);
+ return [fs](Stack& stack) {
+ push(stack, fs);
+ return 0;
+ };
} else if (type->isSubtypeOf(ListType::ofBools())) {
const auto bs = fmap<bool>(node->is(attr::value));
return [bs](Stack& stack) {
@@ -149,6 +169,16 @@
push(stack, ts);
return 0;
};
+ } else if (type->isSubtypeOf(ListType::ofStrings())) {
+ const auto& ss = node->ss(attr::value);
+ auto vals = c10::impl::GenericList(StringType::get());
+ for (const auto& str : ss) {
+ vals.push_back(str);
+ }
+ return [vals](Stack& stack) {
+ push(stack, vals);
+ return 0;
+ };
} else if (type == StringType::get()) {
const auto& s = node->s(attr::value);
return [s](Stack& stack) {
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 41fd8a2f..9a2b565 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -109,6 +109,20 @@
out << "]";
}
+static void printPrimList(std::ostream& out, const std::vector<double>& items) {
+ out << "[";
+ int i = 0;
+ for (auto& item : items) {
+ if (i++ > 0) {
+ out << ", ";
+ }
+ // use ivalue printing so that it will correctly format floats with
+ // no decimal
+ out << IValue(item);
+ }
+ out << "]";
+}
+
static void printStrList(
std::ostream& out,
const std::vector<std::string>& items) {