Define `repr()` on IValues (#32232)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32232

Previously, we were using `operator<<` as the default way of printing
IValue constants during serialization. The semantics of `operator<<`
were ill-defined; and this bit us in particular with strings and lack of
quoting.

This PR defines the role of `operator<<`: much like Python `str()`, it
is intended to produce a human-readable-ish representation for
debugging purposes.

This PR also defines a new `repr()` function on IValue that is intended
to produce a valid Python expression that can be used to recreate an
object with the same value. `repr()` is not defined on all IValue kinds
(notably tensors!) for this reason.

Test Plan: Imported from OSS

Differential Revision: D19417036

Pulled By: suo

fbshipit-source-id: c102d509eaf95a28b6a62280bc99ca6f09603de5
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 70b9df0..77651af 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -23,7 +23,6 @@
 
 } // namespace ivalue
 
-
 TypePtr IValue::type() const {
   switch(tag) {
     case Tag::None:
@@ -64,36 +63,44 @@
 }
 namespace {
 
-template<class T>
-std::ostream& printList(std::ostream & out, const c10::List<T> &v,
-  const std::string start, const std::string finish) {
+using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
+
+template <class T>
+std::ostream& printList(
+    std::ostream& out,
+    const T& list,
+    const std::string start,
+    const std::string finish,
+    IValueFormatter formatter) {
   out << start;
-  for(size_t i = 0; i < v.size(); ++i) {
-    if(i > 0)
+  for (size_t i = 0; i < list.size(); ++i) {
+    if (i > 0){
       out << ", ";
-    // make sure we use ivalue printing, and not default printing for the element type
-    out << IValue(v.get(i));
+    }
+    formatter(out, IValue(list[i]));
   }
   out << finish;
   return out;
 }
 
-template<class T>
-std::ostream& printList(std::ostream & out, const std::vector<T> &v,
-  const std::string start, const std::string finish) {
-  out << start;
-  for(size_t i = 0; i < v.size(); ++i) {
-    if(i > 0)
-      out << ", ";
-    // make sure we use ivalue printing, and not default printing for the element type
-    out << IValue(v[i]);
+// Properly disambiguate the type of an empty list
+std::ostream& printMaybeAnnotatedList(
+    std::ostream& out,
+    const IValue& the_list,
+    IValueFormatter formatter) {
+  if (the_list.toGenericListRef().size() == 0) {
+    out << "annotate(" << the_list.type()->python_str() << ", [])";
+  } else {
+    return printList(out, the_list.toGenericListRef(), "[", "]", formatter);
   }
-  out << finish;
   return out;
 }
 
-template<typename Dict>
-std::ostream& printDict(std::ostream& out, const Dict& v) {
+template <typename Dict>
+std::ostream& printDict(
+    std::ostream& out,
+    const Dict& v,
+    IValueFormatter formatter) {
   out << "{";
 
   bool first = true;
@@ -101,17 +108,83 @@
     if (!first) {
       out << ", ";
     }
-    out << pair.key() << ": " << pair.value();
+
+    formatter(out, pair.key());
+    out << ": ";
+    formatter(out, pair.value());
     first = false;
   }
 
   out << "}";
   return out;
 }
+}
 
-} // anonymous namespace
+std::ostream& IValue::repr(
+    std::ostream& out,
+    std::function<bool(std::ostream&, const IValue& v)>
+        customFormatter) const {
+  // First check if the caller has provided a custom formatter. Use that if possible.
+  if (customFormatter(out, *this)) {
+    return out;
+  }
+
+  const IValue& v = *this;
+  auto formatter = [&](std::ostream& out, const IValue& v) {
+    v.repr(out, customFormatter);
+  };
+  switch (v.tag) {
+    case IValue::Tag::None:
+      return out << v.toNone();
+    case IValue::Tag::Double: {
+      double d = v.toDouble();
+      int c = std::fpclassify(d);
+      if (c == FP_NORMAL || c == FP_ZERO) {
+        int64_t i = int64_t(d);
+        if (double(i) == d) {
+          return out << i << ".";
+        }
+      }
+      auto orig_prec = out.precision();
+      return out << std::setprecision(std::numeric_limits<double>::max_digits10)
+                 << v.toDouble() << std::setprecision(orig_prec);
+    }
+    case IValue::Tag::Int:
+      return out << v.toInt();
+    case IValue::Tag::Bool:
+      return out << (v.toBool() ? "True" : "False");
+    case IValue::Tag::Tuple: {
+      const auto& elements = v.toTuple()->elements();
+      const auto& finish = elements.size() == 1 ? ",)" : ")";
+      return printList(out, elements, "(", finish, formatter);
+    }
+    case IValue::Tag::String:
+      c10::printQuotedString(out, v.toStringRef());
+      return out;
+    case IValue::Tag::GenericList: {
+      auto formatter = [&](std::ostream& out, const IValue& v) {
+        v.repr(out, customFormatter);
+      };
+      return printMaybeAnnotatedList(out, *this, formatter);
+    }
+    case IValue::Tag::Device: {
+      std::stringstream device_stream;
+      device_stream << v.toDevice();
+      out << "torch.device(";
+      c10::printQuotedString(out, device_stream.str());
+      return out << ")";
+    }
+    case IValue::Tag::GenericDict:
+      return printDict(out, v.toGenericDict(), formatter);
+    default:
+      TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind());
+  }
+}
 
 std::ostream& operator<<(std::ostream & out, const IValue & v) {
+  auto formatter = [&](std::ostream& out, const IValue& v) {
+    out << v;
+  };
   switch(v.tag) {
     case IValue::Tag::None:
       return out << v.toNone();
@@ -138,7 +211,7 @@
     case IValue::Tag::Tuple: {
       const auto& elements = v.toTuple()->elements();
       const auto& finish = elements.size() == 1 ? ",)" : ")";
-      return printList(out, elements, "(", finish);
+      return printList(out, elements, "(", finish, formatter);
     }
     case IValue::Tag::String:
       return out << v.toStringRef();
@@ -147,7 +220,7 @@
     case IValue::Tag::Capsule:
       return out << "Capsule";
     case IValue::Tag::GenericList:
-      return printList(out, v.toGenericList(), "[", "]");
+      return printList(out, v.toGenericList(), "[", "]", formatter);
     case IValue::Tag::Future:
       return out << "Future";
     case IValue::Tag::Uninitialized:
@@ -155,7 +228,7 @@
     case IValue::Tag::Device:
       return out << v.toDevice();
     case IValue::Tag::GenericDict:
-      return printDict(out, v.toGenericDict());
+      return printDict(out, v.toGenericDict(), formatter);
     case IValue::Tag::Object:
       // TODO we should attempt to call __str__ if the object defines it.
       auto obj = v.toObject();
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index 956e1f5..05c3ec2 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -456,6 +456,26 @@
   /// this is a shallow comparison of two IValues to test the object identity
   bool isSameIdentity(const IValue& rhs) const;
 
+  // Computes the "official" string representation of an IValue. This produces a
+  // TorchScript expression that can be used to recreate an IValue with the same
+  // value (e.g. when we are printing constants in the serializer).
+  //
+  // Callers can use `customFormatter` to override how `repr()` prints out an
+  // IValue. This is useful if you have some other environment where you can
+  // look up values, and you want to print a reference to that environment (like
+  // the serializer's constant table).
+  //
+  // repr() is not necessarily defined on all objects!
+  std::ostream& repr(
+      std::ostream& stream,
+      std::function<bool(std::ostream&, const IValue& v)> customFormatter)
+      const;
+
+  // Computes an "informal" string representation of an IValue. This should be
+  // used for debugging, or servicing `print()`-like functions.
+  // This is different from `repr()` in that there is no expectation that we can
+  // exactly reconstruct an IValue from the output; feel free to use a
+  // concise/pretty form
   CAFFE2_API friend std::ostream& operator<<(
       std::ostream& out,
       const IValue& v);
diff --git a/test/test_jit.py b/test/test_jit.py
index b81e25f..a2d13e9 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -9154,13 +9154,14 @@
     def test_script_module_const(self):
         class M(torch.jit.ScriptModule):
 
-            __constants__ = ['b', 'i', 'c']
+            __constants__ = ['b', 'i', 'c', 's']
 
             def __init__(self):
                 super(M, self).__init__()
                 self.b = False
                 self.i = 1
                 self.c = 3.5
+                self.s = ["hello"]
 
             @torch.jit.script_method
             def forward(self):
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 0e84b0b..4d64ea9 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -777,48 +777,17 @@
     }
   }
 
-  void printMaybeAnnotatedConstantList(
-      std::ostream& stmt,
-      const char* the_type,
-      size_t list_size,
-      const IValue& the_list) {
-    if (list_size == 0) {
-      stmt << "annotate(List[" << the_type << "], [])";
-    } else {
-      stmt << the_list;
-    }
-  }
-
   void printConstant(TaggedStringStream& stmt, const IValue& v) {
-    std::stringstream ss;
-    if (v.isTensor()) {
-      ss << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
-    } else if (v.isString()) {
-      c10::printQuotedString(ss, v.toStringRef());
-    } else if (v.isDevice()) {
-      std::stringstream device_stream;
-      device_stream << v.toDevice();
-      ss << "torch.device(";
-      c10::printQuotedString(ss, device_stream.str());
-      ss << ")";
-    } else if (v.isTensorList()) {
-      ss << "[";
-      const char* delim = "";
-      for (const at::Tensor& t : v.toTensorListRef()) {
-        ss << delim << "CONSTANTS.c" << getOrAddTensorConstant(t);
-        delim = ", ";
+    const auto customFormatter = [&](std::ostream& ss, const IValue& v) {
+      if (v.isTensor()) {
+        ss << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
+        return true;
       }
-      ss << "]";
-    } else if (v.isBoolList()) {
-      printMaybeAnnotatedConstantList(ss, "bool", v.toBoolList().size(), v);
-    } else if (v.isIntList()) {
-      printMaybeAnnotatedConstantList(ss, "int", v.toIntListRef().size(), v);
-    } else if (v.isDoubleList()) {
-      printMaybeAnnotatedConstantList(
-          ss, "float", v.toDoubleListRef().size(), v);
-    } else {
-      ss << v;
-    }
+      return false;
+    };
+
+    std::stringstream ss;
+    v.repr(ss, customFormatter);
     stmt << ss.str();
   }