blob: 097119f12278499fa77b2cae284467eecd53bda8 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains the logic of android model wrapper generation.
//
// At the beginning is the helper functions handling metadata and code writer.
//
// Codes are generated in every `Generate{FOO}` functions. Gradle and Manifest
// files are simple. The wrapper file generation is a bit complex so we divided
// it into several sub-functions.
//
// The structure of the wrapper file looks like:
//
// [ imports ]
// [ class ]
// [ inner "Outputs" class ]
// [ innner "Metadata" class ]
// [ APIs ] ( including ctors, public APIs and private APIs )
//
// We tried to mostly write it in a "template-generation" way. `CodeWriter` does
// the job as a template renderer. To avoid repeatedly setting the token values,
// helper functions `SetCodeWriterWith{Foo}Info` set the token values with info
// structures (`TensorInfo` and `ModelInfo`) - the Info structures are
// intermediate datastructures between Metadata (represented in Flatbuffers) and
// generated code.
#include "tensorflow_lite_support/codegen/android_java_generator.h"
#include <ctype.h>
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow_lite_support/codegen/code_generator.h"
#include "tensorflow_lite_support/codegen/metadata_helper.h"
#include "tensorflow_lite_support/codegen/utils.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
namespace {
using details_android_java::ModelInfo;
using details_android_java::TensorInfo;
// Helper class to organize the C++ code block as a generated code block.
// Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
class AsBlock {
public:
AsBlock(CodeWriter* code_writer, const std::string& before,
bool trailing_blank_line = false)
: code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
code_writer_->AppendNoNewLine(before);
code_writer_->Append(" {");
code_writer_->Indent();
}
~AsBlock() {
code_writer_->Outdent();
code_writer_->Append("}");
if (trailing_blank_line_) {
code_writer_->NewLine();
}
}
private:
CodeWriter* code_writer_;
bool trailing_blank_line_;
};
// Declare the functions first, so that the functions can follow a logical
// order.
bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*);
std::string GetModelVersionedName(const ModelMetadata* metadata) {
std::string model_name = "MyModel";
if (metadata->name() != nullptr && !(metadata->name()->str().empty())) {
model_name = metadata->name()->str();
}
std::string model_version = "unknown";
if (metadata->version() != nullptr && !(metadata->version()->str().empty())) {
model_version = metadata->version()->str();
}
return model_name + " (Version: " + model_version + ")";
}
TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
const std::string& name, bool is_input, int index,
ErrorReporter* err) {
TensorInfo tensor_info;
std::string tensor_identifier = is_input ? "input" : "output";
tensor_identifier += " " + std::to_string(index);
tensor_info.associated_axis_label_index = FindAssociatedFile(
metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err);
tensor_info.associated_value_label_index = FindAssociatedFile(
metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err);
if (is_input && (tensor_info.associated_axis_label_index >= 0 ||
tensor_info.associated_value_label_index >= 0)) {
err->Warning(
"Found label file on input tensor (%s). Label file for input "
"tensor is not supported yet. The "
"file will be ignored.",
tensor_identifier.c_str());
}
if (tensor_info.associated_axis_label_index >= 0 &&
tensor_info.associated_value_label_index >= 0) {
err->Warning(
"Found both axis label file and value label file for tensor (%s), "
"which is not supported. Only the axis label file will be used.",
tensor_identifier.c_str());
}
tensor_info.is_input = is_input;
tensor_info.name = SnakeCaseToCamelCase(name);
tensor_info.upper_camel_name = tensor_info.name;
tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
tensor_info.normalization_unit =
FindNormalizationUnit(metadata, tensor_identifier, err);
if (metadata->content() != nullptr &&
metadata->content()->content_properties() != nullptr) {
// Enter tensor wrapper type inferring
if (metadata->content()->content_properties_type() ==
ContentProperties_ImageProperties) {
if (metadata->content()
->content_properties_as_ImageProperties()
->color_space() == ColorSpaceType_RGB) {
tensor_info.content_type = "image";
tensor_info.wrapper_type = "TensorImage";
tensor_info.processor_type = "ImageProcessor";
return tensor_info;
} else {
err->Warning(
"Found Non-RGB image on tensor (%s). Codegen currently does not "
"support it, and regard it as a plain numeric tensor.",
tensor_identifier.c_str());
}
}
}
tensor_info.content_type = "tensor";
tensor_info.wrapper_type = "TensorBuffer";
tensor_info.processor_type = "TensorProcessor";
return tensor_info;
}
ModelInfo CreateModelInfo(const ModelMetadata* metadata,
const std::string& package_name,
const std::string& model_class_name,
const std::string& model_asset_path,
ErrorReporter* err) {
ModelInfo model_info;
if (!CodeGenerator::VerifyMetadata(metadata, err)) {
// TODO(b/150116380): Create dummy model info.
err->Error("Validating metadata failed.");
return model_info;
}
model_info.package_name = package_name;
model_info.model_class_name = model_class_name;
model_info.model_asset_path = model_asset_path;
model_info.model_versioned_name = GetModelVersionedName(metadata);
const auto* graph = metadata->subgraph_metadata()->Get(0);
auto names = CodeGenerator::NameInputsAndOutputs(
graph->input_tensor_metadata(), graph->output_tensor_metadata());
std::vector<std::string> input_tensor_names = std::move(names.first);
std::vector<std::string> output_tensor_names = std::move(names.second);
for (int i = 0; i < input_tensor_names.size(); i++) {
model_info.inputs.push_back(
CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
input_tensor_names[i], true, i, err));
if (i < input_tensor_names.size() - 1) {
model_info.inputs_list += ", ";
model_info.input_type_param_list += ", ";
}
model_info.inputs_list += model_info.inputs[i].name;
model_info.input_type_param_list +=
model_info.inputs[i].wrapper_type + " " + model_info.inputs[i].name;
}
for (int i = 0; i < output_tensor_names.size(); i++) {
model_info.outputs.push_back(
CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
output_tensor_names[i], false, i, err));
if (i < output_tensor_names.size() - 1) {
model_info.postprocessor_type_param_list += ", ";
model_info.postprocessors_list += ", ";
}
model_info.postprocessors_list +=
model_info.outputs[i].name + "Postprocessor";
model_info.postprocessor_type_param_list +=
model_info.outputs[i].processor_type + " " +
model_info.outputs[i].name + "Postprocessor";
}
return model_info;
}
void SetCodeWriterWithTensorInfo(CodeWriter* code_writer,
const TensorInfo& tensor_info) {
code_writer->SetTokenValue("NAME", tensor_info.name);
code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name);
code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type);
code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type);
std::string wrapper_name = tensor_info.wrapper_type;
wrapper_name[0] = tolower(wrapper_name[0]);
code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name);
code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type);
code_writer->SetTokenValue("NORMALIZATION_UNIT",
std::to_string(tensor_info.normalization_unit));
code_writer->SetTokenValue(
"ASSOCIATED_AXIS_LABEL_INDEX",
std::to_string(tensor_info.associated_axis_label_index));
code_writer->SetTokenValue(
"ASSOCIATED_VALUE_LABEL_INDEX",
std::to_string(tensor_info.associated_value_label_index));
}
void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
const ModelInfo& model_info) {
code_writer->SetTokenValue("PACKAGE", model_info.package_name);
code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
// Extra info, half generated.
code_writer->SetTokenValue("INPUT_TYPE_PARAM_LIST",
model_info.input_type_param_list);
code_writer->SetTokenValue("INPUTS_LIST", model_info.inputs_list);
code_writer->SetTokenValue("POSTPROCESSORS_LIST",
model_info.postprocessors_list);
code_writer->SetTokenValue("POSTPROCESSOR_TYPE_PARAM_LIST",
model_info.postprocessor_type_param_list);
}
constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
std::string ConvertPackageToPath(const std::string& package) {
if (package == JAVA_DEFAULT_PACKAGE) {
return "";
}
std::string path = package;
std::replace(path.begin(), path.end(), '.', '/');
return path;
}
bool IsImageUsed(const ModelInfo& model) {
for (const auto& input : model.inputs) {
if (input.content_type == "image") {
return true;
}
}
for (const auto& output : model.outputs) {
if (output.content_type == "image") {
return true;
}
}
return false;
}
// The following functions generates the wrapper Java code for a model.
bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("// Generated by TFLite Support.");
code_writer->Append("package {{PACKAGE}};");
code_writer->NewLine();
if (!GenerateWrapperImports(code_writer, model, err)) {
err->Error("Fail to generate imports for wrapper class.");
return false;
}
if (!GenerateWrapperClass(code_writer, model, err)) {
err->Error("Fail to generate wrapper class.");
return false;
}
code_writer->NewLine();
return true;
}
bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
const std::string support_pkg = "org.tensorflow.lite.support.";
std::vector<std::string> imports{
"android.content.Context",
"java.io.IOException",
"java.nio.ByteBuffer",
"java.nio.FloatBuffer",
"java.util.Arrays",
"java.util.HashMap",
"java.util.List",
"java.util.Map",
"org.tensorflow.lite.DataType",
"org.tensorflow.lite.Tensor",
"org.tensorflow.lite.Tensor.QuantizationParams",
support_pkg + "common.FileUtil",
support_pkg + "common.TensorProcessor",
support_pkg + "common.ops.CastOp",
support_pkg + "common.ops.DequantizeOp",
support_pkg + "common.ops.NormalizeOp",
support_pkg + "common.ops.QuantizeOp",
support_pkg + "label.Category",
support_pkg + "label.TensorLabel",
support_pkg + "metadata.MetadataExtractor",
support_pkg + "metadata.schema.NormalizationOptions",
support_pkg + "model.Model",
support_pkg + "tensorbuffer.TensorBuffer",
};
if (IsImageUsed(model)) {
for (const auto& target :
{"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp",
"image.ops.ResizeOp.ResizeMethod"}) {
imports.push_back(support_pkg + target);
}
}
std::sort(imports.begin(), imports.end());
for (const auto& target : imports) {
code_writer->SetTokenValue("TARGET", target);
code_writer->Append("import {{TARGET}};");
}
code_writer->NewLine();
return true;
}
bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
model.model_versioned_name);
code_writer->Append(
R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)");
const auto code_block =
AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}");
code_writer->Append(R"(private final Metadata metadata;
private final Model model;
private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
}
code_writer->NewLine();
if (!GenerateWrapperOutputs(code_writer, model, err)) {
err->Error("Failed to generate output classes");
return false;
}
code_writer->NewLine();
if (!GenerateWrapperMetadata(code_writer, model, err)) {
err->Error("Failed to generate the metadata class");
return false;
}
code_writer->NewLine();
if (!GenerateWrapperAPI(code_writer, model, err)) {
err->Error("Failed to generate the common APIs");
return false;
}
return true;
}
bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
auto class_block = AsBlock(code_writer, "public static class Outputs");
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
if (tensor.associated_axis_label_index >= 0) {
code_writer->Append("private final List<String> {{NAME}}Labels;");
}
code_writer->Append(
"private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
}
// Getters
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->NewLine();
if (tensor.associated_axis_label_index >= 0) {
if (tensor.content_type == "tensor") {
code_writer->Append(
R"(public List<Category> get{{NAME_U}}AsCategoryList() {
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList();
})");
} else { // image
err->Warning(
"Axis label for images is not supported. The labels will "
"be ignored.");
}
} else { // no label
code_writer->Append(
R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() {
return postprocess{{NAME_U}}({{NAME}});
})");
}
}
code_writer->NewLine();
{
const auto ctor_block = AsBlock(
code_writer,
"Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})");
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
if (tensor.content_type == "image") {
code_writer->Append(
R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
} else { // FEATURE, UNKNOWN
code_writer->Append(
"{{NAME}} = "
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
"metadata.get{{NAME_U}}Type());");
}
if (tensor.associated_axis_label_index >= 0) {
code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();");
}
code_writer->Append(
"this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;");
}
}
code_writer->NewLine();
{
const auto get_buffer_block =
AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
for (int i = 0; i < model.outputs.size(); i++) {
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
}
code_writer->Append("return outputs;");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->NewLine();
{
auto processor_block =
AsBlock(code_writer,
"private {{WRAPPER_TYPE}} "
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
code_writer->Append(
"return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});");
}
}
return true;
}
bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append(
"/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
const auto class_block = AsBlock(code_writer, "public static class Metadata");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(private final int[] {{NAME}}Shape;
private final DataType {{NAME}}DataType;
private final QuantizationParams {{NAME}}QuantizationParams;)");
if (tensor.normalization_unit >= 0) {
code_writer->Append(R"(private final float[] {{NAME}}Mean;
private final float[] {{NAME}}Stddev;)");
}
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(private final int[] {{NAME}}Shape;
private final DataType {{NAME}}DataType;
private final QuantizationParams {{NAME}}QuantizationParams;)");
if (tensor.normalization_unit >= 0) {
code_writer->Append(R"(private final float[] {{NAME}}Mean;
private final float[] {{NAME}}Stddev;)");
}
if (tensor.associated_axis_label_index >= 0 ||
tensor.associated_value_label_index >= 0) {
code_writer->Append("private final List<String> {{NAME}}Labels;");
}
}
code_writer->NewLine();
{
const auto ctor_block = AsBlock(
code_writer,
"public Metadata(ByteBuffer buffer, Model model) throws IOException");
code_writer->Append(
"MetadataExtractor extractor = new MetadataExtractor(buffer);");
for (int i = 0; i < model.inputs.size(); i++) {
SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append(
R"(Tensor {{NAME}}Tensor = model.getInputTensor({{ID}});
{{NAME}}Shape = {{NAME}}Tensor.shape();
{{NAME}}DataType = {{NAME}}Tensor.dataType();
{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
if (model.inputs[i].normalization_unit >= 0) {
code_writer->Append(
R"(NormalizationOptions {{NAME}}NormalizationOptions =
(NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
{{NAME}}MeanBuffer.get({{NAME}}Mean);
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
}
}
for (int i = 0; i < model.outputs.size(); i++) {
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append(
R"(Tensor {{NAME}}Tensor = model.getOutputTensor({{ID}});
{{NAME}}Shape = {{NAME}}Tensor.shape();
{{NAME}}DataType = {{NAME}}Tensor.dataType();
{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
if (model.outputs[i].normalization_unit >= 0) {
code_writer->Append(
R"(NormalizationOptions {{NAME}}NormalizationOptions =
(NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
{{NAME}}MeanBuffer.get({{NAME}}Mean);
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
}
if (model.outputs[i].associated_axis_label_index >= 0) {
code_writer->Append(R"(String {{NAME}}LabelsFileName =
extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name();
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
} else if (model.outputs[i].associated_value_label_index >= 0) {
code_writer->Append(R"(String {{NAME}}LabelsFileName =
extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name();
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
}
}
}
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public int[] get{{NAME_U}}Shape() {
return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
}
public DataType get{{NAME_U}}Type() {
return {{NAME}}DataType;
}
public QuantizationParams get{{NAME_U}}QuantizationParams() {
return {{NAME}}QuantizationParams;
})");
if (tensor.normalization_unit >= 0) {
code_writer->Append(R"(
public float[] get{{NAME_U}}Mean() {
return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
}
public float[] get{{NAME_U}}Stddev() {
return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
})");
}
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public int[] get{{NAME_U}}Shape() {
return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
}
public DataType get{{NAME_U}}Type() {
return {{NAME}}DataType;
}
public QuantizationParams get{{NAME_U}}QuantizationParams() {
return {{NAME}}QuantizationParams;
})");
if (tensor.normalization_unit >= 0) {
code_writer->Append(R"(
public float[] get{{NAME_U}}Mean() {
return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
}
public float[] get{{NAME_U}}Stddev() {
return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
})");
}
if (tensor.associated_axis_label_index >= 0 ||
tensor.associated_value_label_index >= 0) {
code_writer->Append(R"(
public List<String> get{{NAME_U}}Labels() {
return {{NAME}}Labels;
})");
}
}
return true;
}
bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append(R"(public Metadata getMetadata() {
return metadata;
}
)");
code_writer->Append(R"(/**
* Creates interpreter and loads associated files if needed.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public static {{MODEL_CLASS_NAME}} newInstance(Context context) throws IOException {
return newInstance(context, MODEL_NAME, new Model.Options.Builder().build());
}
/**
* Creates interpreter and loads associated files if needed, but loading another model in the same
* input / output structure with the original one.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath) throws IOException {
return newInstance(context, modelPath, new Model.Options.Builder().build());
}
/**
* Creates interpreter and loads associated files if needed, with running options configured.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public static {{MODEL_CLASS_NAME}} newInstance(Context context, Model.Options runningOptions) throws IOException {
return newInstance(context, MODEL_NAME, runningOptions);
}
/**
* Creates interpreter for a user-specified model.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath, Model.Options runningOptions) throws IOException {
Model model = Model.createModel(context, modelPath, runningOptions);
Metadata metadata = new Metadata(model.getData(), model);
MyImageClassifier instance = new MyImageClassifier(model, metadata);)");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(
R"( instance.reset{{NAME_U}}Preprocessor(
instance.buildDefault{{NAME_U}}Preprocessor());)");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(
R"( instance.reset{{NAME_U}}Postprocessor(
instance.buildDefault{{NAME_U}}Postprocessor());)");
}
code_writer->Append(R"( return instance;
}
)");
// Pre, post processor setters
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public void reset{{NAME_U}}Preprocessor({{PROCESSOR_TYPE}} processor) {
{{NAME}}Preprocessor = processor;
})");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public void reset{{NAME_U}}Postprocessor({{PROCESSOR_TYPE}} processor) {
{{NAME}}Postprocessor = processor;
})");
}
// Process method
code_writer->Append(R"(
/** Triggers the model. */
public Outputs process({{INPUT_TYPE_PARAM_LIST}}) {
Outputs outputs = new Outputs(metadata, {{POSTPROCESSORS_LIST}});
Object[] inputBuffers = preprocessInputs({{INPUTS_LIST}});
model.run(inputBuffers, outputs.getBuffer());
return outputs;
}
/** Closes the model. */
public void close() {
model.close();
}
)");
{
auto block =
AsBlock(code_writer,
"private {{MODEL_CLASS_NAME}}(Model model, Metadata metadata)");
code_writer->Append(R"(this.model = model;
this.metadata = metadata;)");
}
for (const auto& tensor : model.inputs) {
code_writer->NewLine();
SetCodeWriterWithTensorInfo(code_writer, tensor);
auto block = AsBlock(
code_writer,
"private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Preprocessor()");
code_writer->Append(
"{{PROCESSOR_TYPE}}.Builder builder = new "
"{{PROCESSOR_TYPE}}.Builder()");
if (tensor.content_type == "image") {
code_writer->Append(R"( .add(new ResizeOp(
metadata.get{{NAME_U}}Shape()[1],
metadata.get{{NAME_U}}Shape()[2],
ResizeMethod.NEAREST_NEIGHBOR)))");
}
if (tensor.normalization_unit >= 0) {
code_writer->Append(
R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
}
code_writer->Append(
R"( .add(new QuantizeOp(
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
metadata.get{{NAME_U}}QuantizationParams().getScale()))
.add(new CastOp(metadata.get{{NAME_U}}Type()));
return builder.build();)");
}
for (const auto& tensor : model.outputs) {
code_writer->NewLine();
SetCodeWriterWithTensorInfo(code_writer, tensor);
auto block = AsBlock(
code_writer,
"private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Postprocessor()");
code_writer->AppendNoNewLine(
R"({{PROCESSOR_TYPE}}.Builder builder = new {{PROCESSOR_TYPE}}.Builder()
.add(new DequantizeOp(
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
metadata.get{{NAME_U}}QuantizationParams().getScale())))");
if (tensor.normalization_unit >= 0) {
code_writer->AppendNoNewLine(R"(
.add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
}
code_writer->Append(R"(;
return builder.build();)");
}
code_writer->NewLine();
{
const auto block =
AsBlock(code_writer,
"private Object[] preprocessInputs({{INPUT_TYPE_PARAM_LIST}})");
CodeWriter param_list_gen(err);
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("{{NAME}} = {{NAME}}Preprocessor.process({{NAME}});");
SetCodeWriterWithTensorInfo(&param_list_gen, tensor);
param_list_gen.AppendNoNewLine("{{NAME}}.getBuffer(), ");
}
param_list_gen.Backspace(2);
code_writer->AppendNoNewLine("return new Object[] {");
code_writer->AppendNoNewLine(param_list_gen.ToString());
code_writer->Append("};");
}
return true;
}
bool GenerateBuildGradleContent(CodeWriter* code_writer,
const ModelInfo& model_info) {
code_writer->Append(R"(buildscript {
repositories {
google()
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:3.2.1'
}
}
allprojects {
repositories {
google()
jcenter()
flatDir {
dirs 'libs'
}
}
}
apply plugin: 'com.android.library'
android {
compileSdkVersion 29
defaultConfig {
targetSdkVersion 29
versionCode 1
versionName "1.0"
}
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility = '1.8'
targetCompatibility = '1.8'
}
lintOptions {
abortOnError false
}
}
configurations {
libMetadata
}
dependencies {
libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic'
}
task downloadLibs(type: Sync) {
from configurations.libMetadata
into "$buildDir/libs"
rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar"
}
preBuild.dependsOn downloadLibs
dependencies {
compileOnly 'org.checkerframework:checker-qual:2.5.8'
api 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
api files("$buildDir/libs/tensorflow-lite-support-metadata.jar")
implementation 'org.apache.commons:commons-compress:1.19'
})");
return true;
}
bool GenerateAndroidManifestContent(CodeWriter* code_writer,
const ModelInfo& model_info) {
code_writer->Append(R"(<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="{{PACKAGE}}">
</manifest>)");
return true;
}
bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
// TODO(b/158651848) Generate imports for TFLS util types like TensorImage.
code_writer->AppendNoNewLine(R"(
```
import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
// 1. Initialize the Model
{{MODEL_CLASS_NAME}} model = null;
try {
model = {{MODEL_CLASS_NAME}}.newInstance(context); // android.content.Context
} catch (IOException e) {
e.printStackTrace();
}
if (model != null) {
// 2. Set the inputs)");
for (const auto& t : model_info.inputs) {
SetCodeWriterWithTensorInfo(code_writer, t);
if (t.content_type == "image") {
code_writer->Append(R"(
// Prepare tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
Bitmap bitmap = ...;
TensorImage {{NAME}} = TensorImage.fromBitmap(bitmap);
// Alternatively, load the input tensor "{{NAME}}" from pixel values.
// Check out TensorImage documentation to load other image data structures.
// int[] pixelValues = ...;
// int[] shape = ...;
// TensorImage {{NAME}} = new TensorImage();
// {{NAME}}.load(pixelValues, shape);)");
} else {
code_writer->Append(R"(
// Prepare input tensor "{{NAME}}" from an array.
// Check out TensorBuffer documentation to load other data structures.
TensorBuffer {{NAME}} = ...;
int[] values = ...;
int[] shape = ...;
{{NAME}}.load(values, shape);)");
}
}
code_writer->Append(R"(
// 3. Run the model
{{MODEL_CLASS_NAME}}.Outputs outputs = model.process({{INPUTS_LIST}});)");
code_writer->Append(R"(
// 4. Retrieve the results)");
for (const auto& t : model_info.outputs) {
SetCodeWriterWithTensorInfo(code_writer, t);
if (t.associated_axis_label_index >= 0) {
code_writer->SetTokenValue("WRAPPER_TYPE", "List<Category>");
code_writer->Append(
" List<Category> {{NAME}} = "
"outputs.get{{NAME_U}}AsCategoryList();");
} else {
code_writer->Append(
" {{WRAPPER_TYPE}} {{NAME}} = "
"outputs.get{{NAME_U}}As{{WRAPPER_TYPE}}();");
}
}
code_writer->Append(R"(}
```)");
return true;
}
GenerationResult::File GenerateWrapperFile(const std::string& module_root,
const ModelInfo& model_info,
ErrorReporter* err) {
const auto java_path = JoinPath(module_root, "src/main/java");
const auto package_path =
JoinPath(java_path, ConvertPackageToPath(model_info.package_name));
const auto file_path =
JoinPath(package_path, model_info.model_class_name + JAVA_EXT);
CodeWriter code_writer(err);
code_writer.SetIndentString(" ");
SetCodeWriterWithModelInfo(&code_writer, model_info);
if (!GenerateWrapperFileContent(&code_writer, model_info, err)) {
err->Error("Generating Java wrapper content failed.");
}
const auto java_file = code_writer.ToString();
return GenerationResult::File{file_path, java_file};
}
GenerationResult::File GenerateBuildGradle(const std::string& module_root,
const ModelInfo& model_info,
ErrorReporter* err) {
const auto file_path = JoinPath(module_root, "build.gradle");
CodeWriter code_writer(err);
SetCodeWriterWithModelInfo(&code_writer, model_info);
if (!GenerateBuildGradleContent(&code_writer, model_info)) {
err->Error("Generating build.gradle failed.");
}
const auto content = code_writer.ToString();
return GenerationResult::File{file_path, content};
}
GenerationResult::File GenerateAndroidManifest(const std::string& module_root,
const ModelInfo& model_info,
ErrorReporter* err) {
const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml");
CodeWriter code_writer(err);
SetCodeWriterWithModelInfo(&code_writer, model_info);
if (!GenerateAndroidManifestContent(&code_writer, model_info)) {
err->Error("Generating AndroidManifest.xml failed.");
}
return GenerationResult::File{file_path, code_writer.ToString()};
}
GenerationResult::File GenerateDoc(const std::string& module_root,
const ModelInfo& model_info,
ErrorReporter* err) {
std::string lower = model_info.model_class_name;
for (int i = 0; i < lower.length(); i++) {
lower[i] = std::tolower(lower[i]);
}
const auto file_path = JoinPath(module_root, lower + ".md");
CodeWriter code_writer(err);
SetCodeWriterWithModelInfo(&code_writer, model_info);
if (!GenerateDocContent(&code_writer, model_info)) {
err->Error("Generating doc failed.");
}
return GenerationResult::File{file_path, code_writer.ToString()};
}
} // namespace
AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
: CodeGenerator(), module_root_(module_root) {}
GenerationResult AndroidJavaGenerator::Generate(
const Model* model, const std::string& package_name,
const std::string& model_class_name, const std::string& model_asset_path) {
GenerationResult result;
if (model == nullptr) {
err_.Error(
"Cannot read model from the buffer. Codegen will generate nothing.");
return result;
}
const ModelMetadata* metadata = GetMetadataFromModel(model);
if (metadata == nullptr) {
err_.Error(
"Cannot find TFLite Metadata in the model. Codegen will generate "
"nothing.");
return result;
}
details_android_java::ModelInfo model_info = CreateModelInfo(
metadata, package_name, model_class_name, model_asset_path, &err_);
result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_));
result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_));
result.files.push_back(
GenerateAndroidManifest(module_root_, model_info, &err_));
result.files.push_back(GenerateDoc(module_root_, model_info, &err_));
return result;
}
GenerationResult AndroidJavaGenerator::Generate(
const char* model_storage, const std::string& package_name,
const std::string& model_class_name, const std::string& model_asset_path) {
const Model* model = GetModel(model_storage);
return Generate(model, package_name, model_class_name, model_asset_path);
}
std::string AndroidJavaGenerator::GetErrorMessage() {
return err_.GetMessage();
}
} // namespace codegen
} // namespace support
} // namespace tflite