Define `repr()` on IValues (#32232)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32232
Previously, we were using `operator<<` as the default way of printing
IValue constants during serialization. The semantics of `operator<<`
were ill-defined; and this bit us in particular with strings and lack of
quoting.
This PR defines the role of `operator<<`: much like Python `str()`, it
is intended to produce a human-readable-ish representation for
debugging purposes.
This PR also defines a new `repr()` function on IValue that is intended
to produce a valid Python expression that can be used to recreate an
object with the same value. `repr()` is not defined on all IValue kinds
(notably tensors!) for this reason.
Test Plan: Imported from OSS
Differential Revision: D19417036
Pulled By: suo
fbshipit-source-id: c102d509eaf95a28b6a62280bc99ca6f09603de5
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 70b9df0..77651af 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -23,7 +23,6 @@
} // namespace ivalue
-
TypePtr IValue::type() const {
switch(tag) {
case Tag::None:
@@ -64,36 +63,44 @@
}
namespace {
-template<class T>
-std::ostream& printList(std::ostream & out, const c10::List<T> &v,
- const std::string start, const std::string finish) {
+using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
+
+template <class T>
+std::ostream& printList(
+ std::ostream& out,
+ const T& list,
+ const std::string start,
+ const std::string finish,
+ IValueFormatter formatter) {
out << start;
- for(size_t i = 0; i < v.size(); ++i) {
- if(i > 0)
+ for (size_t i = 0; i < list.size(); ++i) {
+ if (i > 0){
out << ", ";
- // make sure we use ivalue printing, and not default printing for the element type
- out << IValue(v.get(i));
+ }
+ formatter(out, IValue(list[i]));
}
out << finish;
return out;
}
-template<class T>
-std::ostream& printList(std::ostream & out, const std::vector<T> &v,
- const std::string start, const std::string finish) {
- out << start;
- for(size_t i = 0; i < v.size(); ++i) {
- if(i > 0)
- out << ", ";
- // make sure we use ivalue printing, and not default printing for the element type
- out << IValue(v[i]);
+// Properly disambiguate the type of an empty list
+std::ostream& printMaybeAnnotatedList(
+ std::ostream& out,
+ const IValue& the_list,
+ IValueFormatter formatter) {
+ if (the_list.toGenericListRef().size() == 0) {
+ out << "annotate(" << the_list.type()->python_str() << ", [])";
+ } else {
+ return printList(out, the_list.toGenericListRef(), "[", "]", formatter);
}
- out << finish;
return out;
}
-template<typename Dict>
-std::ostream& printDict(std::ostream& out, const Dict& v) {
+template <typename Dict>
+std::ostream& printDict(
+ std::ostream& out,
+ const Dict& v,
+ IValueFormatter formatter) {
out << "{";
bool first = true;
@@ -101,17 +108,83 @@
if (!first) {
out << ", ";
}
- out << pair.key() << ": " << pair.value();
+
+ formatter(out, pair.key());
+ out << ": ";
+ formatter(out, pair.value());
first = false;
}
out << "}";
return out;
}
+}
-} // anonymous namespace
+std::ostream& IValue::repr(
+ std::ostream& out,
+ std::function<bool(std::ostream&, const IValue& v)>
+ customFormatter) const {
+ // First check if the caller has provided a custom formatter. Use that if possible.
+ if (customFormatter(out, *this)) {
+ return out;
+ }
+
+ const IValue& v = *this;
+ auto formatter = [&](std::ostream& out, const IValue& v) {
+ v.repr(out, customFormatter);
+ };
+ switch (v.tag) {
+ case IValue::Tag::None:
+ return out << v.toNone();
+ case IValue::Tag::Double: {
+ double d = v.toDouble();
+ int c = std::fpclassify(d);
+ if (c == FP_NORMAL || c == FP_ZERO) {
+ int64_t i = int64_t(d);
+ if (double(i) == d) {
+ return out << i << ".";
+ }
+ }
+ auto orig_prec = out.precision();
+ return out << std::setprecision(std::numeric_limits<double>::max_digits10)
+ << v.toDouble() << std::setprecision(orig_prec);
+ }
+ case IValue::Tag::Int:
+ return out << v.toInt();
+ case IValue::Tag::Bool:
+ return out << (v.toBool() ? "True" : "False");
+ case IValue::Tag::Tuple: {
+ const auto& elements = v.toTuple()->elements();
+ const auto& finish = elements.size() == 1 ? ",)" : ")";
+ return printList(out, elements, "(", finish, formatter);
+ }
+ case IValue::Tag::String:
+ c10::printQuotedString(out, v.toStringRef());
+ return out;
+ case IValue::Tag::GenericList: {
+ auto formatter = [&](std::ostream& out, const IValue& v) {
+ v.repr(out, customFormatter);
+ };
+ return printMaybeAnnotatedList(out, *this, formatter);
+ }
+ case IValue::Tag::Device: {
+ std::stringstream device_stream;
+ device_stream << v.toDevice();
+ out << "torch.device(";
+ c10::printQuotedString(out, device_stream.str());
+ return out << ")";
+ }
+ case IValue::Tag::GenericDict:
+ return printDict(out, v.toGenericDict(), formatter);
+ default:
+ TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind());
+ }
+}
std::ostream& operator<<(std::ostream & out, const IValue & v) {
+ auto formatter = [&](std::ostream& out, const IValue& v) {
+ out << v;
+ };
switch(v.tag) {
case IValue::Tag::None:
return out << v.toNone();
@@ -138,7 +211,7 @@
case IValue::Tag::Tuple: {
const auto& elements = v.toTuple()->elements();
const auto& finish = elements.size() == 1 ? ",)" : ")";
- return printList(out, elements, "(", finish);
+ return printList(out, elements, "(", finish, formatter);
}
case IValue::Tag::String:
return out << v.toStringRef();
@@ -147,7 +220,7 @@
case IValue::Tag::Capsule:
return out << "Capsule";
case IValue::Tag::GenericList:
- return printList(out, v.toGenericList(), "[", "]");
+ return printList(out, v.toGenericList(), "[", "]", formatter);
case IValue::Tag::Future:
return out << "Future";
case IValue::Tag::Uninitialized:
@@ -155,7 +228,7 @@
case IValue::Tag::Device:
return out << v.toDevice();
case IValue::Tag::GenericDict:
- return printDict(out, v.toGenericDict());
+ return printDict(out, v.toGenericDict(), formatter);
case IValue::Tag::Object:
// TODO we should attempt to call __str__ if the object defines it.
auto obj = v.toObject();
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index 956e1f5..05c3ec2 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -456,6 +456,26 @@
/// this is a shallow comparison of two IValues to test the object identity
bool isSameIdentity(const IValue& rhs) const;
+ // Computes the "official" string representation of an IValue. This produces a
+ // TorchScript expression that can be used to recreate an IValue with the same
+ // value (e.g. when we are printing constants in the serializer).
+ //
+ // Callers can use `customFormatter` to override how `repr()` prints out an
+ // IValue. This is useful if you have some other environment where you can
+ // look up values, and you want to print a reference to that environment (like
+ // the serializer's constant table).
+ //
+ // repr() is not necessarily defined on all objects!
+ std::ostream& repr(
+ std::ostream& stream,
+ std::function<bool(std::ostream&, const IValue& v)> customFormatter)
+ const;
+
+ // Computes an "informal" string representation of an IValue. This should be
+ // used for debugging, or servicing `print()`-like functions.
+ // This is different from `repr()` in that there is no expectation that we can
+ // exactly reconstruct an IValue from the output; feel free to use a
+ // concise/pretty form
CAFFE2_API friend std::ostream& operator<<(
std::ostream& out,
const IValue& v);
diff --git a/test/test_jit.py b/test/test_jit.py
index b81e25f..a2d13e9 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -9154,13 +9154,14 @@
def test_script_module_const(self):
class M(torch.jit.ScriptModule):
- __constants__ = ['b', 'i', 'c']
+ __constants__ = ['b', 'i', 'c', 's']
def __init__(self):
super(M, self).__init__()
self.b = False
self.i = 1
self.c = 3.5
+ self.s = ["hello"]
@torch.jit.script_method
def forward(self):
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 0e84b0b..4d64ea9 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -777,48 +777,17 @@
}
}
- void printMaybeAnnotatedConstantList(
- std::ostream& stmt,
- const char* the_type,
- size_t list_size,
- const IValue& the_list) {
- if (list_size == 0) {
- stmt << "annotate(List[" << the_type << "], [])";
- } else {
- stmt << the_list;
- }
- }
-
void printConstant(TaggedStringStream& stmt, const IValue& v) {
- std::stringstream ss;
- if (v.isTensor()) {
- ss << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
- } else if (v.isString()) {
- c10::printQuotedString(ss, v.toStringRef());
- } else if (v.isDevice()) {
- std::stringstream device_stream;
- device_stream << v.toDevice();
- ss << "torch.device(";
- c10::printQuotedString(ss, device_stream.str());
- ss << ")";
- } else if (v.isTensorList()) {
- ss << "[";
- const char* delim = "";
- for (const at::Tensor& t : v.toTensorListRef()) {
- ss << delim << "CONSTANTS.c" << getOrAddTensorConstant(t);
- delim = ", ";
+ const auto customFormatter = [&](std::ostream& ss, const IValue& v) {
+ if (v.isTensor()) {
+ ss << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
+ return true;
}
- ss << "]";
- } else if (v.isBoolList()) {
- printMaybeAnnotatedConstantList(ss, "bool", v.toBoolList().size(), v);
- } else if (v.isIntList()) {
- printMaybeAnnotatedConstantList(ss, "int", v.toIntListRef().size(), v);
- } else if (v.isDoubleList()) {
- printMaybeAnnotatedConstantList(
- ss, "float", v.toDoubleListRef().size(), v);
- } else {
- ss << v;
- }
+ return false;
+ };
+
+ std::stringstream ss;
+ v.repr(ss, customFormatter);
stmt << ss.str();
}