| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" |
| |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/lite/schema/schema_generated.h" |
| |
| namespace tflite { |
| |
| using xla::StatusOr; |
| |
| namespace errors = tensorflow::errors; |
| |
| mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { |
| switch (type) { |
| case tflite::TensorType_FLOAT16: |
| return builder.getF16Type(); |
| case tflite::TensorType_FLOAT32: |
| return builder.getF32Type(); |
| case tflite::TensorType_FLOAT64: |
| return builder.getF64Type(); |
| case tflite::TensorType_INT32: |
| return builder.getIntegerType(32); |
| case tflite::TensorType_UINT8: |
| return builder.getIntegerType(8, /*isSigned=*/false); |
| case tflite::TensorType_INT64: |
| return builder.getIntegerType(64); |
| case tflite::TensorType_STRING: |
| return mlir::TF::StringType::get(builder.getContext()); |
| case tflite::TensorType_BOOL: |
| return builder.getI1Type(); |
| case tflite::TensorType_INT16: |
| return builder.getIntegerType(16); |
| case tflite::TensorType_COMPLEX64: |
| return mlir::ComplexType::get(builder.getF32Type()); |
| case tflite::TensorType_COMPLEX128: |
| return mlir::ComplexType::get(builder.getF64Type()); |
| case tflite::TensorType_INT8: |
| return builder.getIntegerType(8); |
| case tflite::TensorType_UINT64: |
| return builder.getIntegerType(64, /*isSigned=*/false); |
| } |
| } |
| |
| tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { |
| switch (type) { |
| case tflite::TensorType_BOOL: |
| return tensorflow::DT_BOOL; |
| case tflite::TensorType_COMPLEX64: |
| return tensorflow::DT_COMPLEX64; |
| case tflite::TensorType_COMPLEX128: |
| return tensorflow::DT_COMPLEX128; |
| case tflite::TensorType_FLOAT16: |
| return tensorflow::DT_HALF; |
| case tflite::TensorType_FLOAT32: |
| return tensorflow::DT_FLOAT; |
| case tflite::TensorType_FLOAT64: |
| return tensorflow::DT_DOUBLE; |
| case tflite::TensorType_INT8: |
| return tensorflow::DT_INT8; |
| case tflite::TensorType_INT16: |
| return tensorflow::DT_INT16; |
| case tflite::TensorType_INT32: |
| return tensorflow::DT_INT32; |
| case tflite::TensorType_INT64: |
| return tensorflow::DT_INT64; |
| case tflite::TensorType_STRING: |
| return tensorflow::DT_STRING; |
| case tflite::TensorType_UINT8: |
| return tensorflow::DT_UINT8; |
| case tflite::TensorType_UINT64: |
| return tensorflow::DT_UINT64; |
| } |
| } |
| |
| StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) { |
| switch (type) { |
| case tensorflow::DT_BOOL: |
| return tflite::TensorType_BOOL; |
| case tensorflow::DT_COMPLEX64: |
| return tflite::TensorType_COMPLEX64; |
| case tensorflow::DT_HALF: |
| return tflite::TensorType_FLOAT16; |
| case tensorflow::DT_FLOAT: |
| return tflite::TensorType_FLOAT32; |
| case tensorflow::DT_INT8: |
| return tflite::TensorType_INT8; |
| case tensorflow::DT_INT16: |
| return tflite::TensorType_INT16; |
| case tensorflow::DT_INT32: |
| return tflite::TensorType_INT32; |
| case tensorflow::DT_INT64: |
| return tflite::TensorType_INT64; |
| case tensorflow::DT_STRING: |
| return tflite::TensorType_STRING; |
| case tensorflow::DT_UINT8: |
| return tflite::TensorType_UINT8; |
| default: |
| return errors::InvalidArgument("unsupported tensor data type", type); |
| } |
| } |
| |
| mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) { |
| auto type = type_attr.getValue(); |
| auto shaped_type = type.dyn_cast<mlir::ShapedType>(); |
| if (shaped_type) { |
| return shaped_type.getElementType(); |
| } else { |
| return type; |
| } |
| } |
| |
| bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) { |
| auto val_defn_op = val.getDefiningOp(); |
| mlir::TFL::QuantizeOp q_op = |
| llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op); |
| if (!q_op) return true; |
| |
| // Ignore shape details - we're really only trying to |
| // check if quantization is the same. |
| auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr()); |
| auto stripped_qtype = GetShapeStrippedType(qtype_attr); |
| return stripped_src_qtype == stripped_qtype; |
| } |
| |
| } // namespace tflite |