blob: acf9780a5bfcc28ca79415edd4eb96e287217bc0 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/concatenation.h"
#include <stdint.h>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace concatenation {
// This file has two implementation of Concatenation.
enum KernelType {
kReference,
kGenericOptimized,
};
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
int axis = params->axis;
int num_inputs = node->inputs->size;
// The number of dimensions of the input tensors must match, and all
// dimensions except 'axis' must be equal.
const TfLiteTensor* t0 = GetInput(context, node, 0);
TfLiteType input_type = t0->type;
if (axis < 0) axis += t0->dims->size;
TF_LITE_ENSURE(context, axis >= 0);
TF_LITE_ENSURE(context, axis < t0->dims->size);
// TODO(ahentz): These are limitations of our implementation that could be
// removed with a bit of effort.
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
TF_LITE_ENSURE(context,
input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
input_type == kTfLiteInt32 || input_type == kTfLiteInt64);
// Output dimensions will match input dimensions, except 'axis', which
// will be the sum of inputs
int sum_axis = t0->dims->data[axis];
for (int i = 1; i < num_inputs; ++i) {
const TfLiteTensor* t = GetInput(context, node, i);
TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
TF_LITE_ENSURE_EQ(context, t->type, input_type);
for (int d = 0; d < t0->dims->size; ++d) {
if (d == axis) {
sum_axis += t->dims->data[axis];
} else {
TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
}
}
}
TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
for (int d = 0; d < t0->dims->size; ++d) {
output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
}
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
if (input_type == kTfLiteInt8) {
// Make sure there is no re-scaling needed for Int8 quantized kernel. This
// is a restriction we introduced to Int8 kernels.
VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteTensor* t = GetInput(context, node, i);
TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
TF_LITE_ENSURE_EQ(context, t->params.zero_point,
output->params.zero_point);
}
}
if (input_type == kTfLiteInt16) {
// Make sure there all Int16 inputs have a null zero-point.
for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteTensor* t = GetInput(context, node, i);
TF_LITE_ENSURE_EQ(context, t->params.zero_point, 0);
}
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
}
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
int axis = params->axis;
TfLiteTensor* output = GetOutput(context, node, 0);
if (axis < 0) axis += output->dims->size;
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
// allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have
// a model with a Concatenation with fused activation function.
#define TF_LITE_CONCATENATION(scalar) \
{ \
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
op_params.axis = axis; \
op_params.inputs_count = node->inputs->size; \
if (kernel_type == kReference) { \
reference_ops::Concatenation(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \
GetTensorData<scalar>(output)); \
} else { \
optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \
GetTensorData<scalar>(output)); \
} \
}
#define TF_LITE_CONCATENATION_QUANTIZED() \
{ \
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
op_params.axis = axis; \
op_params.input_zeropoint = all_inputs.zero_point(); \
op_params.input_scale = all_inputs.scale(); \
op_params.inputs_count = node->inputs->size; \
op_params.output_zeropoint = output->params.zero_point; \
op_params.output_scale = output->params.scale; \
if (kernel_type == kReference) { \
reference_ops::ConcatenationWithScaling( \
op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), GetTensorData<uint8>(output)); \
} else { \
optimized_ops::ConcatenationWithScaling( \
op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), GetTensorData<uint8>(output)); \
} \
}
switch (output->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
TF_LITE_CONCATENATION(float);
break;
case kTfLiteInt32:
TF_LITE_CONCATENATION(int32);
break;
case kTfLiteUInt8:
TF_LITE_CONCATENATION_QUANTIZED();
break;
case kTfLiteInt8:
TF_LITE_CONCATENATION(int8_t);
break;
case kTfLiteInt64:
TF_LITE_CONCATENATION(int64_t);
break;
case kTfLiteInt16:
TF_LITE_CONCATENATION(int16_t);
break;
default:
context->ReportError(context, "Type '%s' is not supported currently.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
#undef TF_LITE_CONCATENATION_QUANTIZED
#undef TF_LITE_CONCATENATION
return kTfLiteOk;
}
#undef TF_LITE_MACRO_DISPATCH
} // namespace concatenation
TfLiteRegistration* Register_CONCATENATION_REF() {
static TfLiteRegistration r = {
nullptr, nullptr, concatenation::Prepare,
concatenation::Eval<concatenation::kReference>};
return &r;
}
TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
static TfLiteRegistration r = {
nullptr, nullptr, concatenation::Prepare,
concatenation::Eval<concatenation::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_CONCATENATION() {
// TODO(ahentz): It turns out the two versions of Concatenation are almost
// identical, so we should consider removing one.
return Register_CONCATENATION_GENERIC_OPT();
}
} // namespace builtin
} // namespace ops
} // namespace tflite