ModelAnalyer: Dump contents of int32 constant tensors
It gives easier understanding of some ops such as RESHAPE, TRANSPOSE.
```
Op#0 RESHAPE(T#1, T#4[512, 512]) -> [T#5]
Op#1 TRANSPOSE(T#5, T#3[1, 0]) -> [T#6]
```
PiperOrigin-RevId: 446279587
diff --git a/tensorflow/lite/python/analyzer_test.py b/tensorflow/lite/python/analyzer_test.py
index 34a0e02..9d6a01d 100644
--- a/tensorflow/lite/python/analyzer_test.py
+++ b/tensorflow/lite/python/analyzer_test.py
@@ -205,6 +205,36 @@
txt = mock_stdout.getvalue()
self.assertIn('Subgraph#0 main() -> [T#0]', txt)
+ def testTxtWithEinsum(self):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec(shape=[1, 100, 512], dtype=tf.float32),
+ tf.TensorSpec(shape=[512, 8, 64], dtype=tf.float32)
+ ])
+ def func(lhs, rhs):
+ return tf.einsum('ABD,DNH->ABNH', lhs, rhs)
+
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [func.get_concrete_function()], func)
+ fb_model = converter.convert()
+ mock_stdout = io.StringIO()
+ with test.mock.patch.object(sys, 'stdout', mock_stdout):
+ analyzer.ModelAnalyzer.analyze(model_content=fb_model)
+ txt = mock_stdout.getvalue()
+ self.assertIn('Op#0 RESHAPE(T#1, T#4[512, 512]) -> [T#5]', txt)
+ self.assertIn('Op#1 TRANSPOSE(T#5, T#3[1, 0]) -> [T#6]', txt)
+ self.assertIn('Op#2 FULLY_CONNECTED(T#0, T#6, T#-1) -> [T#7]', txt)
+ self.assertIn('Op#3 RESHAPE(T#7, T#2[1, 100, 8, 64]) -> [T#8]', txt)
+ self.assertIn(
+ 'T#2(einsum/Einsum) shape:[4], type:INT32 RO 16 bytes, '
+ 'data:[1, 100, 8, 64]', txt)
+ self.assertIn(
+ 'T#3(einsum/Einsum2) shape:[2], type:INT32 RO 8 bytes, '
+ 'data:[1, 0]', txt)
+ self.assertIn(
+ 'T#4(einsum/Einsum3) shape:[2], type:INT32 RO 8 bytes, '
+ 'data:[512, 512]', txt)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc b/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc
index d4609fa..19c8a08 100644
--- a/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc
+++ b/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc
@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
@@ -31,14 +32,56 @@
const float kThreshold_zero_buffer_ratio = 10.0f;
constexpr char kSectionSplitter[] =
"---------------------------------------------------------------\n";
+const int kMaxContentDumpCnt = 5;
+
+// Returns string representation of the given tensor data up to 5 elements.
+const std::string get_tensor_data_str(const tflite::Tensor* tensor,
+ const tflite::Model* model) {
+ std::stringstream ss;
+ auto buffer_idx = tensor->buffer();
+ if (buffer_idx != 0 && buffer_idx < model->buffers()->size()) {
+ auto* buffer = model->buffers()->Get(buffer_idx);
+ if (buffer->data() == nullptr) {
+ return "";
+ }
+ ss << "[";
+ if (buffer->data()->size() != 0) {
+ if (tensor->type() == tflite::TensorType_INT32) {
+ auto data = reinterpret_cast<const int32_t*>(buffer->data()->data());
+ int data_cnt = buffer->data()->size() / sizeof(int32_t);
+ for (int i = 0; i < std::min(kMaxContentDumpCnt, data_cnt); ++i) {
+ ss << data[i];
+ if (i != data_cnt - 1) {
+ ss << ", ";
+ }
+ }
+ if (data_cnt > kMaxContentDumpCnt) {
+ ss << "...";
+ }
+ }
+ }
+ ss << "]";
+ }
+ return ss.str();
+}
// Returns string representation of the given tensor of the subgraph.
-const std::string tensor_str(const int tensor_idx, const int subgraph_idx) {
+const std::string tensor_str(const int tensor_idx, const int subgraph_idx,
+ const tflite::Model* model = nullptr) {
std::stringstream ss;
if (subgraph_idx != 0 && tensor_idx != -1)
ss << "T#" << subgraph_idx << "_" << tensor_idx;
else
ss << "T#" << tensor_idx;
+ if (model && tensor_idx != -1) {
+ const SubGraph* subgraph = model->subgraphs()->Get(subgraph_idx);
+ if (subgraph) {
+ auto tensor = subgraph->tensors()->Get(tensor_idx);
+ if (tensor) {
+ ss << get_tensor_data_str(tensor, model);
+ }
+ }
+ }
return ss.str();
}
@@ -94,6 +137,7 @@
auto* buffer = model->buffers()->Get(buffer_idx);
if (buffer->data() && buffer->data()->size() != 0) {
out_stream << " RO " << buffer->data()->size() << " bytes";
+ out_stream << ", data:" << get_tensor_data_str(tensor, model);
stats->buffer_usage[subgraph_idx] += buffer->data()->size();
}
}
@@ -103,7 +147,9 @@
// Dump list of input or output tensors.
void dump_tensor_list(std::stringstream& out_stream,
const flatbuffers::Vector<int32_t>* tensors,
- const int subgraph_idx, bool verbose = false) {
+ const int subgraph_idx,
+ const tflite::Model* model = nullptr,
+ bool verbose = false) {
if (tensors == nullptr) {
return;
}
@@ -112,7 +158,7 @@
if (verbose) {
out_stream << "tensor #" << tensor_idx;
} else {
- out_stream << tensor_str(tensor_idx, subgraph_idx);
+ out_stream << tensor_str(tensor_idx, subgraph_idx, model);
}
if (i != tensors->Length() - 1) {
if (verbose) {
@@ -137,10 +183,10 @@
// Dump the given Operator node.
void dump_node(std::stringstream& out_stream, const int node_no,
const OperatorCode* op_code, const Operator* op,
- const int subgraph_index) {
+ const int subgraph_index, const ::tflite::Model* model) {
out_stream << "Op#" << node_no << " " << get_op_name(op_code);
out_stream << "(";
- dump_tensor_list(out_stream, op->inputs(), subgraph_index);
+ dump_tensor_list(out_stream, op->inputs(), subgraph_index, model);
if (GetBuiltinCode(op_code) == BuiltinOperator_CALL_ONCE) {
out_stream << subgraph_str(
op->builtin_options_as_CallOnceOptions()->init_subgraph_index());
@@ -179,9 +225,11 @@
model->operator_codes()->Get(first_op->opcode_index());
out_stream << "For example, in " << subgraph_str(0) << ", the "
<< get_op_name(first_op_code) << " op takes\n";
- dump_tensor_list(out_stream, first_op->inputs(), 0, /*verbose=*/true);
+ dump_tensor_list(out_stream, first_op->inputs(), 0, nullptr,
+ /*verbose=*/true);
out_stream << " as input and produces ";
- dump_tensor_list(out_stream, first_op->outputs(), 0, /*verbose=*/true);
+ dump_tensor_list(out_stream, first_op->outputs(), 0, nullptr,
+ /*verbose=*/true);
out_stream << " as output.\n\n";
}
}
@@ -359,7 +407,7 @@
const OperatorCode* op_code =
model->operator_codes()->Get(op->opcode_index());
out_stream << " "; // indents for operators
- dump_node(out_stream, /*node_no=*/j, op_code, op, i);
+ dump_node(out_stream, /*node_no=*/j, op_code, op, i, model);
if (check_gpu_compatibility) {
auto status =
CheckGpuDelegateCompatibility(op_code, op, subgraph, model);