Add micro CAST into the build
diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD
index b40e92e..79eaa14 100644
--- a/tensorflow/lite/micro/kernels/BUILD
+++ b/tensorflow/lite/micro/kernels/BUILD
@@ -107,6 +107,7 @@
"hard_swish.cc",
"add.cc",
"arg_min_max.cc",
+ "cast.cc",
"ceil.cc",
"circular_buffer.cc",
"comparisons.cc",
diff --git a/tensorflow/lite/micro/kernels/cast.cc b/tensorflow/lite/micro/kernels/cast.cc
index a7f9e68..02fbd47 100644
--- a/tensorflow/lite/micro/kernels/cast.cc
+++ b/tensorflow/lite/micro/kernels/cast.cc
@@ -12,15 +12,12 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <algorithm>
#include <complex>
#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
-#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace {
@@ -35,13 +32,7 @@
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
- TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
- TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
- TF_LITE_ENSURE_EQ(context, output->bytes, input->bytes);
- TF_LITE_ENSURE_EQ(context, output->dims->size, input->dims->size);
- for (int i = 0; i < output->dims->size; ++i) {
- TF_LITE_ENSURE_EQ(context, output->dims->data[i], input->dims->data[i]);
- }
+
return kTfLiteOk;
}
@@ -69,60 +60,45 @@
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
TfLiteTensor* out, int num_elements) {
switch (out->type) {
- case kTfLiteInt64:
- copyCast(in, out->data.i64, num_elements);
- break;
- case kTfLiteInt32:
- copyCast(in, out->data.i32, num_elements);
- break;
- case kTfLiteUInt8:
- copyCast(in, out->data.uint8, num_elements);
+ case kTfLiteInt8:
+ copyCast(in, out->data.int8, num_elements);
break;
case kTfLiteFloat32:
copyCast(in, GetTensorData<float>(out), num_elements);
break;
- case kTfLiteBool:
- copyCast(in, out->data.b, num_elements);
- break;
case kTfLiteComplex64:
copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64),
num_elements);
break;
default:
// Unsupported type.
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
+ TF_LITE_KERNEL_LOG(context, "Output type %s (%d) not supported.",
TfLiteTypeGetName(out->type), out->type);
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input;
- TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
- TfLiteTensor* output;
- TF_LITE_ENSURE_OK(context,
- GetOutputSafe(context, node, kOutputTensor, &output));
- const int num_elements = NumElements(input);
- TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
+ const TfLiteEvalTensor* input =
+ tflite::micro::GetEvalInput(context, node, kInputTensor);
+ TfLiteEvalTensor* output =
+ tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+ int num_elements = MatchingFlatSize(tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorShape(output));
+
switch (input->type) {
- case kTfLiteInt64:
- return copyToTensor(context, input->data.i64, output, num_elements);
- case kTfLiteInt32:
- return copyToTensor(context, input->data.i32, output, num_elements);
- case kTfLiteUInt8:
- return copyToTensor(context, input->data.uint8, output, num_elements);
+ case kTfLiteInt8:
+ return copyToTensor(context, input->data.int8, output, num_elements);
case kTfLiteFloat32:
return copyToTensor(context, GetTensorData<float>(input), output,
num_elements);
- case kTfLiteBool:
- return copyToTensor(context, input->data.b, output, num_elements);
case kTfLiteComplex64:
return copyToTensor(
context, reinterpret_cast<std::complex<float>*>(input->data.c64),
output, num_elements);
default:
// Unsupported type.
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
+ TF_LITE_KERNEL_LOG(context, "Input type %s (%d) not supported.",
TfLiteTypeGetName(intput->type), intput->type);
}
return kTfLiteOk;
diff --git a/tensorflow/lite/micro/kernels/cast_test.cc b/tensorflow/lite/micro/kernels/cast_test.cc
index 2d83ee7..b9ffd29 100644
--- a/tensorflow/lite/micro/kernels/cast_test.cc
+++ b/tensorflow/lite/micro/kernels/cast_test.cc
@@ -75,7 +75,7 @@
const float golden[] = {
std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
- std::complex<float>(5.0f, 15.0f),
+ std::complex<float>(4.0f, 19.0f),
std::complex<float>(6.0f, 16.0f)};
tflite::testing::TestCastComplex64ToComplex64(input_dims, input_values, golden, output_data);
}
diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h
index a93c827..165430a 100644
--- a/tensorflow/lite/micro/kernels/micro_ops.h
+++ b/tensorflow/lite/micro/kernels/micro_ops.h
@@ -31,6 +31,7 @@
// (https://abseil.io/tips/130). Any new ops (or cleanup of existing ops should
// have their Register function declarations in the tflite namespace.
+TfLiteRegistration Register_CAST();
TfLiteRegistration Register_CONV_2D();
TfLiteRegistration Register_DEPTHWISE_CONV_2D();
TfLiteRegistration Register_EXP();
diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h
index c850a38..c59fb3c 100644
--- a/tensorflow/lite/micro/micro_mutable_op_resolver.h
+++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h
@@ -138,6 +138,10 @@
ParsePool);
}
+ TfLiteStatus AddCast() {
+ return AddBuiltin(BuiltinOperator_CAST, Register_CAST(), ParseCast);
+ }
+
TfLiteStatus AddCeil() {
return AddBuiltin(BuiltinOperator_CEIL, tflite::ops::micro::Register_CEIL(),
ParseCeil);
diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile
index e0009ad..4809574 100644
--- a/tensorflow/lite/micro/tools/make/Makefile
+++ b/tensorflow/lite/micro/tools/make/Makefile
@@ -259,6 +259,7 @@
tensorflow/lite/micro/kernels/activations_test.cc \
tensorflow/lite/micro/kernels/add_test.cc \
tensorflow/lite/micro/kernels/arg_min_max_test.cc \
+tensorflow/lite/micro/kernels/cast_test.cc \
tensorflow/lite/micro/kernels/ceil_test.cc \
tensorflow/lite/micro/kernels/circular_buffer_test.cc \
tensorflow/lite/micro/kernels/comparisons_test.cc \
@@ -304,6 +305,7 @@
tensorflow/lite/micro/kernels/activations.cc \
tensorflow/lite/micro/kernels/add.cc \
tensorflow/lite/micro/kernels/arg_min_max.cc \
+tensorflow/lite/micro/kernels/cast.cc \
tensorflow/lite/micro/kernels/ceil.cc \
tensorflow/lite/micro/kernels/circular_buffer.cc \
tensorflow/lite/micro/kernels/comparisons.cc \