blob: a35a6db9ad48f657c8700f135197d0b5d465dd73 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <functional>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace ops {
namespace custom {
namespace ragged {
namespace ragged_range {
namespace {
constexpr int kInputStarts = 0;
constexpr int kInputLimits = 1;
constexpr int kInputDeltas = 2;
constexpr int kOutputNestedSplits = 0;
constexpr int kOutputDenseValues = 1;
TfLiteIntArray* IntArrayFromInt(int x) {
TfLiteIntArray* result = TfLiteIntArrayCreate(1);
result->data[0] = x;
return result;
}
// Returns the number of elements in the specified range.
template <typename T, typename SPLITS_TYPE>
SPLITS_TYPE RangeSize(T start, T limit, T delta) {
if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
return 0;
}
// The following is copied from tensorflow::RangeOp::Compute().
return (
std::is_integral<T>::value
? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
: std::ceil(std::abs((limit - start) / delta)));
}
template <typename T, typename SPLITS_TYPE>
TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor& input_starts =
context->tensors[node->inputs->data[kInputStarts]];
TfLiteTensor& input_limits =
context->tensors[node->inputs->data[kInputLimits]];
TfLiteTensor& input_deltas =
context->tensors[node->inputs->data[kInputDeltas]];
// Determine which tensors we need to broadcast.
const bool broadcast_starts = NumElements(&input_starts) == 1;
const bool broadcast_limits = NumElements(&input_limits) == 1;
const bool broadcast_deltas = NumElements(&input_deltas) == 1;
// nrows (number of output rows) is the size of the non-broadcast inputs,
// or 1 if all inputs are scalars.
std::vector<int> in_sizes;
if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]);
if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]);
if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]);
if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes),
std::not_equal_to<>()) != std::end(in_sizes)) {
context->ReportError(
context,
"Invalid argument: starts, limits, and deltas must have the "
"same shape");
return kTfLiteError;
}
const SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes.front();
const T* starts = GetTensorData<T>(&input_starts);
const T* limits = GetTensorData<T>(&input_limits);
const T* deltas = GetTensorData<T>(&input_deltas);
TfLiteTensor& rt_nested_splits_out =
context->tensors[node->outputs->data[kOutputNestedSplits]];
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, &rt_nested_splits_out,
IntArrayFromInt(nrows + 1)));
SPLITS_TYPE* rt_nested_splits =
GetTensorData<SPLITS_TYPE>(&rt_nested_splits_out);
rt_nested_splits[0] = 0;
for (int row = 0; row < nrows; ++row) {
const T start = broadcast_starts ? starts[0] : starts[row];
const T limit = broadcast_limits ? limits[0] : limits[row];
const T delta = broadcast_deltas ? deltas[0] : deltas[row];
if (delta == 0) {
context->ReportError(context, "Invalid argument: Requires delta != 0");
return kTfLiteError;
}
rt_nested_splits[row + 1] =
rt_nested_splits[row] + RangeSize<T, SPLITS_TYPE>(start, limit, delta);
}
const SPLITS_TYPE nvals = rt_nested_splits[nrows];
TfLiteTensor& rt_dense_values_out =
context->tensors[node->outputs->data[kOutputDenseValues]];
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, &rt_dense_values_out,
IntArrayFromInt(nvals)));
T* rt_dense_values = GetTensorData<T>(&rt_dense_values_out);
int value_index = 0;
for (int row = 0; row < nrows; ++row) {
const SPLITS_TYPE row_size =
rt_nested_splits[row + 1] - rt_nested_splits[row];
T value = broadcast_starts ? starts[0] : starts[row];
const T delta = broadcast_deltas ? deltas[0] : deltas[row];
for (SPLITS_TYPE i = 0; i < row_size; ++i) {
rt_dense_values[value_index++] = value;
value += delta;
}
}
return kTfLiteOk;
}
template <typename SPLITS_TYPE>
TfLiteStatus EvalSplitsT(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor& rt_dense_values_out =
context->tensors[node->outputs->data[kOutputDenseValues]];
switch (rt_dense_values_out.type) {
case kTfLiteInt32:
return EvalT<int32_t, SPLITS_TYPE>(context, node);
case kTfLiteInt64:
return EvalT<int64_t, SPLITS_TYPE>(context, node);
case kTfLiteFloat32:
return EvalT<float, SPLITS_TYPE>(context, node);
case kTfLiteFloat64:
return EvalT<double, SPLITS_TYPE>(context, node);
default:
context->ReportError(context,
"Invalid argument: Not supported VALUES type");
return kTfLiteError;
}
}
} // namespace
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Set outputs dynamic.
TfLiteTensor& nested_splits =
context->tensors[node->outputs->data[kOutputNestedSplits]];
SetTensorToDynamic(&nested_splits);
TfLiteTensor& dense_values =
context->tensors[node->outputs->data[kOutputDenseValues]];
SetTensorToDynamic(&dense_values);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor& rt_nested_splits_out =
context->tensors[node->outputs->data[kOutputNestedSplits]];
switch (rt_nested_splits_out.type) {
case kTfLiteInt32:
return EvalSplitsT<int32_t>(context, node);
case kTfLiteInt64:
return EvalSplitsT<int64_t>(context, node);
default:
context->ReportError(context,
"Invalid argument: Not supported ROW_SPLITS type");
return kTfLiteError;
}
}
} // namespace ragged_range
} // namespace ragged
TfLiteRegistration* Register_RAGGED_RANGE() {
static TfLiteRegistration r = {nullptr /*Initialize*/, nullptr /*Free*/,
ragged::ragged_range::Prepare,
ragged::ragged_range::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite