[net_runner] Get shape info from qtensors (#34321)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34321
Mostly cosmetic as we can infer the shape anyway. It can remove a lot of the noise in the log though.
Note that weight sharing doesn't work yet. I'll add another diff to address this.
Reviewed By: houseroad
Differential Revision: D20290841
fbshipit-source-id: fe6f9b60d05dbe150af15b5d9d7a69fd902e12cc
diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc
index c1d6458..9bcdc7b 100644
--- a/caffe2/utils/proto_utils.cc
+++ b/caffe2/utils/proto_utils.cc
@@ -10,11 +10,11 @@
#include <google/protobuf/io/coded_stream.h>
#ifndef CAFFE2_USE_LITE_PROTO
-#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
#else
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
-#endif // !CAFFE2_USE_LITE_PROTO
+#endif // !CAFFE2_USE_LITE_PROTO
#include "caffe2/core/logging.h"
@@ -104,7 +104,9 @@
public:
explicit IfstreamInputStream(const string& filename)
: ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
- ~IfstreamInputStream() { ifs_.close(); }
+ ~IfstreamInputStream() {
+ ifs_.close();
+ }
int Read(void* buffer, int size) {
if (!ifs_) {
@@ -117,7 +119,7 @@
private:
std::ifstream ifs_;
};
-} // namespace
+} // namespace
C10_EXPORT string ProtoDebugString(const MessageLite& proto) {
string serialized = proto.SerializeAsString();
@@ -158,17 +160,17 @@
LOG(FATAL) << "Not implemented yet.";
}
-#else // CAFFE2_USE_LITE_PROTO
+#else // CAFFE2_USE_LITE_PROTO
// Full protocol buffer.
+using ::google::protobuf::Message;
+using ::google::protobuf::io::CodedInputStream;
+using ::google::protobuf::io::CodedOutputStream;
using ::google::protobuf::io::FileInputStream;
using ::google::protobuf::io::FileOutputStream;
using ::google::protobuf::io::ZeroCopyInputStream;
-using ::google::protobuf::io::CodedInputStream;
using ::google::protobuf::io::ZeroCopyOutputStream;
-using ::google::protobuf::io::CodedOutputStream;
-using ::google::protobuf::Message;
namespace TextFormat {
C10_EXPORT bool ParseFromString(const string& spec, Message* proto) {
@@ -178,8 +180,7 @@
auto num_replaced = c10::ReplaceAll(bc_spec, "cuda_gpu_id", "device_id");
if (num_replaced) {
LOG(ERROR) << "Your model was serialized in Protobuf TextFormat and "
- << "it has "
- << num_replaced
+ << "it has " << num_replaced
<< " places using the deprecated field name 'cuda_gpu_id'!\n"
<< spec
<< "\nPlease re-export your model in Protobuf binary format "
@@ -187,7 +188,8 @@
}
}
- return ::google::protobuf::TextFormat::ParseFromString(std::move(bc_spec), proto);
+ return ::google::protobuf::TextFormat::ParseFromString(
+ std::move(bc_spec), proto);
}
} // namespace TextFormat
@@ -226,7 +228,7 @@
C10_EXPORT bool ReadProtoFromBinaryFile(
const char* filename,
MessageLite* proto) {
-#if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified
+#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
int fd = open(filename, O_RDONLY | O_BINARY);
#else
int fd = open(filename, O_RDONLY);
@@ -259,7 +261,7 @@
close(fd);
}
-#endif // CAFFE2_USE_LITE_PROTO
+#endif // CAFFE2_USE_LITE_PROTO
C10_EXPORT ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
for (auto& arg : def.arg()) {
@@ -274,8 +276,7 @@
ProtoDebugString(def));
} else {
LOG(WARNING) << "Duplicated argument name [" << arg.name()
- << "] found in operator def: "
- << ProtoDebugString(def);
+ << "] found in operator def: " << ProtoDebugString(def);
}
}
arg_map_[arg.name()] = arg;
@@ -286,7 +287,9 @@
for (auto& arg : netdef.arg()) {
CAFFE_ENFORCE(
arg_map_.count(arg.name()) == 0,
- "Duplicated argument name [", arg.name(), "] found in net def: ",
+ "Duplicated argument name [",
+ arg.name(),
+ "] found in net def: ",
ProtoDebugString(netdef));
arg_map_[arg.name()] = arg;
}
@@ -303,7 +306,7 @@
bool SupportsLosslessConversion(const InputType& value) {
return static_cast<InputType>(static_cast<TargetType>(value)) == value;
}
-}
+} // namespace
bool operator==(const TensorProto& l, const TensorProto& r) {
return l.SerializeAsString() == r.SerializeAsString();
}
@@ -312,6 +315,14 @@
output << n.SerializeAsString();
return output;
}
+bool operator==(const QTensorProto& l, const QTensorProto& r) {
+ return l.SerializeAsString() == r.SerializeAsString();
+}
+
+std::ostream& operator<<(std::ostream& output, const QTensorProto& n) {
+ output << n.SerializeAsString();
+ return output;
+}
bool operator==(const NetDef& l, const NetDef& r) {
return l.SerializeAsString() == r.SerializeAsString();
}
@@ -412,6 +423,7 @@
INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(TensorProto, tensors, false)
+INSTANTIATE_GET_REPEATED_ARGUMENT(QTensorProto, qtensors, false)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
#define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \