blob: 2819deeae6eefff3768834dc2a43348e5c61a795 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This program runs the tflite model specified in --model with random inputs.
// For string type, the input is filled with a fixed string.
#include <string>
#include <glog/logging.h>
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/command_line_flags.h"
void FillRandomString(tflite::DynamicBuffer* buffer,
const TfLiteIntArray* dim_array,
const std::function<std::string()>& random_func) {
int num_elements = 1;
for (size_t i = 0; i < dim_array->size; i++) {
num_elements *= dim_array->data[i];
}
for (int i = 0; i < num_elements; ++i) {
auto str = random_func();
buffer->AddString(str.data(), str.length());
}
}
void RunWithRandomInputs(const std::string& filename) {
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(filename.c_str());
// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
LOG(FATAL) << "Could not initialize interpreter for TFLite model.";
}
// Resize input tensors, if desired.
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(FATAL) << "Could not allocate tensor.";
}
// Fill the random data.
std::vector<std::vector<uint8_t>> sample;
for (int tensor_idx : interpreter->inputs()) {
auto tensor = interpreter->tensor(tensor_idx);
if (tensor->type == kTfLiteString) {
tflite::DynamicBuffer buffer;
FillRandomString(&buffer, tensor->dims, []() {
return "we're have some friends over saturday to hang out in the "
"yard";
});
buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
} else {
std::vector<uint8_t> data(tensor->bytes);
for (auto it = data.begin(); it != data.end(); ++it) {
*it = random();
}
sample.push_back(data);
tensor->data.raw = reinterpret_cast<char*>(sample.rbegin()->data());
}
}
// Running inference.
if (interpreter->Invoke() != kTfLiteOk) {
LOG(FATAL) << "Failed to run the model.";
}
// Get the output.
for (int tensor_idx : interpreter->outputs()) {
auto tensor = interpreter->tensor(tensor_idx);
LOG(INFO) << "Output type: " << TfLiteTypeGetName(tensor->type);
}
}
int main(int argc, char** argv) {
// Parse flags to get the filename.
std::string filename;
std::vector<tflite::Flag> flag_list{tflite::Flag::CreateFlag(
"model", &filename, "The tflite model to run sample inference.",
tflite::Flag::kRequired)};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
tensorflow::port::InitMain(argv[0], &argc, &argv);
// Run the model with random inputs.
RunWithRandomInputs(filename);
return 0;
}