Add operator<< for at::Type
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
diff --git a/aten/src/ATen/Formatting.cpp b/aten/src/ATen/Formatting.cpp
index 353158a..b7aaabd 100644
--- a/aten/src/ATen/Formatting.cpp
+++ b/aten/src/ATen/Formatting.cpp
@@ -49,6 +49,10 @@
return out << toString(t);
}
+std::ostream& operator<<(std::ostream & out, const Type& t) {
+ return out << t.toString();
+}
+
static std::tuple<double, int64_t> __printFormat(std::ostream& stream, const Tensor& self) {
auto size = self.numel();
if(size == 0) {
diff --git a/aten/src/ATen/Formatting.h b/aten/src/ATen/Formatting.h
index 37dfb3e..fe496a1 100644
--- a/aten/src/ATen/Formatting.h
+++ b/aten/src/ATen/Formatting.h
@@ -9,6 +9,7 @@
AT_API std::ostream& operator<<(std::ostream & out, IntList list);
AT_API std::ostream& operator<<(std::ostream & out, Backend b);
AT_API std::ostream& operator<<(std::ostream & out, ScalarType t);
+AT_API std::ostream& operator<<(std::ostream & out, const Type & t);
AT_API std::ostream& print(std::ostream& stream, const Tensor & tensor, int64_t linesize);
static inline std::ostream& operator<<(std::ostream & out, const Tensor & t) {
return print(out,t,80);