blob: 7b18420bd9941243198224f2417e28b1fbfd3564 [file] [log] [blame]
#include "torch/csrc/onnx/onnx.h"
namespace torch { namespace onnx {
template <>
bool micropb_encode<std::string, nullptr>(pb_ostream_t *stream, std::string* arg) {
return pb_encode_string(stream, reinterpret_cast<const pb_byte_t *>(arg->c_str()), arg->size());
}
// NB: Overloads don't work so great for signed variables. Hope this doesn't
// come up!
template <>
bool micropb_encode<int64_t, nullptr>(pb_ostream_t *stream, int64_t* arg) {
// Yes, this looks dodgy, and yes, this is what the docs say to do:
// https://jpa.kapsi.fi/nanopb/docs/reference.html#pb-encode-varint
return pb_encode_varint(stream, *reinterpret_cast<uint64_t*>(arg));
}
template <>
bool micropb_encode<float, nullptr>(pb_ostream_t *stream, float* arg) {
return pb_encode_fixed32(stream, static_cast<void*>(arg));
}
template <>
bool micropb_encode<double, nullptr>(pb_ostream_t *stream, double* arg) {
return pb_encode_fixed64(stream, static_cast<void*>(arg));
}
template <>
bool micropb_encode<Dimension, nullptr>(pb_ostream_t *stream, Dimension* arg) {
return pb_encode_submessage(stream, onnx_TensorShapeProto_Dimension_fields,
static_cast<void*>(arg));
}
// TODO: I'm not entirely sure why this can't be in the header...
bool micropb_callback_string_from_tensor(pb_ostream_t *stream, const pb_field_t *field, void * const *arg) {
at::Tensor* t = static_cast<at::Tensor*>(*arg);
JIT_ASSERT(t->is_contiguous());
// Packed array format!
pb_encode_tag_for_field(stream, field);
pb_encode_string(stream, (pb_byte_t*)(t->data_ptr()), t->type().elementSizeInBytes()*t->numel());
return true;
}
GraphProto* AttributeProto::add_graphs() {
auto ptr = new GraphProto();
graphs.emplace_back(ptr);
return ptr;
}
constexpr char indent_char = ' ';
constexpr size_t indent_multiplier = 2;
std::string idt(size_t indent) {
return std::string(indent * indent_multiplier, indent_char);
}
std::string nlidt(size_t indent) {
return std::string("\n") + idt(indent);
}
void TensorProto::dump(std::ostream& stream, size_t indent) {
stream << "TensorProto shape: [";
for (size_t i = 0; i < dims.size(); ++i) {
stream << *dims[i] << (i == dims.size() - 1 ? "" : " ");
}
stream << "]";
}
void TensorShapeProto::dump(std::ostream& stream, size_t indent) {
for (size_t i=0; i < dims.size(); ++i) {
auto &dim = dims[i];
if (dim->has_dim_value) {
stream << dim->dim_value;
} else {
stream << "?";
}
stream << (i == dims.size() - 1 ? "" : " ");
}
}
void TypeProtoTensor::dump(std::ostream& stream, size_t indent) {
stream << "Tensor dims: ";
shape->dump(stream);
}
void TypeProto::dump(std::ostream& stream, size_t indent) {
tensor_type->dump(stream);
}
void ValueInfoProto::dump(std::ostream& stream, size_t indent) {
stream << "{name: \"" << name
<< "\", type:";
type->dump(stream);
stream << "}";
}
void AttributeProto::dump(std::ostream& stream, size_t indent) {
stream << "{ name: '" << name << "', type: ";
if (proto.has_f) {
stream << "float, value: " << proto.f;
} else if (proto.has_i) {
stream << "int, value: " << proto.i;
} else if (s.length()) {
stream << "string, value: '" << s << "'";
} else if (g) {
stream << "graph, value:\n";
g->dump(stream, indent+1);
stream << nlidt(indent);
} else if (t) {
stream << "tensor, value:";
t->dump(stream, indent+1);
} else if (floats.size()) {
stream << "floats, values: [";
for (size_t i=0; i < floats.size(); ++i)
stream << *floats[i] << (i == floats.size() - 1 ? "" : " ");
stream << "]";
} else if (ints.size()) {
stream << "ints, values: [";
for (size_t i=0; i < ints.size(); ++i)
stream << *ints[i] << (i == ints.size() - 1 ? "" : " ");
stream << "]";
} else if (strings.size()) {
stream << "strings, values: [";
for (size_t i=0; i < strings.size(); ++i)
stream << "'" << *strings[i] << "'" << (i == strings.size() - 1 ? "" : " ");
stream << "]";
} else if (tensors.size()) {
stream << "tensors, values: [";
for (auto& t : tensors) {
t->dump(stream, indent+1);
}
stream << "]";
} else if (graphs.size()) {
stream << "graphs, values: [";
for (auto& g : graphs) {
g->dump(stream, indent+1);
}
stream << "]";
} else {
stream << "UNKNOWN";
}
stream << "}";
}
void NodeProto::dump(std::ostream& stream, size_t indent) {
stream << "Node {type: \"" << op_type << "\", inputs: [";
for (size_t i=0; i < inputs.size(); ++i) {
stream << *inputs[i] << (i == inputs.size() - 1 ? "" : ",");
}
stream << "], outputs: [";
for (size_t i=0; i < outputs.size(); ++i) {
stream << *outputs[i] << (i == outputs.size() - 1 ? "" : ",");
}
stream << "], attributes: [";
for (size_t i=0; i < attributes.size(); ++i) {
attributes[i]->dump(stream, indent+1);
stream << (i == attributes.size() - 1 ? "" : ",");
}
stream << "]}";
}
void GraphProto::dump(std::ostream& stream, size_t indent) {
stream << idt(indent) << "GraphProto {" << nlidt(indent+1)
<< "name: \"" << name << "\"" << nlidt(indent+1)
<< "inputs: [";
for (size_t i=0; i < inputs.size(); ++i) {
inputs[i]->dump(stream, indent+2);
stream << (i == inputs.size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent+1)
<< "outputs: [";
for (size_t i=0; i < outputs.size(); ++i) {
outputs[i]->dump(stream, indent+2);
stream << (i == outputs.size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent+1)
<< "initializers: [";
for (size_t i=0; i < initializers.size(); ++i) {
initializers[i]->dump(stream, indent+2);
stream << (i == initializers.size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent+1)
<< "nodes: [" << nlidt(indent+2);
for (size_t i=0; i < nodes.size(); ++i) {
nodes[i]->dump(stream, indent+2);
if (i != nodes.size() - 1) stream << "," << nlidt(indent+2);
}
stream << nlidt(indent+1) << "]\n" << idt(indent) << "}\n";
}
void OperatorSetIdProto::dump(std::ostream& stream, size_t indent) {
stream << "OperatorSetIdProto { domain: " << domain << "}";
}
void ModelProto::dump(std::ostream& stream, size_t indent) {
stream << idt(indent)
<< "ModelProto {" << nlidt(indent+1)
<< "producer_name: \"" << producer_name << "\"" << nlidt(indent+1)
<< "domain: \"" << domain << "\"" << nlidt(indent+1)
<< "doc_string: \"" << doc_string << "\"";
if (graph) {
stream << nlidt(indent+1) << "graph:\n";
graph->dump(stream, indent+2);
}
if (opset_import.size()) {
stream << idt(indent+1) << "opset_import: [";
for (auto &opset_imp : opset_import) {
opset_imp->dump(stream, indent+2);
}
stream << "],\n";
}
stream << idt(indent) << "}\n";
}
}} // namespace onnx