PR #44851: TensorFlow and TensorFlow Lite to TOSA Legalizations
(This re-applies the change without some pattern-exclusion debugging code in legalize_tfl which was tripping the thread sanitizer)
PiperOrigin-RevId: 344859502
Change-Id: I88d0dca8bdaf93a646f92362e8896fd2e11a5306
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 26cb27e..f839acd 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -112,6 +112,8 @@
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
+ "//tensorflow/compiler/mlir/tosa:tf_tosa_passes",
+ "//tensorflow/compiler/mlir/tosa:tfl_tosa_passes",
],
)
diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD
index 68f8bf5..8082228 100644
--- a/tensorflow/compiler/mlir/tosa/BUILD
+++ b/tensorflow/compiler/mlir/tosa/BUILD
@@ -3,6 +3,9 @@
# https://developer.mlplatform.org/w/tosa/
# https://github.com/llvm/llvm-project/blob/master/mlir/docs/Dialects/TOSA.md
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
+load("//third_party/mlir:tblgen.bzl", "gentbl")
+
package(
default_visibility = [":internal"],
licenses = ["notice"], # Apache 2.0
@@ -10,9 +13,9 @@
package_group(
name = "internal",
- includes = [
- ],
+ includes = ["//third_party/mlir:subpackages"],
packages = [
+ "//tensorflow/compiler/mlir/...",
],
)
@@ -25,11 +28,304 @@
],
)
-# This can be removed once the package contains real targets.
-# It is merely needed to satisfy Bazel's need to expand exclusions
-# to something.
+config_setting(
+ name = "enable-build",
+ values = {"define": "build-tosa=true"},
+ visibility = ["//visibility:public"],
+)
+
+filegroup(
+ name = "tosa_ops_td_files",
+ srcs = [
+ "@llvm-project//mlir:TdFiles",
+ ],
+ # TODO: Switch to pruned list of TD files once build file changes land.
+ # srcs = [
+ # "@llvm-project//mlir:TosaDialectTdFiles",
+ # ],
+ compatible_with = get_compatible_with_cloud(),
+)
+
+gentbl(
+ name = "tosa_pass_inc_gen",
+ compatible_with = get_compatible_with_cloud(),
+ tbl_outs = [
+ (
+ "-gen-pass-decls -name LegalizeTosa",
+ "transforms/passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "transforms/passes.td",
+ td_srcs = [
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
+
+gentbl(
+ name = "tosa_legalize_tf_inc_gen",
+ compatible_with = get_compatible_with_cloud(),
+ tbl_outs = [
+ (
+ "-gen-rewriters",
+ "transforms/tf_legalize_patterns.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "transforms/tf_legalize_patterns.td",
+ td_srcs = [
+ ":tosa_ops_td_files",
+ "@llvm-project//mlir:StdOpsTdFiles",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
+ ],
+)
+
+gentbl(
+ name = "tosa_legalize_tfl_inc_gen",
+ compatible_with = get_compatible_with_cloud(),
+ tbl_outs = [
+ (
+ "-gen-rewriters",
+ "transforms/tfl_legalize_patterns.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "transforms/tfl_legalize_patterns.td",
+ td_srcs = [
+ ":tosa_ops_td_files",
+ "@llvm-project//mlir:StdOpsTdFiles",
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
+ ],
+)
+
cc_library(
- name = "placeholder",
- srcs = [],
- hdrs = [],
+ name = "tosa_legalize_tf",
+ srcs = [
+ "transforms/legalize_tf.cc",
+ "transforms/tf_legalize_patterns.inc",
+ ],
+ hdrs = [
+ "transforms/legalize_common.h",
+ "transforms/legalize_utils.h",
+ "transforms/passes.h",
+ "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
+ ],
+ compatible_with = get_compatible_with_cloud(),
+ deps = [
+ ":tosa_legalize_tf_inc_gen",
+ ":tosa_pass_inc_gen",
+ "//tensorflow/compiler/mlir/tensorflow",
+ "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_all_ops_inc_gen",
+ "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels:conv_grad_shape_utils",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Analysis",
+ "@llvm-project//mlir:Dialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tosa_legalize_tfl",
+ srcs = [
+ "transforms/legalize_tfl.cc",
+ "transforms/tfl_legalize_patterns.inc",
+ ],
+ hdrs = [
+ "transforms/legalize_common.h",
+ "transforms/legalize_utils.h",
+ "transforms/passes.h",
+ "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
+ "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
+ ],
+ compatible_with = get_compatible_with_cloud(),
+ deps = [
+ ":tosa_legalize_tfl_inc_gen",
+ ":tosa_pass_inc_gen",
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen",
+ "//tensorflow/compiler/mlir/lite:validators",
+ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
+ "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+ "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels:conv_grad_shape_utils",
+ "//tensorflow/lite/schema:schema_fbs",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Analysis",
+ "@llvm-project//mlir:Dialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tosa_legalize_common",
+ srcs = [
+ "transforms/legalize_common.cc",
+ "transforms/legalize_utils.cc",
+ "transforms/tf_legalize_patterns.inc",
+ ],
+ hdrs = [
+ "transforms/legalize_common.h",
+ "transforms/legalize_utils.h",
+ "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
+ ],
+ compatible_with = get_compatible_with_cloud(),
+ deps = [
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen",
+ "//tensorflow/compiler/mlir/lite:validators",
+ "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_all_ops_inc_gen",
+ "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels:conv_grad_shape_utils",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Analysis",
+ "@llvm-project//mlir:Dialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tosa_fuse_bias_tf",
+ srcs = [
+ "transforms/fuse_bias_tf.cc",
+ ],
+ hdrs = [
+ "transforms/passes.h",
+ "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
+ ],
+ compatible_with = get_compatible_with_cloud(),
+ deps = [
+ ":tosa_legalize_common",
+ ":tosa_pass_inc_gen",
+ "//tensorflow/compiler/mlir/tensorflow",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tosa_convert_tfl_uint8",
+ srcs = [
+ "transforms/convert_tfl_uint8.cc",
+ ],
+ hdrs = [
+ "transforms/passes.h",
+ "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
+ ],
+ compatible_with = get_compatible_with_cloud(),
+ deps = [
+ ":tosa_legalize_common",
+ ":tosa_pass_inc_gen",
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tosa_pipelines",
+ srcs = [
+ "tosa_passpipes.cc",
+ ],
+ hdrs = [
+ "tosa_passpipes.h",
+ "transforms/passes.h",
+ "transforms/register_passes.h",
+ ],
+ compatible_with = get_compatible_with_cloud(),
+ deps = [
+ ":tosa_pass_inc_gen",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tf_tosa_passes",
+ srcs = [
+ "tf_tosa_pipeline.cc",
+ ],
+ hdrs = [
+ ],
+ compatible_with = get_compatible_with_cloud(),
+ deps = [
+ ":tosa_fuse_bias_tf",
+ ":tosa_legalize_common",
+ ":tosa_legalize_tf",
+ ":tosa_pipelines",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tfl_tosa_passes",
+ srcs = [
+ "tfl_tosa_pipeline.cc",
+ ],
+ hdrs = [
+ ],
+ compatible_with = get_compatible_with_cloud(),
+ deps = [
+ ":tosa_convert_tfl_uint8",
+ ":tosa_legalize_common",
+ ":tosa_legalize_tfl",
+ ":tosa_pipelines",
+ ],
+ alwayslink = 1,
)
diff --git a/tensorflow/compiler/mlir/tosa/g3doc/legalization.md b/tensorflow/compiler/mlir/tosa/g3doc/legalization.md
new file mode 100644
index 0000000..95726ae
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/g3doc/legalization.md
@@ -0,0 +1,4241 @@
+# TOSA Lowerings
+
+## Introduction
+
+### Overview
+
+This document provides pseudo-code lowerings from TensorFlow and TensorFlow Lite
+MLIR Dialects (https://www.tensorflow.org/mlir/dialects) to the TOSA Dialect
+(https://mlir.llvm.org/docs/Dialects/TOSA/).
+
+The documentation is a work-in-progress: sections with missing legalizations are
+in the process of being written.
+
+## Syntax
+
+The pseudo-code syntax used in this document is described below.
+
+### Value
+
+In pseudo-code, symbol starting with "%" indicates it’s a value. A value is
+evaluated by an operator at run time, and operator can consume and can only
+consume a list of values as operands. Note value’s tensor type is determined at
+compile time. Only the evaluation happens at run time One can easily construct a
+data flow subgraph by looking at the producer/consumer.
+
+### Tensor Type
+
+Tensor type is an attribute determined by legalization at compile time,
+describing the shape and element data type. It’s noted as tensor<shape,
+dtype>, or shorthanded as tensor<%t.type>
+
+### Operator Prototype
+
+In pseudocode an TOSA operator is prototyped as following format.
+
+%<output\_value> = tosa.<OPERATOR>(%<input\_value>)
+{<attribute = …} : (tensor<input\_shape, input\_type>, …) →
+tensor<output\_shape, output\_type>
+
+### Value Attributes
+
+For the purposes of brevity and clarity in this document, the pseudocode allows
+the following notation on value attribute.
+
+Shorthand | Description
+------------------- | ---------------------------------------------------
+`%t.shape` | Shape vector for the tensor
+`%t.shape[i]` | Size of dimension i for the tensor
+`%t.rank` | Rank of the tensor
+`%t.dtype` | Datatype of the tensor
+`%t.dtype.scale` | Quantized scaling parameter (double)
+`%t.dtype.zp` | Quantized zero-point (int64)
+`%t.dtype.signed` | Boolean indicating the type is signed
+`%t.dtype.num_bits` | Number of bits in the datatype
+`%t.num_elements` | Number of elements in the tensor
+`%t.type` | Tuple of `tensor<%t.shape, %t.dtype>`
+`%t.size` | For tensor lists: the number of tensors in the list
+
+### Tensor Dimension Shorthand
+
+Where the TOSA Specification allows the use of named dimensions, the following
+names may be used.
+
+Name | Description
+---- | --------------------
+`N` | Batch dimension
+`H` | Height dimension
+`W` | Width dimension
+`C` | Channel dimension
+`M` | Depthwise multiplier
+
+Each of these may be prefixed with `I` for the input dimension or `O` for the
+output dimension or `K` for kernel dimensions.
+
+## Common Legalization Functions
+
+The following pseudocode helper functions are used to cannonicalize arguments
+from different frameworks to the TOSA dialect.
+
+### .as_constant(): Matched as Constant
+
+Wherever %tensor.as_constant() is specified, a constant vector will be created
+to hold the value in the %tensor at compile time. This only succeeds if %tensor
+is fed by a constant type operator. If constant matching fails, the lowering
+will fail and be terminated.
+
+## Common Legalization Functions
+
+The following pseudo-code helper functions are used to cannonicalize arguments
+from different frameworks to the TOSA dialect.
+
+### apply_rank_broadcast()
+
+```
+// Applies a TOSA-lowering broadcast to tensor 'a' with respect
+// to sibling tensor 'b'.
+//
+// The resulting tensors will have matching ranks. TOSA broadcast
+// operators accept degenerate broadcasting (1 vs non-1)
+Value apply_rank_broadcast(Value %a, Value %b) {
+
+ if (%a.rank < %b.rank) {
+
+ auto new_a_shape = [1] * %b.rank
+
+ if (%a.rank <= 1) {
+ new_a_shape[%b.rank - 1] = a.shape[0]
+ %out = tosa.RESHAPE(%a, new_a_shape)
+ return %out
+ }
+
+ // Working from the right on both tensors, try to match all of a's
+ // dimensions to b
+ int adim = %a.rank - 1
+ for (int bdim = b.rank() - 1; bdim >= 0 && adim >= 0; bdim--) {
+ if (%a.shape[adim] == %b.shape[bdim] ||
+ %a.shape[adim] == 1 ||
+ %b.shape[bdim] == 1) {
+ new_a_shape[bdim] = a.shape[adim]
+ adim--
+ }
+
+ assert(adim == -1)
+ assert(product(a.shape) == product(new_a_shape))
+
+ %out = tosa.RESHAPE(%a) {new_shape=new_a_shape} (tensor<%a.type>) -> tensor<new_a_shape, %a.dtype>
+ } else {
+ %out = %a
+ }
+
+ return %out;
+}
+```
+
+### get_padding_values_from_explicit_pad_attr()
+
+```
+vector<int64_t> get_padding_values_from_explict_pad_attr(vector<int64_t> explicit_pad,
+ tensorflow::TensorFormat data_format_tf)
+{
+ int64_t pad_before, pad_after
+ vector<int64_t> computed_paddings
+
+ for (int i = 0; i < 2; i++) {
+ int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf, i)
+ pad_before = explicit_pad[dim * 2]
+ pad_after = explicit_pad[dim * 2 + 1]
+ computed_paddings.push_back(pad_before)
+ computed_paddings.push_back(pad_after)
+ }
+
+ return computed_paddings
+}
+```
+
+### get_padding_values_from_pad_type()
+
+Calculate explicit padding array based on pad type
+
+```
+vector<int64_t> get_padding_values_from_pad_type(tensorflow::Padding padding, tensorflow::TensorFormat data_format,
+ uint32_t first_filter_spatial_dim, type input_type, type filter_type
+ vector strides, vector dilations)
+{
+ assert(padding != tensorflow::Padding::EXPLICIT);
+
+ vector<int64_t> computed_padding;
+
+ // Padding over H and W dimensions
+ for (int i = 0; i < 2; i++) {
+ int ifm_dim = get_tensor_spatial_dim_index(4, data_format, i);
+
+ int filter_dim = first_filter_spatial_dim + i;
+
+ int dim_dilation = dilations[ifm_dim];
+ int dim_stride = strides[ifm_dim];
+
+ int64_t op_size, pad_before_tf, pad_after_tf;
+
+ tensorflow::GetWindowedOutputSizeVerboseV2(input_type.shape[ifm_dim], filter_type.shape[filter_dim],
+ dim_dilation, dim_stride, padding,
+ // Outputs
+ &op_size, &pad_before_tf, &pad_after_tf);
+ computed_paddings.push_back(pad_before_tf);
+ computed_paddings.push_back(pad_after_tf);
+ }
+
+ return computed_paddings;
+}
+```
+
+### positive_axis()
+
+```
+// Cannonicalize scalar axis attributes to a scalar positive axis attribute
+int32_t positive_axis(int32_t axis, int32_t rank)
+{
+ if (axis < 0)
+ axis += rank;
+
+ return axis;
+}
+```
+
+### compute_scale_32()
+
+```
+void compute_scale_32(double scale, int32_t& multiplier, int32_t& shift)
+{
+ /* Generates mantissa and shift values where mantissa is in [-1.0,-0.5] or
+ [0.5, 1.0] such that
+ multiplier = mantissa*2^shift */
+
+ const double mantissa = std::frexp(scale, &shift);
+ auto shifted_m = std::round(mantissa * (int64_t(1) << 31));
+
+ assert(shifted_m <= (int64_t(1) << 31)); // can't be greater that 1.0
+ if (shifted_m == (int64_t(1) << 31)) {
+ shifted_m /= 2;
+ shift++;
+ }
+ // TOSA expect right shift to be positive, and embed (1 << 31) into right
+ // shift bits
+ shift = (-shift) + 31;
+
+ assert(shifted_m <= std::numeric_limits<int32_t>::max());
+
+ multiplier = static_cast<int32_t>(shifted_m);
+
+}
+```
+
+### lower_batch_to_space_nd_op()
+
+```
+Value lower_batch_to_space_nd_op(Value %input, Value %block_shape, Value %crops, shape_t output_shape)
+{
+
+ vector <size_t> block_shape(%block_shape.rank)
+ vector std::pair<size_t, size_t> crops_arr
+
+ size_t remaining_shape_rank = %input.rank - %block.rank - 1
+ size_t crops_dim = %crops.shape[0]
+
+ for (int i = 0; i < crops_dim; i++) {
+ crops[i] = std::make_pair(%crops.as_constant()[i * crops_dim + 0],
+ %crops.as_constant()[i * crops_dim + 1])
+ }
+
+ // Step 1: Reshape input to
+ // [block_shape[0],
+ // ...
+ // [block_shape[M-1],
+ // [batch / prod(block_shape)]
+ // [input_shape[1],
+ // ...
+ // [input_shape[N-1]
+
+ vector <size_t> a1_shape(%block.rank + %input.rank)
+
+ for (int i = 0; i < %block.rank; i++) {
+ a1_shape[i] = %block.shape[i]
+ }
+
+ a1_shape[%block.rank] = %input.shape.[0] / %block.num_elements
+
+ for (int i = 1; i < %input.rank; i++) {
+ a1_shape[i + %block.rank] = %input.shape[i]
+ }
+
+ // Step 2. Permute to shape:
+ // [ batch / prod(block_shape) ],
+ // [ input_shape[1] ], [ block_shape[0] ]
+ // ...
+ // [ input_shape[M] ], [ block_shape[M-1]
+ // + remaining_input_shapes input_shape[M+1 .. N-1]
+ vector <size_t> a2_perm(%block.rank + %input.rank)
+
+ a2_perm[0] = %block.rank
+ for (int i = 0; i < %block.rank; i++) {
+ a2_perm[1 + i * 2 + 0] = %block.rank + 1 + i
+ a2_perm[1 + i * 2 + 1] = i
+ }
+
+ // Step 3. Reshape to
+ // [ batch / prod(block_shape) ],
+ // [input_shape[1] * block_shape[0] ],
+ // ..
+ // [input_shape[M * block_shape[M-1],
+ // + remaining input shapes [input_shape[M+1.. N-1]]
+ vector <size_t> a3_shape(%input.rank)
+
+ %a3_shape[0] = %input.shape[0] / %block.num_elements
+ for (int i = 0; i < %block.rank; i++) {
+ a3_shape[i + 1] = %input.shape[i + 1] * %block.shape[i]
+ }
+
+ for (int i = 0; remaining_block_shape; i++) {
+ a3_shape[1 + %block.rank + 1] = %input.shape[%block.rank + 1 + i]
+ }
+
+ // Step 4 Crop the start/end dimensions using slice
+ vector <size_t> a4_begin(%input.rank), a4_size(%input.rank)
+
+ for (int i = 0; i < %input.rank; i++) {
+ if (i == 0 || i > crop_dims) {
+ a4_begin[i] = 0
+ a4_size[i] = output_shape[i]
+ } else {
+ a4_begin[i] = %crops[i-1].first
+ a4_size[i] = crops[i - 1].first - crops[i - 1].second
+ }
+ }
+
+ %a1_reshape = tosa.RESHAPE(%input) {new_shape=a1_shape} : (tensor<%input.type>) -> tensor<a1_shape, %input.dtype>
+ %a2_transpose = tosa.TRANSPOSE(%a1_reshape) {perms=a2_perm} : (tensor<%a1_reshape.type>) -> tensor<%a2_transpose.type>
+ %a3_reshape = tosa.RESHAPE(%a2_transpose) {new_shape=a3_shape} : (tensor<%a2_transpose.type>) -> tensor<a3_shape, %input.dtype>
+ %output = tosa.SLICE(%a3_reshape) {begin=a4_begin, size=a4_size} : (tensor<%a3_reshape.type>) -> tensor<a4_size, %input.dtype>
+
+ return %output
+}
+```
+
+### lower_concatv2_op()
+
+```
+Value lower_concatv2_op(Value %values, int32_t axis)
+{
+ int32_t tosa_axis = positive_axis(axis)
+
+ assert(%values.size >= 2)
+
+ // Convert scalar inputs to a tensor
+ if (%values:0.size == 0) {
+ for (int i = 0; i < %values.size; i++) {
+ %values:i = tosa.RESHAPE(%values:i) {new_shape=1} : (tensor<{}, %values:i.dtype>) -> tensor<{1}, %values:i>
+ }
+ }
+
+ // TODO: rescale
+
+ %concat_op = tosa.CONCAT(%values:0, %values:1) {axis=tosa_axis} : (tensor<%values:0.type>, tensor<%values:1.type>) -> tensor<%concat_op.type>
+
+ for (int i = 2; i < %values.size; i++) {
+ // TODO: rescale
+ %concat_op = tosa.CONCAT(%concat_op, %values:i) {axis=tosa_axis} : (tensor<%concat_op.type>, tensor<%values:i.type>) -> tensor<%concat_op.type>
+ }
+
+ return %concat_op
+}
+```
+
+### lower_depth_to_space_op()
+
+```
+Value lower_depth_to_space_op(Value %input, size_t block_size[], Format_t data_format)
+{
+ assert(data_format == 'NHWC')
+
+ vector <size_t> a2_shape = {%input.shape[0],
+ %input.shape[1],
+ %input.shape[2],
+ block_size[0],
+ block_size[1],
+ %input.shape[3] / (block_size[0] * block_size[1])}
+
+ vector <size_t> a4_shape = {%input.shape[0],
+ %input.shape[1] * block_size[0],
+ %input.shape[2] * block_size[1],
+ %input.shape[3] / (block_size[0] * block_size[1])}
+
+ %a2_reshape = tosa.RESHAPE(%input) {new_shape=a2_shape} : (tensor<%input.type>) -> tensor<a2_shape, %input.dtype)
+ %a3_transpose = tosa.TRANSPOSE(%a2_reshape) {perms={0, 1, 3, 2, 4, 5}} : (tensor<%a2_reshape.type>) -> tensor<%a3_transpose.type>
+ %output = tosa.RESHAPE(%a3_transpose) {new_shape=a4_shape} : (tensor<%a3_transpose.type>) -> tensor<a4_shape, %input.dtype>
+
+ return %output
+}
+```
+
+### lower_elu_op()
+
+```
+Value lower_elu_op(Value %value)
+{
+ // elu(x) = x < 0 ? (exp(x) - 1) : x
+ // Create constants for 0/1 and reshape to match the rank
+ // of %value
+ %one_const = tosa.CONST({1}) : () -> tensor<{1}, float>
+ %zero_const = tosa.CONST({0}) : () -> tensor<{1}, float>
+
+ vector bcast_shape
+ for (int i = 0; i < %value.rank; i++) {
+ bcast_shape.push_back(1)
+ }
+
+ %one_reshape = tosa.RESHAPE(%one_const) {new_shape=bcast_shape} : (tensor<%one_const.type>) -> tensor<%one_reshape.type>
+ %zero_reshape = tosa.RESHAPE(%zero_const) {new_shape=bcast_shape} : (tensor<%zero_const.type>) -> tensor<%zero_reshape.type>
+
+ %exp_in = tosa.EXP(%value) : (tensor<%value.type>) -> tensor<%exp_in.type>
+ %sub = tosa.SUB(%exp_in, %one_reshape) : (tensor<%exp_in.type>, tensor<%one_reshape.type>) -> tensor<%sub.type>
+ %ge = tosa.GREATER_EQUAL(%value, %zero_reshape) : (tensor<%value.type>, tensor<%zero_reshape.type>) -> tensor<%value.shape, bool>
+ %output = tosa.SELECT(%ge, %value, %sub) : (tensor<%ge.type>, tensor<%value.type>, %tensor<%sub.type>) -> tensor<%output.type>
+ return %output
+}
+```
+
+### lower_expand_dims()
+
+```
+Value lower_expand_dims(Value %input, int32_t axis)
+{
+ vector<size_t> reshape_dims
+
+ if (axis < 0 || axis >= %input.rank) {
+ // Insert at the end of the tensor
+ axis += %input.rank
+ for (int i = 0; i < input.rank; i++) {
+ reshape_dims.push_back(%input.shape[i])
+ }
+ } else {
+ for (int i= 0 ; i < %input.rank; i++) {
+ if (i == axis) {
+ reshape_dims.push_back(1)
+ }
+ reshape_dims.push_back(%input.shape[i])
+ }
+ }
+
+ %output = tosa.RESHAPE(%input) {new_shape=reshape_dims} (tensor<%input.type>) -> tensor<%output.type>
+ return %output
+}
+```
+
+### lower_fake_quant_op()
+
+```
+Value lower_fake_quant_op(Value %inputs, type output_type, double min, double max,
+ int64_t num_bits, bool narrow_range)
+{
+ assert(num_bits == 8 || num_bits == 16)
+
+ int64_t qmax = (1L << (num_bits - 1)) - 1;
+ int64_t qmin = -(1L << (num_bits - 1))
+
+ if (narrow_range) {
+ qmin = qmin + 1
+ }
+
+ double scale = (max - min) / double(qmax - qmin)
+
+ int64_t zeropoint = (int64_t)std::round((-min) / scale + double(qmin))
+
+ %quantized = lower_quantized_op(%inputs.type, %inputs, 1.0 / scale, zeropoint)
+
+ %dequantized = lower_dequantized_op(output_type, %quantized_op, scale, zeropoint)
+
+ return %dequantized
+}
+```
+
+### lower_floor_div()
+
+```
+Value lower_floor_div(Value %lhs, Value %rhs)
+{
+ %recip = tosa.RECIPROCAL(%rhs) : (tensor<%rhs.type>) -> tensor<%recip.type>
+ %mul = tosa.MUL(%lhs, %recip) : (tensor<%lhs.type>, tensor<%recip.type>) -> tensor<%mul.type>
+ %output = tosa.FLOOR(%mul) : (tensor<%mul.type>) -> tensor<%output.type>
+
+ return %output
+}
+```
+
+### lower_floor_mod()
+
+```
+Value lower_floor_mod(Value %lhs, Value %rhs)
+{
+ %recip = tosa.RECIPROCAL(%rhs) : (tensor<%rhs.type>) -> tensor<%recip.type>
+ %mul = tosa.MUL(%lhs, %recip) : (tensor<%lhs.type>, tensor<%recip.type>) -> tensor<%mul.type>
+ %floor = tosa.FLOOR(%mul) : (tensor<%mul.type>) -> tensor<%floor.type>
+ %output = tosa.SUB(%mul, %floor) : (tensor<%mul.type>, tensor<%floor.type>) -> tensor<%output.type>
+ return %output
+}
+```
+
+### lower_quantized_op()
+
+```
+Value lower_quantized_op(type output_type, Value %inputs, double scale, int64_t zeropoint)
+{
+ // TODO: fill in this function
+}
+```
+
+### lower_dequantized_op()
+
+```
+Value lower_dequantized_op(type output_type, Value %inputs, double scale, int64_t zeropoint)
+{
+ // TODO: fill in this function
+}
+```
+
+### lower_log_softmax_op()
+
+```
+Value lower_log_softmax_op(Value %logits)
+{
+ %op1 = tosa.EXP(%logits) : (tensor<%logits.type>) -> tensor<%op1.type>
+ %op2 = tosa.REDUCE_SUM(%logits) {axis=(%logits.rank-1)} : (tensor<%logits.type>) -> tensor<%op2.type>
+ %op3 = tosa.RECIPROCAL(%op2) : (tensor<%op2.type>) -> tensor<%op3.type>
+ %op4 = tosa.MUL(%op1, %op3) : (tensor<%op1.type>, tensor<%op3.type>) -> tensor<%op4.type>
+ %op5 = tosa.LOG(%op4) : (tensor<%op4.type>) -> tensor<%op5.type>
+
+ return %op5
+}
+```
+
+### lower_pack_op()
+
+```
+Value lower_pack_op(Value %input[], size_t axis)
+{
+ size_t concat_axis = positive_axis(axis)
+
+ size_t input_tensor_rank = %input[0].rank
+
+ // Convert any rank 0 to rank 1 with reshape
+ if (input_tensor_rank == 0) {
+ for (int i = 0; i < %input.size; i++) {
+ %input[i] = tosa.RESHAPE(%input[i], {1})
+ }
+ }
+
+ vector<size_t> output_shape
+ for (int i = 0; i < input_tensor_rank; i++) {
+ output_shape.push_back(%input[0].shape[i]
+ }
+
+ output_shape[concat_axis] = output_shape[concat_axis] * %input.size
+
+ // First pair of tensors
+ %concat = tosa.CONCAT(%input[0], %input[1]) {axis=concat_axis} : (tensor<%input[0].type>, tensor<%input[1].type>) -> tensor<%concat.type>
+
+ // Remaining tensors
+ for (int i = 2; i < %input.size; i++) {
+ %concat = tosa.CONCAT(%concat, %input[i]) {axis=concat_axis} : (tensor<%concat.type>, tensor<%input[i].type>) -> tensor<%concat.type>
+ }
+
+ if (input_tensor_rank == 0) {
+ // No reshape needed for rank 0, already done
+ %output = %concat
+ } else
+
+ %reshape = tosa.RESHAPE(%concat) {new_shape=output_shape} : (tensor<%concat.type>) -> tensor<%reshape.type>
+
+ if (concat_axis == input_tensor_rank) {
+ // Output shape is [A, B, C, .. n] in this case,
+ // need to reshape to [N, A, B, C, ..] with perm [1, 2, 3, .. 0]
+ concat_axis = 0
+
+ vector <size_t> perms
+ for (int i = 0; i < %input[0].rank; i++)
+ perms.push_back(i + 1)
+ perms.push_back(0)
+
+ %output = tosa.TRANSPOSE(%reshape) {perms=perms} : (tensor<%reshape.type>) -> tensor<%output.type>
+ } else {
+ %output = %reshape
+ }
+
+ return %output
+}
+```
+
+### lower_reduce_op()
+
+```
+Value lower_reduce_op<tosa_op_t OP>(Value %input, shape_t output_shape, Value %axes, bool keep_dims)
+{
+
+ vector axes_vec = %axes.as_constant();
+
+ // Special case of no axes means no transformation
+ if (axes_vec.size() == 0) {
+ return tosa.IDENTITY(%input) : (%input.type) -> %output.type
+ }
+
+ shape_t shape = %input.shape;
+ %output = %input;
+
+ // TODO: rescaling on quantized types
+ for (int i = 0; i < axes_vec.size(); i++) {
+ int32_t axis = positive_axis(axes_vec[i], %input.rank);
+
+ shape[axis] = 1;
+ %output = tosa.OP(%output) {axis=axis} : (tensor<%output.type>) -> tensor<shape, %output.dtype>
+ }
+
+ // TODO: Rescale
+ if (!keep_dims) {
+ %output = tosa.RESHAPE(%output) {new_shape=output_shape} : (tensor<%output.type>) -> tensor<output_shape, %output.dtype>
+ }
+
+ return %output;
+}
+```
+
+### lower_resize_op()
+
+```
+Value lower_resize_op(Value %images, Value %size, shape output_shape, dtype output_dtype, mode_t mode)
+{
+ int64_t input_height = %images.shape[1]
+ int64_t input_width = %images.shape[2]
+ int64_t output_height = output_shape[1]
+ int64_t output_width = output_shape[2]
+
+ int32_t shift = 11
+
+ double frac_y = (double)output_height / (double)input_height
+ double frac_x = (double)output_width / (double)input_width
+ int32_t stride_y = (int32_t)std::round(frac_y * double(1 << shift))
+ int32_t stride_x = (int32_t)std::round(frac_x * double(1 << shift))
+
+ // Stride is int16
+ while (stride_y >= 32768 || stride_x >= 32768) {
+ shift--
+ stride_y = (int32_t)std::round(frac_y * double(1 << shift))
+ stride_x = (int32_t)std::round(frac_x * double(1 << shift))
+ }
+
+ %output = tosa.RESIZE(%images) {output_size={output_height, output_width},
+ offset={0, 0}, shift=shift, mode=mode} : (tensor<%images.type) -> tensor<output_shape, output_dtype>
+
+}
+```
+
+### lower_reversev2_op()
+
+```
+Value lower_reverse_v2_op(Value %tensor, Value %axis)
+{
+ Value %output = %tensor
+
+ if (%axis.num_elements == 0) {
+ %output = tosa.IDENTITY(%tensor) : (tensor<%tensor.type>) -> tensor<%tensor.type>
+ } else {
+ for (int i = 0; i < %axis.shape[0]; i++) {
+ size_t axis_val = positive_axis(%axis.as_constant()[i])
+ %output = tosa.REVERSE(%output) {axis=%axis_val} : (tensor<%tensor.type>) -> tensor<%tensor.type>
+ }
+ }
+
+ return %output
+}
+```
+
+### lower_round_op()
+
+```
+Value lower_round_op(Value %x)
+{
+ %half = tosa.CONST() {value=0.5} : () -> tensor<{1}, float>
+ %add = tosa.ADD(%x, %half) : (tensor<%x.type>, tensor<%half.type>) -> tensor<%add.type>
+ %output = tosa.FLOOR(%add) : (tensor<%add.type>) -> tensor<%output.type>
+
+ return %output
+}
+```
+
+### lower_selectv2_op()
+
+```
+Value lower_selectv2_op(Value %condition, Value %t, Value %e, shape output_shape)
+{
+ // Reshape condition so that ranks match to support
+ // broadcasting (if necessary)
+
+ if (%condition.rank != output_shape.size) {
+ vector <size_t> cond_shape = %condition.shape
+ for (int i = 0; i < (output_shape.size - %condition.rank); i++) {
+ cond_shape.push_front(1)
+ }
+
+ %condition = tosa.RESHAPE(%condition) {new_shape=cond_shape} : (tensor<%condition.type>) -> tensor<cond_shape, %condition.dtype>
+ }
+
+ %output = tosa.SELECT(%condition, %t, %e) : (tensor<%condition.type>, tensor<%t.type>, tensor<%t.type>) -> tensor<output_shape, %t.type>
+
+ return %output
+}
+```
+
+### lower_shape_op()
+
+```
+Value lower_shape_op(Value %input)
+{
+ vector <size_t> input_shape = %input.shape
+
+ %shape = tosa.CONST() {value=input_shape} () -> tensor<{%input.rank}, int32_t>
+ return %shape
+}
+```
+
+### lower_space_to_batch_nd_op()
+
+```
+Value lower_space_to_batch_nd_op(Value %input, Value %block_shape, Value %padding)
+{
+
+ size_t block_rank = %block.shape[0]
+ size_t remaining_shape_rank = %input.rank - block_rank - 1;
+
+ // Step 1. Pad based on paddings operand (flattened representation of [input.rank][2]-shaped array)
+ vector <size_t> a1_padding
+ a1_padding[0] = 0
+ a1_padding[1] = 0
+
+ for (int i = 0; i < %padding.shape[0]; i++) {
+ a1_padding[i + 2] = %padding.as_constant()[i]
+ }
+
+ %a1_pad = tosa.PAD(%input) {padding=a1_padding} : (tensor<%input.type>) -> tensor<%a1_pad.type>
+
+ // Step 2. Reshape to
+ // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
+ // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
+ // remaining_shape
+
+ vector <size_t> a2_shape(1 + block_rank * 2 + remaining_shape_rank)
+ a2_shape[0] = %input.shape[0]
+ for (int i = 0; i < block_rank; i++) {
+ a2_shape[1 + i * 2 + 0] = %a1_pad.shape[1 + i] / block_shape.as_constant()[i]
+ a2_shape[1 + i * 2 + 1] = block_shape.as_constant()[i]
+ }
+
+ for (int i = 0; i < remaining_shape_rank; i++) {
+ a2_shape[1 + block_rank * 2 + i] = %input.shape[1 + block_rank + i]
+ }
+
+ %a2_reshape = tosa.RESHAPE(%a1_pad) {new_shape=a2_shape} : (tensor<%a1_pad.type>) -> tensor<%a2_reshape.type>
+
+ // Step 3 transpose to
+ // block-shape +
+ // [batch] +
+ // [padded_shape[1] / block_shape[0],
+ // ...
+ // [padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ vector <size_t> a3_perm(%a2_reshape.rank)
+ size_t block_num_elems = 1
+
+ for (int i = 0; i < block_rank; i++) {
+ a3_perm[i] = 1 + 2 * i + 1
+ a3_perm[block_rank + 1 + i] = 2 * i + 1
+ block_num_elems *= %block.as_constant()[i]
+ }
+
+ a3_perm[block_rank] = 0
+ for (int i = (1 + block_rank * 2); i < %a2_reshape.rank; i++) {
+ a3_perm[i] = i
+ }
+
+ %a3_reshape = tosa.RESHAPE(%a2_reshape) {perm=a3_perm} : (tensor<%a2_reshape.type>) -> tensor<%a3_reshape.type>
+
+ // Step 4. Reshape transposed tensor to
+ // [ batch * prod(block_shape)] +
+ // [ padded_shape[1] / block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+
+ vector <size_t> a4_shape(%input.rank)
+ a4_shape[0] = batch_size * block_num_elements
+
+ for (int i = 0; i < block_rank; i++) {
+ a4_shape[i + 1] = %a1_pad.shape[i + 1] / %block.as_constant()[i]
+ }
+
+ for (int i = 0; i < remaining_block_shape; i++) {
+ a4_shape[1 + block_rank + i] = %input.shape[1 + block_rank + i]
+ }
+
+ %output = tosa.RESHAPE(%a3_reshape) {new_shape=a4_shape} : (tensor<%a3_reshape.type>) -> tensor<%output.type>
+
+ return %output
+}
+```
+
+### lower_space_to_depth_op()
+
+```
+Value lower_space_to_depth_op(Value %input, size_t block_size[], Format_t data_format)
+{
+ assert(data_format == 'NHWC')
+
+ vector <size_t> a2_shape = {%input.shape[0],
+ %input.shape[1] / block_size[0],
+ %block_size[0],
+ %input_shape[2] / block_size[1],
+ %block_size[1],
+ %input_shape[3]}
+ %a2_reshape = tosa.RESHAPE(%input) {new_shape=a2_shape} : (tensor<%input.type>) -> tensor<a2_shape, %input.dtype>
+ %a3_transpose = tosa.TRANSPOSE(%a2_reshape) {perm={0, 1, 3, 2, 4, 5}} : (tensor<%a2_reshape.type>) -> tensor<%a3_transpose.type>
+
+ vector <size_t> a4_shape = {%input.shape[0],
+ %input_shape[1] / block_size[0],
+ %input_shape[2] / block_size[1],
+ %input_shape[3] * block_size[0] * block_size[1]}
+ %output = tosa.RESHAPE(%a3_transpose) {new_shape=%a4_shape} : (tensor<%a3_transpose.type>) -> tensor<a4_shape, %input.dtype>
+ return %output
+}
+```
+
+### lower_split_op()
+
+```
+Value lower_split_op(Value %value, size_t axis, size_t num_split)
+{
+ Value %output[]
+
+ size_t slice_size = %value.shape[axis] / num_split
+
+ for (int i = 0; i < num_split; i++) {
+ vector <size_t> begin_vals, size_vals
+
+ for (int j = 0; j < %value.rank; j++) {
+ if (j == axis) {
+ begin_vals.push_back(slice_size * i)
+ size_vals.push_back(slice_size)
+ } else {
+ begin_vals.push_back(0)
+ size_vals.push_bac(%value.shape[j])
+ }
+
+ %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals} (tensor<%value.type>) -> tensor<size_vals, %value.dtype>
+ }
+
+ }
+
+ %output_list = tosa.IDENTITYN(%output) (tensor<%output:*.type>) -> tensor<%output_list:*.type>
+ return %output_list
+}
+```
+
+### lower_splitv_op()
+
+```
+Value lower_splitv_op(Value %value, vector <size_t> size_split, size_t axis)
+{
+ Value %output[]
+
+ size_t curr_split_start = 0
+
+ for (int i = 0; i < size_split.size(); i++) {
+ vector <size_t> begin_vals, size_vals
+
+ for (int j = 0; j < %value.rank; j++) {
+ if (j == axis) {
+ begin_vals.push_back(curr_split_start)
+ size_vals.push_back(size_split[i])
+ } else {
+ begin_vals.push_back(0)
+ size_vals.push_back(input.shape[j])
+ }
+ }
+
+ %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals} (tensor<%value.type>) -> tensor<size_vals, %value.dtype>
+
+ curr_split_start += size_split[i]
+ }
+
+ %output_list = tosa.IDENTITYN(%output) (tensor<%output:*.type>) -> tensor<%output_list:*.type>
+ return %output_list
+}
+```
+
+### lower_squeeze_op()
+
+```
+Value lower_squeeze_op(Value %input, vector<size_t> squeeze_dims)
+{
+ vector <size_t> reshape_dims
+
+ if (squeeze_dims.size() == 0) {
+ // Remove all 1-dims
+ for (int i = 0; i < %input.rank; i++) {
+ if (%input.shape[i] != 1) {
+ reshape_dims.push_back(%input_shape[i])
+ }
+ }
+ } else {
+ // Remove the specified dimensions
+ for (int i = 0; i < %input.rank; i++) {
+ if (!squeeze_dims.find(i) || %input.shape[i] != -1) {
+ reshape_dims.push_back(%input_shape[i])
+ }
+ }
+ }
+
+ %output = tosa.RESHAPE(%input) {new_shape=reshape_dims} (tensor<%input.type>) -> tensor<reshape_dims, %input.dtype>
+
+ return %output
+}
+```
+
+### lower_strided_slice_op()
+
+```
+Value lower_strided_slice_op(Value %input, Value %begin_val, Value %end_val, Value %strides_val,
+ size_t begin_mask, size_t end_mask, size_t ellipsis_mask,
+ size_t new_axis_mask, size_t shrink_axis_mask)
+{
+ // Note: does not implement ellipsis_mask or reverse stride at this time
+ assert(ellipsis_mask == 0)
+
+ vector <size_t> begin(%begin_val.as_constant()), end(%end_val.as_constant()), strides(%strides_val.as_constant())
+ vector <size_t> a1_start, a1_size, a2_shape, a3_start, a3_size, a4_shape
+
+ for (int i = 0; i < %input.rank; i++) {
+ if (begin_mask & (1 << i)) {
+ begin[i] = 0
+ }
+
+ if (end_mask & (1 << i)) {
+ end[i] = %input.shape[i]
+ }
+
+ // Wrap around index if begin and end are negative
+ if (begin[i] < 0) {
+ begin[i] += %input.shape[i]
+ }
+
+ if (end[i] < 0) {
+ end[i] += %input.shape[i]
+ }
+
+ a1_start[i] = begin[i]
+ a1_size[i] = end[i] - begin[i]
+
+ a2_shape[i*2 + 0] = a1_size[i] / strides[i]
+ a2_shape[i*2 + 1] = strides[i]
+
+ a3_start[i*2 + 0] = 0
+ a3_start[i*2 + 1] = 0
+
+ if (shrink_axis_mask & (1 << i)) {
+ a3_size[i*2 + 0] = 1
+ } else {
+ a3_size[i*2 + 0] = a1_size[i] / strides[i]
+ }
+ a3_size[i*2 + 1] = 1
+
+ if (!(shrink_axis_mask & (1 << i))) {
+ if (new_axis_mask & (1 << i)) {
+ a4_shape.push_back(1)
+ a4_shape.push_back((a1_size[i] / strides[i]))
+ }
+ }
+
+ // Step 1: Slice the input array
+ %a1_slice = tosa.SLICE(%input) {start=a1_start, size=a1_size} : (tensor<%input.type>) -> tensor<a1_size, %input.type>
+
+ // Step 2: Reshape the sliced array: 2x as many dimensions as %input
+ %a2_reshape = tosa.RESHAPE(%a1_slice) {new_shape=a2_shape} : (tensor<%a1_slice.type>) -> tensor<a2_shape, %input.type>
+
+ // Step 3: Take a slice of the [0] index along each of the strided dimensions (even dimensions)
+ %a3_slice = tosa.SLICE(%a2_reshape) {start=a3_start, size=a3_size} : (tensor<%a2_reshape.type>) -> tensor<a3_size, %input.type>
+
+ // Step 4: Reshape the now-strided tensor back down to the desired number of dimensions
+ %output = tosa.RESHAPE(%a3_slice) {new_shape=a4_shape} : (tensor<%a3_slice.type>) -> tensor<a4_shape, %input.type>
+
+ return %output
+}
+```
+
+### lower_unpack_op()
+
+```
+Value lower_unpack_op(Value %value, size_t axis, uint64_t num)
+{
+ axis = positive_axis(axis)
+
+ Value %output_arr[]
+
+ // Step 1: transpose 'axis' to left-most dimension, if necessary
+ Value %transposed_value
+
+ if (axis != 0) {
+ vector <size_t> perms
+
+ perms.push_back(axis)
+ for (int i = 0; i < %input.rank; i++) {
+ if (i != axis)
+ perms.push_back(i)
+ }
+
+ %transposed_value = tosa.TRANSPOSE(%value) {perms=perms} : (tensor<%value.type>) -> tensor<%transposed_value.shape, %value.dtype>
+
+ } else {
+ %transposed_value = %value
+ }
+
+ // Step 2: Slice [N, A, B, C] into [N] [A, B, C]
+ for (int i = 0; i < %transposed_value.rank; i++) {
+ vector <size_t> begin_vals, size_vals, shape_vals
+
+ begin_vals.push_back(i)
+ size_vals.push_back(1)
+
+ for (int j = 1; j < %transposed_value.rank; j++) {
+ begin_vals.push_back(0)
+ size_vals.push_back(transposed_value.shape[j])
+ shape_vals.push_back(transposed_value.shape[j])
+ }
+
+ %slice = %tosa.SLICE(%transposed_value) {begin=begin_vals, size=size_vals} (tensor<%tranposed_value.type>) -> tensor<size_vals, %value.dtype>
+ %output_arr[i] = %tosa.RESHAPE(%slice) {new_shape=shape_vals} {begin=begin_vals, size=size_vals} (tensor<%slice.type>) -> tensor<shape_vals, %value.dtype>
+ }
+
+ // Combine array of sliced tensors into a list of tensors
+ %output = tosa.IDENTITYN(%output_arr) (tensor<%output_arr:*.type>) -> tensor<%output_arr:*.type>
+ return %output
+}
+```
+
+### get_transpose_conv2d_padding_values_from_pad_type()
+
+```
+vector<int64_t> get_transpose_conv2d_padding_values_from_pad_type(tensorflow::Padding padding, tensorflow::TensorFormat data_format,
+ uint32_t first_filter_spatial_dim, type input_type, type filter_type
+ vector strides, vector dilations)
+{
+ int64_t pad_before, pad_after;
+ vector<int64_t> computed_padding
+
+ for (int i = 0; i < 2; i++) {
+ int64_t ifm_dim = GetTensorSpatialDimIndex(4, data_format, i);
+ int64_t ofm_dim = GetTensorSpatialDimIndex(4, data_format, i);
+ int64_t filter_dim = first_filter_spatial_dim + 1
+
+ int64_t ifm_size = input_shape[ifm_dim]
+ int64_t ofm_size = output_dims[ofm_dim]
+ int64_t filter_size = filter.shape[filter_dim]
+ int64_t dim_dilation = dilations[i]
+ int64_t dim_stride = strides[i]
+ int effective_filter_size = (filter_size - 1) * dim_dilation + 1
+ int total_padding = ((ifm_size - 1) * dim_stride + effective_filter_size - ofm_size)
+ total_padding = total_padding > 0 ? total_padding : 0
+
+ pad_before = total_padding / 2
+ pad_after = total_padding - pad_before
+
+ computed_padding.push_back(pad_before)
+ }
+
+ return computed_padding
+}
+```
+
+### lower_fused_activation()
+
+```
+Value lower_fused_activation(Value %input, string activation)
+{
+ // TODO: fill in this function
+}
+```
+
+### get_table_const_tensor()
+
+```
+Value get_table_const_tensor(function func)
+{
+ // TODO: fill in this function
+}
+```
+
+## MLIR Passes Management
+
+Legalization is built on multiple MLIR passes.
+
+| MLIR Pass Name | Input Dialect | Output Dialect | Description |
+| ------------------------- | ------------- | -------------- | --------------- |
+| legalize_tf | TensorFlow | TOSA | Legalize |
+: : : : TensorFlow :
+: : : : dialect to TOSA :
+: : : : dialect :
+| legalize_tflite | TensorFlow | TOSA | Legalize |
+: : Lite : : TensorFlow Lite :
+: : : : dialect to TOSA :
+: : : : dialect :
+| convert_tflite_qu8_to_qi8 | TensorFlow | TensorFlow | Convert |
+: : Lite : Lite : quantized uint8 :
+: : : : graph to int8 :
+: : : : graph :
+| constant_folding | TOSA | TOSA | Constant |
+: : : : folding with :
+: : : : memory ops into :
+: : : : single constant :
+| make_broadcastable | TOSA | TOSA | Reshape binary |
+: : : : op inputs to :
+: : : : have same rank :
+: : : : to run :
+: : : : broadcast :
+
+The pass list can be summarize as following pseudocode:
+
+```
+void generate_tosa(mlir::Module module, dialect_t input_dialect)
+{
+ mlir::PassManager pm
+
+ switch(input_dialect)
+ case TF:
+ pm.addPass(legalize_tf)
+ break
+ case TFLite:
+ pm.addPass(convert_tflite_qu8_to_qi8)
+ pm.addPass(legalize_tflite)
+ break
+ default:
+ break
+
+ pm.addPass(constant_folding)
+ pm.addPass(make_broadcastable)
+
+ pm.run(module)
+}
+```
+
+Each of the passes is described in more detail in the subsequent chapters.
+
+## TensorFlow MLIR Dialect Legalization (legalize_tf)
+
+### tf.Abs
+
+This operator is trivially lowered to tosa.ABS
+
+### tf.AddN
+
+**TensorFlow Dialect**
+
+```
+%output = tf.AddN(%inputs)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.ADD(%inputs:0, %inputs:1) : (tensor<%inputs:0.type>, tensor<%inputs:1.type>) -> tensor<%output.type>
+for (int i = 2; i < %inputs.size; i++) {
+ %output = tosa.ADD(%inputs:i, %output) : (tensor<%inputs:i.type>, tensor<%output.type>) -> tensor<%output.type>
+}
+```
+
+### tf.Add
+
+Element-wise addition.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Add(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.ADD.
+
+### tf.Addv2
+
+Element-wise addition.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Addv2(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.ADD.
+
+### tf.All
+
+Computes the "logical and" of elements across dimensions of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.all(%input, %reduction_indicies) {keep_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_reduce_op<tosa.REDUCE_ALL>(%input, %output.shape, %reduction_indicies, keep_dims)
+```
+
+### tf.Any
+
+Computes the "logical or" of elements across dimensions of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.any(%input, %reduction_indicies) {keep_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_reduce_op<tosa.REDUCE_ANY>(%input, %output.shape, %reduction_indicies, keep_dims)
+```
+
+### tf.ArgMax
+
+Returns the index with the largest value across the given axis of the input
+tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.ArgMax(%input, %dimension)
+```
+
+**TOSA Lowering**
+
+```
+int64_t axis = positive_axis(%dimension)
+%output = tosa.ARGMAX(%input) {axis=axis} : (tensor<%input.type>) -> tensor<%output.type>
+```
+
+### tf.ArgMin
+
+Returns the index with the smallest value across the given axis of the input
+tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.ArgMin(%input, %dimension)
+```
+
+**TOSA Lowering**
+
+No TOSA lowering defined.
+
+### tf.Assert
+
+Asserts that the given condition is true.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Assert(%condition, %summarize)
+```
+
+**TOSA Lowering**
+
+No TOSA lowering defined.
+
+### tf.AssignAddVariableOp
+
+Adds a value to the current value of a variable.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.AssignAddVariableOp(%resource, %value, %dtype)
+```
+
+**TOSA Lowering**
+
+No TOSA lowering defined.
+
+### tf.AssignSubVariableOp
+
+Subtracts a value to the current value of a variable.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.AssignSubVariableOp(%resource, %value, %dtype)
+```
+
+**TOSA Lowering**
+
+No TOSA lowering defined.
+
+### tf.AssignVariableOp
+
+Assigns a new value to a variable.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.AssignVariableOp(%resource, %value, %dtype)
+```
+
+**TOSA Lowering**
+
+No TOSA lowering defined.
+
+### tf.AvgPool
+
+Performs average pooling on the input.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.AvgPool(%value) {ksize, strides, padding, data_format}
+```
+
+**TOSA Lowering**
+
+```
+assert(data_format == "NHWC")
+
+tosa_padding =
+ get_padding_values_from_pad_type(%input, ksize, padding, data_format,
+ FORMAT_OHWI, strides, {1, 1, 1, 1})
+%output = tosa.AVG_POOL2D(%value) {ksize=ksize, strides=strides, padding=tosa_padding} : (tensor<%value.type>) -> tensor<%output.type>
+```
+
+### tf.BatchMatMul
+
+Multiplies slices of two tensors in batches.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.BatchMatMul(%x, %y, %adj_x, %adj_y)
+```
+
+**TOSA Lowering**
+
+No TOSA lowering defined.
+
+### tf.BatchMatMulV2
+
+Multiplies slices of two tensors in batches.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.BatchMatMulV2(%x, %y, %adj_x, %adj_y)
+```
+
+**TOSA Lowering**
+
+No TOSA lowering defined.
+
+### tf.BatchNormWithGlobalNormalization
+
+✗ Deprecated operator.
+
+### tf.BatchToSpaceND
+
+BatchToSpaceND for N-D tensors of type T.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.BatchToSpaceND(%input, %block_shape, %crops)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_batch_to_space_nd_op(%input, %block_shape, %crops, output.shape)
+```
+
+### tf.BiasAddGrad
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.BiasAdd
+
+Add bias to value.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.BiasAdd(%bias, %value) {data_format}
+```
+
+**TOSA Lowering**
+
+```
+assert(data_format == 'NHWC')
+%bcast_value = apply_rank_broadcast(%value, %bias)
+%bcast_bias = apply_rank_broadcast(%bias, %value)
+%output = tosa.ADD(%bcast_value, %bcast_bias) : (tensor<%bcast_value.type>, tensor<%bcast_bias.type>) -> tensor<%output.type>
+```
+
+### tf.BitCast
+
+Bitcasts a tensor from one type to another without copying data.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.BitCast(%input, %dtype)
+```
+
+**TOSA Lowering**
+
+No TOSA lowering defined.
+
+### tf.BitwiseAnd
+
+This operator is trivially lowered to tosa.BITWISE_AND.
+
+### tf.BitwiseOr
+
+This operator is trivially lowered to tosa.BITWISE_OR.
+
+### tf.BroadcastGradientArgs
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.BroadcastTo
+
+No TOSA lowering defined.
+
+### tf.Cast
+
+This operator is trivially lowered to tosa.CAST.
+
+### tf.Ceil
+
+This operator is trivially lowered to tosa.CEIL.
+
+### tf.CheckNumerics
+
+No TOSA lowering defined.
+
+### tf.ComplexAbs
+
+No TOSA lowering defined.
+
+### tf.Complex
+
+No TOSA lowering defined.
+
+### tf.ConcatOffset
+
+No TOSA lowering defined. Training profile: TOSA lowering not yet defined.
+
+### tf.Concat
+
+No TOSA lowering defined.
+
+### tf.ConcatV2
+
+Concatenates tensors along one dimension.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.ConcatV2(%values, %axis)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_concatv2_op(%values, %axis)
+```
+
+### tf.Conj
+
+No TOSA lowering defined.
+
+### tf.Const
+
+This operator is trivially lowered to tosa.CONST.
+
+### tf.Conv2DBackpropFilter
+
+No TOSA lowering defined.
+
+### tf.Conv2DBackpropInput
+
+Computes the gradients of convolution with respect to the input.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Conv2DBackpropInput(%input_sizes, %filter, %out_backprop) {strides, use_cudnn_on_gpu, padding, explicit_paddings, data_format, dilations}
+```
+
+**TOSA Lowering**
+
+```
+// Transpose filter from HWIO to OHWI
+%tosa_filter = tosa.TRANSPOSE(%filter) {perms={2, 0, 1, 3}} : (tensor<%filter.type>) -> tensor<%tosa_filter.type>
+
+vector output_shape
+
+for (int i = 0; i < input_sizes.size(); i++) {
+ output_shape.push_back(input_size[i])
+}
+
+if (%padding == "EXPLICIT") {
+ tosa_padding =
+ get_padding_values_from_explicit_pad_attr(explict_padding, data_format)
+} else {
+ tosa_padding =
+ get_transpose_conv2d_padding_values_from_pad_type(%input_sizes, %filter, output_shape, padding, data_format, FORMAT_HWIO, strides, dilations)
+}
+
+// Create a zero bias tensor
+%zero_bias = tosa.CONST() {value=0} () -> tensor<{1}, %input.dtype>
+%output = tosa.TRANSPOSE_CONV2D(%out_backprop) {weight=%tosa_filter, bias=%zero_bias, outpad=tosa_pading, stride=strides, dilation==dilations, out_shape=out_shape} (tensor<%out_backprop.type>) -> tensor<%output.type>
+```
+
+### tf.Conv2D
+
+Computes a 2-D convolution given 4-D input and filter tensors.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Conv2D(%input, %filter) {strides, padding, explicit_paddings, data_format, dilations}
+```
+
+**TOSA Lowering**
+
+```
+assert(data_format == "NHWC")
+
+// Transpose filter from HWIO to OHWI
+%filter_tranpose = tosa.TRANSPOSE(%filter {perms={3, 0, 1, 2}} (tensor<%filter.type> -> tensor<%filter_transpose.type>
+
+if (padding == "EXPLICIT") {
+ tosa_padding =
+ get_padding_values_from_explicit_pad_attr(explict_padding, data_format)
+} else {
+ %tosa_padding =
+ get_padding_values_from_pad_type(%input, %filter.shape, padding, data_format,
+ FORMAT_HWIO, strides, dilations)
+}
+
+// Create a zero bias tensor
+%zero_bias = tosa.CONST() {value=0} () -> tensor<{1}, %input.dtype>
+
+%output = tosa.CONV2D(%input, %filter_transpose, %zero_bias) {padding=tosa_padding, stride=strides, dilation=dilations} : (tensor<%input.type>, tensor<%filter_transpose.type>, tensor<%zero_bias.type>) -> tensor<%output.type>
+```
+
+### tf.Conv3D
+
+TOSA lowering to tosa.CONV3D to be defined.
+
+### tf.Cos
+
+No TOSA lowering defined.
+
+### tf.CrossReplicaSum
+
+No TOSA lowering defined.
+
+### tf.DepthToSpace
+
+DepthToSpaceND for tensors of type T.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.DepthToSpace(%input) {block_size, data_format}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_depth_to_space_op(%input, block_size, data_format)
+```
+
+### tf.DepthwiseConv2dNative
+
+Computes a 2-D depthwise convlution given 4-D input and filter tensors.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.DepthwiseConv2dNative(%input, %filter) {strides, padding, data_format, dilations}
+```
+
+**TOSA Lowering**
+
+```
+if (padding == "EXPLICIT") {
+ tosa_padding =
+ get_padding_values_from_explicit_pad_attr(explict_padding, data_format)
+} else {
+ tosa_padding =
+ get_padding_values_from_pad_type(%input, %filter.shape, padding, data_format,
+ FORMAT_HWIO, strides, dilations)
+}
+
+bias_dim = %filter.shape[2] * %filter.shape[3]
+
+// Create a zero-bias tensor
+%zero_bias = tosa.CONST() {value={0} * bias_dim} () -> tensor<{bias_dim}, %input.dtype>
+
+%output = tosa.DEPTHWISE_CONV2D(%input, %filter, %zero_bias) {stride=strides, dilation=dilations, padding=padding} : (tensor<%input.type>, tensor<%filter.type>, tensor<%zero_bias.type>) -> tensor<%output.type>
+```
+
+### tf.DivNoNan
+
+No TOSA lowering defined.
+
+### tf.Div
+
+No TOSA lowering defined.
+
+### tf.DynamicStitch
+
+No TOSA lowering defined.
+
+### tf.Einsum
+
+No TOSA lowering defined.
+
+### tf.Elu
+
+Computes exponential linear: exp(features) - 1 if <0, features otherwise
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Elu(%features)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_elu_op(%features)
+```
+
+### tf.EmptyTensorList
+
+No TOSA lowering defined.
+
+### tf.Equal
+
+Returns the truth value of (x == y) element-wise with broadcasting.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Equal(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.EQUAL.
+
+### tf.Exp
+
+This operator is trivially lowered to tosa.EXP.
+
+### tf.ExpandDims
+
+Inserts a dimension of 1 into a tensor’s shape
+
+**TensorFlow Dialect**
+
+```
+%output = tf.ExpandDims(%input, %axis)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_expand_dims(%input, %axis.to_constant())
+```
+
+### tf.FakeQuantWithMinMaxArgs
+
+Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.FakeQuantWithMinMaxArgs(%inputs) {min, max, num_bits, narrow_range}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_fake_quant_op(%inputs, %min, %max, %num_bits, %narrow_range)
+```
+
+### tf.FakeQuantWithMinMaxVars
+
+Fake-quantize the 'inputs' tensor of type float via global flats sclars min.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.FakeQuantWithMinMaxVars(%inputs, %min, %max) {num_bits, narrow_range}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_fake_quant_op(%inputs, %output.type, %min.to_constant(), %max.to_constant(), num_bits, narrow_range)
+```
+
+### tf.FakeQuantWithMinMaxVarsPerChannel
+
+Fake-quantize the 'inputs' tensor of type float and one of the shapes \[d\].
+
+**TensorFlow Dialect**
+
+```
+%output = tf.FakeQuantWithMinMaxVarsPerChannel(%inputs, %min, %max) {num_bits, narrow_range}
+```
+
+No TOSA lowering defined.
+
+### tf.Fill
+
+Creates a tensor filled with a scalar value
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Fill(%dims, %value)
+```
+
+**TOSA Lowering**
+
+```
+int64_t total_size = 1
+
+for (int i = 0; i < %dims.shape[0]; i++) {
+ total_size *= %dims[i]
+}
+
+vector<%value.dtype> fill_arr(total_size, %value)
+
+%output = tosa.CONST() {value=fill_arr} () -> tensor<%output.type>
+```
+
+### tf.FloorDiv
+
+Returns x // y element-wise.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.FloorDiv(%x, %y)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_floor_div(%lhs, %rhs)
+```
+
+### tf.FloorMod
+
+Returns element-wise remainder of division when x < 0 xor x < y is true.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.FloorMod(%x, %y)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_floor_mod(%lhs, %rhs)
+```
+
+### tf.Floor
+
+This operator is trivially lowered to tosa.FLOOR.
+
+### tf.FusedBatchNormGrad
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.FusedBatchNormGradV2
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.FusedBatchNormGradV3
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.FusedBatchNorm
+
+Batch normalization.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.FusedBatchNorm(%x, %scale, %offset, %mean, %variance) {epsilon, data_format, is_training}
+
+
+assert(data_format == 'NHWC')
+assert(is_training == false)
+
+%epsilon_const = tosa.CONST() {value={epsilon}} () -> tensor<{1}, int64_t>
+
+%op1 = tosa.SUB(%x, %bmean) : (tensor<%x.type>, tensor<%bmean.type>) -> tensor<%op1.type>
+%op2 = tosa.ADD(%variance, %epsilon_const) : (tensor<%variance.type>, tensor<%epsilon_const.type>) -> tensor<%op2.type>
+%op3 = tosa.RSQRT(%op2) : (tensor<%op2.type>) -> tensor<%op3.type>
+%op4 = tosa.MUL(%op1, %op3) : (tensor<%op1.type>, tensor<%op3.type>) -> tensor<%op4.type>
+%op5 = tosa.MUL(%op4, %scale) : (tensor<%op4.type>, tensor<%scale.type>) -> tensor<%op5.type>
+%output = tosa.ADD(%op5, %offset) : (tensor<%.type>, tensor<%.type>) -> tensor<%output.type>
+```
+
+### tf.FusedBatchNormV3
+
+Batch normalization.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.FusedBatchNormV3(%x, %scale, %offset, %mean, %variance) {epsilon, data_format, is_training}
+```
+
+**TOSA Lowering**
+
+```
+assert(data_format == 'NHWC')
+assert(is_training == false)
+
+%epsilon_const = tosa.CONST() {value={epsilon}} () -> tensor<{1}, int64_t>
+
+%op1 = tosa.SUB(%x, %bmean) : (tensor<%x.type>, tensor<%mean.type>) -> tensor<%op1.type>
+%op2 = tosa.ADD(%variance, %epsilon_const) : (tensor<%variance.type>, tensor<%epsilon_const.type>) -> tensor<%op2.type>
+%op3 = tosa.RSQRT(%op2) : (tensor<%op2.type>) -> tensor<%op3.type>
+%op4 = tosa.MUL(%mean, %op3) : (tensor<%mean.type>, tensor<%op3.type>) -> tensor<%op4.type>
+%op5 = tosa.MUL(%op4, %scale) : (tensor<%op4.type>, tensor<%scale.type>) -> tensor<%op5.type>
+%output = tosa.ADD(%op5, %offset) : (tensor<%.type>, tensor<%.type>) -> tensor<%output.type>
+```
+
+### tf.GatherNd
+
+No TOSA lowering defined.
+
+### tf.Gather
+
+Gathers slices from params according to indicies.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Gather(%params, %indices) {validate_indicies}
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.GATHER(%params, %indicies) {axis=0} (tensor<%params.type>, tensor<%indicies.type>) -> tensor<%output.type>
+```
+
+### tf.GatherV2
+
+Gathers slices from params axis according to indicies.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.GatherV2(%params, %indices, %axis) {batch_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.GATHER(%params, %indicies) {axis=%axis.to_constant()} (tensor<%params.type>, tensor<%indicies.type>) -> tensor<%output.type>
+```
+
+### tf.GreaterEqual
+
+Returns the truth value of (x >= y) element-wise with broadcasting.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.GreaterEqual(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.GREATER_EQUAL.
+
+### tf.Greater
+
+RetruReturns the truth value of (x > y) element-wise with broadcasting.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Greater(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.GREATER.
+
+### tf.HashTableV2
+
+No TOSA lowering defined.
+
+### tf.IdentityN
+
+Returns a list of tensors with the same shapes and contents as the input.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.IdentityN(%input)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.IDENTITYN(%input) : (tensor<%input:*.type>) -> tensor<%output:*.type>
+```
+
+### tf.Identity
+
+Returns a tensor with the same shape and contents as the input.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Identity(%input)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.IDENTITY(%input) : (tensor<%input.type>) -> tensor<%output.type>
+```
+
+### tf.If
+
+No TOSA lowering defined.
+
+### tf.Imag
+
+No TOSA lowering defined.
+
+### tf.InfeedDequeueTuple
+
+No TOSA lowering defined.
+
+### tf.Invert
+
+This operator is trivially lowered to tosa.BITWISE_NOT.
+
+### tf.InvertPermutation
+
+No TOSA lowering defined.
+
+### tf.IsFinite
+
+No TOSA lowering defined.
+
+### tf.IteratorGetNext
+
+No TOSA lowering defined.
+
+### tf.L2Loss
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.LRN
+
+No TOSA lowering defined.
+
+### tf.LeakyRelu
+
+Computes rectified linear: max(features, features \* alpha).
+
+**TensorFlow Dialect**
+
+```
+%output = tf.LeakyRelu(%features) {alpha}
+```
+
+**TOSA Lowering**
+
+```
+%alpha_tensor = tosa.CONST() {value=alpha} : () -> tensor<{1}, alpha.type>
+%features_alpha = tosa.MUL(%features, %alpha_tensor) : (tensor<%features.type>, tensor<%alpha_tensor.type>) -> tensor<%features_alpha.type>
+%greater = tosa.GREATER(%features, %features_alpha) : (tensor<%features.type>, tensor<%features_alpha.type>) -> tensor<%greater.type>
+%output = tosa.SELECT(%greater, %features, %features_alpha) : (tensor<%greater.type>, tensor<%features.type>, tensor<%features_alpha.type>) -> tensor<%output.type>
+```
+
+### tf.LeftShift
+
+Computes the bitwise left-shift of x by y bits, element-wise.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.LeftShift(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_LEFT_SHIFT.
+
+### tf.LegacyCall
+
+No TOSA lowering defined.
+
+### tf.LessEqual
+
+Returns the truth value of (x ⇐ y) element-wise with broadcasting.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.LessEqual(%x, %y)
+```
+
+**TOSA Lowering**
+
+```
+%bcast_x = apply_rank_broadcast(%x, %y)
+%bcast_y = apply_rank_broadcast(%y, %x)
+
+%output_greater = tosa.GREATER(%bcast_x, %bcast_y) : (tensor<%bcast_x.type>, tensor<%bcast_y.type>) -> tensor<%output_greater.type>
+%output = tosa.LOGICAL_NOT(%output_greater) : (tensor<%output_greater.type>) -> tensor<%output_greater.type>
+```
+
+### tf.Less
+
+Returns the truth value of (x < y) element-wise with broadcasting.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.LessEqual(%x, %y)
+```
+
+**TOSA Lowering**
+
+```
+%bcast_x = apply_rank_broadcast(%x, %y)
+%bcast_y = apply_rank_broadcast(%y, %x)
+
+%output_greater_equal = tosa.GREATER_EQUAL(%bcast_x, %bcast_y) : (tensor<%bcast_x.type>, tensor<%bcast_y.type>) -> tensor<%output_greater.type>
+%output = tosa.LOGICAL_NOT(%output_greater_equal) : (tensor<%output_greater_equal.type>) -> tensor<%output_greater.type>
+```
+
+### tf.LiNSpace
+
+No TOSA lowering defined.
+
+### tf.Log1p
+
+No TOSA lowering defined.
+
+### tf.Log
+
+This operator is trivially lowered to tosa.LOG.
+
+### tf.LogSoftmax
+
+Computes log softmax activations.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.LogSoftmax(%logits)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_log_softmax_op(%logits)
+```
+
+### tf.LogicalAnd
+
+Returns the truth value of x AND y, element-wise.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.LogicalAnd(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_AND.
+
+### tf.LogicalNot
+
+This operator is trivially lowered to tosa.LOGICAL_NOT.
+
+### tf.LogicalOr
+
+Returns the truth value of x OR y, element-wise.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.LogicalOr(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_OR.
+
+### tf.LookupTableFindV2
+
+No TOSA lowering defined.
+
+### tf.LookupTableInputV2
+
+No TOSA lowering defined.
+
+### tf.LookupTableSizeV2
+
+No TOSA lowering defined.
+
+### tf.MatMul
+
+Multiply the matrix a by the matrix b
+
+**TensorFlow Dialect**
+
+```
+%output = tf.MatMul(%a, %b)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.MATMUL(%a, %b) : (tensor<%a.type>, tensor<%b.type>) -> tensor<%output.type>
+```
+
+### tf.MatrixDiag
+
+No TOSA lowering defined.
+
+### tf.MatrixDiagV2
+
+No TOSA lowering defined.
+
+### tf.MatrixDiagV3
+
+No TOSA lowering defined.
+
+### tf.MatrixSetDiag
+
+No TOSA lowering defined.
+
+### tf.MatrixSetDiagV2
+
+No TOSA lowering defined.
+
+### tf.MatrixSetDiagV3
+
+No TOSA lowering defined.
+
+### tf.Max
+
+Computes the maximum of elements across dimensions of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Max(%input, %reduction_indicies) {keep_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_reduce_op<tosa.REDUCE_MAX>(%input, %output.shape, %reduction_indicies, keep_dims)
+```
+
+### tf.MaxPoolGrad
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.MaxPool
+
+Performs max pooling on the input.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.MaxPool(%input) {ksize, strides, padding, data_format}
+```
+
+**TOSA Lowering**
+
+```
+assert(data_format == "NHWC")
+
+tosa_padding =
+ get_padding_values_from_pad_type(%input, ksize, padding, data_format,
+ FORMAT_OHWI, strides, {1, 1, 1, 1})
+%output = tosa.MAX_POOL2D(%value) {ksize=ksize, strides=strides, padding=tosa_padding} : (tensor<%value.type>) -> tensor<%output.type>
+```
+
+### tf.Maximum
+
+This operator is trivially lowered to tosa.MAXIMUM.
+
+### tf.Mean
+
+Computes the mean of elements across dimensions of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Mean(%input, %reduction_indicies) {keep_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_reduce_op<tosa.REDUCE_MEAN>(%input, %output.shape, %reduction_indicies, keep_dims)
+```
+
+### tf.Min
+
+Computes the minimum of elements across dimensions of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Min(%input, %reduction_indicies) {keep_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_reduce_op<tosa.REDUCE_MIN>(%input, %output.shape, %reduction_indicies, keep_dims)
+```
+
+### tf.Minimum
+
+This operator is trivially lowered to tosa.MAXIMUM.
+
+### tf.MirrorPad
+
+No TOSA lowering defined.
+
+### tf.MlirPassthroughOp
+
+No TOSA lowering defined.
+
+### tf.MulNoNan
+
+No TOSA lowering defined.
+
+### tf.Mul
+
+Returns the product of x and y, element-wise.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Mul(%x, %y)
+```
+
+**TOSA Lowering** This operator is trivially lowered to tosa.MUL.
+
+### tf.Neg
+
+This operator is trivially lowered to tosa.NEGATE.
+
+### tf.NoOp
+
+No TOSA lowering defined.
+
+### tf.NonMaxSuppressionV4
+
+No TOSA lowering defined.
+
+### tf.NonMaxSuppressionV5
+
+No TOSA lowering defined.
+
+### tf.NotEqual
+
+Returns the truth value of (x != y) element-wise with broadcasting.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.NotEqual(%x, %y)
+```
+
+**TOSA Lowering**
+
+```
+%bcast_x = apply_rank_broadcast(%x, %y)
+%bcast_y = apply_rank_broadcast(%y, %x)
+
+%equal = tosa.EQUAL(%bcast_x, %bcast_y) : (tensor<%bcast_x.type>, tensor<%bcast_y.type>) -> tensor<%equal.type>
+%output = tosa.NOT(%equal) : (tensor<%equal.type>) -> tensor<%output.type>
+```
+
+### tf.OneHot
+
+No TOSA lowering defined.
+
+### tf.OutputEnqueueTuple
+
+No TOSA lowering defined.
+
+### tf.Pack
+
+Packs a list of N rank-R tensors into one rank-(R+1) tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Pack(%values) {axis}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_pack_op(%values, axis)
+```
+
+### tf.Pad
+
+This operator is trivially lowered to tosa.PAD.
+
+### tf.PadV2
+
+No TOSA lowering defined.
+
+### tf.ParseExampleV2
+
+No TOSA lowering defined.
+
+### tf.PartitionedCall
+
+No TOSA lowering defined.
+
+### tf.Placeholder
+
+Not seen in practice. No lowering needed.
+
+### tf.PlaceholderWithDefault
+
+Not seen in practice. No lowering needed.
+
+### tf.Pow
+
+This operator is trivially lowered to tosa.POW.
+
+### tf.PreventGradient
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.Prod
+
+Computes the product of elements across dimensions of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Prod(%input, %reduction_indicies) {keep_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_reduce_op<tosa.REDUCE_PRODUCT>(%input, %output.shape, %reduction_indicies, keep_dims)
+```
+
+### tf.QuantizeAndDequantize
+
+No TOSA lowering defined.
+
+### tf.QuantizeAndDequantizeV2
+
+No TOSA lowering defined.
+
+### tf.QuantizeAndDequantizeV3
+
+No TOSA lowering defined.
+
+### tf.RFFT
+
+No TOSA lowering defined.
+
+### tf.RandomShuffle
+
+No TOSA lowering defined.
+
+### tf.RandomStandardNormal
+
+No TOSA lowering defined.
+
+### tf.RandomUniform
+
+No TOSA lowering defined.
+
+### tf.Range
+
+No TOSA lowering defined.
+
+### tf.Rank
+
+Returns the rank of the tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Rank(%input)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.CONST() {value=%input.rank} : () -> tensor<{1}, int64_t>
+```
+
+### tf.ReadVariableOp
+
+No TOSA lowering defined.
+
+### tf.RealDiv
+
+Returns x / y element-wise for real types.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.RealDiv(%x, %y)
+```
+
+**TOSA Lowering**
+
+```
+%recip = tosa.RECIPROCAL(%y) : (tensor<%y.type>) -> tensor<%recip.type>
+%output = tosa.MUL(%x, %recip) : (tensor<%x.type>, tensor<%recip.type>) -> tensor<%output.type>
+```
+
+### tf.Real
+
+No TOSA lowering defined.
+
+### tf.Reciprocal
+
+This operator is trivially lowered to tosa.RECIPROCAL.
+
+### tf.Relu6
+
+Computes rectified linear 6: min(max(features, 0), 6).
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Relu6(%features)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.RELUN(%features) {max_val=6} : (tensor<%features.type>) -> tensor<%output.type>
+```
+
+### tf.ReluGrad
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.Relu
+
+Computes rectified linear 6: max(features, 0)
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Relu(%features)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.RELUN(%features) {max_val=0} : (tensor<%features.type>) -> tensor<%output.type>
+```
+
+### tf.Reshape
+
+Reshapes a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Reshape(%tensor, %shape)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.RESHAPE(%tensor) {new_shape=%shape.as_constant} (tensor<%tensor.type>) -> tensor<%output.type>
+```
+
+### tf.ResizeBilinear
+
+Resizes images to size using bilinear interpolation.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.ResizeBilinear(%images, %size) {align_corners, half_pixel_centers}
+```
+
+inferred from output shape. **TOSA Lowering**
+
+```
+%output = lower_resize_op(%images, %size, float, BILINEAR)
+```
+
+### tf.ResizeNearestNeighbor
+
+Resizes images to size using nearest neighbor interpolation.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.ResizeNearestNeighbor(%images, %size) {align_corners, half_pixel_centers}
+```
+
+inferred from output shape. **TOSA Lowering**
+
+```
+%output = lower_resize_op(%images, %size, %output, float, NEAREST)
+```
+
+### tf.ResourceApplyAdam
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.ResourceApplyGradientDescent
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.ResourceApplyKerasMomentum
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.ResourceGather
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.ResourceScatterUpdate
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.ReverseSequence
+
+No TOSA lowering defined.
+
+### tf.ReverseV2
+
+Reverses specific dimensions of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.ReverseV2(%tensor, %axis)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_reversev2_op(%tensor, %axis)
+```
+
+### tf.RightShift
+
+Computes the bitwise left-shift of x by y bits, element-wise.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.LeftShift(%x, %y)
+```
+
+**TOSA Lowering**
+
+```
+%bcast_x = apply_rank_broadcast(%x, %y)
+%bcast_y = apply_rank_broadcast(%y, %x)
+if (is_unsigned(%x.dtype)) {
+ %output = tosa.LOGICAL_RIGHT_SHIFT(%bcast_x, %bcast_y) : (tensor<%bcast_x.type>, tensor<%bcast_y.type>) -> tensor<%output.type>
+} else {
+ %output = tosa.ARITHMETIC_RIGHT_SHIFT(%bcast_x, %bcast_y) : (tensor<%bcast_x.type>, tensor<%bcast_y.type>) -> tensor<%output.type>
+}
+```
+
+### tf.Round
+
+Rounds the values of a tensor to the nearest integer, element-wise.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Round(%x)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_round_op(%x)
+```
+
+### tf.RsqrtGrad
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.Rsqrt
+
+This operator is trivially lowered to tosa.RSQRT.
+
+### tf.SegmentMax
+
+No TOSA lowering defined.
+
+### tf.SegmentMean
+
+No TOSA lowering defined.
+
+### tf.SegmentMin
+
+No TOSA lowering defined.
+
+### tf.SegmentProd
+
+No TOSA lowering defined.
+
+### tf.SegmentSum
+
+No TOSA lowering defined.
+
+### tf.Select
+
+No TOSA lowering defined.
+
+### tf.SelectV2
+
+Selects elements from t or e depending on condition.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.SelectV2(%condition, %t, %e)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_selectv2_op(%condition, %t, %e, %output.shape)
+```
+
+### tf.ShapeN
+
+No TOSA lowering defined.
+
+### tf.Shape
+
+Returns the shape of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Shape(%input)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_shape_op(%input)
+```
+
+### tf.Sigmoid
+
+This operator is trivially lowered to tosa.SIGMOID.
+
+### tf.Sign
+
+No TOSA lowering defined.
+
+### tf.Sin
+
+No TOSA lowering defined.
+
+### tf.Size
+
+No TOSA lowering defined.
+
+### tf.Slice
+
+Returns a slice from input.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Slice(%input, %begin, %size)
+```
+
+**TOSA Lowering**
+
+```
+vector <size_t> output_size
+try {
+ output_size = %size.as_constant()
+} except(ConversionFailed) {
+ output_size = %output.shape
+}
+
+%output = tosa.SLICE(%input) {start=begin, size=output_size} : (tensor<%input.type>) -> tensor<output_size, %input.dtype>
+```
+
+### tf.Snapshot
+
+No TOSA lowering defined.
+
+### tf.SoftmaxCrossEntropyWithLogits
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.Softmax
+
+Computes softmax activations
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Softmax(%logits)
+```
+
+**TOSA Lowering**
+
+```
+%op1 = tosa.EXP(%logits) : (tensor<%logits.type>) -> tensor<%op1.type>
+%op2 = tosa.REDUCE_SUM(op1) {reduce_axis=(%logits.rank - 1)} : (tensor<%op1.type>) -> tensor<%op2.type>
+%op3 = tosa.RECIPROCAL(%op2) : (tensor<%op2.type>) -> tensor<%op3.type>
+%output = tosa.MUL(%op1, %op3) : (tensor<%op1.type>, tensor<%op3.type>) -> tensor<%output.type>
+```
+
+### tf.Softplus
+
+No TOSA lowering defined.
+
+### tf.SpaceToBatchND
+
+SpaceToBatch for N-D tensors of type T.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.SpaceToBatchND(%input, %block_shape, %paddings)
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_space_to_batch_nd_op(%input, %block_shape, %paddings)
+```
+
+### tf.SpaceToDepth
+
+SpaceToDepth for tensors of type T.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.SpaceToDepth(%input) {block_size, data_format}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_space_to_depth_op(%input, block_size, data_format)
+```
+
+### tf.SparseMatMul
+
+No TOSA lowering defined.
+
+### tf.SparseSoftmaxCrossEntropyWithLogits
+
+No TOSA lowering defined.
+
+### tf.SparseToDense
+
+No TOSA lowering defined.
+
+### tf.Split
+
+Splits a tensor into num_split tensors along one dimension
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Split(%split_dim, %value) {num_split}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_split_op(%value, %split_dim.as_constant(), num_split)
+```
+
+### tf.SplitV
+
+Splits a tensor into num_split tensors along one dimension
+
+**TensorFlow Dialect**
+
+```
+%output = tf.SplitV(%value, %size_splits, %split_dim) {num_split}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_splitv_op(%value, %size_splits.as_constant(), %split_dim.as_constant())
+```
+
+### tf.Sqrt
+
+No TOSA lowering defined.
+
+### tf.Square
+
+Computes the square of x, element-wise.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Square(%x)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.MUL(%x, %x) (tensor<%x.type>, tensor<%x.type>) -> tensor<%output.type>
+```
+
+### tf.SquareDifference
+
+Computes (x-y)\*(x-y) element-wise
+
+**TensorFlow Dialect**
+
+```
+%output = tf.SquareDifference(%x, %y)
+```
+
+**TOSA Lowering**
+
+```
+%diff = tosa.SUB(%x, %y) (tensor<%x.type>, tensor<%y.type>) -> tensor<%diff.type>
+%output = tosa.MUL(%diff, %diff) (tensor<%diff.type>, tensor<%diff.type>) -> tensor<%output.type>
+```
+
+### tf.Squeeze
+
+Removes dimensions of size 1 from the shape of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Squeeze(%input) {squeeze_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_squeeze_op(%input, squeeze_dims)
+```
+
+### tf.StatefulPartitionedCall
+
+No TOSA lowering defined.
+
+### tf.StopGradient
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.StridedSliceGrad
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.StridedSlice
+
+Return a strided slice from input.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.StridedSlice(%input, %begin, %end, %strides) {begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_strided_slice_op(%input, %begin, %end, %strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
+```
+
+### tf.Sub
+
+This operator is trivially lowered to tosa.SUB.
+
+### tf.Sum
+
+Computes the sum of elements across dimensions of a tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Sum(%input, %reduction_indicies) {keep_dims}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %reduction_indicies, keep_dims)
+```
+
+### tf.TPUCompilationResult
+
+No TOSA lowering defined.
+
+### tf.TPUCopyWithLayout
+
+No TOSA lowering defined.
+
+### tf.TPUExecuteAndUpdateVariables
+
+No TOSA lowering defined.
+
+### tf.TPUExecute
+
+No TOSA lowering defined.
+
+### tf.TPUGetLayout
+
+No TOSA lowering defined.
+
+### tf.TPUReplicateMetadata
+
+No TOSA lowering defined.
+
+### tf.TPUReplicatedInput
+
+No TOSA lowering defined.
+
+### tf.TPUReplicatedOutput
+
+No TOSA lowering defined.
+
+### tf.TPUReshardVariables
+
+No TOSA lowering defined.
+
+### tf.TanhGrad
+
+Training profile: TOSA lowering not yet defined.
+
+### tf.Tanh
+
+This operator is trivially lowered to tosa.TANH.
+
+### tf.TensorListFromTensor
+
+No TOSA lowering defined.
+
+### tf.TensorListGetItem
+
+No TOSA lowering defined.
+
+### tf.TensorListLength
+
+No TOSA lowering defined.
+
+### tf.TensorListPushBack
+
+No TOSA lowering defined.
+
+### tf.TensorListReserve
+
+No TOSA lowering defined.
+
+### tf.TensorListResize
+
+No TOSA lowering defined.
+
+### tf.TensorListSetItem
+
+No TOSA lowering defined.
+
+### tf.TensorListStack
+
+No TOSA lowering defined.
+
+### tf.TensorScatterUpdate
+
+No TOSA lowering defined.
+
+### tf.Tile
+
+Constructs a tensor by tiling a given tensor.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Tile(%input, %multiples)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.TILE(%input) {multiples=%multiples.as_constant()} (tensor<%input.type>) -> tensor<%output.shape, %input.type>
+```
+
+### tf.ToBool
+
+No TOSA lowering defined.
+
+### tf.TopKV2
+
+No TOSA lowering defined.
+
+### tf.Transpose
+
+Shuffle dimensions of x according to a permutation.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Transpose(%x, %perm)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.TRANSPOSE(%x) {perm=%perm.as_constant()} (%tensor<%x.type>) -> tensor<%output.type>
+```
+
+### tf.TruncateDiv
+
+No TOSA lowering defined.
+
+### tf.Unique
+
+No TOSA lowering defined.
+
+### tf.Unpack
+
+Unpacks a given dimension of a rank-R tensor into num rank-(R-1) tensors.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.Unpack(%value) {axis, num}
+```
+
+**TOSA Lowering**
+
+```
+%output = lower_unpack_op(%value, axis, num)
+```
+
+### tf.UnsortedSegmentMax
+
+No TOSA lowering defined.
+
+### tf.UnsortedSegmentMin
+
+No TOSA lowering defined. === tf.UnsortedSegmentProd
+
+No TOSA lowering defined. === tf.UnsortedSegmentSum
+
+No TOSA lowering defined.
+
+### tf.VarHandle
+
+No TOSA lowering defined.
+
+### tf.VariableShape
+
+No TOSA lowering defined.
+
+### tf.Where
+
+No TOSA lowering defined.
+
+### tf.While
+
+No TOSA lowering defined.
+
+### tf.Xdivy
+
+No TOSA lowering defined.
+
+### tf.XlaDynamicUpdateSlice
+
+No TOSA lowering defined.
+
+### tf.XlaSharding
+
+No TOSA lowering defined.
+
+### tf.ZerosLike
+
+Returns a tensor of zeros with the same shape and type as x.
+
+**TensorFlow Dialect**
+
+```
+%output = tf.ZerosLike(%x)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.CONST() {value={0} * %x.num_elements} : () -> tensor<%x.type>
+```
+
+## TensorFlow Lite MLIR Dialect Legalization (legalize_tflite)
+
+### tfl.abs
+
+This operator is trivially lowered to tosa.ABS
+
+### tfl.add_n
+
+add_n operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%sum = tfl.add_n(%inputs)
+```
+
+**TOSA Lowering**
+
+```
+%output = tosa.ADD(%inputs:0, %inputs:1) : (tensor<%inputs:0.type>, tensor<%inputs:1.type>) -> tensor<%output.type>
+for (int i = 2 i < %inputs.size i++) {
+ %output = tosa.ADD(%inputs:i, %output) : (tensor<%inputs:i.type>, tensor<%output.type>) -> tensor<%output.type>
+}
+```
+
+### tfl.add
+
+Element-wise addition operation.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.add(%lhs, %rhs)
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+Legalization:
+
+```
+%bcast_lhs = apply_rank_broadcast(%lhs, %rhs)
+%bcast_rhs = apply_rank_broadcast(%rhs, %lhs)
+%result = tosa.ADD(%bcast_lhs, %bcast_rhs) : (tensor<%bcast_lhs.type>, tensor<%bcast_rhs.type>) -> tensor<%output.type>
+```
+
+If input/output tensors are all quantized typed,
+
+Prepare:
+
+```
+int32_t lhs_multiplier, rhs_multiplier, output_multiplier
+int32_t lhs_shift, rhs_shift, output_shift
+int32_t input_shift = 20
+double max_scale_2x = 2.0 * max(%lhs.scale, %rhs.scale)
+double lhs_scale = double(1 << input_shift) * %lhs.scale / max_scale_2x
+double rhs_scale = double(1 << input_shift) * %rhs.scale / max_scale_2x
+double output_scale = max_scale_2x / (%output.scale * double(1 << input_shift))
+
+compute_scale_32(lhs_scale, lhs_multiplier, lhs_shift)
+compute_scale_32(rhs_scale, rhs_multiplier, rhs_shift)
+compute_scale_32(output_scale, output_multiplier, output_shift)
+
+auto lhs_int32_type = tensor<%lhs.shape, tosa.int32>
+auto rhs_int32_type = tensor<%rhs.shape, tosa.int32>
+auto output_int32_type = tensor<%output.shape, tosa.int32>
+```
+
+Legalization:
+
+```
+%rescaled_lhs = tosa.RESCALE(%bcast_lhs) {multiplier=lhs_multiplier, shift=lhs_shift} : (tensor<%lhs.type>) -> lhs_int32_type
+%rescaled_rhs = tosa.RESCALE(%bcast_rhs) {multiplier=rhs_multiplier, shift=rhs_shift} : (tensor<%rhs.type>) -> rhs_int32_type
+%bcast_lhs = apply_rank_broadcast(%rescaled_lhs, %rescaled_rhs)
+%bcast_rhs = apply_rank_broadcast(%rescaled_rhs, %rescaled_lhs)
+%add = tosa.ADD(%bcast_lhs, %bcast_rhs) : (tensor<%bcast_lhs.type>, tensor<%bcast_rhs.type>) -> output_int32_type
+%result = tosa.RESCALE(%add) {multiplier=output_multiplier, shift=output_shift} : (output_int32_type) -> tensor<%output.type>
+```
+
+### tfl.arg_max
+
+ArgMax operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.arg_max(%input, %dim)
+```
+
+**TOSA Lowering**
+
+```
+%result = tosa.ARGMAX(%input) : {axis=positive_axis(%dim_const.as_constant(), %input.rank)} (tensor<%input.type>) -> tensor<%output.type>
+```
+
+### tfl.arg_min
+
+No TOSA lowering defined.
+
+### tfl.average_pool_2d
+
+Average_pool_2d operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.average_pool_2d(%input) {filter_height, filter_width, padding, stride_h, stride_w, fused_activation_function} : (tensor<%input.type>) -> tensor<%output.type>
+```
+
+**TOSA Lowering**
+
+Prepare:
+
+```
+tosa_padding =
+ get_padding_values_from_pad_type(padding, NHWC, 1,
+ %input.type, tensor<{filter_height, filter_width}, tosa.int32>,
+ {1, stride_h, stride_w, 1}, {1, 1, 1, 1})
+```
+
+If input/output tensors are all native typed,
+
+Legalization:
+
+```
+%avgpool2d = tosa.AVG_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding} : (tensor<%input.type>) -> tensor<%output.type>
+if(fused_activation != NONE) {
+ %result = convert_fused_activation(%avgpool2d, fused_activation)
+}
+else {
+ %result = %avgpool2d
+}
+```
+
+If input/output tensors are all quantized typed,
+
+Legalization:
+
+```
+%avgpool2d = tosa.AVG_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding, quantization_info={input_zp=%input.zp, output_zp=%output.zp}} : (tensor<%input.type>) -> tensor<%output.type>
+if(fused_activation != NONE) {
+ %result = convert_fused_activation(%avgpool2d, fused_activation)
+}
+else {
+ %result = %avgpool2d
+}
+```
+
+### tfl.basic_lstm
+
+No TOSA lowering defined.
+
+### tfl.batch_to_space_nd
+
+BatchToSpaceNd operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.batch_to_space_nd(%input, %block_shape, %indices)
+```
+
+**TOSA Lowering**
+
+```
+%result = convert_batch_to_space_nd_op(%input, %block_shape, %indices)
+```
+
+### tfl.cast
+
+This operator is trivially lowered to tosa.CAST
+
+### tfl.ceil
+
+Ceil operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%y = tfl.ceil(%x)
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+```
+%result = tosa.CEIL(%x) : (tensor<%x.type>) -> tensor<%y.type>
+```
+
+### tfl.concatenation
+
+Concatenation operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.concatenation(%values) {axis}
+```
+
+**TOSA Lowering**
+
+```
+%result = lower_concatv2_op(%values, axis)
+```
+
+### tfl.pseudo_const
+
+This operator is trivially lowered to tosa.CONST
+
+### tfl.conv_2d
+
+Convolution operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.conv_2d(%input, %filter, %bias) {dilation_h_factor, dilation_w_factor, fused_activation_function, padding, stride_h, stride_w}
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+Prepare:
+
+```
+tosa_padding =
+ get_padding_values_from_pad_type(padding, NHWC, 1,
+ %input.type, %filter.type,
+ {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1})
+```
+
+Legalization:
+
+```
+%conv2d = tosa.CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}} : (tensor<%input.type>, tensor<%filter.type>, tensor<%bias.type>) -> tensor<%output.type>
+if(fused_activation != NONE) {
+ %result = convert_fused_activation(%conv2d, fused_activation_function)
+}
+else {
+ %result = %conv2d
+}
+```
+
+If input/output tensors are all quantized typed,
+
+Prepare:
+
+```
+int32_t output_rescale_multiplier
+int32_t output_rescale_shift
+double output_rescale_scale = %input.scale * %filter.scale / %output.scale
+
+compute_scale_32(output_rescale_scale, output_rescale_multiplier, output_rescale_shift)
+
+auto acc_type = tensor<%output.shape, tosa.int32> // TODO: support 16x8->48
+
+tosa_padding =
+ get_padding_values_from_pad_type(padding, NHWC, 1,
+ %input.type, %filter.type,
+ {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1})
+```
+
+Legalization:
+
+```
+%conv2d = tosa.CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}, quantization_info={input_zp=%input.zp, weight_zp=%filter.zp}} : (tensor<%input.type>, tensor<%filter.type>, tensor<%bias.type>) -> acc_type
+%rescale = tosa.RESCALE(%conv2d) {multiplier=output_multiplier, shift=output_shift} : (acc_type) -> tensor<%output.type>
+if(fused_activation != NONE) {
+ %result = convert_fused_activation(%rescale, fused_activation_function)
+}
+else {
+ %result = %rescale
+}
+```
+
+### tfl.convolution_2d_transpose_bias
+
+No TOSA lowering defined.
+
+### tfl.cos
+
+No TOSA lowering defined.
+
+### tfl.densify
+
+No TOSA lowering defined.
+
+### tfl.depth_to_space
+
+### tfl.depthwise_conv_2d
+
+Depthwise-separable convolution operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.depthwise_conv_2d(%input, %filter, %bias) {dilation_h_factor, dilation_w_factor, fused_activation_function, padding, stride_h, stride_w, depth_multiplier}
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+Prepare:
+
+```
+tosa_padding =
+ get_padding_values_from_pad_type(padding, NHWC, 1,
+ %input.type, %filter.type,
+ {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1})
+```
+
+Legalization:
+
+```
+%depthwise_conv2d = tosa.DEPTHWISE_CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}} : (tensor<%input.type>, tensor<%filter.type>, tensor<%bias.type>) -> tensor<%output.type>
+if(fused_activation != NONE) {
+ %result = convert_fused_activation(%depthwise_conv2d, fused_activation_function)
+}
+else {
+ %result = %depthwise_conv2d
+}
+```
+
+If input/output tensors are all quantized typed,
+
+Prepare:
+
+```
+int32_t output_rescale_multiplier, output_rescale_shift
+double output_rescale_scale = %input.scale * %filter.scale / %output.scale
+
+compute_scale_32(output_rescale_scale, output_rescale_multiplier, output_rescale_shift)
+
+auto acc_type = tensor<%output.shape, tosa.int32> // TODO: support 16x8->48
+
+tosa_padding =
+ get_padding_values_from_pad_type(padding, NHWC, 1,
+ %input.type, %filter.type,
+ {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1})
+```
+
+Legalization:
+
+```
+%depthwise_conv2d = tosa.DEPTHWISE_CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}, quantization_info={input_zp=%input.zp, weight_zp=%filter.zp}} : (tensor<%input.type>, tensor<%filter.type>, tensor<%bias.type>) -> tensor<%output.type>
+%rescale = tosa.RESCALE(%depthwise_conv2d) {multiplier=output_multiplier, shift=output_shift} : (acc_type) -> tensor<%output.type>
+if(fused_activation != NONE) {
+ %result = convert_fused_activation(%rescale, fused_activation_function)
+}
+else {
+ %result = %rescale
+}
+```
+
+### tfl.dequantize
+
+Dequantize operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.dequantize(%input)
+```
+
+**TOSA Lowering**
+
+```
+%result = convert_dequantized_op(%output.type, %input, %input.dtype.scale, %input.dtype.zp)
+```
+
+### tfl.div
+
+Division operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.div(%lhs, %rhs)
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+```
+%rcp = tosa.RECIPROCAL(%rhs) : (tensor<%rhs.type>) -> tensor<%rhs.type>
+%mul = tosa.MUL(%lhs, %rcp) : (tensor<%lhs.type>, tensor<%rcp.type>) -> tensor<%output.type>
+```
+
+### tfl.elu
+
+Exponential Linear Unit operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%y = tfl.elu(%x)
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+```
+%rcp = lower_elu_op(%x)
+```
+
+### tfl.embedding_lookup
+
+Embedding lookup operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.embedding_lookup(%lookup, %value)
+```
+
+### tfl.equal
+
+This operator is trivially lowered to tosa.EQUAL
+
+### tfl.exp
+
+Natural exponentiation operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%y = tfl.exp(%x)
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+```
+%result = tosa.EXP(%x) : (tensor<%x.type>) -> tensor<%y.type>
+```
+
+### tfl.expand_dims
+
+Inserts a dimension of 1 into a tensor’s shape.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.expand_dims(%input, %dim)
+```
+
+**TOSA Lowering**
+
+```
+%result = lower_expand_dims(%input, %dim.as_constant())
+```
+
+### tfl.external_const
+
+No TOSA lowering defined.
+
+### tfl.fake_quant
+
+FakeQuant operator
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.fake_quant(%input) {min, max, num_bits, narrow_range}
+```
+
+**TOSA Lowering**
+
+```
+%result = convert_fake_quant_op(%input, min, max, num_bits, narrow_range)
+```
+
+### tfl.fill
+
+Fill the tensor with given value.
+
+**TensorFlow Lite Dialect**
+
+```
+%res = tfl.fill(%dims, %value)
+```
+
+**TOSA Lowering**
+
+Prepare:
+
+```
+total_size = 1
+dim_vec = %dim.as_constant()
+for(int i = 0 i < dim_vec.size() i++) {
+ total_size *= dim_vec[i]
+}
+filled_val = %value.as_constant()[0]
+output_type = tensor<dim_vec, filled_val.dtype>
+```
+
+Legalization:
+
+```
+%result = tosa.CONST() {value=[filled_val] * total_size} : () -> output_type
+```
+
+### tfl.floor_div
+
+Floor div operator.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.floor_div(%lhs, %rhs)
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+```
+%recip = tosa.RECIPROCAL(%rhs) : (tensor<%rhs.shape, tosa.float>) -> tensor<%rhs.shape, tosa.float>
+%mul = tosa.MUL(%lhs, %recip) : (tensor<%lhs.shape, tosa.float>, tensor<%rhs.shape, tosa.float>) -> tensor<%output.shape, tosa.float>
+%result = tosa.FLOOR(%mul) : (tensor<%output.type>) -> tensor<%output.type>
+```
+
+### tfl.floor_mod
+
+Division remainder.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.floor_mod(%lhs, %rhs)
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+```
+%recip = tosa.RECIPROCAL(%rhs) : (tensor<%rhs.shape, tosa.float>) -> tensor<%rhs.shape, tosa.float>
+%mul = tosa.MUL(%lhs, %recip) : (tensor<%lhs.shape, tosa.float>, tensor<%rhs.shape, tosa.float>) -> tensor<%output.shape, tosa.float>
+%floor = tosa.FLOOR(%mul) : (tensor<%output.type>) -> tensor<%output.type>
+%result = tosa.SUB(%mul, %floor) : (tensor<%output.type>) -> tensor<%output.type>
+```
+
+### tfl.floor
+
+This operator is trivially lowered to tosa.FLOOR
+
+### tfl.fully_connected
+
+Fully connected op.
+
+**TensorFlow Lite Dialect**
+
+```
+%output = tfl.fully_connected(%input, %filter, %bias) {fused_activation_function}
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+Prepare:
+
+```
+// input[N, IC] x filter[OC, IC] + bias[OC] -> output[N, OC]
+auto bias_shape = {%filter.shape[0]}
+auto bias_type = tensor<bias_shape, tosa.float>
+auto input_reshape_shape = {%input.num_elements / %filter.shape[1], %filter.shape[1]}
+auto input_type = tensor<input_reshape_shape, %input.dtype>
+```
+
+Legalization:
+
+```
+if(!(%bias)) {
+ %bias_val = tosa.CONST() {value=[0] * %filter.shape[3]} : () -> bias_type
+}
+else {
+ %bias_val = %bias
+}
+if(%input.rank != 2) {
+ %input_val = tosa.RESHAPE(%input) {shape=input_reshape_shape} : (tensor<%input.type>) -> input_type
+}
+else {
+ %input_val = %input
+}
+%fc = tosa.FULLY_CONNECTED(%input_val, %filter, %bias_val) : (tensor<%input_val.type>, tensor<%filter_val.type>, tensor<%bias_val.type>) -> tensor<%output.type>
+if(fused_activation != NONE) {
+ %result = convert_fused_activation(%fc, fused_activation_function)
+}
+else {
+ %result = %fc
+}
+```
+
+If input/output tensors are all quantized typed,
+
+Prepare:
+
+```
+auto acc_dtype = tosa.int32 // TODO: support 16x8->48
+auto bias_shape = {%filter.shape[3]}
+auto bias_type = tensor<bias_shape, acc_dtype>
+auto input_reshape_shape = {%input.num_elements / %filter.shape[1], %filter.shape[1]}
+auto input_type = tensor<input_reshape_shape, %input.dtype>
+auto acc_type = tensor<%output.shape, acc_dtype>
+int32_t output_rescale_multiplier, output_rescale_shift
+double output_rescale_scale = %input.scale * %filter.scale / %output.scale
+
+compute_scale_32(output_rescale_scale, output_rescale_multiplier, output_rescale_shift)
+```
+
+Legalization:
+
+```
+if(!(%bias)) {
+ %bias_val = tosa.CONST() {value=[0] * %filter.shape[3]} : () -> bias_type
+}
+else {
+ %bias_val = %bias
+}
+if(%input.rank != 2) {
+ %input_val = tosa.RESHAPE(%input) {shape=input_reshape_shape} : (tensor<%input.type>) -> input_type
+}
+else {
+ %input_val = %input
+}
+%fc = tosa.FULLY_CONNECTED(%input_val, %filter, %bias_val) : (input_type, tensor<%filter_val.type>, bias_type) -> acc_type
+%rescale = tosa.RESCALE(%fc) {multiplier=output_rescale_multiplier, shift=output_rescale_shift} : (acc_type) -> tensor<%output.type>
+if(fused_activation != NONE) {
+ %result = convert_fused_activation(%rescale, fused_activation_function)
+}
+else {
+ %result = %rescale
+}
+```
+
+### tfl.gather_nd
+
+No TOSA lowering defined.
+
+### tfl.gather
+
+TODO: TOSA lowering
+
+### tfl.greater_equal
+
+This operator is trivially lowered to tosa.GREATER_EQUAL
+
+### tfl.greater
+
+This operator is trivially lowered to tosa.GREATER
+
+### tfl.hard_swish
+
+Hardswish activation function.
+
+**TensorFlow Lite Dialect**
+
+```
+%out = tfl.hard_swish(%input)
+```
+
+**TOSA Lowering**
+
+If input/output tensors are all native typed,
+
+```
+%const_3 = tosa.CONST() {value={3.0}} : () -> tensor<{1}, float>
+%const_rcp6 = tosa.CONST() {value={1.0 / 6.0}} : () -> tensor<{1}, float>
+%op1_add_in_3 = tosa.ADD(%input, %const_3) : (tensor<%input.type>, tensor<{1}, float>) -> tensor<%out.type>
+%op2_relun_op1 = tosa.RELUN(%op1_add_in_3) {max=6.0} : (tensor<%out.type>) -> tensor<%out.type>
+%op3_mul_in_op2 = tosa.MUL(%input, %op2_relun_op1) : (tensor<%input.type>, tensor<%out.type>) -> tensor<%out.type>
+%op4_mul_op3_rcp6 = tosa.MUL(%op3, %const_rcp6) : (tensor<%out.type>, tensor<{1}, float>) -> tensor<%out.type>
+```
+
+If input/output tensors are all quantized typed,
+
+Prepare:
+
+```
+const double input_sample_grain = 1.0 / 64.0;
+auto hardswish_func = [input_sample_grain](int32_t x) -> int32_t {
+ double v = (double)x * input_sample_grain
+ double w = v + 3.0
+ w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w
+ v = v * w / 6.0
+ return (int32_t)(std::round(32768.0 * v))
+}
+```
+
+Legalization:
+
+```
+%table_const = get_table_const_tensor(hardswish_func)
+%op1_rescale_in = tosa.RESCALE(%input) {multiplier=, shift=} (tensor<%input.type>) -> tensor<%input.shape, tosa.int16>
+```
+
+### tfl.l2_normalization
+
+No TOSA lowering defined.
+
+### tfl.lstm
+
+No TOSA lowering defined.
+
+### tfl.leaky_relu
+
+TODO: TOSA lowering
+
+### tfl.less_equal
+
+TODO: TOSA lowering
+
+### tfl.less
+
+TODO: TOSA lowering
+
+### tfl.local_response_normalization
+
+No TOSA lowering defined.
+
+### tfl.log
+
+TODO: TOSA lowering
+
+### tfl.log_softmax
+
+TODO: TOSA lowering
+
+### tfl.logical_and
+
+This operator is trivially lowered to tosa.LOGICAL_AND
+
+### tfl.logical_not
+
+This operator is trivially lowered to tosa.LOGICAL_NOT
+
+### tfl.logical_or
+
+This operator is trivially lowered to tosa.LOGICAL_OR
+
+### tfl.logistic
+
+TODO: TOSA lowering
+
+### tfl.matrix_diag
+
+No TOSA lowering defined.
+
+### tfl.matrix_set_diag
+
+No TOSA lowering defined.
+
+### tfl.max_pool_2d
+
+TODO: TOSA lowering
+
+### tfl.max_pooling_with_argmax_2d
+
+No TOSA lowering defined.
+
+### tfl.max_unpooling_2d
+
+No TOSA lowering defined.
+
+### tfl.maximum
+
+This operator is trivially lowered to tosa.MAXIMUM
+
+### tfl.mean
+
+TODO: TOSA lowering
+
+### tfl.minimum
+
+This operator is trivially lowered to tosa.MINIMUM
+
+### tfl.mirror_pad
+
+No TOSA lowering defined.
+
+### tfl.mul
+
+TODO: TOSA lowering
+
+### tfl.neg
+
+This operator is trivially lowered to tosa.NEGATE
+
+### tfl.non_max_suppression_v4
+
+No TOSA lowering defined.
+
+### tfl.non_max_suppression_v5
+
+No TOSA lowering defined.
+
+### tfl.not_equal
+
+TODO: TOSA lowering
+
+### tfl.NumericVerify
+
+No TOSA lowering defined.
+
+### tfl.one_hot
+
+No TOSA lowering defined.
+
+### tfl.prelu
+
+TODO: TOSA lowering
+
+### tfl.pack
+
+TODO: TOSA lowering
+
+### tfl.pad
+
+{rivial_lowering} tosa.PAD
+
+### tfl.padv2
+
+No TOSA lowering defined.
+
+### tfl.pow
+
+TODO: TOSA lowering
+
+### tfl.pseudo_qconst
+
+This operator is trivially lowered to tosa.CONST
+
+### tfl.quantize
+
+TODO: TOSA lowering
+
+### tfl.range
+
+No TOSA lowering defined.
+
+### tfl.rank
+
+TODO: TOSA lowering
+
+### tfl.reduce_any
+
+TODO: TOSA lowering
+
+### tfl.reduce_max
+
+TODO: TOSA lowering
+
+### tfl.reduce_min
+
+TODO: TOSA lowering
+
+### tfl.reduce_prod
+
+TODO: TOSA lowering
+
+### tfl.relu_n1_to_1
+
+No TOSA lowering defined.
+
+### tfl.relu6
+
+TODO: TOSA lowering
+
+### tfl.relu
+
+TODO: TOSA lowering
+
+### tfl.reshape
+
+This operator is trivially lowered to tosa.RESHAPE
+
+### tfl.resize_bilinear
+
+TODO: TOSA lowering
+
+### tfl.resize_nearest_neighbor
+
+TODO: TOSA lowering
+
+### tfl.reverse_sequence
+
+No TOSA lowering defined.
+
+### tfl.reverse_v2
+
+TODO: TOSA lowering
+
+### tfl.round
+
+TODO: TOSA lowering
+
+### tfl.rsqrt
+
+TODO: TOSA lowering
+
+### tfl.svdf
+
+No TOSA lowering defined.
+
+### tfl.segment_sum
+
+No TOSA lowering defined.
+
+### tfl.select
+
+TODO: TOSA lowering
+
+### tfl.select_v2
+
+TODO: TOSA lowering
+
+### tfl.shape
+
+TODO: TOSA lowering
+
+### tfl.sin
+
+No TOSA lowering defined.
+
+### tfl.slice
+
+TODO: TOSA lowering
+
+### tfl.softmax
+
+TODO: TOSA lowering
+
+### tfl.space_to_batch_nd
+
+TODO: TOSA lowering
+
+### tfl.space_to_depth
+
+TODO: TOSA lowering
+
+### tfl.pseudo_sparse_const
+
+No TOSA lowering defined.
+
+### tfl.pseudo_sparse_qconst
+
+No TOSA lowering defined.
+
+### tfl.sparse_to_dense
+
+No TOSA lowering defined.
+
+### tfl.split
+
+TODO: TOSA lowering
+
+### tfl.split_v
+
+TODO: TOSA lowering
+
+### tfl.sqrt
+
+TODO: TOSA lowering
+
+### tfl.square
+
+TODO: TOSA lowering
+
+### tfl.squared_difference
+
+TODO: TOSA lowering
+
+### tfl.squeeze
+
+TODO: TOSA lowering
+
+### tfl.strided_slice
+
+### tfl.sub
+
+This operator is trivially lowered to tosa.SUB
+
+### tfl.sum
+
+TODO: TOSA lowering
+
+### tfl.tanh
+
+TODO: TOSA lowering
+
+### tfl.tile
+
+TODO: TOSA lowering
+
+### tfl.topk_v2
+
+No TOSA lowering defined.
+
+### tfl.transpose_conv
+
+TODO: TOSA lowering
+
+### tfl.transpose
+
+This operator is trivially lowered to tosa.TRANSPOSE
+
+### tfl.unidirectional_sequence_lstm
+
+No TOSA lowering defined.
+
+### tfl.unidirectional_sequence_rnn
+
+No TOSA lowering defined.
+
+### tfl.unique
+
+No TOSA lowering defined.
+
+### tfl.unpack
+
+TODO: TOSA lowering
+
+### tfl.where
+
+No TOSA lowering defined.
+
+### tfl.while
+
+TODO: TOSA lowering
+
+### tfl.yield
+
+This operator is trivially lowered to tosa.YIELD
+
+### tfl.zeros_like
+
+TODO: TOSA lowering
+
+## Common Passes
+
+### make_broadcastable
+
+### Applied to OP
+
+For each of the following of OPs:
+
+```
+tosa.ADD, tosa.SUB, tosa.MUL, tosa.EQUAL, tosa.GREATER, tosa.GREATER_EQUAL
+```
+
+From:
+
+```
+%output = tosa.OP(%input1, %input2) : (tensor<%input1.type>, tensor<%input2.type>) -> tensor<%output.type>
+```
+
+To:
+
+```
+%bcast_input1 = apply_rank_broadcast(%input1, %input2)
+%bcast_input2 = apply_rank_broadcast(%input2, %input1)
+%result = tosa.OP(%bcast_input1, %bcast_input2) : (tensor<%bcast_input1.type>, tensor<%bcast_input2.type>) -> tensor<%output.type>
+```
+
+### constant_folding
+
+#### tosa.CONST + tosa.RESHAPE
+
+From:
+
+```
+%cst = tosa.CONST()
+%transpose = tosa.RESHAPE(%cst)
+```
+
+To:
+
+```
+%result = tosa.CONST()
+```
+
+#### tosa.CONST + tosa.TRANSPOSE
+
+From:
+
+```
+%cst = tosa.CONST()
+%transpose = tosa.TRANSPOSE(%cst)
+```
+
+To:
+
+```
+%result = tosa.CONST()
+```
+
+### convert_tflite_qu8_to_qi8
+
+From:
+
+```
+%cst = tosa.CONST() () -> tensor<%cst.shape, quant<u8>, ...>
+```
+
+From:
+
+```
+%result = tosa.CONST() () -> tensor<%cst.shape, quant<i8>, ...>
+```
diff --git a/tensorflow/compiler/mlir/tosa/tests/BUILD b/tensorflow/compiler/mlir/tosa/tests/BUILD
new file mode 100644
index 0000000..65b3022
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tests/BUILD
@@ -0,0 +1,21 @@
+load("//tensorflow:tensorflow.bzl", "filegroup")
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+
+package(licenses = ["notice"])
+
+glob_lit_tests(
+ data = [":test_utilities"],
+ driver = "@llvm-project//mlir:run_lit.sh",
+ test_file_exts = ["mlir"],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+ name = "test_utilities",
+ testonly = True,
+ data = [
+ "//tensorflow/compiler/mlir:tf-opt",
+ "@llvm-project//llvm:FileCheck",
+ "@llvm-project//llvm:not",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir b/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir
new file mode 100644
index 0000000..1d9101d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir
@@ -0,0 +1,15 @@
+// RUN: tf-opt --tosa-convert-tfl-uint8 --verify-each %s | FileCheck %s
+
+// Operations for testing --tosa-convert-tfl-uint8
+
+// ----
+
+// CHECK-LABEL: test_add_u8
+// CHECK: tosa.rescale
+// CHECK: tosa.rescale
+// CHECK: tfl.add
+// CHECK: tosa.rescale
+func @test_add_u8(%arg0: tensor<14x19x!quant.uniform<u8:f32, 0.015603500418365002:128>>, %arg1: tensor<14x19x!quant.uniform<u8:f32, 0.015612985007464886:127>>) -> tensor<14x19x!quant.uniform<u8:f32, 0.028094837442040443:127>> {
+ %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<14x19x!quant.uniform<u8:f32, 0.015603500418365002:128>>, tensor<14x19x!quant.uniform<u8:f32, 0.015612985007464886:127>>) -> tensor<14x19x!quant.uniform<u8:f32, 0.028094837442040443:127>>
+ return %0 : tensor<14x19x!quant.uniform<u8:f32, 0.028094837442040443:127>>
+}
diff --git a/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir b/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir
new file mode 100644
index 0000000..781e61b
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir
@@ -0,0 +1,16 @@
+// RUN: tf-opt --tosa-fuse-bias-tf --verify-each %s | FileCheck %s
+
+// Operations for testing --tosa-fuse-bias-tf
+
+// ----
+
+// CHECK-LABEL: test_conv2d_bias
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.conv2d
+func @test_conv2d_bias(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<3x3x4x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
+ %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x4x4x4xf32>, tensor<3x3x4x8xf32>) -> tensor<1x4x4x8xf32>
+ %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+
+ return %1 : tensor<1x4x4x8xf32>
+}
diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
new file mode 100644
index 0000000..2c3939d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
@@ -0,0 +1,798 @@
+// RUN: tf-opt --tf-to-tosa-pipeline --verify-each %s | FileCheck %s
+
+// Operations for testing tf-to-tosa-pipeline
+// TODO: These tests are fairly minimal. Expand the checks to be more robust.
+
+// -----
+
+// CHECK-LABEL: test_conv2d
+// CHECK: tosa.const
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.conv2d
+func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1x1x8x16xf32>) -> tensor<1x32x32x16xf32> {
+ %3 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x32x32x8xf32>, tensor<1x1x8x16xf32>) -> tensor<1x32x32x16xf32>
+ return %3 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_depthwise_conv2d
+// CHECK: tosa.const
+// CHECK: tosa.depthwise_conv2d
+func @test_depthwise_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1x1x8x2xf32>) -> tensor<1x32x32x16xf32> {
+ %5 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>, tensor<1x1x8x2xf32>) -> tensor<1x32x32x16xf32>
+ %6 = "tf.Identity"(%5) : (tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
+ return %6 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_transpose_conv2d
+// CHECK-DAG: "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>}
+// CHECK-DAG: "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>}
+// CHECK-DAG: tosa.transpose
+// CHECK: tosa.transpose_conv2d
+func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1x1x16x8xf32>) -> tensor<1x32x32x16xf32> {
+ %3 = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %4 = "tf.Conv2DBackpropInput"(%3, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<1x1x16x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x16xf32>
+ return %4 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add
+// CHECK: tosa.add
+func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Add"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub
+// CHECK: tosa.sub
+func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Sub"(%arg0, %arg1) : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul
+// CHECK: tosa.mul
+func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Mul"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_exp
+// CHECK: tosa.exp
+func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_rcp
+// CHECK: tosa.reciprocal
+func @test_rcp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Reciprocal"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_relu
+// CHECK: tosa.reluN
+func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Relu"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_relu6
+// CHECK: tosa.reluN
+func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Relu6"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_leaky_relu
+func @test_leaky_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.LeakyRelu"(%arg0) {alpha = 0.707330704 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_concat
+// CHECK: tosa.concat
+func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tf.ConcatV2"(%arg0, %arg1, %2) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<i32>) -> tensor<26x21x3xf32>
+ return %3 : tensor<26x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_bitwise_and
+// CHECK: tosa.bitwise_and
+func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
+ %2 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
+ return %2 : tensor<13x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_bitwise_or
+// CHECK: tosa.bitwise_or
+func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
+ %2 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
+ return %2 : tensor<13x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_bitwise_not
+// CHECK: tosa.bitwise_not
+func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> {
+ %2 = "tf.Invert"(%arg0) : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32>
+ return %2 : tensor<13x21x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_bitwise_xor
+// CHECK: tosa.bitwise_xor
+func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ %2 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ return %2 : tensor<13x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_logical_and
+// CHECK: tosa.logical_and
+func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> {
+ %2 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1>
+ return %2 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_logical_or
+// CHECK: tosa.logical_or
+func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ %2 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %2 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_logical_not
+// CHECK: tosa.logical_not
+func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
+ %2 = "tf.LogicalNot"(%arg0) : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1>
+ return %2 : tensor<1x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_any
+// CHECK: tosa.reduce_any
+// CHECK: tosa.reshape
+func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Any"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1>
+ return %3 : tensor<21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_all
+// CHECK: tosa.reduce_all
+// CHECK: tosa.reshape
+func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.All"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1>
+ return %3 : tensor<21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_min
+// CHECK: tosa.reduce_min
+// CHECK: tosa.reshape
+func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Min"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %3 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_max
+// CHECK: tosa.reduce_max
+// CHECK: tosa.reshape
+func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Max"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %3 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_sum
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reshape
+func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Sum"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %3 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_mean
+// CHECK: "tosa.const"() {value = dense<0.0769230798>
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reshape
+// CHECK: tosa.reshape
+// CHECK: tosa.mul
+func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Mean"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %3 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_product
+// CHECK: tosa.reduce_prod
+// CHECK: tosa.reshape
+func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Prod"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %3 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_min
+// CHECK: tosa.minimum
+func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Minimum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_max
+// CHECK: tosa.maximum
+func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Maximum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_pow
+// CHECK: tosa.pow
+func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_abs
+// CHECK: tosa.abs
+func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Abs"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_ceil
+// CHECK: tosa.ceil
+func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Ceil"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_floor
+// CHECK: tosa.floor
+func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Floor"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_log
+// CHECK: tosa.log
+func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate
+// CHECK: tosa.negate
+func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Neg"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_rsqrt
+// CHECK: tosa.rsqrt
+func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Rsqrt"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sigmoid
+// CHECK: tosa.sigmoid
+func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Sigmoid"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_square
+// CHECK: tosa.mul
+func @test_square(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Square"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_equal
+// CHECK: tosa.equal
+func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> {
+ %2 = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1>
+ return %2 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal
+// CHECK: tosa.greater_equal
+func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %2 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %2 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater
+// CHECK: tosa.greater
+func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %2 = "tf.Greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %2 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_less
+// CHECK: tosa.greater_equal
+// CHECK: tosa.logical_not
+func @test_less(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %2 = "tf.Less"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %2 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_less_equal
+// CHECK: tosa.greater
+// CHECK: tosa.logical_not
+func @test_less_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xi1> {
+ %2 = "tf.LessEqual"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xi1>
+ return %2 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_argmax
+// CHECK: tosa.argmax
+func @test_argmax(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xi32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tf.ArgMax"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<i32>) -> tensor<21x3xi32>
+ return %3 : tensor<21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_avg_pool2d
+// CHECK: tosa.avg_pool2d
+func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ %2 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %2 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_max_pool2d
+// CHECK: tosa.max_pool2d
+func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ %2 = "tf.MaxPool"(%arg0) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %2 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape
+// CHECK: tosa.reshape
+func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
+ %0 = "tf.Const"() {value = dense<[1, 819]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %3 = "tf.Reshape"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<1x819xf32>
+ %4 = "tf.Identity"(%3) : (tensor<1x819xf32>) -> tensor<1x819xf32>
+ return %4 : tensor<1x819xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_transpose
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
+ %2 = "tf.Const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %3 = "tf.Transpose"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+ return %3 : tensor<3x13x21xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_slice
+// CHECK: tosa.slice
+func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
+ %2 = "tf.Const"() {value = dense<[6, 8, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
+ %3 = "tf.Const"() {value = dense<[4, 11, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
+ %4 = "tf.Slice"(%arg0, %2, %3) : (tensor<13x21x3xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x11x1xf32>
+ return %4 : tensor<4x11x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_strided_slice
+// CHECK: tosa.slice
+// CHECK: tosa.reshape
+// CHECK: tosa.slice
+// CHECK: tosa.reshape
+func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> {
+ %2 = "tf.Const"() {value = dense<[4, 0, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
+ %3 = "tf.Const"() {value = dense<[13, 21, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
+ %4 = "tf.Const"() {value = dense<[1, 3, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
+ %5 = "tf.StridedSlice"(%arg0, %2, %3, %4) {begin_mask = 2 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<13x21x3xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<9x7x2xf32>
+ return %5 : tensor<9x7x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_select
+// CHECK: tosa.const
+// CHECK: tosa.reshape
+// CHECK: tosa.select
+func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Const"() {value = dense<false> : tensor<1xi1>} : () -> tensor<1xi1>
+ %3 = "tf.SelectV2"(%2, %arg0, %arg1) : (tensor<1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %3 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_addn
+// CHECK: tosa.add
+// CHECK: tosa.add
+// CHECK: tosa.add
+func @test_addn(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_concatv2
+// CHECK: tosa.concat
+// CHECK: tosa.concat
+// CHECK: tosa.concat
+func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<52x21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %arg3, %2) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<i32>) -> tensor<52x21x3xf32>
+ return %3 : tensor<52x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_stack
+// CHECK: tosa.concat
+// CHECK: tosa.concat
+// CHECK: tosa.concat
+// CHECK: tosa.reshape
+func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> {
+ %2 = "tf.Pack"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32>
+ return %2 : tensor<4x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_unstack
+// CHECK: tosa.slice
+// CHECK: tosa.reshape
+// CHECK: tosa.identityn
+func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> {
+ %2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32>
+ %3 = "tf.Identity"(%2) : (tensor<32x32x8xf32>) -> tensor<32x32x8xf32>
+ return %3 : tensor<32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_pad
+// CHECK: tosa.const
+// CHECK: tosa.pad
+func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %3 = "tf.Pad"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+ return %3 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_expand_dims
+// CHECK: tosa.reshape
+func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tf.ExpandDims"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<i32>) -> tensor<1x13x21x3xf32>
+ return %3 : tensor<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_shape
+// CHECK: tosa.const
+func @test_shape() -> tensor<3xi32> {
+ %3 = "tf.Const"() {value = dense<[13, 21, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+ return %3 : tensor<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_rank
+// CHECK: tosa.const
+func @test_rank() -> tensor<i32> {
+ %3 = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+ return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: test_elu
+// CHECK: tosa.const
+// CHECK: tosa.const
+// CHECK: tosa.exp
+// CHECK: tosa.reshape
+// CHECK: tosa.sub
+// CHECK: tosa.reshape
+// CHECK: tosa.greater_equal
+// CHECK: tosa.select
+func @test_elu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Elu"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_softmax
+// CHECK: tosa.exp
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reciprocal
+// CHECK: tosa.mul
+func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Softmax"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_log_softmax
+// CHECK: tosa.exp
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reciprocal
+// CHECK: tosa.mul
+// CHECK: tosa.log
+func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.LogSoftmax"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul
+// CHECK: tosa.matmul
+func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> {
+ %2 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32>
+ return %2 : tensor<14x28xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_scalar
+// CHECK: tosa.const
+// CHECK: tosa.reshape
+// CHECK: tosa.add
+func @test_add_scalar(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+ %3 = "tf.Add"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<f32>) -> tensor<13x21x3xf32>
+ return %3 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_1d
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reshape
+// CHECK: tosa.reshape
+// CHECK: tosa.add
+func @test_add_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %3 = "tf.Sum"(%arg1, %0) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<3xf32>
+ %4 = "tf.Add"(%arg0, %3) : (tensor<13x21x3xf32>, tensor<3xf32>) -> tensor<13x21x3xf32>
+ return %4 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_1d_const
+// CHECK: tosa.add
+func @test_add_1d_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %3 = "tf.Add"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %3 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_split
+// CHECK: tosa.slice
+// CHECK: tosa.slice
+// CHECK: tosa.slice
+// CHECK: tosa.identityn
+func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) {
+ %6 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %7:3 = "tf.Split"(%6, %arg0) : (tensor<i32>, tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>)
+ return %7#0, %7#1, %7#2 : tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_tile
+// CHECK: tosa.tile
+func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
+ %2 = "tf.Const"() {value = dense<[3, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %3 = "tf.Tile"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<39x21x6xf32>
+ %4 = "tf.Identity"(%3) : (tensor<39x21x6xf32>) -> tensor<39x21x6xf32>
+ return %4 : tensor<39x21x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reverse
+// CHECK: tosa.reverse
+func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.ReverseV2"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<13x21x3xf32>
+ return %3 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_gather
+// CHECK: tosa.const
+// CHECK: tosa.gather
+func @test_gather(%arg0: tensor<13x21x3xi32>) -> tensor<26x21x3xi32> {
+ %2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tf.Const"() {value = dense<[2, 2, 7, 6, 6, 1, 5, 4, 2, 11, 10, 11, 7, 7, 5, 3, 12, 7, 11, 0, 9, 5, 4, 12, 1, 9]> : tensor<26xi32>} : () -> tensor<26xi32>
+ %4 = "tf.GatherV2"(%arg0, %3, %2) {batch_dims = 0 : i64} : (tensor<13x21x3xi32>, tensor<26xi32>, tensor<i32>) -> tensor<26x21x3xi32>
+ return %4 : tensor<26x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_space_to_batch
+// CHECK-DAG: "tosa.const"() {value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]>
+// CHECK-DAG: "tosa.const"() {value = dense<[2, 0, 1, 3]>
+// CHECK: tosa.pad
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32> {
+ %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Const"() {value = dense<[[0, 1]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
+ %4 = "tf.SpaceToBatchND"(%arg0, %2, %3) : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<26x11x3xf32>
+ return %4 : tensor<26x11x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_batch_to_space
+// CHECK-DAG: "tosa.const"() {value = dense<[3, 1, 2, 0]>
+// CHECK-DAG: "tosa.const"() {value = dense<[2, 3, 0, 4, 1, 5]>
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+// CHECK: tosa.slice
+func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1xf32> {
+ %2 = "tf.Const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32>
+ %3 = "tf.Const"() {value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+ %4 = "tf.Const"() {value = dense<[3, 1, 2, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %5 = "tf.Transpose"(%arg0, %4) : (tensor<1x32x32x8xf32>, tensor<4xi32>) -> tensor<8x32x32x1xf32>
+ %6 = "tf.BatchToSpaceND"(%5, %2, %3) : (tensor<8x32x32x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<2x64x64x1xf32>
+ return %6 : tensor<2x64x64x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_space_to_depth
+// CHECK: "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]>
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> {
+ %2 = "tf.SpaceToDepth"(%arg0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32>
+ return %2 : tensor<1x16x16x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_depth_to_space
+// CHECK: "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]>
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> {
+ %2 = "tf.DepthToSpace"(%arg0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32>
+ return %2 : tensor<1x64x64x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_with_min_max_args
+// CHECK-DAG: "tosa.const"() {value = dense<16383.75> : tensor<f32>}
+// CHECK-DAG: "tosa.const"() {value = dense<-1.000000e+00> : tensor<f32>}
+// CHECK-DAG: "tosa.const"() {value = dense<6.10360876E-5> : tensor<f32>}
+// CHECK: tosa.reshape
+// CHECK: tosa.mul
+// CHECK: tosa.reshape
+// CHECK: tosa.add
+// CHECK: tosa.cast
+// CHECK: tosa.rescale
+// CHECK: tosa.rescale
+// CHECK: tosa.cast
+// CHECK: tosa.reshape
+// CHECK: tosa.sub
+// CHECK: tosa.reshape
+// CHECK: tosa.mul
+func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %2 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 2.000000e+00 : f32, min = -2.000000e+00 : f32, narrow_range = false, num_bits = 16 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
new file mode 100644
index 0000000..a8cea22
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
@@ -0,0 +1,907 @@
+// RUN: tf-opt --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s
+
+// Operations for testing tfl-to-tosa-pipeline
+
+// TODO: For all fakequant tests: compute and add checks on rescale attribute
+// values
+// TODO: These tests are fairly minimal. Expand the checks to be more robust.
+
+
+// -----
+
+// CHECK-LABEL: test_conv2d
+// CHECK: tosa.const
+// CHECK: tosa.conv2d
+func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> {
+ %cst = constant dense<0.000000e+00> : tensor<16xf32>
+ %0 = "tfl.conv_2d"(%arg0, %cst_0, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_bias
+// CHECK: tosa.conv2d
+func @test_conv2d_bias(%arg0: tensor<1x32x32x8xf32>, %cst: tensor<16x1x1x8xf32>, %cst_0: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+ %0 = "tfl.conv_2d"(%arg0, %cst, %cst_0) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_transpose_conv2d
+// CHECK: tosa.const
+// CHECK: tosa.transpose_conv2d
+func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> {
+ %cst = constant dense<[1, 32, 32, 16]> : tensor<4xi32>
+ %cst_1 = constant unit
+ %0 = "tfl.transpose_conv"(%cst, %cst_0, %arg0, %cst_1) {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<4xi32>, tensor<16x1x1x8xf32>, tensor<1x32x32x8xf32>, none) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_conv2d
+// CHECK: tosa.const
+// CHECK: tosa.const
+// CHECK: tosa.conv2d
+// CHECK: tosa.rescale
+func @test_fakequant_conv2d(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>> {
+ %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, value = dense<"0x851F811ED39B1160E8BFD11A44C8815EC054BEB7658131420857498B9B7FA28499818C7AB44894E64B81C6C350A581E8042F48DB13B85A81EEE481FD28A43BBBC381A70384A46F47811C2A4D64D8D285DEDCE37F1FFC6B5BB0A3794EED7F98D9060BA5ED5EC6A37F7FF4E67364062F078AE9DDDF778155794C54AE536D7FAC05"> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1} >>
+ %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform<i32:f32:0, {2.0,2.0,1.0,1.0,1.0,2.0,2.4,1.7,2.3,2.4,2.4,2.3,2.1,2.4,2.1,2.4}>>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform<i32:f32:0, {2.0,2.0,1.0,1.0,1.0,2.0,2.4,1.7,2.3,2.4,2.4,2.3,2.1,2.4,2.1,2.4} >>
+ %2 = "tfl.conv_2d"(%arg0, %0, %1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, tensor<16x!quant.uniform<i32:f32:0, {2.0,2.0,1.0,1.0,1.0,2.0,2.4,1.7,2.3,2.4,2.4,2.3,2.1,2.4,2.1,2.4} >>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>>
+ return %2 : tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>>
+}
+
+// -----
+
+// TODO: Compute and add checks on rescale attribute values
+
+// CHECK-LABEL: test_fakequant_depthwise_conv2d_bias
+// CHECK-DAG: "tosa.const"() {value = dense<[{{\[}}[{{\[}}-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
+// CHECK-DAG: "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK-DAG: "tosa.const"() {value = dense<[-2879, 6636, 3531, 23376, -79787, -6142, 5582, -30384, 17330, -4549, -3518, 16215, 2695, -2670, 8399, -12223]> : tensor<16xi32>} : () -> tensor<16xi32>
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+// CHECK: tosa.depthwise_conv2d
+// CHECK: tosa.rescale
+func @test_fakequant_depthwise_conv2d_bias(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015678688883781433:-1>>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>> {
+ %0 = "tfl.pseudo_qconst"() {qtype = tensor<1x1x1x16x!quant.uniform<i8<-127:127>:f32:3, {0.1,0.1,0.1,0.1,2.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16x!quant.uniform<i8<-127:127>:f32:3, {0.1,0.1,0.1,0.1,2.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1} >>
+ %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform<i32:f32:0, {9.1E-5,1.9E-4,2.3E-4,4.5E-5,3.6E-6,2.3E-4,2.3E-4,5.6E-5,5.8E-5,1.7E-4,7.1E-5,7.3E-5,2.2E-4,1.5E-4,1.7E-4,7.3E-5}>>, value = dense<[-2879, 6636, 3531, 23376, -79787, -6142, 5582, -30384, 17330, -4549, -3518, 16215, 2695, -2670, 8399, -12223]> : tensor<16xi32>} : () -> tensor<16x!quant.uniform<i32:f32:0, {9.1E-5,1.9E-4,2.3E-4,4.5E-5,3.6E-6,2.3E-4,2.3E-4,5.6E-5,5.8E-5,1.7E-4,7.1E-5,7.3E-5,2.2E-4,1.5E-4,1.7E-4,7.3E-5} >>
+ %2 = "tfl.depthwise_conv_2d"(%arg0, %0, %1) {depth_multiplier = 2 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015678688883781433:-1>>, tensor<1x1x1x16x!quant.uniform<i8<-127:127>:f32:3, {0.1,0.1,0.1,0.1,2.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, tensor<16x!quant.uniform<i32:f32:0, {9.1E-5,1.9E-4,2.3E-4,4.5E-5,3.6E-6,2.3E-4,2.3E-4,5.6E-5,5.8E-5,1.7E-4,7.1E-5,7.3E-5,2.2E-4,1.5E-4,1.7E-4,7.3E-5} >>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>>
+ return %2 : tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>>
+}
+
+// -----
+
+// CHECK-LABEL: test_add
+// CHECK: tosa.add
+func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub
+// CHECK: tosa.sub
+func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.sub"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul
+// CHECK: tosa.mul
+func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_exp
+// CHECK: tosa.exp
+func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_rcp
+// CHECK: tosa.const
+// CHECK: tosa.reciprocal
+// CHECK: tosa.reshape
+// CHECK: tosa.mul
+func @test_rcp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %cst = constant dense<1.000000e+00> : tensor<f32>
+ %0 = "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<f32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_relu
+// CHECK: tosa.reluN
+func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.relu"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_relu6
+// CHECK: tosa.reluN
+func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.relu6"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_leaky_relu
+func @test_leaky_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.leaky_relu"(%arg0) {alpha = 0.707330704 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_concat
+// CHECK: tosa.concat
+func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
+ %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
+ return %0 : tensor<26x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_logical_and
+// CHECK: tosa.logical_and
+func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> {
+ %0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_logical_or
+// CHECK: tosa.logical_or
+func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ %0 = "tfl.logical_or"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_logical_not
+// CHECK: tosa.logical_not
+func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
+ %0 = "tfl.logical_not"(%arg0) : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1>
+ return %0 : tensor<1x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_any
+// CHECK: tosa.reduce_any
+// CHECK: tosa.reshape
+func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
+ %cst = constant dense<0> : tensor<1xi32>
+ %0 = "tfl.reduce_any"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1>
+ return %0 : tensor<21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_min
+// CHECK: tosa.reduce_min
+// CHECK: tosa.reshape
+func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %cst = constant dense<0> : tensor<1xi32>
+ %0 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %0 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_max
+// CHECK: tosa.reduce_max
+// CHECK: tosa.reshape
+func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %cst = constant dense<0> : tensor<1xi32>
+ %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %0 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_sum
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reshape
+func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %cst = constant dense<0> : tensor<1xi32>
+ %0 = "tfl.sum"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %0 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_mean
+// CHECK: "tosa.const"() {value = dense<0.0769230798>
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reshape
+// CHECK: tosa.reshape
+// CHECK: tosa.mul
+func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %cst = constant dense<0> : tensor<1xi32>
+ %0 = "tfl.mean"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %0 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reduce_product
+// CHECK: tosa.reduce_prod
+// CHECK: tosa.reshape
+func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %cst = constant dense<0> : tensor<1xi32>
+ %0 = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32>
+ return %0 : tensor<21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_min
+// CHECK: tosa.minimum
+func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.minimum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_max
+// CHECK: tosa.maximum
+func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.maximum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_pow
+// CHECK: tosa.pow
+func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_abs
+// CHECK: tosa.abs
+func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.abs"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_ceil
+// CHECK: tosa.ceil
+func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.ceil"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_floor
+// CHECK: tosa.floor
+func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.floor"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_log
+// CHECK: tosa.log
+func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate
+// CHECK: tosa.negate
+func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.neg"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_rsqrt
+// CHECK: tosa.rsqrt
+func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.rsqrt"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sigmoid
+// CHECK: tosa.sigmoid
+func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.logistic"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_square
+// CHECK: tosa.mul
+func @test_square(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.square"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_equal
+// CHECK: tosa.equal
+func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = "tfl.equal"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal
+// CHECK: tosa.greater_equal
+func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = "tfl.greater_equal"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater
+// CHECK: tosa.greater
+func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = "tfl.greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_less
+// CHECK: tosa.greater_equal
+// CHECK: tosa.logical_not
+func @test_less(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = "tfl.less"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_less_equal
+// CHECK: tosa.greater
+// CHECK: tosa.logical_not
+func @test_less_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = "tfl.less_equal"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_avg_pool2d
+// CHECK: tosa.avg_pool2d
+func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_max_pool2d
+// CHECK: tosa.max_pool2d
+func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape
+// CHECK: tosa.reshape
+func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
+ %cst = constant dense<[1, 819]> : tensor<2xi32>
+ %0 = "tfl.reshape"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<1x819xf32>
+ return %0 : tensor<1x819xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_transpose
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
+ %cst = constant dense<[2, 0, 1]> : tensor<3xi32>
+ %0 = "tfl.transpose"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+ return %0 : tensor<3x13x21xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_slice
+// CHECK: tosa.slice
+func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
+ %cst = constant dense<[6, 8, 0]> : tensor<3xi32>
+ %cst_0 = constant dense<[4, 11, 1]> : tensor<3xi32>
+ %0 = "tfl.slice"(%arg0, %cst, %cst_0) : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x11x1xf32>
+ return %0 : tensor<4x11x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_strided_slice
+// CHECK: tosa.slice
+// CHECK: tosa.reshape
+// CHECK: tosa.slice
+// CHECK: tosa.reshape
+func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> {
+ %cst = constant dense<[4, 0, 1]> : tensor<3xi32>
+ %cst_0 = constant dense<[13, 21, 3]> : tensor<3xi32>
+ %cst_1 = constant dense<[1, 3, 1]> : tensor<3xi32>
+ %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<9x7x2xf32>
+ return %0 : tensor<9x7x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_select
+// CHECK: tosa.const
+// CHECK: tosa.reshape
+// CHECK: tosa.select
+func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %cst = constant dense<false> : tensor<1xi1>
+ %0 = "tfl.select_v2"(%cst, %arg0, %arg1) : (tensor<1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_addn
+// CHECK: tosa.add
+// CHECK: tosa.add
+// CHECK: tosa.add
+func @test_addn(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.add_n"(%arg0, %arg1, %arg2, %arg3) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_concatv2
+// CHECK: tosa.concat
+// CHECK: tosa.concat
+// CHECK: tosa.concat
+func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<52x21x3xf32> {
+ %0 = "tfl.concatenation"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<52x21x3xf32>
+ return %0 : tensor<52x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_stack
+// CHECK: tosa.concat
+// CHECK: tosa.concat
+// CHECK: tosa.concat
+// CHECK: tosa.reshape
+func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> {
+ %0 = "tfl.pack"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i32, values_count = 4 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32>
+ return %0 : tensor<4x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_unstack
+// CHECK: tosa.slice
+// CHECK: tosa.reshape
+// CHECK: tosa.identityn
+func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> {
+ %0 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32>
+ return %0 : tensor<32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_pad
+// CHECK: tosa.const
+// CHECK: tosa.pad
+func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %cst = constant dense<0> : tensor<3x2xi32>
+ %0 = "tfl.pad"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_expand_dims
+// CHECK: tosa.reshape
+func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> {
+ %cst = constant dense<[1, 13, 21, 3]> : tensor<4xi32>
+ %0 = "tfl.reshape"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<1x13x21x3xf32>
+ return %0 : tensor<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_shape
+// CHECK: tosa.const
+func @test_shape() -> tensor<3xi32> {
+ %cst = constant dense<[13, 21, 3]> : tensor<3xi32>
+ return %cst : tensor<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_rank
+// CHECK: tosa.const
+func @test_rank() -> tensor<i32> {
+ %cst = constant dense<3> : tensor<i32>
+ return %cst : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: test_elu
+// CHECK: tosa.const
+// CHECK: tosa.const
+// CHECK: tosa.exp
+// CHECK: tosa.reshape
+// CHECK: tosa.sub
+// CHECK: tosa.reshape
+// CHECK: tosa.greater_equal
+// CHECK: tosa.select
+func @test_elu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.elu"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_softmax
+// CHECK: tosa.exp
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reciprocal
+// CHECK: tosa.mul
+func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_log_softmax
+// CHECK: tosa.exp
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reciprocal
+// CHECK: tosa.mul
+// CHECK: tosa.log
+func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.log_softmax"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul
+// CHECK-DAG: "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK-DAG: "tosa.const"() {value = dense<0.000000e+00> : tensor<28xf32>} : () -> tensor<28xf32>
+// CHECK: tosa.transpose
+// CHECK: tosa.fully_connected
+func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> {
+ %cst = constant dense<[1, 0]> : tensor<2xi32>
+ %cst_0 = constant unit
+ %0 = "tfl.transpose"(%arg1, %cst) : (tensor<19x28xf32>, tensor<2xi32>) -> tensor<28x19xf32>
+ %1 = "tfl.fully_connected"(%arg0, %0, %cst_0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<14x19xf32>, tensor<28x19xf32>, none) -> tensor<14x28xf32>
+ return %1 : tensor<14x28xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_scalar
+// CHECK: tosa.const
+// CHECK: tosa.reshape
+// CHECK: tosa.add
+func @test_add_scalar(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %cst = constant dense<1.000000e+00> : tensor<f32>
+ %0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<f32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_1d
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.reshape
+// CHECK: tosa.reshape
+// CHECK: tosa.add
+func @test_add_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %cst = constant dense<[0, 1]> : tensor<2xi32>
+ %0 = "tfl.sum"(%arg1, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<3xf32>
+ %1 = "tfl.add"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<3xf32>) -> tensor<13x21x3xf32>
+ return %1 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_1d_const
+// CHECK: tosa.add
+func @test_add_1d_const(%arg0: tensor<13x21x3xf32>, %cst: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_split
+// CHECK: tosa.slice
+// CHECK: tosa.slice
+// CHECK: tosa.slice
+// CHECK: tosa.identityn
+func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) {
+ %cst_0 = constant dense<1> : tensor<i32>
+ %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor<i32>, tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>)
+ return %0#0, %0#1, %0#2 : tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_tile
+// CHECK: tosa.tile
+func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
+ %cst = constant dense<[3, 1, 2]> : tensor<3xi32>
+ %0 = "tfl.tile"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<39x21x6xf32>
+ return %0 : tensor<39x21x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_space_to_batch
+// CHECK-DAG: "tosa.const"() {value = dense<[{{\[}}0, 0], [0, 1], [0, 0]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+// CHECK-DAG: "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK: tosa.pad
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32> {
+ %cst = constant dense<2> : tensor<1xi32>
+ %cst_0 = constant dense<[[0, 1]]> : tensor<1x2xi32>
+ %0 = "tfl.space_to_batch_nd"(%arg0, %cst, %cst_0) : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<26x11x3xf32>
+ return %0 : tensor<26x11x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_batch_to_space
+// CHECK-DAG: "tosa.const"() {value = dense<[3, 1, 2, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK-DAG: "tosa.const"() {value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>} : () -> tensor<6xi32>
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+// CHECK: tosa.slice
+func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1xf32> {
+ %cst = constant dense<2> : tensor<2xi32>
+ %cst_0 = constant dense<0> : tensor<2x2xi32>
+ %cst_1 = constant dense<[3, 1, 2, 0]> : tensor<4xi32>
+ %0 = "tfl.transpose"(%arg0, %cst_1) : (tensor<1x32x32x8xf32>, tensor<4xi32>) -> tensor<8x32x32x1xf32>
+ %1 = "tfl.batch_to_space_nd"(%0, %cst, %cst_0) : (tensor<8x32x32x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<2x64x64x1xf32>
+ return %1 : tensor<2x64x64x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_space_to_depth
+// CHECK: "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> {
+ %0 = "tfl.space_to_depth"(%arg0) {block_size = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32>
+ return %0 : tensor<1x16x16x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_depth_to_space
+// CHECK: "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+// CHECK: tosa.reshape
+func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> {
+ %0 = "tfl.depth_to_space"(%arg0) {block_size = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32>
+ return %0 : tensor<1x64x64x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_with_min_max_args
+// CHECK-DAG: "tosa.const"() {value = dense<16383.75> : tensor<f32>}
+// CHECK-DAG: "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
+// CHECK-DAG: "tosa.const"() {value = dense<6.10360876E-5> : tensor<f32>}
+// CHECK: tosa.reshape
+// CHECK: tosa.mul
+// CHECK: tosa.reshape
+// CHECK: tosa.add
+// CHECK: tosa.cast
+// CHECK: tosa.rescale
+// CHECK: tosa.rescale
+// CHECK: tosa.cast
+// CHECK: tosa.reshape
+// CHECK: tosa.sub
+// CHECK: tosa.reshape
+// CHECK: tosa.mul
+func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tfl.quantize"(%arg0) {qtype = tensor<13x21x3x!quant.uniform<u16:f32, 6.1036087586785687E-5:32768>>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x!quant.uniform<u16:f32, 6.1036087586785687E-5:32768>>
+ %1 = "tfl.dequantize"(%0) : (tensor<13x21x3x!quant.uniform<u16:f32, 6.1036087586785687E-5:32768>>) -> tensor<13x21x3xf32>
+ %2 = "tfl.dequantize"(%0) : (tensor<13x21x3x!quant.uniform<u16:f32, 6.1036087586785687E-5:32768>>) -> tensor<13x21x3xf32>
+ return %2 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_add
+// CHECK: tosa.rescale
+// CHECK: tosa.rescale
+// CHECK: tosa.add
+// CHECK: tosa.rescale
+func @test_fakequant_add(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.01564602367579937:-1>>, %arg1: tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.028171317651867867:-1>> {
+ %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3x!quant.uniform<i8:f32, 0.01564602367579937:-1>>, tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.028171317651867867:-1>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.028171317651867867:-1>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_sub
+// CHECK: tosa.rescale
+// CHECK: tosa.rescale
+// CHECK: tosa.sub
+// CHECK: tosa.rescale
+func @test_fakequant_sub(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.015683440491557121:-1>>, %arg1: tensor<13x21x3x!quant.uniform<i8:f32, 0.015669029206037521>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.028217222541570663:-1>> {
+ %0 = "tfl.sub"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3x!quant.uniform<i8:f32, 0.015683440491557121:-1>>, tensor<13x21x3x!quant.uniform<i8:f32, 0.015669029206037521>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.028217222541570663:-1>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.028217222541570663:-1>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_mul
+// CHECK: tosa.rescale
+// CHECK: tosa.rescale
+// CHECK: tosa.mul
+// CHECK: tosa.rescale
+func @test_fakequant_mul(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.015681236982345581>>, %arg1: tensor<13x21x3x!quant.uniform<i8:f32, 0.015647144988179207:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.0078376950696110725>> {
+ %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3x!quant.uniform<i8:f32, 0.015681236982345581>>, tensor<13x21x3x!quant.uniform<i8:f32, 0.015647144988179207:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.0078376950696110725>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.0078376950696110725>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_avg_pool2d
+// CHECK: tosa.avg_pool2d
+func @test_fakequant_avg_pool2d(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684349462389946:-1>>) -> tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684349462389946:-1>> {
+ %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684349462389946:-1>>) -> tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684349462389946:-1>>
+ return %0 : tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684349462389946:-1>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_max_pool2d
+// CHECK: tosa.max_pool2d
+func @test_fakequant_max_pool2d(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.01568342000246048:-1>>) -> tensor<1x32x32x8x!quant.uniform<i8:f32, 0.01568342000246048:-1>> {
+ %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.01568342000246048:-1>>) -> tensor<1x32x32x8x!quant.uniform<i8:f32, 0.01568342000246048:-1>>
+ return %0 : tensor<1x32x32x8x!quant.uniform<i8:f32, 0.01568342000246048:-1>>
+}
+
+// -----
+
+// TODO: add additional checks on the quantized softmax lowering,
+// as it is one of the most complicated lowerings overall.
+
+// CHECK-LABEL: test_fakequant_softmax
+// CHECK-DAG: "tosa.const"() {value = dense<"{{.*}}"> : tensor<513xi16>} : () -> tensor<513x!quant.uniform<i16:f32, 1.000000e+00>>
+// CHECK-DAG: "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: "tosa.const"() {value = dense<34> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: "tosa.const"() {value = dense<-2147483648> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: "tosa.const"() {value = dense<16> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: "tosa.const"() {value = dense<"{{.*}}"> : tensor<513xi16>} : () -> tensor<513x!quant.uniform<i16:f32, 1.000000e+00>>
+// CHECK: tosa.rescale
+// CHECK: tosa.reduce_max
+// CHECK: tosa.sub
+// CHECK: tosa.rescale
+// CHECK: tosa.table
+// CHECK: tosa.reshape
+// CHECK: tosa.arithmetic_right_shift
+// CHECK: tosa.reduce_sum
+// CHECK: tosa.clz
+// CHECK: tosa.reshape
+// CHECK: tosa.sub
+// CHECK: tosa.logical_left_shift
+// CHECK: tosa.reshape
+// CHECK: tosa.sub
+// CHECK: tosa.reshape
+// CHECK: tosa.arithmetic_right_shift
+// CHECK: tosa.cast
+// CHECK: tosa.table
+// CHECK: tosa.rescale
+// CHECK: tosa.rescale
+// CHECK: tosa.mul
+// CHECK: tosa.arithmetic_right_shift
+// CHECK: tosa.rescale
+func @test_fakequant_softmax(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.0156164625659585>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>> {
+ %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3x!quant.uniform<i8:f32, 0.0156164625659585>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_sigmoid
+// CHECK: tosa.const
+// CHECK: tosa.rescale
+// CHECK: tosa.table
+// CHECK: tosa.rescale
+func @test_fakequant_sigmoid(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.015667613595724106>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>> {
+ %0 = "tfl.logistic"(%arg0) : (tensor<13x21x3x!quant.uniform<i8:f32, 0.015667613595724106>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_tanh
+// CHECK: tosa.const
+// CHECK: tosa.rescale
+// CHECK: tosa.table
+// CHECK: tosa.rescale
+func @test_fakequant_tanh(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.015673128888010979:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 7.812500e-03>> {
+ %0 = "tfl.tanh"(%arg0) : (tensor<13x21x3x!quant.uniform<i8:f32, 0.015673128888010979:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 7.812500e-03>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 7.812500e-03>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_relu
+// CHECK: tosa.rescale
+// CHECK: tosa.reluN
+// CHECK: tosa.rescale
+func @test_fakequant_relu(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.015671534463763237:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015671534463763237:-1>> {
+ %0 = "tfl.relu"(%arg0) : (tensor<13x21x3x!quant.uniform<i8:f32, 0.015671534463763237:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015671534463763237:-1>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015671534463763237:-1>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_relu6
+// CHECK: tosa.rescale
+// CHECK: tosa.reluN
+// CHECK: tosa.rescale
+func @test_fakequant_relu6(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.015639215707778931>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015639215707778931>> {
+ %0 = "tfl.relu6"(%arg0) : (tensor<13x21x3x!quant.uniform<i8:f32, 0.015639215707778931>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015639215707778931>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015639215707778931>>
+}
+
+// -----
+
+// CHECK-LABEL: test_fakequant_leaky_relu
+func @test_fakequant_leaky_relu(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.015563514083623886:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015563514083623886:-1>> {
+ %0 = "tfl.leaky_relu"(%arg0) {alpha = 0.368738383 : f32} : (tensor<13x21x3x!quant.uniform<i8:f32, 0.015563514083623886:-1>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015563514083623886:-1>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015563514083623886:-1>>
+}
diff --git a/tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc b/tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc
new file mode 100644
index 0000000..e8d1aa7
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc
@@ -0,0 +1,29 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
+
+namespace mlir {
+
+namespace tosa {
+
+static mlir::PassPipelineRegistration<TOSALegalizationPipelineOptions>
+ tf_tosa_pipeline("tf-to-tosa-pipeline",
+ "TensorFlow to TOSA legalization pipeline",
+ createTFtoTOSALegalizationPipeline);
+
+} // namespace tosa
+
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc b/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc
new file mode 100644
index 0000000..8552a68
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc
@@ -0,0 +1,29 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
+
+namespace mlir {
+
+namespace tosa {
+
+static mlir::PassPipelineRegistration<TOSALegalizationPipelineOptions>
+ tfl_tosa_pipeline("tfl-to-tosa-pipeline",
+ "TensorFlow Lite to TOSA legalization pipeline",
+ createTFLtoTOSALegalizationPipeline);
+
+} // namespace tosa
+
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tosa_passpipes.cc b/tensorflow/compiler/mlir/tosa/tosa_passpipes.cc
new file mode 100644
index 0000000..d18fd55
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tosa_passpipes.cc
@@ -0,0 +1,75 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+
+namespace tosa {
+
+void addPreOptMlirPasses(mlir::OpPassManager& pm) {
+ // Inline all functions into main and then delete the functions themselves.
+ pm.addPass(mlir::createInlinerPass());
+
+ // Now that there is only one function, run some MLIR passes on it.
+ pm.addPass(mlir::createCanonicalizerPass());
+ pm.addPass(mlir::createCSEPass());
+
+ pm.addPass(mlir::createLoopFusionPass());
+ pm.addPass(mlir::createMemRefDataFlowOptPass());
+}
+
+void addPostOptMlirPasses(mlir::OpPassManager& pm) {
+ pm.addPass(mlir::tosa::createTosaMakeBroadcastablePass());
+ // Inline the call/return basic blocks within TOSA control flow ops.
+ pm.addPass(mlir::createInlinerPass());
+ // Clean up with DCE.
+ pm.addPass(mlir::createSymbolDCEPass());
+}
+
+void createTFtoTOSALegalizationPipeline(
+ OpPassManager& pm, const TOSALegalizationPipelineOptions& opts) {
+ addPreOptMlirPasses(pm);
+
+ pm.addPass(mlir::tosa::createFuseBiasTFPass());
+ pm.addPass(mlir::tosa::createLegalizeTFPass());
+
+ addPostOptMlirPasses(pm);
+}
+
+void createTFLtoTOSALegalizationPipeline(
+ OpPassManager& pm, const TOSALegalizationPipelineOptions& opts) {
+ addPreOptMlirPasses(pm);
+
+ pm.addPass(mlir::tosa::createConvertTFLUint8Pass());
+ pm.addPass(mlir::tosa::createLegalizeTFLPass());
+
+ addPostOptMlirPasses(pm);
+}
+
+} // namespace tosa
+
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tosa_passpipes.h b/tensorflow/compiler/mlir/tosa/tosa_passpipes.h
new file mode 100644
index 0000000..bd77ecf4
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/tosa_passpipes.h
@@ -0,0 +1,43 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H
+#define TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/PassManager.h"
+#include "llvm/ADT/Optional.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
+
+namespace mlir {
+
+namespace tosa {
+
+void addPreOptMlirPasses(mlir::OpPassManager& pm);
+
+void addPostOptMlirPasses(mlir::OpPassManager& pm);
+
+void createTFtoTOSALegalizationPipeline(
+ OpPassManager& pm, const TOSALegalizationPipelineOptions& opts);
+
+void createTFLtoTOSALegalizationPipeline(
+ OpPassManager& pm, const TOSALegalizationPipelineOptions& opts);
+
+} // namespace tosa
+
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H
diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc
new file mode 100644
index 0000000..e6767b0
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc
@@ -0,0 +1,369 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
+// input and output tensors. This is needed because TOSA precision is
+// implemented in the int8 domain. This pass does:
+// 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
+// remapped.
+// 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
+// is uint8 typed.
+// 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
+// typed.
+
+#include <climits>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <numeric>
+
+#include "mlir/Dialect/Quant/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
+
+#define PASS_NAME "tosa-convert-tfl-uint8"
+#define DEBUG_TYPE PASS_NAME
+
+namespace mlir {
+
+namespace tosa {
+
+namespace {
+// Performs lowering to TOSA dialect.
+class ConvertUint8ToInt8
+ : public PassWrapper<ConvertUint8ToInt8, FunctionPass> {
+ public:
+ explicit ConvertUint8ToInt8() {}
+ void runOnFunction() override;
+};
+
+struct ConvertUint8QConstOp : public RewritePattern {
+ explicit ConvertUint8QConstOp(MLIRContext *context)
+ : RewritePattern(TFL::QConstOp::getOperationName(), 1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &builder) const override {
+ auto tfl_qconst_op = cast<TFL::QConstOp>(op);
+
+ // Skip if it's not ranked tensor type.
+ auto output_type =
+ tfl_qconst_op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ if (!output_type)
+ return builder.notifyMatchFailure(op, "not ranked tensor");
+
+ // Skip if output is not per-tensor quantized type.
+ auto output_element_type =
+ output_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ if (!output_element_type) return failure();
+
+ // Skip if output is not uint8.
+ if (output_element_type.isSigned() ||
+ output_element_type.getStorageTypeIntegralWidth() != 8) {
+ return failure();
+ }
+
+ mlir::DenseElementsAttr src_dense_attr =
+ tfl_qconst_op.value().cast<DenseElementsAttr>();
+
+ double type_range_min =
+ static_cast<double>(output_element_type.getStorageTypeMin() -
+ output_element_type.getZeroPoint()) *
+ output_element_type.getScale();
+ double type_range_max =
+ static_cast<double>(output_element_type.getStorageTypeMax() -
+ output_element_type.getZeroPoint()) *
+ output_element_type.getScale();
+ bool narrow_range =
+ output_element_type.getStorageTypeMin() == 1 ? true : false;
+
+ auto dst_qconst_type = TypeAttr::get(RankedTensorType::get(
+ output_type.getShape(),
+ buildQTypeFromMinMax(
+ builder, output_element_type.getExpressedType(),
+ builder.getF64FloatAttr(type_range_min),
+ builder.getF64FloatAttr(type_range_max),
+ builder.getI32IntegerAttr(
+ output_element_type.getStorageTypeIntegralWidth()),
+ 0, true /* signed */, builder.getBoolAttr(narrow_range))));
+
+ Type dst_dense_element_type = builder.getIntegerType(8);
+ llvm::function_ref<APInt(const APInt &)> mapping =
+ [](const APInt &in) -> APInt {
+ int64_t in_i64 = in.getLimitedValue();
+ int64_t out_i64 = in_i64 - 128;
+ return APInt(8, out_i64, true);
+ };
+
+ auto dst_dense_attr =
+ src_dense_attr.mapValues(dst_dense_element_type, mapping);
+
+ builder.replaceOpWithNewOp<TFL::QConstOp>(op, dst_qconst_type,
+ dst_dense_attr);
+
+ return success();
+ }
+};
+
+LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context,
+ mlir::FuncOp &function) {
+ size_t num_blocks_in_main = 0;
+ mlir::Region *region = function.getCallableRegion();
+ OpBuilder builder(&context);
+
+ auto tmp_const_type = RankedTensorType::get({1}, builder.getIntegerType(8));
+ auto tmp_const_attr =
+ DenseElementsAttr::get(tmp_const_type, {static_cast<uint8_t>(0)});
+
+ for (mlir::Block &bb : region->getBlocks()) {
+ // Always have one block for each region right now.
+ num_blocks_in_main++;
+ if (num_blocks_in_main > 1) {
+ return function.emitError("Invalid MLIR: multiple blocks in a region");
+ }
+
+ if (!bb.isEntryBlock()) {
+ return function.emitError("Invalid MLIR: block must be entry block");
+ }
+
+ // Insert rescale uint8->int8 after placeholders.
+ for (Value arg : bb.getArguments()) {
+ auto uint8_type = arg.getType().dyn_cast<mlir::RankedTensorType>();
+ if (!uint8_type) continue;
+
+ auto uint8_element_type =
+ uint8_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ if (!uint8_element_type) continue;
+
+ if (uint8_element_type.isSigned() ||
+ uint8_element_type.getStorageTypeIntegralWidth() != 8)
+ continue;
+
+ double type_range_min =
+ static_cast<double>(uint8_element_type.getStorageTypeMin() -
+ uint8_element_type.getZeroPoint()) *
+ uint8_element_type.getScale();
+ double type_range_max =
+ static_cast<double>(uint8_element_type.getStorageTypeMax() -
+ uint8_element_type.getZeroPoint()) *
+ uint8_element_type.getScale();
+ bool narrow_range =
+ uint8_element_type.getStorageTypeMin() == 1 ? true : false;
+
+ Type int8_type = RankedTensorType::get(
+ uint8_type.getShape(),
+ buildQTypeFromMinMax(
+ builder, uint8_element_type.getExpressedType(),
+ builder.getF64FloatAttr(type_range_min),
+ builder.getF64FloatAttr(type_range_max),
+ builder.getI32IntegerAttr(
+ uint8_element_type.getStorageTypeIntegralWidth()),
+ 0, true /* signed */, builder.getBoolAttr(narrow_range)));
+
+ int32_t uint8_zp = uint8_element_type.getZeroPoint();
+ int32_t int8_zp = uint8_zp - 128;
+
+ // Keep original input_val use with tmp_val.
+ Value tmp_val = builder.create<TFL::ConstOp>(
+ function.getLoc(), tmp_const_type, tmp_const_attr);
+ arg.replaceAllUsesWith(tmp_val);
+ auto rescale_op = builder.create<tosa::RescaleOp>(
+ function.getLoc(), int8_type, arg,
+ builder.getI32IntegerAttr(uint8_zp),
+ builder.getI32IntegerAttr(int8_zp),
+ builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
+ builder.getBoolAttr(true), builder.getBoolAttr(false),
+ builder.getBoolAttr(false));
+
+ Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
+ bb.push_front(op_rescale_op);
+ tmp_val.replaceAllUsesWith(rescale_op.getResult());
+ tmp_val.getDefiningOp()->erase();
+ }
+
+ // Record types of original graph output before we convert intermediate
+ // tensor.
+ auto terminator = bb.getTerminator();
+ SmallVector<Type, 4> output_types;
+ for (Value val : terminator->getOperands()) {
+ output_types.push_back(val.getType());
+ }
+
+ // Convert intermediate tensor.
+ for (auto &op : bb) {
+ for (Value output_val : op.getResults()) {
+ // Skip if output value is not RankedTensorType.
+ auto output_type =
+ output_val.getType().dyn_cast<mlir::RankedTensorType>();
+ if (!output_type) continue;
+
+ // Skip if output value is not per-tensor quantized element type.
+ auto output_element_type =
+ output_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ if (!output_element_type) continue;
+
+ // Skip if output is not uint8.
+ if (output_element_type.isSigned() ||
+ output_element_type.getStorageTypeIntegralWidth() != 8)
+ continue;
+
+ double type_range_min =
+ static_cast<double>(output_element_type.getStorageTypeMin() -
+ output_element_type.getZeroPoint()) *
+ output_element_type.getScale();
+ double type_range_max =
+ static_cast<double>(output_element_type.getStorageTypeMax() -
+ output_element_type.getZeroPoint()) *
+ output_element_type.getScale();
+ bool narrow_range =
+ output_element_type.getStorageTypeMin() == 1 ? true : false;
+
+ Type new_type = RankedTensorType::get(
+ output_type.getShape(),
+ buildQTypeFromMinMax(
+ builder, output_element_type.getExpressedType(),
+ builder.getF64FloatAttr(type_range_min),
+ builder.getF64FloatAttr(type_range_max),
+ builder.getI32IntegerAttr(
+ output_element_type.getStorageTypeIntegralWidth()),
+ 0, true /* signed */, builder.getBoolAttr(narrow_range)));
+
+ output_val.setType(new_type);
+ }
+ }
+
+ if (terminator->getNumOperands() != output_types.size()) {
+ return function.emitError(
+ "Terminator's operand mismatch with number of outputs in graph");
+ }
+
+ // Insert int8->uint8 rescale before all terminator's operand.
+ for (int32_t i = 0; i < terminator->getNumOperands(); i++) {
+ auto defining_op = terminator->getOperand(i).getDefiningOp();
+ // skip if operand of terminator is block arg (nullptr in this case) or
+ // not
+ if (!defining_op) continue;
+ Value input_val = defining_op->getResult(0);
+
+ // Check if graph output is uint8 type.
+ auto uint8_output_type =
+ output_types[i].dyn_cast<mlir::RankedTensorType>();
+ if (!uint8_output_type) continue;
+
+ auto uint8_output_element_type =
+ uint8_output_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ if (!uint8_output_element_type) continue;
+
+ if (uint8_output_element_type.isSigned() ||
+ uint8_output_element_type.getStorageTypeIntegralWidth() != 8)
+ continue;
+
+ // Check if output coming into terminator is int8 type.
+ auto int8_output_type = terminator->getOperand(i)
+ .getType()
+ .dyn_cast<mlir::RankedTensorType>();
+ if (!int8_output_type) continue;
+
+ auto int8_output_element_type =
+ int8_output_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ if (!int8_output_element_type) continue;
+
+ if (!int8_output_element_type.isSigned() ||
+ int8_output_element_type.getStorageTypeIntegralWidth() != 8)
+ continue;
+
+ int32_t int8_zp = int8_output_element_type.getZeroPoint();
+ int32_t uint8_zp = uint8_output_element_type.getZeroPoint();
+
+ // Sanity check if uint8/int8's scale and zeropoint match.
+ if (((uint8_zp - int8_zp) != 128) ||
+ (int8_output_element_type.getScale() !=
+ uint8_output_element_type.getScale())) {
+ return terminator->emitError(
+ "convert_uint8_to_int8: scale mismatch at the output tensors");
+ }
+
+ // Keep original input_val use with tmp_val.
+ Value tmp_val = builder.create<TFL::ConstOp>(
+ function.getLoc(), tmp_const_type, tmp_const_attr);
+ input_val.replaceAllUsesWith(tmp_val);
+ auto rescale_op = builder.create<tosa::RescaleOp>(
+ function.getLoc(), uint8_output_type, input_val,
+ builder.getI32IntegerAttr(int8_zp),
+ builder.getI32IntegerAttr(uint8_zp),
+ builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
+ builder.getBoolAttr(true), builder.getBoolAttr(false),
+ builder.getBoolAttr(false));
+
+ Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
+ bb.push_back(op_rescale_op);
+ op_rescale_op->moveBefore(terminator);
+ tmp_val.replaceAllUsesWith(rescale_op.getResult());
+ tmp_val.getDefiningOp()->erase();
+ }
+ }
+
+ return success();
+}
+
+void ConvertUint8ToInt8::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto &ctx = getContext();
+ auto func = getFunction();
+
+ // Convert uint8 const tensor. const needs to be handled specifically.
+ patterns.insert<ConvertUint8QConstOp>(&ctx);
+ applyPatternsAndFoldGreedily(func, std::move(patterns));
+
+ // Replace uint8 tensor in the graph and insert rescale as needed.
+ convert_graph_uint8_tensor(ctx, func);
+}
+
+} // anonymous namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass() {
+ return std::make_unique<ConvertUint8ToInt8>();
+}
+
+static PassRegistration<ConvertUint8ToInt8> pass(
+ PASS_NAME, "Convert uint8 graph to int8.");
+
+} // namespace tosa
+
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc
new file mode 100644
index 0000000..3bcf1ac
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc
@@ -0,0 +1,151 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Fuse tf.Op + tf.BiasAdd and legalized to TOSA
+
+#include <climits>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <numeric>
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
+
+#define PASS_NAME "tosa-fuse-bias-tf"
+#define DEBUG_TYPE PASS_NAME
+
+// TODO: remove macro when replacing common function return types with
+// llvm::Optional<> Helper macros for checking the return value of a common
+// legalization function that returns a single tensor.
+// Packs the result in a list.
+#define TOSA_REPLACE_LOWERED_OP(REWRITER, OP, LOWERED_OP) \
+ if (LOWERED_OP) { \
+ REWRITER.replaceOp((OP), {(LOWERED_OP)->getResults()}); \
+ return success(); \
+ } else { \
+ return failure(); \
+ }
+
+namespace mlir {
+
+namespace tosa {
+
+namespace {
+
+class FuseBiasTF : public PassWrapper<FuseBiasTF, FunctionPass> {
+ public:
+ explicit FuseBiasTF() {}
+ void runOnFunction() override;
+};
+
+struct ConvertTFBiasAddOp : public RewritePattern {
+ explicit ConvertTFBiasAddOp(MLIRContext* context)
+ : RewritePattern(TF::BiasAddOp::getOperationName(), 1, context) {}
+ LogicalResult matchAndRewrite(Operation* op,
+ PatternRewriter& rewriter) const override;
+};
+
+// Replaces the following pattern:
+// %1 = tf.Conv2D (%ifm, %filter)
+// %2 = tf.BiasAdd(%1, %bias)
+// with
+// %1 = tosa.conv2d(%ifm, %filter, %bias)
+// This can also be done using the pair ot Pat<> options in
+// tf_optimize_patterns.td
+// However, this explicit code can handle both when the LHS or RHS is the
+// defining conv2d op.
+// TODO: support other pattern. e.g. tf.DepthwiseConv2DNative
+
+LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
+
+ auto output_type =
+ tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto value = tf_biasadd_op.value();
+ auto bias = tf_biasadd_op.bias();
+
+ TF::Conv2DOp tf_conv2d_op =
+ dyn_cast_or_null<TF::Conv2DOp>(value.getDefiningOp());
+
+ if (!tf_conv2d_op) {
+ return failure();
+ }
+
+ // Sanity check to confirm rhs() has the expected shape of bias
+ auto filter_shape =
+ tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>().getShape();
+
+ auto bias_shape = bias.getType().dyn_cast<RankedTensorType>().getShape();
+
+ // Bias dimension must match filter output channels, where tf.conv2d's filter
+ // is [H, W, I, O]
+ if (filter_shape.back() != bias_shape.back()) return failure();
+
+ // Bias tensor that feeds into tosa.conv2d must be rank 1
+ if (bias_shape.size() != 1) return failure();
+
+ auto lowered_op = convertTFConv2DCommon(
+ rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
+ bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
+ tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
+ tf_conv2d_op.data_format());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+void FuseBiasTF::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto* ctx = &getContext();
+ auto func = getFunction();
+
+ // Add the generated patterns to the list.
+ patterns.insert<ConvertTFBiasAddOp>(ctx);
+ applyPatternsAndFoldGreedily(func, std::move(patterns));
+}
+
+} // anonymous namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createFuseBiasTFPass() {
+ return std::make_unique<FuseBiasTF>();
+}
+
+static PassRegistration<FuseBiasTF> pass(
+ PASS_NAME, "Fuse tf.Op + tf.BiasAdd and legalized to TOSA.");
+
+} // namespace tosa
+
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
new file mode 100644
index 0000000..fc041dd
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
@@ -0,0 +1,2773 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file contains legalizations common to mapping both TensorFlow and
+// TensorFlow Lite to TOSA.
+//
+// Conversion functions return nullptr on a lowerization failure or a
+// lowered operator on success. Callers must check and return a
+// LogicalResult failure on nullptr. Helper macros are provided in
+// legalize_common.h to canonicalize this handling.
+
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
+
+#include <climits>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <numeric>
+
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
+
+// TODO for further work:
+// * It is better to return an llvm::Optional instead of an Operation*. It
+// enables generic handling of some of the cases a bit better where
+// we are doing different things with the ops.
+
+namespace mlir {
+namespace tosa {
+
+// Lowers the Pack operator to TOSA.
+Operation* convertPackOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, SmallVector<Value, 8>& inputs,
+ int32_t axis) {
+ //////////////////////////////////////////////////
+ // Operator: output = Pack([values], axis) or output = Stack([values], axis)
+ // Lowering:
+ //
+ // This operator is lowered into a series of pairwise tosa.concat()
+ // operators and a reshape
+ // Depending on the inputs, a tranpose operator is also generated:
+ //
+ // Step 1: concatenate the tensors
+ // a1_concat = tosa.concat(input[0], input[1], axis)
+ // for (i = 2; i < len(input); i++)
+ // a1_concat = tosa.concat(a1_concat, input[i], axis)
+ //
+ // Step 2: reshape to N+1 dimensions
+ // a2_reshape = tosa.reshape(a1_concat, new_rank)
+ //
+ // Step 3: Transpose if a new dimension is being added:
+ // if (axis == rank(values[0]):
+ // // perm will be [1, 2, 3, 0]
+ // a3_transpose = tosa.transpose(a2_reshape, perm)
+
+ // Sanity check 1: make sure all input tensors have the same shape
+ // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
+ // shape[A, B, C]
+ auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+
+ // Check for ranked tensor type.
+ if (!result_type) {
+ op->emitOpError("PackOp: result type not ranked tensor");
+ return nullptr;
+ }
+
+ // Valid axis in TF is [-rank(input), rank(input))
+ // Valid axis in TOSA is [0, rank(input))
+ // Plus rank(input) once if axis is negative.
+ auto input_type = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("PackOp: input type not ranked tensor");
+ return nullptr;
+ }
+
+ auto input_rank = input_type.getShape().size();
+ if (axis < 0) axis += input_rank;
+
+ input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("Input 0 type not ranked tensor.");
+ return nullptr;
+ }
+ ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
+ int input_tensor_rank = input0_tensor_shape.size();
+
+ for (int i = 1; i < inputs.size(); i++) {
+ input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError(llvm::formatv(
+ "reduce axis {} is not in valid range [-rank(input), rank(input))",
+ i));
+ return nullptr;
+ }
+ ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
+ if (next_tensor_shape.size() != input_tensor_rank) {
+ op->emitOpError("PackOp: input tensor rank mismatch.");
+ return nullptr;
+ }
+ for (int d = 0; d < input0_tensor_shape.size(); d++) {
+ if (input0_tensor_shape[d] != next_tensor_shape[d]) {
+ op->emitOpError("PackOp: input tensor shape mismatch.");
+ return nullptr;
+ }
+ }
+ }
+
+ // If input tensors are rank 0, should reshape them to rank 1 size 1 before
+ // performing concat.
+ if (input_tensor_rank == 0) {
+ SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
+ auto reshape_rank1_size1_type =
+ RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
+ result_type.getElementType());
+ ArrayAttr shape_rank1_size1_attr =
+ rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
+ for (int i = 0; i < inputs.size(); i++) {
+ auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), reshape_rank1_size1_type, inputs[i],
+ shape_rank1_size1_attr);
+ inputs[i] = a0_reshape_op.getResult();
+ }
+ }
+
+ // Sanity check 2: axis can be from [0, rank(input)+1]
+ // Where rank(input)+1 means create a new dimension
+ // Negative values are also allowed up to -(rank(input)+1)
+ // where the axis "wraps around".
+ if (axis < 0) axis += input_rank;
+
+ if (axis > (input_tensor_rank + 1)) {
+ op->emitOpError("PackOp: axis out of valid range.");
+ return nullptr;
+ }
+
+ // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
+ // A, B, C]
+ // 2.a check output is rank(input) + 1
+ SmallVector<int64_t, 8> output_shape_vals(result_type.getShape().begin(),
+ result_type.getShape().end());
+ if (output_shape_vals.size() != (input_tensor_rank + 1)) {
+ op->emitOpError("PackOp: output tensor rank mismatch.");
+ return nullptr;
+ }
+ // 2.b check output rank 0 is N
+ if (output_shape_vals[axis] != inputs.size()) {
+ op->emitOpError("PackOp: output tensor shape mismatch.");
+ return nullptr;
+ }
+ // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
+ // We can directly concatenate along that axis and perform the reshape.
+ // For example, stack N [A, B, C] input tensor ranks along axis = 1
+ // after concatenation, output will be [A, N * B, C]
+ // and then reshape it into [A, N, B, C]
+ // a special case would be PackOp.axis() equal to rank(input), in which case
+ // we can't directly concatenate along the PackOp.axis(), instead
+ // we concat along axis=0, and reshape into [N, A, B, C]
+ // and then we need an extra transpose to [A, B, C, N].
+ int64_t concat_axis;
+ SmallVector<int32_t, 8> perm;
+ SmallVector<int64_t, 8> reshape_output_shape;
+ if (axis == 0 && input_tensor_rank == 0) {
+ concat_axis = 0;
+ // Don't need reshape and perm, since we inputs are reshaped into rank 1
+ // size 1. Output will be rank 1 size N.
+ } else if (axis == input_tensor_rank) {
+ concat_axis = 0;
+
+ // A special case when stack axis is equal to input tensor rank:
+ // Output shape is [A, B, C, N]
+ // so reshape output will be [N, A, B, C]
+ // and perm will be [1, 2, 3, 0].
+ reshape_output_shape.push_back(output_shape_vals[axis]);
+ for (int d = 0; d < input_tensor_rank; d++) {
+ perm.push_back(d + 1);
+ reshape_output_shape.push_back(output_shape_vals[d]);
+ }
+ perm.push_back(0);
+ } else {
+ // General case, doesn't need perm vector.
+ concat_axis = axis;
+ reshape_output_shape.assign(output_shape_vals.begin(),
+ output_shape_vals.end());
+ }
+ IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis);
+ ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape);
+
+ // For each concat output, shape will be different.
+ // If input shape is [A, B, C] and concat_axis = 0, 1st concat output will
+ // be [2 * A, B, C].
+ int orig_input_dim_on_axis;
+ SmallVector<int64_t, 4> concat_output_shape;
+ if (input_tensor_rank == 0) {
+ concat_output_shape.push_back(1);
+ orig_input_dim_on_axis = 1;
+ } else {
+ for (int i = 0; i < input_tensor_rank; i++) {
+ concat_output_shape.push_back(input0_tensor_shape[i]);
+ }
+ orig_input_dim_on_axis = input0_tensor_shape[concat_axis];
+ }
+
+ concat_output_shape[concat_axis] = orig_input_dim_on_axis * 2;
+ auto concat_type = RankedTensorType::get(
+ ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
+ auto a1_concat_op = rewriter.create<tosa::ConcatOp>(
+ op->getLoc(), concat_type, inputs[0], inputs[1], concat_axis_attr);
+
+ // K-th concat output will be [(k+1) * A, B, C], last output will be [N * A,
+ // B, C].
+ for (int i = 2; i < inputs.size(); i++) {
+ concat_output_shape[concat_axis] = orig_input_dim_on_axis * (i + 1);
+ concat_type = RankedTensorType::get(ArrayRef<int64_t>(concat_output_shape),
+ result_type.getElementType());
+ a1_concat_op = rewriter.create<tosa::ConcatOp>(op->getLoc(), concat_type,
+ a1_concat_op.getResult(),
+ inputs[i], concat_axis_attr);
+ }
+
+ Operation* lowered_op = nullptr;
+ // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
+ // are reshaped beforehand.
+ if (input_tensor_rank == 0) {
+ lowered_op = a1_concat_op;
+ } else {
+ // Reshape [N * A, B, C] to [N, A, B, C].
+ auto reshape_output_type = RankedTensorType::get(
+ ArrayRef<int64_t>(reshape_output_shape), result_type.getElementType());
+
+ auto a2_reshape_op =
+ rewriter.create<tosa::ReshapeOp>(op->getLoc(), reshape_output_type,
+ a1_concat_op.getResult(), shape_attr);
+
+ // If axis is equal to input tensor rank, then we need extra transpose
+ // [N, A, B, C] to [A, B, C, N]
+ if (axis == input_tensor_rank) {
+ auto a3_transpose_perm =
+ get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm);
+ auto a3_transpose_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(), result_type, a2_reshape_op.getResult(),
+ a3_transpose_perm);
+ lowered_op = a3_transpose_op;
+ } else {
+ lowered_op = a2_reshape_op;
+ }
+ }
+
+ return lowered_op;
+}
+
+// Lowers the Unpack operator to TOSA
+Operation* convertUnpackOp(PatternRewriter& rewriter, Operation* op,
+ Value input_value, int32_t axis) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ auto input_shape = input_type.getShape();
+ int64_t input_rank = input_shape.size();
+
+ SmallVector<Value, 4> results_vec;
+
+ // Negative axis allowed as long as it's within [-input_rank, input_rank).
+ if (axis < 0) axis += input_rank;
+
+ assert(axis >= 0 && axis < input_shape.size());
+
+ // A list of the output types for each slice op
+ SmallVector<Type, 4> outs_type_vec;
+
+ // Step 1: transpose 'axis' to leftmost dimension.
+ Value transposed_input_value;
+ if (axis != 0) {
+ SmallVector<int32_t, 8> perm_vec;
+ SmallVector<int64_t, 2> a1_transpose_shape(input_rank);
+
+ perm_vec.push_back(axis);
+ for (int i = 0; i < input_rank; i++) {
+ if (i == axis) continue;
+ perm_vec.push_back(i);
+ }
+
+ auto a1_transpose_perm =
+ get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm_vec);
+
+ for (int i = 0; i < input_rank; i++) {
+ a1_transpose_shape[i] = input_shape[perm_vec[i]];
+ }
+
+ auto a1_transpose_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_shape),
+ input_type.getElementType()),
+ input_value, a1_transpose_perm);
+
+ transposed_input_value = a1_transpose_op.getResult();
+ } else {
+ // Do nothing if axis is already at leftmost dimension.
+ transposed_input_value = input_value;
+ }
+
+ // Step 2: slice [N, A, B, C] into N [A, B, C].
+ auto transposed_input_type =
+ transposed_input_value.getType().dyn_cast<RankedTensorType>();
+ if (!transposed_input_type) return nullptr;
+
+ auto transposed_input_shape = transposed_input_type.getShape();
+ int64_t transposed_input_rank = transposed_input_shape.size();
+
+ for (int i = 0; i < transposed_input_shape[0]; i++) {
+ SmallVector<int64_t, 4> begin_vals, size_vals, shape_vals;
+
+ for (int j = 0; j < transposed_input_rank; j++) {
+ if (j == 0) {
+ begin_vals.push_back(i);
+ size_vals.push_back(1);
+ } else {
+ begin_vals.push_back(0);
+ size_vals.push_back(transposed_input_shape[j]);
+ shape_vals.push_back(transposed_input_shape[j]);
+ }
+ }
+
+ ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
+ ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
+
+ auto a2_slice_op = rewriter.create<tosa::SliceOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(size_vals),
+ transposed_input_type.getElementType()),
+ transposed_input_value, begin, size);
+
+ auto a3_reshape_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(shape_vals),
+ transposed_input_type.getElementType()),
+ a2_slice_op.getResult(), rewriter.getI64ArrayAttr(shape_vals));
+
+ outs_type_vec.push_back(RankedTensorType::get(
+ ArrayRef<int64_t>(shape_vals), transposed_input_type.getElementType()));
+
+ results_vec.push_back(a3_reshape_op.getResult());
+ }
+
+ // Combine the sequence of tosa.slice() ops into a list
+ // using the IdentityN operator.
+ return rewriter.create<tosa::IdentityNOp>(
+ op->getLoc(), ArrayRef<Type>(outs_type_vec), results_vec);
+}
+
+// Lowers the Select operator to TOSA.
+Operation* convertSelectOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value condition_value,
+ Value x_value, Value y_value) {
+ auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+ auto condition_type = condition_value.getType().dyn_cast<RankedTensorType>();
+ auto x_type = x_value.getType().dyn_cast<RankedTensorType>();
+ auto y_type = y_value.getType().dyn_cast<RankedTensorType>();
+
+ Operation* result_op = nullptr;
+
+ if (!result_type || !condition_type || !x_type || !y_type) {
+ op->emitOpError("Select: failed ranked tensor type check");
+ return nullptr;
+ }
+
+ // First check whether we need to reshape the condition to match
+ // the same rank as the then/else clauses.
+ if (result_type.getRank() == condition_type.getRank()) {
+ // Nothing to reshape.
+ result_op = rewriter.create<tosa::SelectOp>(
+ op->getLoc(), result_type, condition_value, x_value, y_value);
+ } else {
+ // Need to reshape the condition.
+ SmallVector<int64_t, 8> new_cond_dims;
+ for (int i = 0; i < (result_type.getRank() - condition_type.getRank());
+ i++) {
+ new_cond_dims.push_back(1);
+ }
+ for (int i = 0; i < condition_type.getRank(); i++) {
+ new_cond_dims.push_back(condition_type.getShape()[i]);
+ }
+
+ auto reshape_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(new_cond_dims),
+ condition_type.getElementType()),
+ condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
+
+ auto new_select = rewriter.create<tosa::SelectOp>(
+ op->getLoc(), result_type, reshape_op, x_value, y_value);
+ result_op = new_select;
+ }
+
+ return result_op;
+}
+
+// Lowers the ZerosLike operator to TOSA by creating a constant
+// of the desired type and shape.
+Operation* convertZerosLikeOp(PatternRewriter& rewriter, Operation* op,
+ Value result, Value input) {
+ auto result_type = result.getType().dyn_cast<RankedTensorType>();
+ if (!result_type) {
+ op->emitOpError("Zeroslike: result not ranked tensor type");
+ return nullptr;
+ }
+
+ auto input_type = input.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("Zeroslike: input not ranked tensor type");
+ return nullptr;
+ }
+
+ auto input_shape = input_type.getShape();
+
+ ShapedType zero_type =
+ RankedTensorType::get(input_shape, input_type.getElementType());
+ Attribute zero_attr = rewriter.getZeroAttr(zero_type);
+
+ return rewriter.create<tosa::ConstOp>(op->getLoc(), zero_type,
+ zero_attr.cast<ElementsAttr>());
+}
+
+// Lowers the Mul operator to TOSA. For quantized types, this requires
+// inserting rescale operators before and after the operation.
+Operation* convertMultiplyOp(PatternRewriter& rewriter, Operation* op,
+ Value output_val, Value input_lhs_val,
+ Value input_rhs_val) {
+ auto input_lhs_type = input_lhs_val.getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type = input_rhs_val.getType().dyn_cast<RankedTensorType>();
+ auto output_type = output_val.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return nullptr;
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ op->emitOpError(
+ "ConvertMultiplyOp: input/output tensor should "
+ "be all quantized or all floating-point");
+ return nullptr;
+ }
+
+ Value output;
+ if (output_is_qtype) {
+ auto rescale_type =
+ RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
+ auto input_lhs_qtype = input_lhs_type.getElementType()
+ .cast<mlir::quant::UniformQuantizedType>();
+ auto input_rhs_qtype = input_rhs_type.getElementType()
+ .cast<mlir::quant::UniformQuantizedType>();
+ auto output_qtype =
+ output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+
+ double in_lhs_scale = input_lhs_qtype.getScale();
+ double in_rhs_scale = input_rhs_qtype.getScale();
+ double output_scale = output_qtype.getScale();
+
+ double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
+
+ auto op1_rescale_lhs = buildRescaleToInt32(
+ rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
+ auto op2_rescale_rhs = buildRescaleToInt32(
+ rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
+ auto op3_mul_op1_op2 = rewriter.create<tosa::MulOp>(
+ op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs, 0);
+ auto op4_rescale_op3 = buildRescaleFromInt32(
+ rewriter, op, output_type, op3_mul_op1_op2.getResult(),
+ output_rescale_scale, output_qtype.getZeroPoint());
+ output = op4_rescale_op3;
+ } else {
+ auto op1_mul_in = rewriter.create<tosa::MulOp>(
+ op->getLoc(), output_type, input_lhs_val, input_rhs_val, 0);
+
+ output = op1_mul_in.getResult();
+ }
+
+ return output.getDefiningOp();
+}
+
+// Lowers the SquaredDifference operator to TOSA.
+Operation* convertSquaredDifferenceOp(PatternRewriter& rewriter, Operation* op,
+ Value result, Value x, Value y) {
+ // Squared-difference is (x-y)*(x-y).
+ // This lowering calculates the difference and multiplies.
+ auto result_type = result.getType().dyn_cast<RankedTensorType>();
+ if (!result_type) {
+ op->emitOpError("SquaredDifference: result not ranked tensor type");
+ return nullptr;
+ }
+
+ auto x_type = x.getType().dyn_cast<RankedTensorType>();
+ auto y_type = y.getType().dyn_cast<RankedTensorType>();
+ if (!x_type || !y_type) {
+ op->emitOpError("SquaredDifference: inputs not ranked tensor type");
+ return nullptr;
+ }
+
+ auto sub_op = rewriter.create<tosa::SubOp>(op->getLoc(), result_type, x, y);
+ return rewriter.create<tosa::MulOp>(
+ op->getLoc(), result_type, sub_op.getResult(), sub_op.getResult(), 0);
+}
+
+// Lowers the Round operator to TOSA.
+Operation* convertRoundOp(PatternRewriter& rewriter, Operation* op,
+ Value result, Value input) {
+ // Implements banker's rounding by calculating floor(input + 0.5).
+ auto result_type = result.getType().dyn_cast<RankedTensorType>();
+ if (!result_type) {
+ op->emitOpError("Round: result not ranked tensor type");
+ return nullptr;
+ }
+
+ auto input_type = input.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("Round: input not ranked tensor type");
+ return nullptr;
+ }
+
+ auto add_op = rewriter.create<tosa::AddOp>(
+ op->getLoc(), result_type, input,
+ getTosaConstTensorSingleF32(rewriter, op, 0.5));
+ return rewriter.create<tosa::FloorOp>(op->getLoc(), result_type,
+ add_op.getResult());
+}
+
+// Lowers ConcatV2 to TOSA.
+Operation* convertConcatV2Op(PatternRewriter& rewriter, Operation* op,
+ Value result_value, SmallVector<Value, 8>& values,
+ int32_t axis) {
+ // ConcatV2 becomes a series of TOSA Concat operators that take pairs of
+ // tensors as arguments. Rank-0 tensors are reshaped to Rank-1,
+ // shape (1,) tensors.
+ auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+ if (!result_type) {
+ op->emitOpError("ConcatV2Op: result type not ranked tensor.");
+ return nullptr;
+ }
+
+ // Valid axis in TF is [-rank(input), rank(input)).
+ // Valid axis in TOSA is [0, rank(input)).
+ // Plus rank(input) once if axis is negative.
+ auto input_type = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("ConcatV2Op: input type not ranked tensor.");
+ return nullptr;
+ }
+
+ auto input_rank = input_type.getShape().size();
+
+ if (axis < 0) axis += input_rank;
+
+ assert(values.size() >= 2);
+
+ if (!values[0].getType().dyn_cast<RankedTensorType>() ||
+ !values[1].getType().dyn_cast<RankedTensorType>()) {
+ op->emitOpError("ConcatV2Op: value type not ranked tensor.");
+ return nullptr;
+ }
+
+ Value lhs_val = values[0];
+ Value rhs_val = values[1];
+ auto lhs_type = lhs_val.getType().cast<RankedTensorType>();
+ auto rhs_type = rhs_val.getType().cast<RankedTensorType>();
+ ArrayRef<int64_t> lhs_tensor_shape = lhs_type.getShape();
+ ArrayRef<int64_t> rhs_tensor_shape = rhs_type.getShape();
+ int input_tensor_rank = lhs_tensor_shape.size();
+
+ // For each concat output, shape will be different.
+ // If input tensors are rank 0, should reshape them to rank 1 size 1 before
+ // performing concat. If not, most dimensions should have same size as input
+ // except the concat'd axis.
+ //
+ // If input is [A0, B, C] and [A1, B, C] and axis = 0
+ // this concat output will be [A0 + A1, B, C].
+ SmallVector<int64_t, 4> concat_result_shape;
+ if (input_tensor_rank == 0) {
+ if (axis != 0) {
+ op->emitOpError("ConcatV2Op: axis invalid.");
+ return nullptr;
+ }
+ SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
+ auto reshape_rank1_size1_type =
+ RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
+ result_type.getElementType());
+ ArrayAttr shape_rank1_size1_attr =
+ rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
+ for (int i = 0; i < values.size(); i++) {
+ auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), reshape_rank1_size1_type, values[i],
+ shape_rank1_size1_attr);
+ values[i] = a0_reshape_op.getResult();
+ }
+ concat_result_shape.push_back(2);
+ } else {
+ if (axis < 0 || axis >= input_tensor_rank) {
+ op->emitOpError("ConcatV2Op: axis invalid.");
+ return nullptr;
+ }
+ for (int i = 0; i < input_tensor_rank; i++) {
+ concat_result_shape.push_back(lhs_tensor_shape[i]);
+ }
+ concat_result_shape[axis] = lhs_tensor_shape[axis] + rhs_tensor_shape[axis];
+ }
+
+ auto concat_type = RankedTensorType::get(
+ ArrayRef<int64_t>(concat_result_shape), result_type.getElementType());
+
+ mlir::quant::UniformQuantizedType lhs_quant_type =
+ lhs_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ mlir::quant::UniformQuantizedType rhs_quant_type =
+ rhs_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ mlir::quant::UniformQuantizedType result_quant_type =
+ result_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+
+ double lhs_scale, rhs_scale, result_scale;
+ int32_t lhs_zeropoint, rhs_zeropoint, result_zeropoint;
+
+ // tfl.concat currently allows different scales for each input tensor, which
+ // TFlite team will fix in:
+ // https://github.com/tensorflow/tensorflow/issues/39658
+ //
+ // For backward compatibility, we still need to support this artifact by
+ // scaling inputs to let them have the same scales.
+ if (result_quant_type && lhs_quant_type && rhs_quant_type) {
+ lhs_scale = static_cast<double>(lhs_quant_type.getScale());
+ lhs_zeropoint = lhs_quant_type.getZeroPoint();
+ rhs_scale = static_cast<double>(rhs_quant_type.getScale());
+ rhs_zeropoint = rhs_quant_type.getZeroPoint();
+ result_scale = static_cast<double>(result_quant_type.getScale());
+ result_zeropoint = result_quant_type.getZeroPoint();
+
+ // Rescale input if scale is not equal to output tensor scale.
+ if (lhs_scale != result_scale) {
+ auto rescale_type =
+ RankedTensorType::get(lhs_type.getShape(), result_quant_type);
+
+ auto rescale_op = buildRescale(rewriter, op, rescale_type, lhs_val,
+ lhs_scale / result_scale, lhs_zeropoint,
+ result_zeropoint);
+
+ lhs_val = rescale_op;
+ }
+ if (rhs_scale != result_scale) {
+ auto rescale_type =
+ RankedTensorType::get(rhs_type.getShape(), result_quant_type);
+
+ auto rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
+ rhs_scale / result_scale, rhs_zeropoint,
+ result_zeropoint);
+
+ rhs_val = rescale_op;
+ }
+ }
+
+ auto concat_op = rewriter.create<tosa::ConcatOp>(
+ op->getLoc(), concat_type, lhs_val, rhs_val,
+ rewriter.getI64IntegerAttr(axis));
+ for (int i = 2; i < values.size(); i++) {
+ rhs_val = values[i];
+ rhs_type = rhs_val.getType().dyn_cast<RankedTensorType>();
+ rhs_tensor_shape = rhs_type.getShape();
+ rhs_quant_type = rhs_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+
+ if (input_tensor_rank == 0) {
+ concat_result_shape[axis] = concat_result_shape[axis] + 1;
+ } else {
+ concat_result_shape[axis] =
+ concat_result_shape[axis] + rhs_tensor_shape[axis];
+ }
+ concat_type = RankedTensorType::get(ArrayRef<int64_t>(concat_result_shape),
+ result_type.getElementType());
+
+ if (rhs_quant_type && result_quant_type) {
+ rhs_scale = static_cast<float>(rhs_quant_type.getScale());
+ rhs_zeropoint = rhs_quant_type.getZeroPoint();
+
+ if (rhs_scale != result_scale) {
+ auto rescale_type =
+ RankedTensorType::get(rhs_type.getShape(), result_quant_type);
+
+ auto rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
+ rhs_scale / result_scale, rhs_zeropoint,
+ result_zeropoint);
+
+ rhs_val = rescale_op;
+ }
+ }
+
+ concat_op = rewriter.create<tosa::ConcatOp>(
+ op->getLoc(), concat_type, concat_op.getResult(), rhs_val,
+ rewriter.getI64IntegerAttr(axis));
+ }
+
+ return concat_op;
+}
+
+// Lowers SpaceToBatchND to TOSA.
+Operation* convertSpaceToBatchNDOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ Value block_shape_value,
+ Value paddings_value) {
+ /////////////////////////////////////////////////
+ // Operator: output = SpaceToBatchND(input, block_shape, paddings)
+ // Lowering:
+ //
+ // SpaceToBatch input tensors are broken into three pieces:
+ // (a) batch dimension (N in NHWC)
+ // (b) input being transformed to batch dimension (typically H, W in NHWC)
+ // (c) remainder of input (typically C in NHWC)
+ //
+ // Step 0. Generate padding constant for the first reshape.
+ // No padding on the batch dimension
+ // The input paddings array is addressed as [input_rank][2]
+ // No padding on the remaining dimensions
+ //
+ // a0_pad_const = tosa.const(input=Tensor<input_rank, 2>)
+ //
+ // Step 1. Pad the input tensor
+ //
+ // a1_pad_input_op = tosa.pad(input=input, shape=a0_pad_const_op)
+ //
+ // Step 2. Reshape the padded structure of shape padded_shape to
+ // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
+ // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
+ // remaining_shape
+ //
+ // block_rank = M (number of elements in block_shape)
+ // New rank: input_rank + block_rank
+ //
+ // a2_reshape_a1_op = tosa.reshape(input=a1_pad_input_op, shape=a2_shape)
+ //
+ // Step 3. Transpose dimensions to:
+ // block-shape +
+ // [batch] +
+ // [padded_shape[1] / block_shape[0],
+ // ...
+ // [padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ //
+ // a3_transpose_a2_op = tosa.tranpose(input=a2_reshape_a1_op,
+ // perms=a3_perm)
+ //
+ // Step 4. Reshape the transposed tensor to flatten block_shape stuff
+ // into the batch dimension with the following shape:
+ // [ batch * prod(block_shape)] +
+ // [ padded_shape[1] / block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ //
+ // a4_reshape_a3_op = tosa.reshape(input=a3_tranpose_a2_op,
+ // shape=a3_shape)
+ //
+
+ auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ auto block_shape_type =
+ block_shape_value.getType().dyn_cast<RankedTensorType>();
+ auto paddings_type = paddings_value.getType().dyn_cast<RankedTensorType>();
+
+ // Not a ranked tensor output.
+ if (!result_type) {
+ op->emitOpError("SpaceToBatchND: result type not ranked tensor");
+ return nullptr;
+ }
+ if (!input_type) {
+ op->emitOpError("SpaceToBatchND: input type not ranked tensor");
+ return nullptr;
+ }
+ if (!block_shape_type) {
+ op->emitOpError("SpaceToBatchND: block shape type not ranked tensor");
+ return nullptr;
+ }
+ if (!paddings_type) {
+ op->emitOpError("SpaceToBatchND: paddings type not ranked tensor");
+ return nullptr;
+ }
+
+ // Follow implementation in
+ // tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+
+ // So, to figure out the spatial_shape, remove the batch dimension and
+ // then use the next block_rank dimensions. The remaining dimensions are
+ // remaining_shape.
+
+ auto block_shape = block_shape_type.getShape();
+ auto input_shape = input_type.getShape();
+
+ int block_rank = block_shape[0];
+ int batch_size = input_shape[0];
+ int input_rank = input_type.getRank();
+ int remaining_shape_rank = input_rank - block_rank - 1;
+ int block_num_elems = 1;
+ int padding_sum = 0;
+
+ ElementsAttr block_shape_elems;
+ ElementsAttr paddings_elems;
+
+ if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems)))
+ return nullptr;
+
+ if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
+ return nullptr;
+
+ SmallVector<int32_t, 2> a0_pad_const(2 * (input_rank));
+ SmallVector<int64_t, 2> padded_shape(input_rank);
+
+ // 1. Pad based on paddings operand. No padding on the batch dimension.
+ // The a0_pad_const array is addressed as [input_rank][2], but
+ // it is flattened to a 1D array because LLVM appears to only accept 1D.
+ //
+ // padded_shape[] is the shape of the padded output of step a1.
+ // The name is retained for consistency with the TF reference code.
+ padded_shape[0] = input_shape[0];
+
+ // Batch dimension padding
+ a0_pad_const[0] = 0;
+ a0_pad_const[1] = 0;
+
+ // This iterator seems to be the only reliable way to get
+ // int values out of a multi-dimensional ElementsAttr.
+ int idx = 0;
+
+ for (auto i : paddings_elems.getValues<IntegerAttr>()) {
+ a0_pad_const[idx + 2] = i.getInt();
+ padding_sum += i.getInt();
+ idx++;
+ }
+
+ // Insert padding on the spatial shape dimensions
+ for (int i = 0; i < block_rank; i++) {
+ int32_t lo_pad = a0_pad_const[2 * (i + 1) + 0];
+ int32_t hi_pad = a0_pad_const[2 * (i + 1) + 1];
+
+ padded_shape[i + 1] = input_shape[i + 1] + lo_pad + hi_pad;
+ }
+
+ // No padding on the remaining_shape dimensions
+ for (int i = 0; i < remaining_shape_rank; i++) {
+ a0_pad_const[2 * (i + block_rank + 1) + 0] = 0;
+ a0_pad_const[2 * (i + block_rank + 1) + 1] = 0;
+ padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
+ }
+
+ auto a0_pad_const_attr_type =
+ RankedTensorType::get({(input_rank), 2}, rewriter.getIntegerType(32));
+
+ // Create a const op to generate the tensor type for the input padding array
+ auto a0_pad_const_op = rewriter.create<tosa::ConstOp>(
+ op->getLoc(), a0_pad_const_attr_type,
+ DenseElementsAttr::get(a0_pad_const_attr_type,
+ llvm::makeArrayRef<int32_t>(a0_pad_const)));
+
+ auto a1_pad_input_op = rewriter.create<tosa::PadOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(padded_shape),
+ result_type.getElementType()),
+ input_value, a0_pad_const_op.getResult());
+
+ // 2. Reshape the padded structure of shape padded_shape to
+ // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
+ // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
+ // remaining_shape
+
+ // block_rank = M (number of elements in block_shape)
+ // New rank: input_rank + block_rank
+ SmallVector<int64_t, 2> a2_shape(1 + block_rank * 2 + remaining_shape_rank);
+
+ // First dimension is batch.
+ a2_shape[0] = input_type.getShape()[0];
+ for (int i = 0; i < block_rank; i++) {
+ int32_t block_shape_val =
+ rewriter
+ .getI32IntegerAttr(
+ block_shape_elems.getValue<IntegerAttr>(i).getInt())
+ .getInt();
+ a2_shape[1 + i * 2 + 0] = padded_shape[1 + i] / block_shape_val;
+ a2_shape[1 + i * 2 + 1] = block_shape_val;
+ block_num_elems *= block_shape_val;
+ }
+
+ // Copy in the remaining block shape.
+ for (int i = 0; i < remaining_shape_rank; i++) {
+ a2_shape[1 + block_rank * 2 + i] = input_shape[1 + block_rank + i];
+ }
+
+ auto a2_reshape_a1_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a2_shape),
+ result_type.getElementType()),
+ a1_pad_input_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
+
+ // 3. Transpose dimensions to:
+ // block-shape +
+ // [batch] +
+ // [padded_shape[1] / block_shape[0],
+ // ...
+ // [padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ int32_t a2_reshape_a1_rank =
+ a2_reshape_a1_op.getResult().getType().cast<RankedTensorType>().getRank();
+ SmallVector<int32_t, 8> a3_perm(a2_reshape_a1_rank);
+ SmallVector<int64_t, 2> a3_transpose_shape(a2_reshape_a1_rank);
+
+ for (int i = 0; i < block_rank; i++) {
+ a3_perm[i] = 1 + 2 * i + 1;
+ a3_perm[block_rank + 1 + i] = 1 + 2 * i;
+ }
+ a3_perm[block_rank] = 0;
+ for (int i = 1 + block_rank * 2; i < a2_reshape_a1_rank; i++) {
+ a3_perm[i] = i;
+ }
+
+ for (int i = 0; i < a3_transpose_shape.size(); i++) {
+ a3_transpose_shape[i] = a2_shape[a3_perm[i]];
+ }
+
+ auto a3_transpose_const =
+ get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, a3_perm);
+
+ auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a3_transpose_shape),
+ result_type.getElementType()),
+ a2_reshape_a1_op.getResult(), a3_transpose_const);
+
+ // 4. Reshape the transposed tensor to flatten block_shape
+ // into the batch dimension with the following shape:
+ // [ batch * prod(block_shape)] +
+ // [ padded_shape[1] / block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ SmallVector<int64_t, 2> a4_reshape_shape(input_rank);
+
+ // Batch
+ a4_reshape_shape[0] = batch_size * block_num_elems;
+
+ // padded shape / block_shape.
+ for (int i = 0; i < block_rank; i++) {
+ int32_t block_shape_val =
+ rewriter
+ .getI32IntegerAttr(
+ block_shape_elems.getValue<IntegerAttr>(i).getInt())
+ .getInt();
+ a4_reshape_shape[i + 1] = padded_shape[i + 1] / block_shape_val;
+ }
+
+ // Copy in remainder shape.
+ for (int i = 0; i < remaining_shape_rank; i++) {
+ a4_reshape_shape[1 + block_rank + i] = input_shape[1 + block_rank + i];
+ }
+
+ return rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), result_type, a3_transpose_a2_op.getResult(),
+ rewriter.getI64ArrayAttr(a4_reshape_shape));
+}
+
+// Lowers BatchToSpaceND to TOSA.
+Operation* convertBatchToSpaceNDOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ Value block_shape_value, Value crops_value) {
+ /////////////////////////////////////////////////
+ // Operator: output = BatchToSpaceND(input, block_shape, clips)
+ // Lowering:
+ //
+ // BatchToSpace input tensors are broken into three pieces:
+ // (a) batch dimension (N in NHWC)
+ // (b) input being transformed from batch dimension (typically H, W in
+ // NHWC)
+ // (c) remainder of input (typically C in NHWC)
+ //
+ // Step 1. Reshape input to:
+ // [block_shape[0],
+ // ...
+ // [block_shape[M-1],
+ // [batch / prod(block_shape)]
+ // [input_shape[1],
+ // ...
+ // [input_shape[N-1]
+ //
+ // a1_reshape_input_op = tosa.reshape(input=input, shape=a1_shape)
+ //
+ // Step 2. Permute to shape
+ // [ batch / prod(block_shape) ],
+ // [ input_shape[1] ], [ block_shape[1] ]
+ // ...
+ // [ input_shape[M] ], [ block_shape[M-1]
+ // + remaining_input_shapes input_shape[M .. N-1]
+ //
+ // a2_transpose_a1 = tosa.transpose(input=a1_reshape_input_op,
+ // shape=a2_shape)
+ //
+ // Step 3. Reshape to:
+ // [ batch / prod(block_shape) ],
+ // [input_shape[1] * block_shape[0] ],
+ // ..
+ // [input_shape[M * block_shape[M-1],
+ // + remaining input shapes [input_shape[M+1.. N-1]]
+ //
+ // a3_reshape_a2 = tosa.reshape(input=a2_transpose_a1, shape=a3_shape)
+ //
+ // Step 4. Crop the start/end dimensions according to crops of the
+ // a3_reshape_a2 shape
+ //
+ // a4_slice_a3 = tosa.slice(input=a3_reshape_a2, start=a4_start,
+ // size=a4_size)
+
+ auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ auto block_shape_type =
+ block_shape_value.getType().dyn_cast<RankedTensorType>();
+ auto crops_type = crops_value.getType().dyn_cast<RankedTensorType>();
+
+ if (!result_type) {
+ op->emitOpError("BatchToSpaceND: result type not ranked tensor");
+ return nullptr;
+ }
+ if (!input_type) {
+ op->emitOpError("BatchToSpaceND: input type not ranked tensor");
+ return nullptr;
+ }
+ if (!block_shape_type) {
+ op->emitOpError("BatchToSpaceND: block shape type not ranked tensor");
+ return nullptr;
+ }
+ if (!crops_type) {
+ op->emitOpError("BatchToSpaceND: crops type not ranked tensor");
+ return nullptr;
+ }
+
+ // Another 4-step process
+ int block_rank = block_shape_type.getShape()[0];
+ int input_rank = input_type.getRank();
+ int crops_dims = crops_type.getShape()[0];
+ int remaining_shape_rank = input_rank - block_rank - 1;
+ auto input_shape = input_type.getShape();
+
+ ElementsAttr block_shape_elems;
+ ElementsAttr crops_elems;
+
+ if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems))) {
+ op->emitOpError("BatchToSpaceND: block_shape not a constant");
+ return nullptr;
+ }
+
+ if (!matchPattern(crops_value, m_Constant(&crops_elems))) {
+ op->emitOpError("BatchToSpaceND: crops not a constant");
+ return nullptr;
+ }
+
+ SmallVector<int64_t, 4> block_shape(block_rank);
+ SmallVector<std::pair<int64_t, int64_t>, 4> crops(crops_dims);
+
+ // Extract values for block_shape and crops now.
+ int block_num_elems = 1;
+ for (int i = 0; i < block_rank; i++) {
+ int block_shape_val =
+ rewriter
+ .getI32IntegerAttr(
+ block_shape_elems.getValue<IntegerAttr>(i).getInt())
+ .getInt();
+ block_num_elems *= block_shape_val;
+ block_shape[i] = block_shape_val;
+ }
+
+ // This iterator seems to be the only reliable way to get
+ // int values out of a multi-dimensional ElementsAttr
+ SmallVector<int32_t, 2> crops_const(2 * (crops_dims));
+ int idx = 0;
+ for (auto i : crops_elems.getValues<IntegerAttr>()) {
+ crops_const[idx++] = i.getInt();
+ }
+
+ for (int i = 0; i < crops_dims; i++) {
+ int crops_lo = crops_const[i * crops_dims + 0];
+ int crops_hi = crops_const[i * crops_dims + 1];
+ crops[i] = std::make_pair(crops_lo, crops_hi);
+ }
+
+ // Step 1. Reshape input to:
+ // [block_shape[0],
+ // ...
+ // [block_shape[M-1],
+ // [batch / prod(block_shape)]
+ // [input_shape[1],
+ // ...
+ // [input_shape[N-1]
+ SmallVector<int64_t, 2> a1_shape(block_rank + input_rank);
+
+ for (int i = 0; i < block_rank; i++) a1_shape[i] = block_shape[i];
+
+ a1_shape[block_rank] = input_shape[0] / block_num_elems;
+
+ for (int i = 0; i < input_rank - 1; i++)
+ a1_shape[i + block_rank + 1] = input_shape[i + 1];
+
+ auto a1_reshape_input_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a1_shape),
+ result_type.getElementType()),
+ input_value, rewriter.getI64ArrayAttr(a1_shape));
+
+ // 2. Permute to shape
+ // [ batch / prod(block_shape) ],
+ // [ input_shape[1] ], [ block_shape[0] ]
+ // ...
+ // [ input_shape[M] ], [ block_shape[M-1]
+ // + remaining_input_shapes input_shape[M+1 .. N-1]
+
+ // 2a. calculate the permutation
+ SmallVector<int32_t, 8> a2_perm(block_rank + input_rank);
+ SmallVector<int64_t, 2> a2_transpose_shape(block_rank + input_rank);
+
+ a2_perm[0] = block_rank;
+ for (int i = 0; i < block_rank; i++) {
+ a2_perm[1 + i * 2 + 0] = block_rank + 1 + i;
+ a2_perm[1 + i * 2 + 1] = i;
+ }
+
+ for (int i = 0; i < remaining_shape_rank; i++) {
+ a2_perm[1 + 2 * block_rank + i] = 1 + 2 * block_rank + i;
+ }
+
+ // 2b. calculate the a2_permuted shape
+ for (int i = 0; i < (block_rank + input_rank); i++) {
+ a2_transpose_shape[i] = a1_shape[a2_perm[i]];
+ }
+
+ auto a2_transpose_perm =
+ get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, a2_perm);
+ auto a2_transpose_a1_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a2_transpose_shape),
+ result_type.getElementType()),
+ a1_reshape_input_op.getResult(), a2_transpose_perm);
+
+ // Step 3. Reshape to:
+ // [ batch / prod(block_shape) ],
+ // [input_shape[1] * block_shape[0] ],
+ // ..
+ // [input_shape[M * block_shape[M-1],
+ // + remaining input shapes [input_shape[M+1.. N-1]]
+ SmallVector<int64_t, 2> a4_shape(input_rank);
+
+ a4_shape[0] = input_shape[0] / block_num_elems;
+ for (int i = 0; i < block_rank; i++) {
+ a4_shape[1 + i] = input_shape[i + 1] * block_shape[i];
+ }
+ for (int i = 0; i < remaining_shape_rank; i++) {
+ a4_shape[1 + block_rank + i] = input_shape[block_rank + 1 + i];
+ }
+
+ auto a3_reshape_a2 = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a4_shape),
+ result_type.getElementType()),
+ a2_transpose_a1_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
+
+ // 4. Crop the start/end dimensions on 'spatial dimension' according to
+ // crops
+ // Use a slice operator to do the cropping.
+ //
+ // Calculate a beginning point and a size:
+ // - Begin is the origin, offset by the lo crop amount in each dimension
+ // - Size is the reshaped tensor size, minus the quantity (lo + hi) for each
+ // dimension
+ SmallVector<int64_t, 4> a4_begin_vals(input_rank), a4_size_vals(input_rank);
+
+ for (int i = 0; i < input_rank; i++) {
+ // Batch dimension and remaining dimensions.
+ if (i == 0 || i > crops_dims) {
+ a4_begin_vals[i] = 0;
+ a4_size_vals[i] = result_type.getShape()[i];
+ } else {
+ // Spatial dimension.
+ assert(i - 1 >= 0 && i - 1 < crops_dims);
+ a4_begin_vals[i] = crops[i - 1].first;
+ a4_size_vals[i] = a4_shape[i] - crops[i - 1].first - crops[i - 1].second;
+ }
+ }
+
+ return rewriter.create<tosa::SliceOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a4_size_vals),
+ result_type.getElementType()),
+ a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
+ rewriter.getI64ArrayAttr(a4_size_vals));
+}
+
+// Lowers ExpandDims to TOSA.
+Operation* convertExpandDimsOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ Value dim_value) {
+ // Lowers to a reshape op with 1's inserted in the appropriate dimensions.
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) {
+ op->emitOpError("ExpandDims: output type not ranked tensor");
+ return nullptr;
+ }
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("ExpandDims: input type not ranked tensor");
+ return nullptr;
+ }
+
+ auto input_shape = input_type.getShape();
+
+ ElementsAttr dim_elem;
+ if (!matchPattern(dim_value, m_Constant(&dim_elem))) return nullptr;
+
+ assert(dim_elem.getType().getRank() == 0 && "expected scalar tensor");
+ int32_t dim = dim_elem.getValue<IntegerAttr>({}).getInt();
+
+ SmallVector<int64_t, 4> reshape_dims;
+ if (dim < 0 || dim >= input_shape.size()) { // add dim at end of tensor
+ dim = input_shape.size();
+ for (int i = 0; i < input_shape.size(); i++) {
+ reshape_dims.emplace_back(input_shape[i]);
+ }
+ reshape_dims.emplace_back(1);
+ } else {
+ for (int i = 0; i < input_shape.size(); i++) {
+ if (i == dim) {
+ reshape_dims.emplace_back(1);
+ }
+ reshape_dims.emplace_back(input_shape[i]);
+ }
+ }
+
+ ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
+
+ return rewriter.create<tosa::ReshapeOp>(op->getLoc(), output_type,
+ input_value, shape_attr);
+}
+
+// Lowers Squeeze to TOSA.
+Operation* convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ SmallVector<int32_t, 8>& squeeze_dims) {
+ // Lowers to a reshape op where dimensions in squeeze_dims with size=1
+ // are removed.
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) {
+ op->emitOpError("Squeeze: output type not ranked tensor");
+ return nullptr;
+ }
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("Squeeze: input type not ranked tensor");
+ return nullptr;
+ }
+
+ auto input_shape = input_type.getShape();
+
+ SmallVector<int64_t, 8> reshape_dims;
+
+ if (squeeze_dims.empty()) { // remove all 1-dims
+ for (int i = 0; i < input_shape.size(); i++) {
+ if (input_shape[i] != 1) {
+ reshape_dims.emplace_back(input_shape[i]);
+ }
+ }
+ } else {
+ // Remove only specified dims.
+ // First sort the array so they can be picked off in sequence.
+ std::sort(squeeze_dims.begin(), squeeze_dims.end(),
+ [](const int32_t& a, const int32_t& b) { return a < b; });
+
+ int pos = 0;
+ auto dim = squeeze_dims[pos];
+ for (int i = 0; i < input_shape.size(); i++) {
+ if (i == dim) {
+ pos = pos + 1;
+ if (pos < squeeze_dims.size())
+ dim = squeeze_dims[pos];
+ else
+ dim = -1; // Invalid
+ } else {
+ reshape_dims.emplace_back(input_shape[i]);
+ }
+ }
+ }
+
+ ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
+
+ return rewriter.create<tosa::ReshapeOp>(op->getLoc(), output_type,
+ input_value, shape_attr);
+}
+
+// Lowers ELU to a sequence of TOSA ops.
+Operation* convertEluOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value features_value) {
+ // Lowers Elu using the following formula:
+ // elu(x) = x < 0 ? (exp(x) - 1) : x
+ // one = const({1});
+ // zero = const({0});
+ // one_bcast = reshape(one, [1, ..., rank(x) - 1])
+ // zero_bcast = reshape(zero, [1, ..., rank(x) - 1])
+ // a1 = exp(x);
+ // a2 = sub(a1, one_bcast)
+ // a3 = ge(x, zero_bcast)
+ // a4 = select(a3, x, a2)
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) {
+ op->emitOpError("Elu: output type not ranked tensor");
+ return nullptr;
+ }
+
+ int32_t input_rank = output_type.getShape().size();
+ SmallVector<int64_t, 4> bcast_shape;
+ for (int i = 0; i < input_rank; i++) {
+ bcast_shape.push_back(1);
+ }
+
+ // Can't directly create size=1, rank=rank(input) tensor because
+ // it will be optimized out. Instead, create rank0 tensor and reshape later.
+ auto one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
+
+ auto zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
+
+ auto a1_exp_in_op =
+ rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, features_value);
+
+ auto a2_sub_a1_one_op = rewriter.create<tosa::SubOp>(
+ op->getLoc(), output_type, a1_exp_in_op.getResult(), one_const_op);
+
+ auto a3_ge_in_zero_op = rewriter.create<tosa::GreaterEqualOp>(
+ op->getLoc(),
+ RankedTensorType::get(output_type.getShape(), rewriter.getIntegerType(1)),
+ features_value, zero_const_op);
+
+ return rewriter.create<tosa::SelectOp>(
+ op->getLoc(), output_type, a3_ge_in_zero_op.getResult(), features_value,
+ a2_sub_a1_one_op.getResult());
+}
+
+// Lowers Softmax to a sequence of TOSA ops.
+Operation* convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value logits_value) {
+ // softmax = exp(logits) / reduce_sum(exp(logits), -1)
+ //
+ // or equivalently multiply exp(-max(logits)) to both numerator and
+ // denominator we get:
+ //
+ // softmax = exp(logits - max(logits)) / reduce_sum(exp(logits -
+ // max(logits)), -1)
+ //
+ // We'll use first version for direct fp lowering, and second version for
+ // quantized lowering since second one we can restrict input to exp() be
+ // negative, and thus LUT can always be within [0.0, 1.0].
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+ auto input_type = logits_value.getType().dyn_cast<RankedTensorType>();
+
+ // Not a ranked tensor input/output
+ if (!output_type || !input_type) {
+ op->emitOpError("Softmax: input and result not ranked tensors");
+ return nullptr;
+ }
+
+ // reduce_sum on last dimension
+ int32_t input_rank = input_type.getShape().size();
+ ArrayRef<int64_t> logits_shape = output_type.getShape();
+
+ if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
+ output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
+ SmallVector<int64_t, 4> rsum_shape_v(input_type.getShape().begin(),
+ input_type.getShape().end() - 1);
+ rsum_shape_v.push_back(1);
+ ArrayRef<int64_t> rsum_shape(rsum_shape_v);
+ // The if condition already checks if these are UQTs
+ mlir::quant::UniformQuantizedType in_quant_type =
+ input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+ mlir::quant::UniformQuantizedType out_quant_type =
+ output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+
+ auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
+ true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
+ -32768, 32767);
+ auto int16_logits_type =
+ RankedTensorType::get(logits_shape, int16_element_qtype);
+ auto int32_logits_type =
+ RankedTensorType::get(logits_shape, rewriter.getIntegerType(32));
+ auto int16_rsum_type =
+ RankedTensorType::get(rsum_shape, int16_element_qtype);
+ auto int32_rsum_type =
+ RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
+
+ // Step 1. get x - max(x)
+ auto op1_rescale_in =
+ buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
+ in_quant_type.getZeroPoint(), 0);
+
+ auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
+ op->getLoc(), int32_rsum_type, op1_rescale_in,
+ rewriter.getI64IntegerAttr(input_rank - 1));
+
+ auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
+ op->getLoc(), int32_logits_type, op1_rescale_in,
+ op2_reducemax_op1.getResult());
+
+ // Table input range from -16.0 to 16.0, input below -16.0 treated as
+ // exp(-16.0), which is 0 in 0.16
+ const double exp_sample_grain = 1.0 / 16.0;
+ auto exp_func = [exp_sample_grain](int32_t x) -> int32_t {
+ double v = static_cast<double>(x) * exp_sample_grain;
+ v = v < 0.0 ? std::exp(v) : 1.0;
+ return std::lround(32768.0 * v);
+ };
+
+ auto exp_table_const = getTosa1DConstTensorTable(rewriter, op, exp_func);
+
+ // Step 2. rescale input
+ auto op4_rescale_op3 = buildRescale(
+ rewriter, op, int16_logits_type, op3_sub_op1_op2.getResult(),
+ in_quant_type.getScale() * 128.0 / exp_sample_grain, 0, 0);
+
+ // Step 3. get exp() result
+ // Since we already make sure input x < 0 in step 1,
+ // we can utilize full output 0.16 range.
+
+ // Output is 0.23
+ auto op5_table_op4 = rewriter.create<tosa::TableOp>(
+ op->getLoc(), int32_logits_type, op4_rescale_op3, exp_table_const);
+
+ // Right shift 3 bits. output 0.20
+ auto op6_rshift_op5 = rewriter.create<tosa::ArithmeticRightShiftOp>(
+ op->getLoc(), int32_logits_type, op5_table_op4.getResult(),
+ getTosaConstTensorSingleI32(rewriter, op, 3), true);
+
+ // Step 4. get sum(exp()). output 12.20
+ auto op7_reducesum_op6 = rewriter.create<tosa::ReduceSumOp>(
+ op->getLoc(), int32_rsum_type, op6_rshift_op5.getResult(),
+ rewriter.getI64IntegerAttr(input_rank - 1));
+
+ // Step 5. calculate reciprocal(sum(exp()))
+ auto op8_clz_op7 = rewriter.create<tosa::ClzOp>(
+ op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult());
+
+ // rshift amount of reciprocal(sum(exp()))
+ // 12 from the integer bits of 12.20 accumulator
+ // 30 from output of multiply 0.15 x 0.15
+ // -8 to keep additional 8 bits before output rescaling
+ auto op9_sub_op8 = rewriter.create<tosa::SubOp>(
+ op->getLoc(), int32_rsum_type,
+ getTosaConstTensorSingleI32(rewriter, op, 12 + 30 - 8),
+ op8_clz_op7.getResult());
+
+ // Left shift to get 1.31 format
+ auto op10_lshift_op7_op8 = rewriter.create<tosa::LogicalLeftShiftOp>(
+ op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult(),
+ op8_clz_op7.getResult());
+
+ // Subtract (1 << 31) to make 0 <= x <= 1
+ auto op11_sub_op10 = rewriter.create<tosa::SubOp>(
+ op->getLoc(), int32_rsum_type, op10_lshift_op7_op8.getResult(),
+ getTosaConstTensorSingleI32(rewriter, op, (1u << 31)));
+
+ // Right shift 16 bits to get 16 bits index
+ auto op12_rshift_op11 = rewriter.create<tosa::ArithmeticRightShiftOp>(
+ op->getLoc(), int32_rsum_type, op11_sub_op10.getResult(),
+ getTosaConstTensorSingleI32(rewriter, op, 16), true);
+
+ // cast to 16 bits to index TABLE op
+ auto op13_cast_op12 = rewriter.create<tosa::CastOp>(
+ op->getLoc(), int16_rsum_type, op12_rshift_op11.getResult());
+
+ // Generate table for 1 / (1 + x), for 0 <= x <= 1
+ const double one_over_one_plus_x_sample_grain = 1.0 / 256.0;
+ auto one_over_one_plus_x_func =
+ [one_over_one_plus_x_sample_grain](int32_t x) -> int32_t {
+ double v = static_cast<double>(x) * one_over_one_plus_x_sample_grain;
+ v = v < 0 ? 1.0 : 1.0 / (1.0 + v);
+ return std::lround(32768.0 * v);
+ };
+
+ auto one_over_one_plus_x_table_const =
+ getTosa1DConstTensorTable(rewriter, op, one_over_one_plus_x_func);
+
+ auto op14_table_op13 = rewriter.create<tosa::TableOp>(
+ op->getLoc(), int32_rsum_type, op13_cast_op12.getResult(),
+ one_over_one_plus_x_table_const);
+
+ // Rescale sum(exp(x)) from 0.23 back to 0.16
+ auto op15_rescale_op14 = buildRescale(rewriter, op, int32_rsum_type,
+ op14_table_op13, 1.0 / 128.0, 0, 0);
+
+ // Rescale exp(x) from 0.23 back to 0.16
+ auto op16_rescale_op5 =
+ buildRescale(rewriter, op, int32_logits_type, op5_table_op4.getResult(),
+ 1.0 / 128.0, 0, 0);
+
+ // Step 6. apply the scales we just get explicitly in i32 space
+ // lhs: 0.16, rhs: 0.16, output: 0.32
+ auto op17_mul_op15_op16 =
+ rewriter.create<tosa::MulOp>(op->getLoc(), int32_logits_type,
+ op15_rescale_op14, op16_rescale_op5, 0);
+
+ // Apply right shift from clz
+ auto op18_rshift_op17_op9 = rewriter.create<tosa::ArithmeticRightShiftOp>(
+ op->getLoc(), int32_logits_type, op17_mul_op15_op16.getResult(),
+ op9_sub_op8.getResult(), true);
+
+ // Step 7. output scaling, extra 1.0 / 256.0 since we keep extra 8 bits
+ // in op9_sub_op8
+ auto op19_rescale_op18 = buildRescale(
+ rewriter, op, output_type, op18_rshift_op17_op9.getResult(),
+ 1.0 / (out_quant_type.getScale() * 256.0), 0,
+ out_quant_type.getZeroPoint());
+
+ return op19_rescale_op18.getDefiningOp();
+
+ } else {
+ SmallVector<int64_t, 4> rsum_shape_v(input_type.getShape().begin(),
+ input_type.getShape().end());
+ rsum_shape_v[input_rank - 1] = 1;
+ ArrayRef<int64_t> rsum_shape(rsum_shape_v);
+
+ // Floating-point loewring is more direct:
+ //
+ // op1 = exp(logits)
+ // op2 = reduce_sum(op1, -1)
+ // op3 = reciprocal(op2)
+ // op4 = mul(op1, op3)
+ auto op1_exp_in =
+ rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
+ auto rsum_type =
+ RankedTensorType::get(rsum_shape, output_type.getElementType());
+
+ // Keep dims so we don't need to reshape later
+ auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
+ op->getLoc(), rsum_type, op1_exp_in.getResult(),
+ rewriter.getI64IntegerAttr(input_rank - 1));
+ auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
+ op->getLoc(), rsum_type, op2_reducesum_op1.getResult());
+
+ return rewriter.create<tosa::MulOp>(op->getLoc(), output_type,
+ op1_exp_in.getResult(),
+ op3_reciprocal_op2.getResult(), 0);
+ }
+}
+
+// Lowers LogSoftmax to a sequence of TOSA ops.
+Operation* convertLogSoftmaxOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value logits_value) {
+ // log_softmax = log(exp(logits) / reduce_sum(exp(logits), -1))
+ // op1 = exp(logits)
+ // op2 = reduce_sum(op1, -1)
+ // op3 = reciprocal(op2)
+ // op4 = mul(op1, op3)
+ // op5 = log(op4)
+
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) {
+ op->emitOpError("LogSoftmax: output type not ranked tensor.");
+ return nullptr;
+ }
+
+ auto input_type = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("LogSoftmax: input type not ranked tensor.");
+ return nullptr;
+ }
+
+ mlir::quant::UniformQuantizedType in_quant_type =
+ input_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ mlir::quant::UniformQuantizedType out_quant_type =
+ output_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ if (in_quant_type || out_quant_type) {
+ op->emitOpError("Quantized log_softmax lowering not implemented yet");
+ return nullptr;
+ }
+
+ auto op1_exp_in =
+ rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
+
+ // reduce_sum on last dimension
+ int32_t input_rank = input_type.getShape().size();
+ SmallVector<int64_t, 4> rsum_shape(output_type.getShape().begin(),
+ output_type.getShape().end());
+ rsum_shape[input_rank - 1] = 1;
+ auto rsum_type = RankedTensorType::get(ArrayRef<int64_t>(rsum_shape),
+ output_type.getElementType());
+ // Keep dims so we don't need to reshape later
+ auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
+ op->getLoc(), rsum_type, op1_exp_in.getResult(),
+ rewriter.getI64IntegerAttr(input_rank - 1));
+ auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
+ op->getLoc(), rsum_type, op2_reducesum_op1.getResult());
+
+ auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
+ op->getLoc(), output_type, op1_exp_in.getResult(),
+ op3_reciprocal_op2.getResult(), 0);
+
+ return rewriter.create<tosa::LogOp>(op->getLoc(), output_type,
+ op4_mul_op1_op3.getResult());
+}
+
+// Lowers SpaceToDepth to a sequence of TOSA ops. Supports NHWC.
+Operation* convertSpaceToDepthOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ IntegerAttr block_size_attr,
+ StringAttr data_format) {
+ // NHWC lowering version:
+ // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1]//b, b, orig_shape[2]//b,
+ // b, orig_shape[3]])
+ // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
+ // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1]//b, orig_shape[2]//b,
+ // orig_shape[3]*b*b])
+ // return a4
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+
+ // Not a ranked tensor output.
+ if (!output_type) {
+ op->emitOpError("SpaceToDepth: output type not ranked tensor.");
+ return nullptr;
+ }
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("SpaceToDepth: input type not ranked tensor.");
+ return nullptr;
+ }
+
+ if (input_type.getRank() != 4) {
+ op->emitOpError("SpaceToDepth: input rank not 4.");
+ return nullptr;
+ }
+
+ auto input_shape = input_type.getShape();
+
+ if (!block_size_attr) { // This is a required parameter
+ op->emitOpError("SpaceToDepth: block size attribute not set.");
+ return nullptr;
+ }
+
+ SmallVector<int64_t, 2> block_size;
+ block_size.assign(2, block_size_attr.getInt());
+
+ if (!data_format) data_format = rewriter.getStringAttr("NHWC");
+
+ if (data_format.getValue().str() != "NHWC") {
+ op->emitOpError("SpaceToDepth: data format not NHWC.");
+ return nullptr;
+ }
+
+ assert(block_size[0] * block_size[1] != 0);
+
+ SmallVector<int64_t, 4> a_reshape_dims;
+ a_reshape_dims.push_back(input_shape[0]);
+ a_reshape_dims.push_back(input_shape[1] / block_size[0]);
+ a_reshape_dims.push_back(block_size[0]);
+ a_reshape_dims.push_back(input_shape[2] / block_size[1]);
+ a_reshape_dims.push_back(block_size[1]);
+ a_reshape_dims.push_back(input_shape[3]);
+
+ auto a_reshape_output_type = RankedTensorType::get(
+ ArrayRef<int64_t>(a_reshape_dims), output_type.getElementType());
+ auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), a_reshape_output_type, input_value,
+ rewriter.getI64ArrayAttr(a_reshape_dims));
+
+ auto a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
+ rewriter, op, {0, 1, 3, 2, 4, 5});
+
+ auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
+ a3_transpose_perm);
+
+ SmallVector<int64_t, 4> a3_reshape_dims;
+ a3_reshape_dims.push_back(input_shape[0]);
+ a3_reshape_dims.push_back(input_shape[1] / block_size[0]);
+ a3_reshape_dims.push_back(input_shape[2] / block_size[1]);
+ a3_reshape_dims.push_back(input_shape[3] * block_size[0] * block_size[1]);
+
+ auto a3_reshape_output_type = RankedTensorType::get(
+ ArrayRef<int64_t>(a3_reshape_dims), output_type.getElementType());
+ return rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), a3_reshape_output_type, a3_transpose_a2_op.getResult(),
+ rewriter.getI64ArrayAttr(a3_reshape_dims));
+}
+
+// Lowers DepthToSpace to a sequence of TOSA ops. Supports NHWC.
+Operation* convertDepthToSpaceOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ IntegerAttr block_size_attr,
+ StringAttr data_format) {
+ // NHWC version
+ // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1], orig_shape[2], b, b,
+ // orig_shape[3] // (b*b)])
+ // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
+ // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1] * b, orig_shape[2] * b,
+ // orig_shape[3] // (b*b)])
+ // return a4
+
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+
+ // Not a ranked tensor output
+ if (!output_type) {
+ op->emitOpError("DepthToSpace: output type not ranked tensor.");
+ return nullptr;
+ }
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("DepthToSpace: input type not ranked tensor.");
+ return nullptr;
+ }
+
+ if (input_type.getRank() != 4) return nullptr;
+ auto input_shape = input_type.getShape();
+
+ if (!block_size_attr) { // This is a required parameter
+ op->emitOpError("DepthToSpace: block size attribute not set.");
+ return nullptr;
+ }
+
+ SmallVector<int64_t, 2> block_size;
+ block_size.assign(2, block_size_attr.getInt());
+
+ if (!data_format) data_format = rewriter.getStringAttr("NHWC");
+ if (data_format.getValue().str() != "NHWC") {
+ op->emitOpError("DepthToSpace: data format not NHWC.");
+ return nullptr;
+ }
+
+ assert(block_size[0] * block_size[1] != 0);
+
+ SmallVector<int64_t, 4> a_reshape_dims;
+ a_reshape_dims.push_back(input_shape[0]);
+ a_reshape_dims.push_back(input_shape[1]);
+ a_reshape_dims.push_back(input_shape[2]);
+ a_reshape_dims.push_back(block_size[0]);
+ a_reshape_dims.push_back(block_size[1]);
+ a_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
+
+ auto a_reshape_output_type = RankedTensorType::get(
+ ArrayRef<int64_t>(a_reshape_dims), output_type.getElementType());
+ auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), a_reshape_output_type, input_value,
+ rewriter.getI64ArrayAttr(a_reshape_dims));
+
+ auto a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
+ rewriter, op, {0, 1, 3, 2, 4, 5});
+
+ auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
+ a3_transpose_perm);
+
+ SmallVector<int64_t, 4> a3_reshape_dims;
+ a3_reshape_dims.push_back(input_shape[0]);
+ a3_reshape_dims.push_back(input_shape[1] * block_size[0]);
+ a3_reshape_dims.push_back(input_shape[2] * block_size[1]);
+ a3_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
+
+ auto a3_reshape_output_type = RankedTensorType::get(
+ ArrayRef<int64_t>(a3_reshape_dims), output_type.getElementType());
+ return rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), a3_reshape_output_type, a3_transpose_a2_op.getResult(),
+ rewriter.getI64ArrayAttr(a3_reshape_dims));
+}
+
+// Lowers Split to a sequence of TOSA ops.
+Operation* convertSplitOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ int32_t num_split, int32_t axis) {
+ // This lowering creates num_split slice ops and ties them together
+ // with IdentityN to get from an array of Operations to a single Operation
+ // with a list of result tensors.
+ auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!result_type) {
+ op->emitOpError("Split: output type not ranked tensor.");
+ return nullptr;
+ }
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("Split: input type not ranked tensor.");
+ return nullptr;
+ }
+
+ auto input_shape = input_type.getShape();
+
+ SmallVector<Value, 4> results_vec;
+
+ assert(axis > 0 && axis < input_shape.size());
+ assert((input_shape[axis] % num_split) == 0);
+ assert(num_split > 0);
+
+ int64_t slice_size = input_shape[axis] / num_split;
+
+ SmallVector<Type, 4>
+ outs_type_vec; // A list of the output types for each slice op
+
+ for (int i = 0; i < num_split; i++) {
+ // Each slice has a different begining point.
+ // The slice size is actually the same each op.
+ SmallVector<int64_t, 4> begin_vals, size_vals;
+
+ for (int j = 0; j < input_shape.size(); j++) {
+ if (j == axis) {
+ begin_vals.push_back(slice_size * i);
+ size_vals.push_back(slice_size);
+ } else {
+ begin_vals.push_back(0);
+ size_vals.push_back(input_shape[j]);
+ }
+ }
+
+ ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
+ ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
+
+ outs_type_vec.push_back(RankedTensorType::get(
+ ArrayRef<int64_t>(size_vals), result_type.getElementType()));
+
+ auto slice_op = rewriter.create<tosa::SliceOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(size_vals),
+ result_type.getElementType()),
+ input_value, begin, size);
+
+ results_vec.push_back(slice_op.getResult());
+ }
+
+ // Combine the sequence of tosa.slice() ops into a list
+ // using the IdentityN operator
+ return rewriter.create<tosa::IdentityNOp>(
+ op->getLoc(), ArrayRef<Type>(outs_type_vec), results_vec);
+}
+
+// Lowers SplitV to a sequence of TOSA ops.
+Operation* convertSplitVOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ SmallVector<int32_t, 4>& size_split, int32_t axis) {
+ // This lowering creates num_split slice ops and ties them together
+ // with IdentityN to get from an array of Operations to a single Operation
+ // with a list of result tensors.
+ auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!result_type) {
+ op->emitOpError("SplitV: output type not ranked tensor.");
+ return nullptr;
+ }
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ op->emitOpError("SplitV: input type not ranked tensor.");
+ return nullptr;
+ }
+
+ auto input_shape = input_type.getShape();
+
+ SmallVector<Value, 4> results_vec;
+
+ assert(axis > 0 && axis < input_shape.size());
+ int32_t size_split_sum = 0;
+ for (int i = 0; i < size_split.size(); i++) {
+ size_split_sum += size_split[i];
+ }
+
+ // The split sizes must sum up to the size of the axis being split
+ assert(size_split_sum == input_shape[axis]);
+
+ // Create num_split slice ops:
+ SmallVector<Type, 4>
+ outs_type_vec; // A list of the output types for each slice op
+
+ int32_t curr_split_start = 0;
+ for (int i = 0; i < size_split.size(); i++) {
+ // Each slice has a different begining point.
+ // The slice size is different for each op.
+ SmallVector<int64_t, 4> begin_vals, size_vals;
+
+ for (int j = 0; j < input_shape.size(); j++) {
+ if (j == axis) {
+ begin_vals.push_back(curr_split_start);
+ size_vals.push_back(size_split[i]);
+ } else {
+ begin_vals.push_back(0);
+ size_vals.push_back(input_shape[j]);
+ }
+ }
+
+ ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
+ ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
+
+ outs_type_vec.push_back(RankedTensorType::get(
+ ArrayRef<int64_t>(size_vals), result_type.getElementType()));
+
+ auto slice_op = rewriter.create<tosa::SliceOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(size_vals),
+ result_type.getElementType()),
+ input_value, begin, size);
+
+ results_vec.push_back(slice_op.getResult());
+
+ // Next start position
+ curr_split_start += size_split[i];
+ }
+
+ // Combine the sequence of tosa.slice() ops into a list
+ // using the IdentityN operator
+ return rewriter.create<tosa::IdentityNOp>(
+ op->getLoc(), ArrayRef<Type>(outs_type_vec), results_vec);
+}
+
+// Lowers StridedSlice to a sequence of TOSA ops.
+Operation* convertStridedSliceOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ Value begin_value, Value end_value,
+ Value strides_value, int32_t begin_mask,
+ int32_t end_mask, int32_t ellipsis_mask,
+ int32_t new_axis_mask,
+ int32_t shrink_axis_mask) {
+ // The mask arguments are bitmasks where bit [i] applies to
+ // dimension [i] of the input tensor.
+ //
+ // The rough algorithm for lowering strided slice is as follows:
+ //
+ // 0. Process begin/end masks, since they are basically syntactic sugar
+ // on top of the begin_value/end_value arrays
+ //
+ // 1. Slice1: Ignoring stride, slice the interesting range from the input
+ // tensor
+ //
+ // 2. Reshape2: Reshape the tensor from (1) such that each dimension with
+ // stride is split into two dimensions of size_i/stride_i, stride_i. A naive
+ // implementation doubles the input tensor rank, but only dimensions being
+ // strided actually need to be doubled.
+ //
+ // 3. Slice3: Slice the tensor from (2) such that we select index [0] from
+ // each of the stride_i dimensions in (2)
+ //
+ // 4. Reshape4: Reshape the tensor to eliminate the stride_i dimensions, add
+ // any dimensions in new_axis_mask and remove any dimensions in the
+ // shrink_axis_mask
+
+ // Limitations:
+ // This implementation only supports ellipsis_mask=0 for now
+ // This implementation does not support reverse stride yet. Will need
+ // to insert tosa.Reverse operators for this.
+ assert(ellipsis_mask == 0);
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
+
+ if (!result_type) {
+ op->emitOpError("StridedSlice: output type not ranked tensor.");
+ return nullptr;
+ }
+
+ if (!input_type) {
+ op->emitOpError("StridedSlice: input type not ranked tensor.");
+ return nullptr;
+ }
+
+ int32_t input_rank = input_type.getRank();
+ auto input_shape = input_type.getShape();
+
+ // Extract the begin/end/stride tensors
+ SmallVector<int32_t, 4> begin, end, strides;
+
+ if (getVectorFromValue32(begin_value, begin) != input_rank) {
+ op->emitOpError("StridedSlice: begin doesn't match input_rank.");
+ return nullptr;
+ }
+ if (getVectorFromValue32(end_value, end) != input_rank) {
+ op->emitOpError("StridedSlice: end doesn't match input_rank.");
+ return nullptr;
+ }
+ if (getVectorFromValue32(strides_value, strides) != input_rank) {
+ op->emitOpError("StridedSlice: strides doesn't match input_rank.");
+ return nullptr;
+ }
+
+ SmallVector<int64_t, 2> a1_begin(input_rank), a1_size(input_rank);
+ SmallVector<int64_t, 2> a2_shape(input_rank * 2);
+ SmallVector<int64_t, 2> a3_begin(input_rank * 2), a3_size(input_rank * 2);
+ SmallVector<int64_t, 2> a4_shape;
+
+ // Step 0: Process the begin/end masks and build the begin/sizes for the
+ // first slice
+ int residual = 1;
+ (void)residual;
+ for (int i = 0; i < input_rank; i++) {
+ if (begin_mask & (1 << i)) begin[i] = 0;
+
+ if (end_mask & (1 << i)) end[i] = input_shape[i];
+
+ // Wrap around index if begin and end is negative
+ if (begin[i] < 0) begin[i] += input_shape[i];
+
+ if (end[i] < 0) end[i] += input_shape[i];
+
+ // TODO: support reverse stride
+ a1_begin[i] = begin[i];
+ a1_size[i] = end[i] - begin[i];
+
+ a2_shape[i * 2 + 0] = a1_size[i] / strides[i];
+ a2_shape[i * 2 + 1] = strides[i];
+
+ a3_begin[i * 2 + 0] = 0;
+ a3_begin[i * 2 + 1] = 0;
+
+ if (shrink_axis_mask & (1 << i)) {
+ a3_size[i * 2 + 0] = 1;
+ } else {
+ a3_size[i * 2 + 0] = a1_size[i] / strides[i];
+ }
+ a3_size[i * 2 + 1] = 1;
+
+ if (!(shrink_axis_mask & (1 << i))) {
+ if (new_axis_mask & (1 << i)) a4_shape.push_back(1);
+ a4_shape.push_back((a1_size[i] / strides[i]));
+ }
+ }
+
+ // Make sure we didn't lose any dimensions from the shrink_axis_mask
+ assert(residual == 1);
+
+ // Step 1: Slice the input array
+ auto a1_slice_op = rewriter.create<tosa::SliceOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a1_size),
+ input_type.getElementType()),
+ input_value, rewriter.getI64ArrayAttr(a1_begin),
+ rewriter.getI64ArrayAttr(a1_size));
+
+ // Step 2: reshape the sliced array
+ auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a2_shape),
+ input_type.getElementType()),
+ a1_slice_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
+
+ // Step 3: take a slice along the strides
+ auto a3_slice_op = rewriter.create<tosa::SliceOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a3_size),
+ input_type.getElementType()),
+ a2_reshape_op.getResult(), rewriter.getI64ArrayAttr(a3_begin),
+ rewriter.getI64ArrayAttr(a3_size));
+
+ // Step 4: reshape the now-strided tensor
+ return rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a4_shape),
+ input_type.getElementType()),
+ a3_slice_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
+}
+
+// Lowers FloorDiv to a sequence of TOSA operators.
+Operation* convertFloorDivOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value lhs_value,
+ Value rhs_value) {
+ // FloorDiv lowering:
+ // floor(1/rhs * lhs)
+ //
+ // a1 = reciprocal(rhs);
+ // a2 = mul(lhs, a1);
+ // a3 = floor(a2);
+ // return a3;
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return nullptr;
+
+ auto a1_reciprocal_rhs_op =
+ rewriter.create<tosa::ReciprocalOp>(op->getLoc(), output_type, rhs_value);
+ auto a2_mul_lhs_a1_op =
+ rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
+ a1_reciprocal_rhs_op.getResult(), 0);
+ return rewriter.create<tosa::FloorOp>(op->getLoc(), output_type,
+ a2_mul_lhs_a1_op.getResult());
+}
+
+// Lowers FloorMod to a sequence of TOSA operators.
+Operation* convertFloorModOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value lhs_value,
+ Value rhs_value) {
+ // FloorMod lowering:
+ // (1/rhs * lhs) - floor(1/rhs * lhs)
+ // a1 = reciprocal(rhs);
+ // a2 = mul(lhs, a1);
+ // a3 = floor(a2);
+ // a4 = sub(a2, a3);
+ // return a4;
+
+ auto output_type = result_value.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return nullptr;
+
+ auto a1_reciprocal_rhs_op =
+ rewriter.create<tosa::ReciprocalOp>(op->getLoc(), output_type, rhs_value);
+ auto a2_mul_lhs_a1_op =
+ rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
+ a1_reciprocal_rhs_op.getResult(), 0);
+ auto a3_floor_a2_op = rewriter.create<tosa::FloorOp>(
+ op->getLoc(), output_type, a2_mul_lhs_a1_op.getResult());
+ return rewriter.create<tosa::SubOp>(op->getLoc(), output_type,
+ a2_mul_lhs_a1_op.getResult(),
+ a3_floor_a2_op.getResult());
+}
+
+// Lowers FusedActivation to a sequence of TOSA ops.
+Operation* convertFusedActivation(PatternRewriter& rewriter, Operation* op,
+ Value input_value,
+ StringAttr fused_activation_fn) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype) {
+ auto input_qtype =
+ input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+
+ if (fused_activation_fn.getValue() == "TANH") {
+ // TODO: implement with TABLE
+ op->emitWarning("Quantized TANH lowering TBD!");
+ return nullptr;
+ } else {
+ auto rescale_type = RankedTensorType::get(input_type.getShape(),
+ rewriter.getIntegerType(32));
+
+ auto op1_rescale_in = buildRescaleToInt32(rewriter, op, input_value, 1.0f,
+ input_qtype.getZeroPoint());
+
+ Value op2_relu_op1;
+ if (fused_activation_fn.getValue() == "RELU") {
+ auto relu_op = rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), rescale_type, op1_rescale_in,
+ rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
+ rewriter.getF32FloatAttr(0));
+
+ op2_relu_op1 = relu_op.getResult();
+
+ } else if (fused_activation_fn.getValue() == "RELU6") {
+ int64_t rescaled_6 = std::llround(6.0f / input_qtype.getScale()) +
+ input_qtype.getZeroPoint();
+
+ auto relu_op = rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), rescale_type, op1_rescale_in,
+ rewriter.getI64IntegerAttr(rescaled_6),
+ rewriter.getF32FloatAttr(0.0f));
+
+ op2_relu_op1 = relu_op.getResult();
+
+ } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
+ int64_t rescaled_n1 = std::llround(-1.0f / input_qtype.getScale()) +
+ input_qtype.getZeroPoint();
+ int64_t rescaled_1 = std::llround(1.0f / input_qtype.getScale()) +
+ input_qtype.getZeroPoint();
+
+ auto relu_op = rewriter.create<tosa::ClampOp>(
+ op->getLoc(), rescale_type, op1_rescale_in,
+ rewriter.getI64IntegerAttr(rescaled_n1),
+ rewriter.getI64IntegerAttr(rescaled_1),
+ rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(0.0f));
+
+ op2_relu_op1 = relu_op.getResult();
+ } else {
+ return nullptr;
+ }
+
+ auto op3_rescale_op2 =
+ buildRescaleFromInt32(rewriter, op, input_type, op2_relu_op1, 1.0f,
+ input_qtype.getZeroPoint());
+
+ return op3_rescale_op2.getDefiningOp();
+ }
+ } else {
+ if (fused_activation_fn.getValue() == "RELU") {
+ return rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), input_type, input_value,
+ rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
+ rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
+ } else if (fused_activation_fn.getValue() == "RELU6") {
+ return rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), input_type, input_value, rewriter.getI64IntegerAttr(6),
+ rewriter.getF32FloatAttr(6.0));
+ } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
+ return rewriter.create<tosa::ClampOp>(
+ op->getLoc(), input_type, input_value, rewriter.getI64IntegerAttr(-1),
+ rewriter.getI64IntegerAttr(1), rewriter.getF32FloatAttr(-1.0),
+ rewriter.getF32FloatAttr(1.0));
+ } else if (fused_activation_fn.getValue() == "TANH") {
+ return rewriter.create<tosa::TanhOp>(op->getLoc(), input_type,
+ input_value);
+ } else {
+ // Unsupported activation type. Bail out.
+ return nullptr;
+ }
+ }
+
+ return nullptr;
+}
+
+// Common function for lowering reduce operations to TOSA ops.
+template <typename T>
+Value convertReduceOpCommon(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims,
+ Type reduce_element_type, bool is_quantized,
+ double input_scale, int64_t input_zp,
+ double output_scale, int64_t output_zp) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ ArrayRef<int64_t> input_shape = input_type.getShape();
+ ArrayRef<int64_t> output_shape = output_type.getShape();
+ auto input_rank = input_shape.size();
+ Value val = input_value;
+
+ if (axes_elems.getNumElements() == 0) {
+ // No axes means return the original tensor.
+ auto identity_op =
+ rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
+ val = identity_op.getResult();
+ } else {
+ // Reduce along each axis
+ SmallVector<int64_t, 4> shape_vec(input_shape.begin(), input_shape.end());
+
+ if (is_quantized) {
+ val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
+ }
+
+ for (int i = 0; i < axes_elems.getNumElements(); i++) {
+ int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
+ if (axis_val < 0) axis_val += input_rank;
+ auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
+
+ shape_vec[axis_val] = 1;
+ auto reduce_type = RankedTensorType::get(
+ llvm::makeArrayRef<int64_t>(shape_vec), reduce_element_type);
+
+ auto reduce_op =
+ rewriter.create<T>(op->getLoc(), reduce_type, val, axis_attr);
+
+ val = reduce_op.getResult();
+ }
+
+ if (is_quantized) {
+ auto output_rescale_type = RankedTensorType::get(
+ llvm::makeArrayRef<int64_t>(shape_vec), output_type.getElementType());
+ val = buildRescaleFromInt32(rewriter, op, output_rescale_type, val,
+ output_scale, output_zp);
+ }
+
+ // Optionally squeeze out the reduced axes.
+ if (!keep_dims) {
+ auto reshape_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), output_type, val,
+ rewriter.getI64ArrayAttr(output_shape));
+ val = reshape_op.getResult();
+ }
+ }
+
+ return val;
+}
+
+// Lowers ReduceAll to a sequence of TOSA ops.
+Operation* convertReduceAllOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ Value val = convertReduceOpCommon<tosa::ReduceAllOp>(
+ rewriter, op, output_type, input_value, axes_elems, keep_dims,
+ output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
+
+ return val.getDefiningOp();
+}
+
+// Lowers ReduceAny to a sequence of TOSA ops.
+Operation* convertReduceAnyOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ Value val = convertReduceOpCommon<tosa::ReduceAnyOp>(
+ rewriter, op, output_type, input_value, axes_elems, keep_dims,
+ output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
+
+ return val.getDefiningOp();
+}
+
+// Lowers ReduceMin to a sequence of TOSA ops.
+Operation* convertReduceMinOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ Value val = convertReduceOpCommon<tosa::ReduceMinOp>(
+ rewriter, op, output_type, input_value, axes_elems, keep_dims,
+ output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
+
+ return val.getDefiningOp();
+}
+
+// Lowers ReduceMax to a sequence of TOSA ops.
+Operation* convertReduceMaxOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ Value val = convertReduceOpCommon<tosa::ReduceMaxOp>(
+ rewriter, op, output_type, input_value, axes_elems, keep_dims,
+ output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
+
+ return val.getDefiningOp();
+}
+
+// Lowers ReduceProd to a sequence of TOSA ops.
+Operation* convertReduceProdOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype || output_is_qtype) {
+ op->emitOpError(
+ "ConvertReduceProdOp: input/output tensor should "
+ "be all floating-point.");
+ return nullptr;
+ }
+
+ Value val = convertReduceOpCommon<tosa::ReduceProdOp>(
+ rewriter, op, output_type, input_value, axes_elems, keep_dims,
+ output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
+
+ return val.getDefiningOp();
+}
+
+// Lowers ReduceSum to a sequence of TOSA ops.
+Operation* convertReduceSumOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype != output_is_qtype) {
+ op->emitOpError(
+ "ConvertReduceSumOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ return nullptr;
+ }
+
+ double input_scale = 1.0f;
+ double output_scale = 1.0f;
+ int64_t input_zp = 0;
+ int64_t output_zp = 0;
+ Type reduce_element_type = input_type.getElementType();
+
+ if (input_is_qtype) {
+ auto input_qtype =
+ input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+ auto output_qtype =
+ output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+
+ int32_t input_shift = 20;
+
+ input_scale =
+ static_cast<double>(1 << input_shift) * input_qtype.getScale();
+ output_scale =
+ 1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
+
+ input_zp = input_qtype.getZeroPoint();
+ output_zp = output_qtype.getZeroPoint();
+ reduce_element_type = rewriter.getI32Type();
+ }
+
+ Value val = convertReduceOpCommon<tosa::ReduceSumOp>(
+ rewriter, op, output_type, input_value, axes_elems, keep_dims,
+ reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
+ output_zp);
+
+ return val.getDefiningOp();
+}
+
+// Lowers ReduceMean to a sequence of TOSA ops.
+Operation* convertReduceMeanOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims) {
+ // reduce_mean is lowered as followed:
+ // op1 = reduce_sum(input)
+ // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype != output_is_qtype) {
+ op->emitOpError(
+ "ConvertReduceSumOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ return nullptr;
+ }
+
+ // Only supports float type mean() if it's non-quantized
+ if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
+ op->emitWarning(
+ "Failed convertReduceMean: input unquantized type but output element "
+ "not FloatType!");
+ return nullptr;
+ }
+
+ int64_t input_rank = input_type.getRank();
+ int64_t num_elems_on_reduced_axis = 1;
+ for (int i = 0; i < axes_elems.getNumElements(); i++) {
+ int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
+ if (axis_val < 0) axis_val += input_rank;
+ num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
+ }
+ double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
+
+ double input_scale = 1.0f;
+ double output_scale = 1.0f;
+ int64_t input_zp = 0;
+ int64_t output_zp = 0;
+ Type reduce_element_type = input_type.getElementType();
+
+ if (input_is_qtype) {
+ auto input_qtype =
+ input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+ auto output_qtype =
+ output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+
+ int32_t input_shift = 20;
+
+ input_scale =
+ static_cast<double>(1 << input_shift) * input_qtype.getScale();
+ output_scale = div_scale / (output_qtype.getScale() *
+ static_cast<double>(1 << input_shift));
+
+ input_zp = input_qtype.getZeroPoint();
+ output_zp = output_qtype.getZeroPoint();
+ reduce_element_type = rewriter.getI32Type();
+ }
+
+ Value val = convertReduceOpCommon<tosa::ReduceSumOp>(
+ rewriter, op, output_type, input_value, axes_elems, keep_dims,
+ reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
+ output_zp);
+
+ if (!input_is_qtype) {
+ Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
+ auto mul_op = rewriter.create<tosa::MulOp>(op->getLoc(), output_type, val,
+ div_const, 0);
+ val = mul_op.getResult();
+ }
+
+ return val.getDefiningOp();
+}
+
+// Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
+Operation* convertResizeOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ StringRef mode) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ auto input_shape = input_type.getShape();
+ auto output_shape = output_type.getShape();
+
+ size_t input_height = input_shape[1];
+ size_t input_width = input_shape[2];
+ size_t output_height = output_shape[1];
+ size_t output_width = output_shape[2];
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype != output_is_qtype) {
+ op->emitOpError(
+ "ConvertResizeOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ return nullptr;
+ }
+
+ if (!input_is_qtype) {
+ // TODO: support float type
+ op->emitOpError("ConvertResizeOp: floating-point type not supported yet ");
+ return nullptr;
+ }
+
+ int32_t shift = 11; // Set default shift to maximum allowed
+
+ double frac_y =
+ static_cast<double>(output_height) / static_cast<double>(input_height);
+ double frac_x =
+ static_cast<double>(output_width) / static_cast<double>(input_width);
+ int32_t stride_y = std::lround(frac_y * static_cast<double>(1 << shift));
+ int32_t stride_x = std::lround(frac_x * static_cast<double>(1 << shift));
+
+ // Stride is int16
+ while (stride_y >= 32768 || stride_x >= 32768) {
+ shift--;
+ stride_y = std::lround(frac_y * static_cast<double>(1 << shift));
+ stride_x = std::lround(frac_x * static_cast<double>(1 << shift));
+ }
+
+ ArrayAttr output_size =
+ rewriter.getI64ArrayAttr({static_cast<int64_t>(output_height),
+ static_cast<int64_t>(output_width)});
+ ArrayAttr stride = rewriter.getI64ArrayAttr({stride_y, stride_x});
+ ArrayAttr offset = rewriter.getI64ArrayAttr({0, 0});
+ IntegerAttr shift_attr = rewriter.getI32IntegerAttr(shift);
+ StringAttr resize_mode = rewriter.getStringAttr(mode.str());
+
+ return rewriter.create<tosa::ResizeOp>(op->getLoc(), output_type, input_value,
+ output_size, stride, offset,
+ shift_attr, resize_mode);
+}
+
+// Lowers Quantize to a sequence of TOSA quantization ops.
+Operation* convertQuantizeOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ double scale, int64_t zeropoint) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ auto output_shape = output_type.getShape();
+ auto output_element_type = output_type.getElementType();
+
+ // output element type could only be quantized integer
+ if (!output_element_type.isa<mlir::quant::QuantizedType>()) {
+ op->emitWarning(
+ "Lowering quantizeOp but output element type not quantized!");
+ return nullptr;
+ }
+
+ auto output_fp_type =
+ RankedTensorType::get(output_shape, rewriter.getF32Type());
+
+ Value zp_val =
+ getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
+
+ auto op1_mul_in = rewriter.create<tosa::MulOp>(
+ op->getLoc(), output_fp_type, input_value,
+ getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
+
+ auto op2_add_op1 = rewriter.create<tosa::AddOp>(
+ op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val);
+
+ // TOSA doesn't support CAST FLOAT->AINT8, need to CAST to INT32
+ // followed by a RESCALE
+ RankedTensorType output_int32_type =
+ RankedTensorType::get(output_shape, rewriter.getI32Type());
+
+ auto op3_cast_op2 = rewriter.create<tosa::CastOp>(
+ op->getLoc(), output_int32_type, op2_add_op1.getResult());
+
+ auto op4_rescale_op3 = buildRescale(rewriter, op, output_type,
+ op3_cast_op2.getResult(), 1.0, 0, 0);
+
+ return op4_rescale_op3.getDefiningOp();
+}
+
+// Lowers Dequantize to a sequence of TOSA dequantization ops.
+Operation* convertDequantizeOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ double scale, int64_t zeropoint) {
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ // input element type could only be quantized integer
+ if (!input_type.getElementType().isa<mlir::quant::QuantizedType>())
+ return nullptr;
+
+ auto output_shape = output_type.getShape();
+
+ RankedTensorType output_int32_type =
+ RankedTensorType::get(output_shape, rewriter.getI32Type());
+
+ Value zp_val =
+ getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
+
+ // TOSA doesn't support CAST AINT8 -> FLOAT, need to RESCALE to INT32
+ // followed by a CAST
+ auto op1_rescale_in =
+ buildRescale(rewriter, op, output_int32_type, input_value, 1.0, 0, 0);
+
+ auto op2_cast_op1 =
+ rewriter.create<tosa::CastOp>(op->getLoc(), output_type, op1_rescale_in);
+
+ auto op3_sub_op2 = rewriter.create<tosa::SubOp>(
+ op->getLoc(), output_type, op2_cast_op1.getResult(), zp_val);
+
+ return rewriter.create<tosa::MulOp>(
+ op->getLoc(), output_type, op3_sub_op2.getResult(),
+ getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
+}
+
+// Lowers FakeQuant to a sequence of TOSA quantization ops.
+Operation* convertFakeQuantOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ double min, double max, int64_t num_bits,
+ bool narrow_range) {
+ // FakeQuant is lowered as follow:
+ // op1 = quantize(input)
+ // op2 = dequantize(op1)
+
+ auto input_type = input_value.getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return nullptr;
+
+ // quantized as INT<num_bits>, where num_bits can only be 8, 16
+ if (num_bits != 8 && num_bits != 16) {
+ op->emitWarning("FakeQuantOp lowering handles only 8 and 16 for num_bits!");
+ return nullptr;
+ }
+
+ auto output_shape = output_type.getShape();
+
+ int64_t qmax = (1L << (num_bits - 1)) - 1;
+ int64_t qmin = -(1L << (num_bits - 1));
+ if (narrow_range) {
+ qmin += 1;
+ }
+
+ auto int_element_qtype = mlir::quant::UniformQuantizedType::get(
+ true, rewriter.getIntegerType(num_bits), rewriter.getF32Type(), 1.0f, 0,
+ qmin, qmax);
+ auto output_int_type = RankedTensorType::get(output_shape, int_element_qtype);
+
+ double scale = (max - min) / static_cast<double>(qmax - qmin);
+ int64_t zeropoint = std::llround((-min) / scale + static_cast<double>(qmin));
+
+ // Quantize: round(x / scale + zeropoint)
+ auto quantized_op = convertQuantizeOp(rewriter, op, output_int_type,
+ input_value, 1.0 / scale, zeropoint);
+
+ // Dequantize: ((float)x - zeropoint) * scale
+ return convertDequantizeOp(rewriter, op, output_type,
+ quantized_op->getResult(0), scale, zeropoint);
+}
+
+Operation* convertTFConv2DCommon(
+ PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+ Value input, Value filter, Value bias, ArrayAttr strides_attr,
+ ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
+ StringRef padding_ref, StringRef data_format_ref) {
+ auto input_type = input.getType().dyn_cast<RankedTensorType>();
+ auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type) return nullptr;
+ if (!filter_type) return nullptr;
+
+ // Transpose [H, W, I, O] to [O, H, W, I]
+ auto filter_shape = filter_type.getShape();
+ SmallVector<int64_t, 4> a1_transpose_dims;
+ a1_transpose_dims.push_back(filter_shape[3]);
+ a1_transpose_dims.push_back(filter_shape[0]);
+ a1_transpose_dims.push_back(filter_shape[1]);
+ a1_transpose_dims.push_back(filter_shape[2]);
+ auto a1_filter_transpose_perm =
+ get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {3, 0, 1, 2});
+ auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
+ filter_type.getElementType()),
+ filter, a1_filter_transpose_perm);
+
+ // Only support NHWC now.
+ if (data_format_ref.str() != "NHWC") {
+ op->emitWarning("convertTDConv2DCommon only supports NHWC!");
+ return nullptr;
+ }
+
+ ArrayAttr stride;
+ ArrayAttr dilation;
+ ArrayAttr pad;
+ {
+ if (!strides_attr) {
+ stride = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t stride_h = strides_attr[1].cast<IntegerAttr>().getInt();
+ int64_t stride_w = strides_attr[2].cast<IntegerAttr>().getInt();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ }
+ {
+ if (!dilations_attr) {
+ dilation = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t dilation_h = dilations_attr[1].cast<IntegerAttr>().getInt();
+ int64_t dilation_w = dilations_attr[2].cast<IntegerAttr>().getInt();
+ dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
+ }
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(padding_ref.str(), &tf_pad).ok()) {
+ op->emitWarning("Could not get padding data from padding string term!");
+ return nullptr;
+ }
+
+ tensorflow::TensorFormat data_format_tf;
+ if (!FormatFromString(data_format_ref.str(), &data_format_tf))
+ return nullptr;
+
+ if (tf_pad == tensorflow::Padding::EXPLICIT) {
+ pad = getPaddingValuesFromExplicitPadAttr(explicit_padding_attr,
+ data_format_tf, rewriter);
+ } else {
+ if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
+ 0, // tensorflow::FORMAT_HWIO
+ input_type, filter_type, stride,
+ dilation, rewriter, pad))
+ return nullptr;
+ }
+ }
+
+ return rewriter.create<tosa::Conv2DOp>(op->getLoc(), output_type, input,
+ a1_filter_transpose_op.getResult(),
+ bias, pad, stride, dilation);
+}
+
+}; // namespace tosa
+}; // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h
new file mode 100644
index 0000000..85e2be7
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h
@@ -0,0 +1,242 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
+#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
+
+// This file contains legalizations common to mapping both TensorFlow and
+// TensorFlow Lite to TOSA.
+//
+// Conversion functions return nullptr on a lowerization failure or a lowered
+// operator on success. Callers must check and return a LogicalResult failure
+// on nullptr.
+//
+// For these functions, the framework-specific operands/attributes/defaults
+// are already extracted and placed in a common form for lowering.
+#include "mlir/Dialect/Quant/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+namespace tosa {
+
+// Lowers the Pack operator to TOSA.
+Operation* convertPackOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, SmallVector<Value, 8>& inputs,
+ int32_t axis);
+
+// Lowers the Unpack operator to TOSA.
+Operation* convertUnpackOp(PatternRewriter& rewriter, Operation* op,
+ Value input_value, int32_t axis);
+
+// Lowers the Select operator to TOSA.
+Operation* convertSelectOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value condition_value,
+ Value x_value, Value y_value);
+
+// Lowers the ZerosLike operator to TOSA by creating a constant
+// of the desired type and shape.
+Operation* convertZerosLikeOp(PatternRewriter& rewriter, Operation* op,
+ Value result, Value input);
+
+// Lowers the Mul operator to TOSA. For quantized types, this requires
+// inserting rescale operators before and after the operation.
+Operation* convertMultiplyOp(PatternRewriter& rewriter, Operation* op,
+ Value output_val, Value input_lhs_val,
+ Value input_rhs_val);
+
+// Lowers the SquaredDifference operator to TOSA.
+Operation* convertSquaredDifferenceOp(PatternRewriter& rewriter, Operation* op,
+ Value result, Value x, Value y);
+
+// Lowers the Round operator to TOSA.
+Operation* convertRoundOp(PatternRewriter& rewriter, Operation* op,
+ Value result, Value input);
+
+// Lowers ConcatV2 to TOSA.
+Operation* convertConcatV2Op(PatternRewriter& rewriter, Operation* op,
+ Value result_value, SmallVector<Value, 8>& values,
+ int32_t axis);
+
+// Lowers SpaceToBatchND to TOSA.
+Operation* convertSpaceToBatchNDOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ Value block_shape_value,
+ Value paddings_value);
+
+// Lowers BatchToSpaceND to TOSA.
+Operation* convertBatchToSpaceNDOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ Value block_shape_value, Value crops_value);
+
+// Lowers ExpandDims to TOSA.
+Operation* convertExpandDimsOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ Value dim_value);
+
+// Lowers Squeeze to TOSA.
+Operation* convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ SmallVector<int32_t, 8>& squeeze_dims);
+
+// Lowers ELU to a sequence of TOSA ops.
+Operation* convertEluOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value features_value);
+
+// Lowers Softmax to a sequence of TOSA ops.
+Operation* convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value logits_value);
+
+// Lowers LogSoftmax to a sequence of TOSA ops.
+Operation* convertLogSoftmaxOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value logits_value);
+
+// Lowers SpaceToDepth to a sequence of TOSA ops. Supports NHWC.
+Operation* convertSpaceToDepthOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ IntegerAttr block_size_attr,
+ StringAttr data_format);
+
+// Lowers DepthToSpace to a sequence of TOSA ops. Supports NHWC.
+Operation* convertDepthToSpaceOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ IntegerAttr block_size_attr,
+ StringAttr data_format);
+
+// Lowers Split to a sequence of TOSA ops.
+Operation* convertSplitOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ int32_t num_split, int32_t axis);
+
+// Lowers SplitV to a sequence of TOSA ops.
+Operation* convertSplitVOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ SmallVector<int32_t, 4>& size_split, int32_t axis);
+
+// Lowers StridedSlice to a sequence of TOSA ops.
+Operation* convertStridedSliceOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value input_value,
+ Value begin_value, Value end_value,
+ Value strides_value, int32_t begin_mask,
+ int32_t end_mask, int32_t ellipsis_mask,
+ int32_t new_axis_mask,
+ int32_t shrink_axis_mask);
+
+// Lowers FloorDiv to a sequence of TOSA operators.
+Operation* convertFloorDivOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value lhs_value,
+ Value rhs_value);
+
+// Lowers FloorMod to a sequence of TOSA operators.
+Operation* convertFloorModOp(PatternRewriter& rewriter, Operation* op,
+ Value result_value, Value lhs_value,
+ Value rhs_value);
+
+// Lowers FusedActivation to a sequence of TOSA ops.
+Operation* convertFusedActivation(PatternRewriter& rewriter, Operation* op,
+ Value input_value,
+ StringAttr fused_activation_fn);
+
+// Helper function for implementing quantized divide by power-of-two in TOSA
+// ops.
+Operation* convertRoundingDivideByPOT(PatternRewriter& rewriter, Operation* op,
+ Value input_value, Value rshift_value);
+
+// Lowers ReduceAll to a sequence of TOSA ops.
+Operation* convertReduceAllOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims);
+
+// Lowers ReduceAny to a sequence of TOSA ops.
+Operation* convertReduceAnyOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims);
+
+// Lowers ReduceMin to a sequence of TOSA ops.
+Operation* convertReduceMinOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims);
+
+// Lowers ReduceMax to a sequence of TOSA ops.
+Operation* convertReduceMaxOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims);
+
+// Lowers ReduceProd to a sequence of TOSA ops.
+Operation* convertReduceProdOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims);
+
+// Lowers ReduceSum to a sequence of TOSA ops.
+Operation* convertReduceSumOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims);
+
+// Lowers ReduceMean to a sequence of TOSA ops.
+Operation* convertReduceMeanOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ ElementsAttr axes_elems, bool keep_dims);
+
+// Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
+Operation* convertResizeOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ StringRef mode);
+
+// Lowers Quantize to a sequence of TOSA quantization ops.
+Operation* convertQuantizeOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ double scale, int64_t zeropoint);
+
+// Lowers Dequantize to a sequence of TOSA dequantization ops.
+Operation* convertDequantizeOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ double scale, int64_t zeropoint);
+
+// Lowers FakeQuant to a sequence of TOSA quantization ops.
+Operation* convertFakeQuantOp(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_value,
+ double min, double max, int64_t num_bits,
+ bool narrow_range);
+Operation* convertTFConv2DCommon(
+ PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
+ Value input, Value filter, Value bias, ArrayAttr strides_attr,
+ ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
+ StringRef padding_ref, StringRef data_format_ref);
+
+}; // namespace tosa
+}; // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
new file mode 100644
index 0000000..644795e
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
@@ -0,0 +1,2022 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legalize TensorFlow to TOSA
+
+#include <climits>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <numeric>
+
+#include "mlir/Dialect/Quant/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
+
+#define PASS_NAME "tosa-legalize-tf"
+#define DEBUG_TYPE PASS_NAME
+
+namespace mlir {
+
+namespace tosa {
+
+namespace {
+// Performs lowering to TOSA dialect
+class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
+ public:
+ explicit LegalizeTF() {}
+ void runOnFunction() override;
+};
+
+// All the Pat<> lowering mappings.
+#include "tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.inc"
+
+#define DECL_CONVERT_OP(tf_op) \
+ struct ConvertTF##tf_op##Op : public RewritePattern { \
+ explicit ConvertTF##tf_op##Op(MLIRContext *context) \
+ : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
+ LogicalResult matchAndRewrite(Operation *op, \
+ PatternRewriter &rewriter) const override; \
+ }
+
+// All the explcitly implemented complex lowerings.
+DECL_CONVERT_OP(MatMul);
+DECL_CONVERT_OP(Relu);
+DECL_CONVERT_OP(Relu6);
+DECL_CONVERT_OP(Equal);
+DECL_CONVERT_OP(NotEqual);
+DECL_CONVERT_OP(Greater);
+DECL_CONVERT_OP(GreaterEqual);
+DECL_CONVERT_OP(Add);
+DECL_CONVERT_OP(AddV2);
+DECL_CONVERT_OP(AddN);
+DECL_CONVERT_OP(Sub);
+DECL_CONVERT_OP(Mul);
+DECL_CONVERT_OP(Square);
+DECL_CONVERT_OP(SquaredDifference);
+DECL_CONVERT_OP(Round);
+DECL_CONVERT_OP(FloorDiv);
+DECL_CONVERT_OP(FloorMod);
+DECL_CONVERT_OP(Assert);
+DECL_CONVERT_OP(Maximum);
+DECL_CONVERT_OP(Minimum);
+DECL_CONVERT_OP(RealDiv);
+DECL_CONVERT_OP(ArgMax);
+DECL_CONVERT_OP(AvgPool);
+DECL_CONVERT_OP(MaxPool);
+DECL_CONVERT_OP(ConcatV2);
+DECL_CONVERT_OP(Reshape);
+DECL_CONVERT_OP(Rank);
+DECL_CONVERT_OP(Shape);
+DECL_CONVERT_OP(ExpandDims);
+DECL_CONVERT_OP(Squeeze);
+DECL_CONVERT_OP(Fill);
+DECL_CONVERT_OP(Conv2D);
+DECL_CONVERT_OP(DepthwiseConv2dNative);
+DECL_CONVERT_OP(Conv2DBackpropInput);
+DECL_CONVERT_OP(Elu);
+DECL_CONVERT_OP(Softmax);
+DECL_CONVERT_OP(LogSoftmax);
+DECL_CONVERT_OP(All);
+DECL_CONVERT_OP(Any);
+DECL_CONVERT_OP(Max);
+DECL_CONVERT_OP(Min);
+DECL_CONVERT_OP(Mean);
+DECL_CONVERT_OP(Prod);
+DECL_CONVERT_OP(Sum);
+DECL_CONVERT_OP(FusedBatchNorm);
+DECL_CONVERT_OP(FusedBatchNormV3);
+DECL_CONVERT_OP(BiasAdd);
+DECL_CONVERT_OP(Split);
+DECL_CONVERT_OP(SplitV);
+DECL_CONVERT_OP(Pack);
+DECL_CONVERT_OP(Unpack);
+DECL_CONVERT_OP(Transpose);
+DECL_CONVERT_OP(Tile);
+DECL_CONVERT_OP(Slice);
+DECL_CONVERT_OP(StridedSlice);
+DECL_CONVERT_OP(Less);
+DECL_CONVERT_OP(LessEqual);
+DECL_CONVERT_OP(Pad);
+DECL_CONVERT_OP(ResizeBilinear);
+DECL_CONVERT_OP(ResizeNearestNeighbor);
+DECL_CONVERT_OP(Gather);
+DECL_CONVERT_OP(GatherV2);
+DECL_CONVERT_OP(SelectV2);
+DECL_CONVERT_OP(SpaceToDepth);
+DECL_CONVERT_OP(DepthToSpace);
+DECL_CONVERT_OP(SpaceToBatchND);
+DECL_CONVERT_OP(BatchToSpaceND);
+DECL_CONVERT_OP(ZerosLike);
+DECL_CONVERT_OP(Sigmoid);
+DECL_CONVERT_OP(Tanh);
+DECL_CONVERT_OP(LeakyRelu);
+DECL_CONVERT_OP(Neg);
+DECL_CONVERT_OP(StopGradient);
+DECL_CONVERT_OP(ReverseV2);
+DECL_CONVERT_OP(FakeQuantWithMinMaxArgs);
+DECL_CONVERT_OP(FakeQuantWithMinMaxVars);
+#undef DECL_CONVERT_OP
+
+// TODO: remove macro when replacing common function return types with
+// llvm::Optional<> Helper macros for checking the return value of a common
+// legalization function that returns a single tensor.
+// Packs the result in a list.
+#define TOSA_REPLACE_LOWERED_OP(REWRITER, OP, LOWERED_OP) \
+ if (LOWERED_OP) { \
+ REWRITER.replaceOp((OP), {(LOWERED_OP)->getResults()}); \
+ return success(); \
+ } else { \
+ return failure(); \
+ }
+
+// TODO: remove macro when replacing common function return types with
+// llvm::Optional<> Helper macros for checking the return value of a common
+// legalization function that returns a tensor list.
+#define TOSA_REPLACE_LOWERED_OP_LIST(REWRITER, OP, LOWERED_OP) \
+ if (LOWERED_OP) { \
+ REWRITER.replaceOp((OP), (LOWERED_OP)->getResults()); \
+ return success(); \
+ } else { \
+ return failure(); \
+ }
+
+LogicalResult ConvertTFReluOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_relu_op = cast<TF::ReluOp>(op);
+
+ auto output_type =
+ tf_relu_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
+ op, output_type, tf_relu_op.features(), rewriter.getI64IntegerAttr(0),
+ rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
+ } else {
+ rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
+ op, output_type, tf_relu_op.features(),
+ rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
+ rewriter.getF32FloatAttr(0.0f));
+ }
+ return success();
+}
+
+LogicalResult ConvertTFRelu6Op::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_relu6_op = cast<TF::Relu6Op>(op);
+
+ auto output_type =
+ tf_relu6_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
+ op, output_type, tf_relu6_op.features(), rewriter.getI64IntegerAttr(0),
+ rewriter.getF32FloatAttr(6.0f));
+ } else {
+ rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
+ op, output_type, tf_relu6_op.features(), rewriter.getI64IntegerAttr(6),
+ rewriter.getF32FloatAttr(0.0f));
+ }
+ return success();
+}
+
+LogicalResult ConvertTFEqualOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_equal_op = cast<TF::EqualOp>(op);
+
+ auto output_type =
+ tf_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::EqualOp>(op, output_type, tf_equal_op.x(),
+ tf_equal_op.y());
+ return success();
+}
+
+LogicalResult ConvertTFNotEqualOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_not_equal_op = cast<TF::NotEqualOp>(op);
+
+ auto output_type =
+ tf_not_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto op1_equal_in = rewriter.create<tosa::EqualOp>(
+ op->getLoc(), output_type, tf_not_equal_op.x(), tf_not_equal_op.y());
+
+ auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
+ op->getLoc(), output_type, op1_equal_in.getResult());
+
+ rewriter.replaceOp(op, {op2_not_op1.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFGreaterOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_greater_op = cast<TF::GreaterOp>(op);
+
+ auto output_type =
+ tf_greater_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::GreaterOp>(
+ op, output_type, tf_greater_op.x(), tf_greater_op.y());
+ return success();
+}
+
+LogicalResult ConvertTFGreaterEqualOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_greater_equal_op = cast<TF::GreaterEqualOp>(op);
+
+ auto output_type =
+ tf_greater_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::GreaterEqualOp>(
+ op, output_type, tf_greater_equal_op.x(), tf_greater_equal_op.y());
+ return success();
+}
+
+LogicalResult ConvertTFAddOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_add_op = cast<TF::AddOp>(op);
+
+ auto output_type =
+ tf_add_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::AddOp>(op, output_type, tf_add_op.x(),
+ tf_add_op.y());
+ return success();
+}
+
+LogicalResult ConvertTFAddV2Op::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_addv2_op = cast<TF::AddV2Op>(op);
+
+ auto output_type =
+ tf_addv2_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::AddOp>(op, output_type, tf_addv2_op.x(),
+ tf_addv2_op.y());
+ return success();
+}
+
+// AddN is commutative
+LogicalResult ConvertTFAddNOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_addn_op = cast<TF::AddNOp>(op);
+
+ auto output_type =
+ tf_addn_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ SmallVector<Value, 8> inputs(tf_addn_op.inputs());
+
+ assert(inputs.size() >= 2);
+
+ auto newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type,
+ inputs[0], inputs[1]);
+ for (int i = 2; i < inputs.size(); i++) {
+ newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type, inputs[i],
+ newOp.getResult());
+ }
+
+ rewriter.replaceOp(op, {newOp.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFSubOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_sub_op = cast<TF::SubOp>(op);
+
+ auto output_type =
+ tf_sub_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::SubOp>(op, output_type, tf_sub_op.x(),
+ tf_sub_op.y());
+ return success();
+}
+
+LogicalResult ConvertTFMulOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_mul_op = cast<TF::MulOp>(op);
+
+ auto lowered_op = convertMultiplyOp(rewriter, op, tf_mul_op.getResult(),
+ tf_mul_op.x(), tf_mul_op.y());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFSquareOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_square_op = cast<TF::SquareOp>(op);
+
+ auto lowered_op = convertMultiplyOp(rewriter, op, tf_square_op.getResult(),
+ tf_square_op.x(), tf_square_op.x());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFSquaredDifferenceOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_squared_op = cast<TF::SquaredDifferenceOp>(op);
+
+ auto lowered_op =
+ convertSquaredDifferenceOp(rewriter, op, tf_squared_op.getResult(),
+ tf_squared_op.x(), tf_squared_op.y());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFRoundOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_round_op = cast<TF::RoundOp>(op);
+
+ auto input_type = tf_round_op.x().getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ return op->emitOpError("Round: input not ranked tensor type");
+ }
+
+ if (input_type.getElementType().isa<FloatType>()) {
+ auto lowered_op =
+ convertRoundOp(rewriter, op, tf_round_op.getResult(), tf_round_op.x());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+
+ } else {
+ tf_round_op.replaceAllUsesWith(tf_round_op.x());
+ return success();
+ }
+}
+
+LogicalResult ConvertTFFloorDivOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_floordiv_op = cast<TF::FloorDivOp>(op);
+
+ auto lowered_op = convertFloorDivOp(rewriter, op, tf_floordiv_op.getResult(),
+ tf_floordiv_op.x(), tf_floordiv_op.y());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFFloorModOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_floormod_op = cast<TF::FloorModOp>(op);
+
+ auto lowered_op = convertFloorModOp(rewriter, op, tf_floormod_op.getResult(),
+ tf_floormod_op.x(), tf_floormod_op.y());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFAssertOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ op->dropAllReferences();
+ op->erase();
+ return success();
+}
+
+LogicalResult ConvertTFMaximumOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_maximum_op = cast<TF::MaximumOp>(op);
+
+ auto output_type =
+ tf_maximum_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::MaximumOp>(
+ op, output_type, tf_maximum_op.x(), tf_maximum_op.y());
+ return success();
+}
+
+LogicalResult ConvertTFMinimumOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_minimum_op = cast<TF::MinimumOp>(op);
+
+ auto output_type =
+ tf_minimum_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::MinimumOp>(
+ op, output_type, tf_minimum_op.x(), tf_minimum_op.y());
+ return success();
+}
+
+LogicalResult ConvertTFRealDivOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_div_op = cast<TF::RealDivOp>(op);
+
+ auto y_type = tf_div_op.y().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_div_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type || !y_type) return failure();
+
+ auto reciprocal_op =
+ rewriter.create<tosa::ReciprocalOp>(op->getLoc(), y_type, tf_div_op.y());
+
+ auto mul_op = rewriter.create<tosa::MulOp>(
+ op->getLoc(), output_type, tf_div_op.x(), reciprocal_op.getResult(), 0);
+ rewriter.replaceOp(op, {mul_op.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFArgMaxOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_argmax_op = cast<TF::ArgMaxOp>(op);
+
+ auto input_type = tf_argmax_op.input().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_argmax_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type || !input_type) return failure();
+
+ ElementsAttr axis_elems;
+ if (!matchPattern(tf_argmax_op.dimension(), m_Constant(&axis_elems)))
+ return failure();
+
+ int32_t axis = axis_elems.getValue<IntegerAttr>({}).getInt();
+ if (axis < 0) {
+ axis += input_type.getRank();
+ }
+
+ if (axis < 0 || axis >= input_type.getRank()) {
+ return op->emitOpError("TFArgMax: invalid axis value");
+ }
+
+ IntegerAttr axis_attr = rewriter.getI64IntegerAttr(axis);
+
+ rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, output_type,
+ tf_argmax_op.input(), axis_attr);
+
+ return success();
+}
+LogicalResult ConvertTFAvgPoolOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_avgpool_op = cast<TF::AvgPoolOp>(op);
+
+ auto input_type =
+ tf_avgpool_op.value().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type || !output_type) return failure();
+
+ auto tmpAttr = tf_avgpool_op.data_formatAttr();
+ if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
+
+ ArrayAttr pad;
+ ArrayAttr stride;
+ ArrayAttr kernel;
+ {
+ auto tmpAttr = tf_avgpool_op.strides();
+ if (!tmpAttr) {
+ stride = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
+ int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ }
+ {
+ auto tmpAttr = tf_avgpool_op.ksize();
+ if (!tmpAttr) {
+ kernel = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
+ int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
+ kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
+ }
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tf_avgpool_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ ArrayAttr dilation =
+ rewriter.getI64ArrayAttr({1, 1}); // Pooling has no non-unit dilation
+
+ SmallVector<int64_t, 2> i64array;
+
+ for (auto &elem : tf_avgpool_op.ksize()) {
+ int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
+ i64array.emplace_back(value);
+ }
+
+ auto filter_type = RankedTensorType::get(
+ llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
+
+ if (!getPaddingValuesFromPadType(
+ tf_pad,
+ tensorflow::FORMAT_NHWC, // TFLite only supports this
+ 1, // tensorflow::FORMAT_OHWI,
+ input_type, filter_type, stride, dilation, rewriter, pad))
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
+ op, output_type, tf_avgpool_op.value(), kernel, stride, pad);
+ return success();
+}
+
+LogicalResult ConvertTFMaxPoolOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_maxpool_op = cast<TF::MaxPoolOp>(op);
+
+ auto input_type =
+ tf_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type || !output_type) return failure();
+
+ auto tmpAttr = tf_maxpool_op.data_formatAttr();
+ if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
+
+ ArrayAttr pad;
+ ArrayAttr stride;
+ ArrayAttr kernel;
+ {
+ auto tmpAttr = tf_maxpool_op.strides();
+ if (!tmpAttr) {
+ stride = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
+ int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ }
+ {
+ auto tmpAttr = tf_maxpool_op.ksize();
+ if (!tmpAttr) {
+ kernel = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
+ int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
+ kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
+ }
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tf_maxpool_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ // Pooling has no non-unit dilation
+ ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
+
+ SmallVector<int64_t, 4> i64array;
+
+ for (auto &elem : tf_maxpool_op.ksize()) {
+ int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
+ i64array.emplace_back(value);
+ }
+
+ auto filter_type = RankedTensorType::get(
+ llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
+
+ if (!getPaddingValuesFromPadType(
+ tf_pad,
+ tensorflow::FORMAT_NHWC, // TFLite only supports this
+ 1, // tensorflow::FORMAT_OHWI,
+ input_type, filter_type, stride, dilation, rewriter, pad))
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
+ op, output_type, tf_maxpool_op.input(), kernel, stride, pad);
+ return success();
+}
+
+LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_concatv2_op = cast<TF::ConcatV2Op>(op);
+ SmallVector<Value, 8> values(tf_concatv2_op.values());
+
+ ElementsAttr axis_elems;
+ if (!matchPattern(tf_concatv2_op.axis(), m_Constant(&axis_elems)))
+ return failure();
+
+ int32_t axis = axis_elems.getValue<IntegerAttr>({}).getInt();
+
+ auto lowered_op =
+ convertConcatV2Op(rewriter, op, tf_concatv2_op.getResult(), values, axis);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFReshapeOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_reshape_op = cast<TF::ReshapeOp>(op);
+
+ auto output_type =
+ tf_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ // Regular way to match tensor as element attribute doesn't always work
+ // use output_type.getShape() which is more stable
+ SmallVector<int64_t, 8> shape_vals;
+ for (int i = 0; i < output_type.getShape().size(); i++) {
+ shape_vals.push_back(output_type.getShape()[i]);
+ }
+ ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals);
+
+ rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+ op, output_type, tf_reshape_op.tensor(), shape_attr);
+ return success();
+}
+
+LogicalResult ConvertTFRankOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_rank_op = cast<TF::RankOp>(op);
+
+ auto input_type = tf_rank_op.input().getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return failure();
+
+ int32_t rank = input_type.getRank();
+
+ auto rank_type = RankedTensorType::get({1}, rewriter.getIntegerType(32));
+ auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
+ auto rank_const =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), rank_type, rank_attr);
+
+ rewriter.replaceOp(op, {rank_const.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFShapeOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_shape_op = cast<TF::ShapeOp>(op);
+
+ auto output_type =
+ tf_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto input_type = tf_shape_op.input().getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return failure();
+
+ auto input_shape = input_type.getShape();
+
+ SmallVector<int32_t, 8> shape_arr;
+ for (int i = 0; i < input_shape.size(); i++) {
+ shape_arr.emplace_back(input_shape[i]);
+ }
+
+ auto shape_type = RankedTensorType::get(
+ {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
+ auto shape_attr = DenseElementsAttr::get(
+ shape_type, llvm::makeArrayRef<int32_t>(shape_arr));
+ auto shape_const =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), shape_type, shape_attr);
+
+ rewriter.replaceOp(op, {shape_const.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFExpandDimsOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_expanddims_op = cast<TF::ExpandDimsOp>(op);
+
+ auto lowered_op =
+ convertExpandDimsOp(rewriter, op, tf_expanddims_op.getResult(),
+ tf_expanddims_op.input(), tf_expanddims_op.dim());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFSqueezeOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_squeeze_op = cast<TF::SqueezeOp>(op);
+
+ // Copy squeeze_dims into int32_t array
+ auto squeeze_dims_attr = tf_squeeze_op.squeeze_dimsAttr();
+ SmallVector<int32_t, 8> squeeze_dims;
+ for (auto &squeeze_dim : squeeze_dims_attr) {
+ squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
+ }
+
+ auto lowered_op = convertSqueezeOp(rewriter, op, tf_squeeze_op.getResult(),
+ tf_squeeze_op.input(), squeeze_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFFillOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_fill_op = cast<TF::FillOp>(op);
+
+ auto output_type =
+ tf_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ ElementsAttr dims_elems;
+ if (!matchPattern(tf_fill_op.dims(), m_Constant(&dims_elems)))
+ return failure();
+ SmallVector<int64_t, 4> dims_vals;
+ uint32_t total_size = 1;
+ for (int i = 0; i < dims_elems.getNumElements(); i++) {
+ dims_vals.push_back(dims_elems.getValue<IntegerAttr>(i).getInt());
+ total_size *= dims_vals[i];
+ }
+
+ ElementsAttr value_elem;
+ if (!matchPattern(tf_fill_op.value(), m_Constant(&value_elem)))
+ return failure();
+
+ auto fill_type = RankedTensorType::get(ArrayRef<int64_t>(dims_vals),
+ value_elem.getType().getElementType());
+ DenseElementsAttr fill_attr;
+
+ // Convert to a compatible zero type
+ if (value_elem.getType().getElementType().isa<FloatType>()) {
+ llvm::SmallVector<float, 4> fill_arr(
+ total_size,
+ value_elem.getValue<FloatAttr>(0).getValue().convertToFloat());
+ fill_attr =
+ DenseElementsAttr::get(fill_type, llvm::makeArrayRef<float>(fill_arr));
+ } else {
+ llvm::SmallVector<int32_t, 4> fill_arr(
+ total_size,
+ value_elem.getValue<IntegerAttr>(0).getValue().getLimitedValue());
+ fill_attr = DenseElementsAttr::get(fill_type,
+ llvm::makeArrayRef<int32_t>(fill_arr));
+ }
+ auto fill_const_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), fill_type, fill_attr);
+ rewriter.replaceOp(op, {fill_const_op.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFConv2DOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_conv2d_op = cast<TF::Conv2DOp>(op);
+
+ auto filter_type =
+ tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
+
+ // Set up a zero attr for subsequent pattern replacement if required
+ auto bias_dim = filter_type.getShape().back();
+ auto bias_type =
+ RankedTensorType::get({bias_dim}, filter_type.getElementType());
+ auto bias_attr = rewriter.getZeroAttr(bias_type);
+ auto bias = rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type,
+ bias_attr.cast<ElementsAttr>());
+
+ auto lowered_op = convertTFConv2DCommon(
+ rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
+ bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
+ tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
+ tf_conv2d_op.data_format());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_dwconv2d_op = cast<TF::DepthwiseConv2dNativeOp>(op);
+
+ auto input_type =
+ tf_dwconv2d_op.input().getType().dyn_cast<RankedTensorType>();
+ auto filter_type =
+ tf_dwconv2d_op.filter().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_dwconv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type) return failure();
+ if (!output_type) return failure();
+
+ // Set up a zero attr for subsequent pattern replacement if required
+ if (!filter_type) {
+ return op->emitOpError("DepthwiseConv2d: filter type unranked tensor");
+ }
+
+ auto tmpAttr = tf_dwconv2d_op.data_formatAttr();
+ if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
+
+ ArrayAttr stride;
+ ArrayAttr dilation;
+ ArrayAttr pad;
+ {
+ auto tmpAttr = tf_dwconv2d_op.strides();
+ if (!tmpAttr) {
+ stride = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
+ int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ }
+ {
+ auto tmpAttr = tf_dwconv2d_op.dilations();
+ if (!tmpAttr) {
+ dilation = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
+ int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
+ dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
+ }
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tf_dwconv2d_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ tensorflow::TensorFormat data_format_tf;
+ if (!FormatFromString(tf_dwconv2d_op.data_format().str(), &data_format_tf))
+ return failure();
+
+ if (tf_pad == tensorflow::Padding::EXPLICIT) {
+ pad = getPaddingValuesFromExplicitPadAttr(
+ tf_dwconv2d_op.explicit_paddings(), data_format_tf, rewriter);
+ } else {
+ if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
+ 0, // tensorflow::FORMAT_HWIO
+ input_type, filter_type, stride,
+ dilation, rewriter, pad))
+ return failure();
+ }
+ }
+
+ auto filter_shape = filter_type.getShape();
+ auto bias_dim = filter_shape[2] * filter_shape[3];
+ auto bias_type =
+ RankedTensorType::get({bias_dim}, filter_type.getElementType());
+ auto bias_attr = rewriter.getZeroAttr(bias_type);
+ auto bias = rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type,
+ bias_attr.cast<ElementsAttr>());
+
+ rewriter.replaceOpWithNewOp<tosa::DepthwiseConv2DOp>(
+ op, output_type, tf_dwconv2d_op.input(), tf_dwconv2d_op.filter(), bias,
+ pad, stride, dilation);
+ return success();
+}
+
+LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_conv_op = cast<TF::Conv2DBackpropInputOp>(op);
+
+ auto input_type =
+ tf_conv_op.out_backprop().getType().dyn_cast<RankedTensorType>();
+ auto filter_type = tf_conv_op.filter().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type) return failure();
+ if (!filter_type) return failure();
+ if (!output_type) return failure();
+
+ // Transpose [H, W, I, O] to [O, H, W, I]
+ auto filter_shape = filter_type.getShape();
+ llvm::SmallVector<int64_t, 4> a1_transpose_dims;
+ a1_transpose_dims.push_back(filter_shape[2]);
+ a1_transpose_dims.push_back(filter_shape[0]);
+ a1_transpose_dims.push_back(filter_shape[1]);
+ a1_transpose_dims.push_back(filter_shape[3]);
+ auto a1_filter_transpose_perm =
+ get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {2, 0, 1, 3});
+ auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
+ filter_type.getElementType()),
+ tf_conv_op.filter(), a1_filter_transpose_perm);
+
+ ArrayAttr stride;
+ ArrayAttr dilation;
+ ArrayAttr outpad;
+ ArrayAttr output_shape;
+ {
+ auto tmpAttr = tf_conv_op.strides();
+ if (!tmpAttr) {
+ stride = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
+ int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ }
+ {
+ auto tmpAttr = tf_conv_op.dilations();
+ if (!tmpAttr) {
+ dilation = rewriter.getI64ArrayAttr({1, 1});
+ } else {
+ // Note: hardcoded to NHWC for now
+ int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
+ int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
+ dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
+ }
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tf_conv_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ tensorflow::TensorFormat data_format_tf;
+ if (!FormatFromString(tf_conv_op.data_format().str(), &data_format_tf))
+ return failure();
+
+ if (tf_pad == tensorflow::Padding::EXPLICIT) {
+ outpad = getPaddingValuesFromExplicitPadAttr(
+ tf_conv_op.explicit_paddings(), data_format_tf, rewriter);
+ } else {
+ if (!getTransposeConv2dPaddingValues(tf_pad, data_format_tf,
+ 0, // tensorflow::FORMAT_HWIO,
+ input_type, filter_type, output_type,
+ stride, dilation, rewriter, outpad))
+ return failure();
+ }
+ }
+ {
+ ElementsAttr output_shape_elems;
+ // Match from input_sizes tensor first.
+ if (matchPattern(tf_conv_op.input_sizes(),
+ m_Constant(&output_shape_elems))) {
+ llvm::SmallVector<int64_t, 4> shape_vec;
+ for (int i = 0; i < output_shape_elems.getNumElements(); i++)
+ shape_vec.push_back(
+ output_shape_elems.getValue<IntegerAttr>(i).getInt());
+ output_shape = rewriter.getI64ArrayAttr(shape_vec);
+ } else {
+ // Use output tensor's shape otherwise.
+ output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
+ }
+ }
+
+ SmallVector<float, 8> zero_bias_vec(output_type.getShape()[3], 0.0f);
+ Value zero_bias =
+ get1DConstTensor<tosa::ConstOp, float>(rewriter, op, zero_bias_vec);
+
+ rewriter.replaceOpWithNewOp<tosa::TransposeConv2DOp>(
+ op, output_type, tf_conv_op.out_backprop(),
+ a1_filter_transpose_op.getResult(), zero_bias, outpad, stride, dilation,
+ output_shape);
+
+ return success();
+}
+
+LogicalResult ConvertTFAllOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_all_op = cast<TF::AllOp>(op);
+
+ auto output_type =
+ tf_all_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tf_all_op.reduction_indices(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tf_all_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceAllOp(
+ rewriter, op, output_type, tf_all_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_any_op = cast<TF::AnyOp>(op);
+
+ auto output_type =
+ tf_any_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tf_any_op.reduction_indices(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tf_any_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceAnyOp(
+ rewriter, op, output_type, tf_any_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_max_op = cast<TF::MaxOp>(op);
+
+ auto output_type =
+ tf_max_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tf_max_op.reduction_indices(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tf_max_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceMaxOp(
+ rewriter, op, output_type, tf_max_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFMinOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_min_op = cast<TF::MinOp>(op);
+
+ auto output_type =
+ tf_min_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tf_min_op.reduction_indices(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tf_min_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceMinOp(
+ rewriter, op, output_type, tf_min_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFMeanOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_mean_op = cast<TF::MeanOp>(op);
+
+ auto output_type =
+ tf_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tf_mean_op.reduction_indices(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tf_mean_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceMeanOp(
+ rewriter, op, output_type, tf_mean_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFProdOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_prod_op = cast<TF::ProdOp>(op);
+
+ auto output_type =
+ tf_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tf_prod_op.reduction_indices(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tf_prod_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceProdOp(
+ rewriter, op, output_type, tf_prod_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFSumOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_sum_op = cast<TF::SumOp>(op);
+
+ auto output_type =
+ tf_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tf_sum_op.reduction_indices(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tf_sum_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceSumOp(
+ rewriter, op, output_type, tf_sum_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFEluOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_elu_op = cast<TF::EluOp>(op);
+
+ auto lowered_op =
+ convertEluOp(rewriter, op, tf_elu_op.getResult(), tf_elu_op.features());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFSoftmaxOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_softmax_op = cast<TF::SoftmaxOp>(op);
+
+ auto lowered_op = convertSoftmaxOp(rewriter, op, tf_softmax_op.getResult(),
+ tf_softmax_op.logits());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLogSoftmaxOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_logsoftmax_op = cast<TF::LogSoftmaxOp>(op);
+
+ auto lowered_op = convertLogSoftmaxOp(
+ rewriter, op, tf_logsoftmax_op.getResult(), tf_logsoftmax_op.logits());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_batchnorm_op = cast<TF::FusedBatchNormOp>(op);
+
+ auto output_type =
+ tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ // Lowering:
+ // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
+ //
+ // shape_0 = ones(input.rank)
+ // shape_0[input.rank-1] = input.shape[input.rank-1]
+ // shape_1 = ones(1)
+ //
+ // bmean = reshape(mean, shape_0)
+ // bscale = reshape(scale, shape_0)
+ // boffset= reshape(offset, shape_0)
+ // beps = reshape(epsilon, shape_1)
+ //
+ // op1 = sub(input, bmean)
+ // op2 = add(var, beps)
+ // op3 = rsqrt(op2)
+ // bvar = reshape(op3, shape_0)
+ // op4 = mul(op1, bvar)
+ // op5 = mul(op4, bscale)
+ // op6 = add(op5, boffset)
+
+ auto mean_type =
+ tf_batchnorm_op.mean().getType().dyn_cast<RankedTensorType>();
+ auto variance_type =
+ tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
+ if (!variance_type || !mean_type) return failure();
+
+ Value mean_val, variance_val;
+
+ if (mean_type.getNumElements() == 0) {
+ mean_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 0);
+ } else {
+ mean_val = tf_batchnorm_op.mean();
+ }
+
+ if (variance_type.getNumElements() == 0) {
+ variance_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 1.0);
+ } else {
+ variance_val = tf_batchnorm_op.variance();
+ }
+
+ auto epsilon_type =
+ RankedTensorType::get({1}, variance_type.getElementType());
+ auto epsilon_attr =
+ DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
+ auto epsilon_const =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), epsilon_type, epsilon_attr);
+
+ auto op1_sub_input_mean = rewriter.create<tosa::SubOp>(
+ op->getLoc(), tf_batchnorm_op.getResult(0).getType(), tf_batchnorm_op.x(),
+ mean_val);
+
+ auto op2_add_var_epsilon =
+ rewriter.create<tosa::AddOp>(op->getLoc(), variance_val.getType(),
+ variance_val, epsilon_const.getResult());
+
+ auto op3_rsqrt_op2 = rewriter.create<tosa::RsqrtOp>(
+ op->getLoc(), variance_val.getType(), op2_add_var_epsilon.getResult());
+
+ auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
+ op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
+ op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
+
+ auto op5_mul_op4_scale = rewriter.create<tosa::MulOp>(
+ op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
+ op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
+
+ auto op6_add_op5_offset = rewriter.create<tosa::AddOp>(
+ op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
+ op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
+
+ rewriter.replaceOp(op, {op6_add_op5_offset.getResult()});
+ return success();
+}
+
+LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_batchnorm_op = cast<TF::FusedBatchNormV3Op>(op);
+
+ auto output_type =
+ tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ // Lowering:
+ // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
+ // op1 = sub(input, mean)
+ // op2 = add(var, epsilon)
+ // op3 = rsqrt(op2)
+ // op4 = mul(op1, op3)
+ // op5 = mul(op4, scale)
+ // op6 = add(op5, offset)
+
+ auto op1_sub_input_mean = rewriter.create<tosa::SubOp>(
+ op->getLoc(), tf_batchnorm_op.getResult(0).getType(), tf_batchnorm_op.x(),
+ tf_batchnorm_op.mean());
+
+ auto variance_type =
+ tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
+ if (!variance_type) return failure();
+
+ auto epsilon_type =
+ RankedTensorType::get({1}, variance_type.getElementType());
+ auto epsilon_attr =
+ DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
+ auto epsilon_const =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), epsilon_type, epsilon_attr);
+
+ auto op2_add_var_epsilon = rewriter.create<tosa::AddOp>(
+ op->getLoc(), tf_batchnorm_op.variance().getType(),
+ tf_batchnorm_op.variance(), epsilon_const);
+
+ auto op3_rsqrt_op2 = rewriter.create<tosa::RsqrtOp>(
+ op->getLoc(), tf_batchnorm_op.variance().getType(),
+ op2_add_var_epsilon.getResult());
+
+ auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
+ op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
+ op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
+
+ auto op5_mul_op4_scale = rewriter.create<tosa::MulOp>(
+ op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
+ op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
+
+ auto op6_add_op5_offset = rewriter.create<tosa::AddOp>(
+ op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
+ op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
+
+ rewriter.replaceOp(op, {op6_add_op5_offset.getResult()});
+ return success();
+}
+
+LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
+
+ auto output_type =
+ tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto add_op = rewriter.create<tosa::AddOp>(
+ op->getLoc(), output_type, tf_biasadd_op.value(), tf_biasadd_op.bias());
+
+ rewriter.replaceOp(op, {add_op.getResult()});
+ return success();
+}
+
+LogicalResult ConvertTFSliceOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_slice_op = cast<TF::SliceOp>(op);
+
+ auto output_type =
+ tf_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ ElementsAttr begin_elems, size_elems;
+
+ SmallVector<int64_t, 4> begin_vals, size_vals;
+
+ // Assuming begin is always compile-time constant
+ if (!matchPattern(tf_slice_op.begin(), m_Constant(&begin_elems))) {
+ return op->emitOpError("TF::Slice error: begin is not constant");
+ }
+
+ for (int i = 0; i < begin_elems.getNumElements(); i++)
+ begin_vals.push_back(begin_elems.getValue<IntegerAttr>(i).getInt());
+
+ // Try to match size as compile-time constant first,
+ // if this fails, use the output tensor shape instead.
+ if (matchPattern(tf_slice_op.size(), m_Constant(&size_elems))) {
+ for (int i = 0; i < size_elems.getNumElements(); i++)
+ size_vals.push_back(size_elems.getValue<IntegerAttr>(i).getInt());
+ } else {
+ size_vals.assign(output_type.getShape().begin(),
+ output_type.getShape().end());
+ }
+
+ ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
+ ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
+
+ rewriter.replaceOpWithNewOp<tosa::SliceOp>(op, output_type,
+ tf_slice_op.input(), begin, size);
+ return success();
+}
+
+LogicalResult ConvertTFTileOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_tile_op = cast<TF::TileOp>(op);
+
+ auto output_type =
+ tf_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ ElementsAttr multiples_elems;
+ if (!matchPattern(tf_tile_op.multiples(), m_Constant(&multiples_elems)))
+ return failure();
+ SmallVector<int64_t, 4> multiples_vals;
+ for (int i = 0; i < multiples_elems.getNumElements(); i++)
+ multiples_vals.push_back(multiples_elems.getValue<IntegerAttr>(i).getInt());
+
+ ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
+
+ rewriter.replaceOpWithNewOp<tosa::TileOp>(op, output_type, tf_tile_op.input(),
+ multiples_attr);
+
+ return success();
+}
+
+LogicalResult ConvertTFTransposeOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_transpose_op = cast<TF::TransposeOp>(op);
+
+ auto output_type =
+ tf_transpose_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) {
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
+ op, output_type, tf_transpose_op.x(), tf_transpose_op.perm());
+
+ return success();
+}
+
+LogicalResult ConvertTFPackOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_pack_op = cast<TF::PackOp>(op);
+
+ SmallVector<Value, 8> inputs(tf_pack_op.values());
+
+ assert(inputs.size() >= 2);
+
+ IntegerAttr axis_attr;
+ {
+ auto tmpAttr = tf_pack_op.axisAttr();
+ if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
+ axis_attr = tmpAttr;
+ }
+ int32_t axis_i32 = axis_attr.getInt();
+
+ auto lowered_op =
+ convertPackOp(rewriter, op, tf_pack_op.getResult(), inputs, axis_i32);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFUnpackOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_unpack_op = cast<TF::UnpackOp>(op);
+
+ IntegerAttr axis_attr;
+ {
+ auto tmpAttr = tf_unpack_op.axisAttr();
+ if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
+ axis_attr = tmpAttr;
+ }
+ int32_t axis_i32 = axis_attr.getInt();
+
+ auto lowered_op =
+ convertUnpackOp(rewriter, op, tf_unpack_op.value(), axis_i32);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+// Splits in num_split parts along split_dim
+LogicalResult ConvertTFSplitOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_split_op = cast<TF::SplitOp>(op);
+
+ // Get the number of splits
+ int32_t num_split = -1;
+
+ auto range = tf_split_op.getODSResults(0);
+ num_split = std::distance(range.begin(), range.end());
+
+ // Get the axis
+ int32_t axis = 0;
+ ElementsAttr axisAttrElems;
+ if (matchPattern(tf_split_op.split_dim(), m_Constant(&axisAttrElems))) {
+ axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
+ }
+
+ auto lowered_op = convertSplitOp(rewriter, op, tf_split_op.getResult(0),
+ tf_split_op.value(), num_split, axis);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+// TFSplitV op splits based on a vector of sizes
+LogicalResult ConvertTFSplitVOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_splitv_op = cast<TF::SplitVOp>(op);
+
+ // Get the size_splits array
+ SmallVector<int32_t, 4> size_split;
+ ElementsAttr size_split_elems;
+ if (!matchPattern(tf_splitv_op.size_splits(),
+ m_Constant(&size_split_elems))) {
+ return failure();
+ }
+
+ for (int i = 0; i < size_split_elems.getNumElements(); i++) {
+ size_split.push_back(size_split_elems.getValue<IntegerAttr>(i).getInt());
+ }
+
+ // Get the axis
+ ElementsAttr axisAttrElems;
+ if (!matchPattern(tf_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
+ return op->emitOpError("Cannot read split_dim elems");
+ }
+
+ int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
+
+ auto lowered_op = convertSplitVOp(rewriter, op, tf_splitv_op.getResult(0),
+ tf_splitv_op.value(), size_split, axis);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLessOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_less_op = cast<TF::LessOp>(op);
+
+ auto output_type =
+ tf_less_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ // less(x, y) is not(greater_equal(x, y))
+ auto greater_equal_op = rewriter.create<tosa::GreaterEqualOp>(
+ op->getLoc(), output_type, tf_less_op.x(), tf_less_op.y());
+
+ auto not_op = rewriter.create<tosa::LogicalNotOp>(
+ op->getLoc(), output_type, greater_equal_op.getResult());
+
+ rewriter.replaceOp(op, {not_op.getResult()});
+ return success();
+}
+
+LogicalResult ConvertTFLessEqualOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_less_equal_op = cast<TF::LessEqualOp>(op);
+
+ auto output_type =
+ tf_less_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ // less_equal(x, y) is not(greater(x, y))
+ auto greater_op = rewriter.create<tosa::GreaterOp>(
+ op->getLoc(), output_type, tf_less_equal_op.x(), tf_less_equal_op.y());
+
+ auto not_op = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), output_type,
+ greater_op.getResult());
+
+ rewriter.replaceOp(op, {not_op.getResult()});
+ return success();
+}
+
+LogicalResult ConvertTFPadOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_pad_op = cast<TF::PadOp>(op);
+
+ auto output_type =
+ tf_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto pad_op = rewriter.create<tosa::PadOp>(
+ op->getLoc(), output_type, tf_pad_op.input(), tf_pad_op.paddings());
+
+ rewriter.replaceOp(op, {pad_op.getResult()});
+ return success();
+}
+
+LogicalResult ConvertTFResizeBilinearOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_resize_op = cast<TF::ResizeBilinearOp>(op);
+
+ auto output_type =
+ tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto lowered_op = convertResizeOp(
+ rewriter, op, output_type, tf_resize_op.images(), StringRef("BILINEAR"));
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFResizeNearestNeighborOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_resize_op = cast<TF::ResizeNearestNeighborOp>(op);
+
+ auto output_type =
+ tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto lowered_op = convertResizeOp(
+ rewriter, op, output_type, tf_resize_op.images(), StringRef("NEAREST"));
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFMatMulOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_matmul_op = cast<TF::MatMulOp>(op);
+
+ auto a_type = tf_matmul_op.a().getType().dyn_cast<RankedTensorType>();
+ auto b_type = tf_matmul_op.b().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_matmul_op.getResult().getType().dyn_cast<RankedTensorType>();
+
+ if (!(a_type && b_type && output_type)) {
+ return op->emitOpError("MatMul: a/b/output not ranked tensors");
+ }
+
+ // Can only handle rank=2 inputs
+ if (a_type.getShape().size() != 2) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::MatMulOp>(op, output_type, tf_matmul_op.a(),
+ tf_matmul_op.b());
+
+ return success();
+}
+
+LogicalResult ConvertTFGatherOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_gather_op = cast<TF::GatherOp>(op);
+
+ auto output_type =
+ tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ IntegerAttr axis_attr = rewriter.getI32IntegerAttr(0);
+
+ // TODO: batchdim_attr handling to be implemented with a revised
+ // defintion of the TOSA operator.
+ rewriter.replaceOpWithNewOp<tosa::GatherOp>(
+ op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
+ axis_attr);
+
+ return success();
+}
+
+LogicalResult ConvertTFGatherV2Op::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_gather_op = cast<TF::GatherV2Op>(op);
+
+ auto output_type =
+ tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ // Axis is a tensor in TF. Convert to I64Attr for TOSA
+ ElementsAttr axis_elem;
+ if (!matchPattern(tf_gather_op.axis(), m_Constant(&axis_elem)))
+ return failure();
+ assert(axis_elem.getType().getRank() == 0 && "expected 0D tensor");
+
+ IntegerAttr batchdim_attr;
+ {
+ auto tmpAttr = tf_gather_op.batch_dimsAttr();
+ if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
+ batchdim_attr = tmpAttr;
+ }
+
+ // TODO: batchdim_attr handling to be implemented with a revised
+ // defintion of the TOSA operator.
+ rewriter.replaceOpWithNewOp<tosa::GatherOp>(
+ op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
+ rewriter.getI32IntegerAttr(axis_elem.getValue<IntegerAttr>({}).getInt()));
+
+ return success();
+}
+
+LogicalResult ConvertTFSelectV2Op::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_sel_op = cast<TF::SelectV2Op>(op);
+
+ auto lowered_op =
+ convertSelectOp(rewriter, op, tf_sel_op.getResult(),
+ tf_sel_op.condition(), tf_sel_op.t(), tf_sel_op.e());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFSpaceToDepthOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_s2d_op = cast<TF::SpaceToDepthOp>(op);
+
+ auto lowered_op = convertSpaceToDepthOp(
+ rewriter, op, tf_s2d_op.getResult(), tf_s2d_op.input(),
+ tf_s2d_op.block_sizeAttr(), tf_s2d_op.data_formatAttr());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFDepthToSpaceOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_d2s_op = cast<TF::DepthToSpaceOp>(op);
+
+ auto lowered_op = convertDepthToSpaceOp(
+ rewriter, op, tf_d2s_op.getResult(), tf_d2s_op.input(),
+ tf_d2s_op.block_sizeAttr(), tf_d2s_op.data_formatAttr());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFSpaceToBatchNDOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_s2b_op = cast<TF::SpaceToBatchNDOp>(op);
+
+ auto lowered_op = convertSpaceToBatchNDOp(
+ rewriter, op, tf_s2b_op.getResult(), tf_s2b_op.input(),
+ tf_s2b_op.block_shape(), tf_s2b_op.paddings());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFBatchToSpaceNDOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_b2s_op = cast<TF::BatchToSpaceNDOp>(op);
+
+ auto lowered_op = convertBatchToSpaceNDOp(
+ rewriter, op, tf_b2s_op.getResult(), tf_b2s_op.input(),
+ tf_b2s_op.block_shape(), tf_b2s_op.crops());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_ss_op = cast<TF::StridedSliceOp>(op);
+
+ auto lowered_op = convertStridedSliceOp(
+ rewriter, op, tf_ss_op.getResult(), tf_ss_op.input(), tf_ss_op.begin(),
+ tf_ss_op.end(), tf_ss_op.strides(), tf_ss_op.begin_maskAttr().getInt(),
+ tf_ss_op.end_maskAttr().getInt(), tf_ss_op.ellipsis_maskAttr().getInt(),
+ tf_ss_op.new_axis_maskAttr().getInt(),
+ tf_ss_op.shrink_axis_maskAttr().getInt());
+ TOSA_REPLACE_LOWERED_OP_LIST(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFZerosLikeOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_zeroslike_op = cast<TF::ZerosLikeOp>(op);
+
+ auto lowered_op = convertZerosLikeOp(
+ rewriter, op, tf_zeroslike_op.getResult(), tf_zeroslike_op.x());
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFSigmoidOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_sigmoid_op = cast<TF::SigmoidOp>(op);
+ auto output_type =
+ tf_sigmoid_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(op, output_type,
+ tf_sigmoid_op.x());
+
+ return success();
+}
+
+LogicalResult ConvertTFTanhOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_tanh_op = cast<TF::TanhOp>(op);
+ auto output_type =
+ tf_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::TanhOp>(op, output_type, tf_tanh_op.x());
+
+ return success();
+}
+
+LogicalResult ConvertTFLeakyReluOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_leakyrelu_op = cast<TF::LeakyReluOp>(op);
+ auto output_type =
+ tf_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ // TODO: add lowering with MUL + SELECT
+
+ return failure();
+}
+
+LogicalResult ConvertTFNegOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto tf_neg_op = cast<TF::NegOp>(op);
+ auto output_type =
+ tf_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::NegateOp>(op, output_type, tf_neg_op.x());
+
+ return success();
+}
+
+LogicalResult ConvertTFStopGradientOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_stopgrad_op = cast<TF::StopGradientOp>(op);
+ auto output_type =
+ tf_stopgrad_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::IdentityOp>(op, output_type,
+ tf_stopgrad_op.input());
+
+ return success();
+}
+
+LogicalResult ConvertTFReverseV2Op::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_reverse_op = cast<TF::ReverseV2Op>(op);
+ auto input_type =
+ tf_reverse_op.tensor().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tf_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!input_type || !output_type) return failure();
+
+ ElementsAttr axis_elems;
+ if (!matchPattern(tf_reverse_op.axis(), m_Constant(&axis_elems)))
+ return failure();
+
+ auto input_rank = input_type.getShape().size();
+ Value val = tf_reverse_op.tensor();
+ if (axis_elems.getNumElements() == 0) {
+ auto identity_op =
+ rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
+ val = identity_op.getResult();
+ } else {
+ for (int i = 0; i < axis_elems.getNumElements(); i++) {
+ int64_t axis_val = axis_elems.getValue<IntegerAttr>(i).getInt();
+ if (axis_val < 0) axis_val += input_rank;
+ auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
+ auto reverse_op = rewriter.create<tosa::ReverseOp>(
+ op->getLoc(), output_type, val, axis_attr);
+
+ val = reverse_op.getResult();
+ }
+ }
+
+ rewriter.replaceOp(op, {val});
+
+ return success();
+}
+
+LogicalResult ConvertTFFakeQuantWithMinMaxArgsOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxArgsOp>(op);
+
+ auto output_type =
+ tf_fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto lowered_op =
+ convertFakeQuantOp(rewriter, op, output_type, tf_fakequant_op.inputs(),
+ tf_fakequant_op.minAttr().getValueAsDouble(),
+ tf_fakequant_op.maxAttr().getValueAsDouble(),
+ tf_fakequant_op.num_bitsAttr().getInt(),
+ tf_fakequant_op.narrow_rangeAttr().getValue());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFFakeQuantWithMinMaxVarsOp::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
+
+ auto output_type =
+ tf_fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ // Only support min/max that can be matched at compile time
+ ElementsAttr min_elems, max_elems;
+ if (!matchPattern(tf_fakequant_op.min(), m_Constant(&min_elems)))
+ return failure();
+
+ if (!matchPattern(tf_fakequant_op.max(), m_Constant(&max_elems)))
+ return failure();
+
+ if (min_elems.getNumElements() != 1 && max_elems.getNumElements() != 1)
+ return failure();
+
+ int64_t min_val = min_elems.getValue<IntegerAttr>(0).getInt();
+ int64_t max_val = max_elems.getValue<IntegerAttr>(0).getInt();
+
+ auto lowered_op = convertFakeQuantOp(
+ rewriter, op, output_type, tf_fakequant_op.inputs(), min_val, max_val,
+ tf_fakequant_op.num_bitsAttr().getInt(),
+ tf_fakequant_op.narrow_rangeAttr().getValue());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+void LegalizeTF::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto *ctx = &getContext();
+ auto func = getFunction();
+
+ // Add the generated patterns to the list.
+ populateWithGenerated(ctx, patterns);
+ patterns.insert<ConvertTFMatMulOp>(ctx);
+ patterns.insert<ConvertTFReluOp>(ctx);
+ patterns.insert<ConvertTFRelu6Op>(ctx);
+ patterns.insert<ConvertTFEqualOp>(ctx);
+ patterns.insert<ConvertTFNotEqualOp>(ctx);
+ patterns.insert<ConvertTFGreaterOp>(ctx);
+ patterns.insert<ConvertTFGreaterEqualOp>(ctx);
+ patterns.insert<ConvertTFAddOp>(ctx);
+ patterns.insert<ConvertTFAddV2Op>(ctx);
+ patterns.insert<ConvertTFAddNOp>(ctx);
+ patterns.insert<ConvertTFSubOp>(ctx);
+ patterns.insert<ConvertTFMulOp>(ctx);
+ patterns.insert<ConvertTFSquareOp>(ctx);
+ patterns.insert<ConvertTFSquaredDifferenceOp>(ctx);
+ patterns.insert<ConvertTFRoundOp>(ctx);
+ patterns.insert<ConvertTFFloorDivOp>(ctx);
+ patterns.insert<ConvertTFFloorModOp>(ctx);
+ patterns.insert<ConvertTFAssertOp>(ctx);
+ patterns.insert<ConvertTFMaximumOp>(ctx);
+ patterns.insert<ConvertTFMinimumOp>(ctx);
+ patterns.insert<ConvertTFRealDivOp>(ctx);
+ patterns.insert<ConvertTFArgMaxOp>(ctx);
+ patterns.insert<ConvertTFAvgPoolOp>(ctx);
+ patterns.insert<ConvertTFMaxPoolOp>(ctx);
+ patterns.insert<ConvertTFConcatV2Op>(ctx);
+ patterns.insert<ConvertTFReshapeOp>(ctx);
+ patterns.insert<ConvertTFRankOp>(ctx);
+ patterns.insert<ConvertTFShapeOp>(ctx);
+ patterns.insert<ConvertTFExpandDimsOp>(ctx);
+ patterns.insert<ConvertTFSqueezeOp>(ctx);
+ patterns.insert<ConvertTFFillOp>(ctx);
+ patterns.insert<ConvertTFConv2DOp>(ctx);
+ patterns.insert<ConvertTFDepthwiseConv2dNativeOp>(ctx);
+ patterns.insert<ConvertTFConv2DBackpropInputOp>(ctx);
+ patterns.insert<ConvertTFEluOp>(ctx);
+ patterns.insert<ConvertTFSoftmaxOp>(ctx);
+ patterns.insert<ConvertTFLogSoftmaxOp>(ctx);
+ patterns.insert<ConvertTFAllOp>(ctx);
+ patterns.insert<ConvertTFAnyOp>(ctx);
+ patterns.insert<ConvertTFMaxOp>(ctx);
+ patterns.insert<ConvertTFMinOp>(ctx);
+ patterns.insert<ConvertTFMeanOp>(ctx);
+ patterns.insert<ConvertTFProdOp>(ctx);
+ patterns.insert<ConvertTFSumOp>(ctx);
+ patterns.insert<ConvertTFFusedBatchNormOp>(ctx);
+ patterns.insert<ConvertTFFusedBatchNormV3Op>(ctx);
+ patterns.insert<ConvertTFBiasAddOp>(ctx);
+ patterns.insert<ConvertTFSplitOp>(ctx);
+ patterns.insert<ConvertTFSplitVOp>(ctx);
+ patterns.insert<ConvertTFPackOp>(ctx);
+ patterns.insert<ConvertTFUnpackOp>(ctx);
+ patterns.insert<ConvertTFTransposeOp>(ctx);
+ patterns.insert<ConvertTFTileOp>(ctx);
+ patterns.insert<ConvertTFSliceOp>(ctx);
+ patterns.insert<ConvertTFStridedSliceOp>(ctx);
+ patterns.insert<ConvertTFLessOp>(ctx);
+ patterns.insert<ConvertTFLessEqualOp>(ctx);
+ patterns.insert<ConvertTFPadOp>(ctx);
+ patterns.insert<ConvertTFResizeBilinearOp>(ctx);
+ patterns.insert<ConvertTFResizeNearestNeighborOp>(ctx);
+ patterns.insert<ConvertTFGatherOp>(ctx);
+ patterns.insert<ConvertTFGatherV2Op>(ctx);
+ patterns.insert<ConvertTFSelectV2Op>(ctx);
+ patterns.insert<ConvertTFSpaceToDepthOp>(ctx);
+ patterns.insert<ConvertTFDepthToSpaceOp>(ctx);
+ patterns.insert<ConvertTFSpaceToBatchNDOp>(ctx);
+ patterns.insert<ConvertTFBatchToSpaceNDOp>(ctx);
+ patterns.insert<ConvertTFZerosLikeOp>(ctx);
+ patterns.insert<ConvertTFSigmoidOp>(ctx);
+ patterns.insert<ConvertTFTanhOp>(ctx);
+ patterns.insert<ConvertTFLeakyReluOp>(ctx);
+ patterns.insert<ConvertTFNegOp>(ctx);
+ patterns.insert<ConvertTFStopGradientOp>(ctx);
+ patterns.insert<ConvertTFReverseV2Op>(ctx);
+ patterns.insert<ConvertTFFakeQuantWithMinMaxArgsOp>(ctx);
+ patterns.insert<ConvertTFFakeQuantWithMinMaxVarsOp>(ctx);
+ applyPatternsAndFoldGreedily(func, std::move(patterns));
+}
+
+} // anonymous namespace
+
+// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass() {
+ return std::make_unique<LegalizeTF>();
+}
+
+static PassRegistration<LegalizeTF> pass(
+ PASS_NAME, "Legalize from TensorFlow to TOSA dialect");
+
+} // namespace tosa
+
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
new file mode 100644
index 0000000..c8a5b16
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
@@ -0,0 +1,2838 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legalize TensorFlow Lite to TOSA
+
+#include <climits>
+#include <cstddef>
+#include <cstdint>
+#include <fstream>
+#include <iterator>
+#include <numeric>
+#include <unordered_set>
+
+#include "mlir/Dialect/Quant/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
+
+#define PASS_NAME "tosa-legalize-tfl"
+#define DEBUG_TYPE PASS_NAME
+#define HARDSWISH_EXPLICIT_RESCALING false
+
+// Conditionally avoid converting some TFLite ops to TOSA.
+// By default, all conversions will be invoked.
+//
+// The denylist file lists patterns which are not legalized from TFLite to TOSA.
+llvm::cl::opt<std::string> tfl_tosa_denylist(
+ "tfl-tosa-denylist",
+ llvm::cl::desc("<a list of patterns not legalized from TFLite to TOSA>"),
+ llvm::cl::init("transforms/tfl_tosa_denylist.txt"),
+ llvm::cl::value_desc("pattern name"));
+
+namespace mlir {
+
+namespace tosa {
+
+namespace {
+// Performs lowering to TOSA dialect.
+class LegalizeTFL : public PassWrapper<LegalizeTFL, FunctionPass> {
+ public:
+ explicit LegalizeTFL() {}
+ void runOnFunction() override;
+};
+
+#include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
+
+#define DECL_CONVERT_OP(tfl_op) \
+ struct ConvertTFL##tfl_op##Op : public RewritePattern { \
+ explicit ConvertTFL##tfl_op##Op(MLIRContext* context) \
+ : RewritePattern(TFL::tfl_op##Op::getOperationName(), 1, context) {} \
+ LogicalResult matchAndRewrite(Operation* op, \
+ PatternRewriter& rewriter) const override; \
+ }
+DECL_CONVERT_OP(Relu);
+DECL_CONVERT_OP(Relu6);
+DECL_CONVERT_OP(Equal);
+DECL_CONVERT_OP(NotEqual);
+DECL_CONVERT_OP(Greater);
+DECL_CONVERT_OP(GreaterEqual);
+DECL_CONVERT_OP(Add);
+DECL_CONVERT_OP(Sub);
+DECL_CONVERT_OP(Mul);
+DECL_CONVERT_OP(Square);
+DECL_CONVERT_OP(SquaredDifference);
+DECL_CONVERT_OP(Round);
+DECL_CONVERT_OP(Div);
+DECL_CONVERT_OP(Maximum);
+DECL_CONVERT_OP(Minimum);
+DECL_CONVERT_OP(FloorMod);
+DECL_CONVERT_OP(FloorDiv);
+DECL_CONVERT_OP(AddN);
+DECL_CONVERT_OP(AveragePool2D);
+DECL_CONVERT_OP(MaxPool2D);
+DECL_CONVERT_OP(Concatenation);
+DECL_CONVERT_OP(Reshape);
+DECL_CONVERT_OP(Rank);
+DECL_CONVERT_OP(Shape);
+DECL_CONVERT_OP(ExpandDims);
+DECL_CONVERT_OP(Squeeze);
+DECL_CONVERT_OP(Fill);
+DECL_CONVERT_OP(Elu);
+DECL_CONVERT_OP(Softmax);
+DECL_CONVERT_OP(LogSoftmax);
+DECL_CONVERT_OP(ReduceAny);
+DECL_CONVERT_OP(ReduceMax);
+DECL_CONVERT_OP(ReduceMin);
+DECL_CONVERT_OP(Mean);
+DECL_CONVERT_OP(ReduceProd);
+DECL_CONVERT_OP(Sum);
+DECL_CONVERT_OP(Conv2D);
+DECL_CONVERT_OP(TransposeConv);
+DECL_CONVERT_OP(DepthwiseConv2D);
+DECL_CONVERT_OP(FullyConnected);
+DECL_CONVERT_OP(Split);
+DECL_CONVERT_OP(SplitV);
+DECL_CONVERT_OP(Pack);
+DECL_CONVERT_OP(Unpack);
+DECL_CONVERT_OP(Transpose);
+DECL_CONVERT_OP(Tile);
+DECL_CONVERT_OP(Slice);
+DECL_CONVERT_OP(StridedSlice);
+DECL_CONVERT_OP(HardSwish);
+DECL_CONVERT_OP(ZerosLike);
+DECL_CONVERT_OP(Less);
+DECL_CONVERT_OP(LessEqual);
+DECL_CONVERT_OP(Pad);
+DECL_CONVERT_OP(ResizeBilinear);
+DECL_CONVERT_OP(ResizeNearestNeighbor);
+DECL_CONVERT_OP(Select);
+DECL_CONVERT_OP(SelectV2);
+DECL_CONVERT_OP(SpaceToBatchNd);
+DECL_CONVERT_OP(BatchToSpaceNd);
+DECL_CONVERT_OP(SpaceToDepth);
+DECL_CONVERT_OP(DepthToSpace);
+DECL_CONVERT_OP(Logistic);
+DECL_CONVERT_OP(Tanh);
+DECL_CONVERT_OP(PRelu);
+DECL_CONVERT_OP(LeakyRelu);
+DECL_CONVERT_OP(Neg);
+DECL_CONVERT_OP(Yield);
+DECL_CONVERT_OP(Custom);
+DECL_CONVERT_OP(ReverseV2);
+DECL_CONVERT_OP(Quantize);
+DECL_CONVERT_OP(Dequantize);
+DECL_CONVERT_OP(QConst);
+#undef DECL_CONVERT_OP
+
+// TODO: remove macro when replacing common function return types with
+// llvm::Optional<> Helper macros for checking the return value of a common
+// legalization function that returns a single tensor.
+// Packs the result in a list.
+#define TOSA_REPLACE_LOWERED_OP(REWRITER, OP, LOWERED_OP) \
+ if (LOWERED_OP) { \
+ REWRITER.replaceOp((OP), {(LOWERED_OP)->getResults()}); \
+ return success(); \
+ } else { \
+ return failure(); \
+ }
+
+// TODO: remove macro when replacing common function return types with
+// llvm::Optional<> Helper macros for checking the return value of a common
+// legalization function that returns a tensor list.
+#define TOSA_REPLACE_LOWERED_OP_LIST(REWRITER, OP, LOWERED_OP) \
+ if (LOWERED_OP) { \
+ REWRITER.replaceOp((OP), (LOWERED_OP)->getResults()); \
+ return success(); \
+ } else { \
+ return failure(); \
+ }
+
+LogicalResult ConvertTFLReluOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_relu_op = cast<TFL::ReluOp>(op);
+
+ auto input_type = tfl_relu_op.x().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_relu_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type || !output_type) return failure();
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLReluOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype) {
+ auto rescale_type =
+ RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
+ auto input_qtype = input_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto output_qtype = output_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ auto op1_rescale_in = buildRescaleToInt32(rewriter, op, tfl_relu_op.x(),
+ 1.0f, input_qtype.getZeroPoint());
+ auto op2_relun_op1 = rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), rescale_type, op1_rescale_in,
+ rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
+ rewriter.getF32FloatAttr(0.0f));
+ auto op3_rescale_op2 = buildRescaleFromInt32(
+ rewriter, op, output_type, op2_relun_op1.getResult(), 1.0f,
+ output_qtype.getZeroPoint());
+
+ output = op3_rescale_op2;
+ } else {
+ auto op1_relun_in = rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), output_type, tfl_relu_op.x(),
+ rewriter.getI64IntegerAttr(0),
+ rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
+
+ output = op1_relun_in.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+LogicalResult ConvertTFLRelu6Op::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_relu6_op = cast<TFL::Relu6Op>(op);
+
+ auto input_type = tfl_relu6_op.x().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_relu6_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type || !output_type) return failure();
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLRelu6Op: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype && input_is_qtype) {
+ auto rescale_type =
+ RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
+ auto input_qtype = input_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto output_qtype = output_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ int64_t rescaled_6 = std::llround(6.0f / input_qtype.getScale()) +
+ input_qtype.getZeroPoint();
+
+ auto op1_rescale_in = buildRescaleToInt32(rewriter, op, tfl_relu6_op.x(),
+ 1.0f, input_qtype.getZeroPoint());
+ auto op2_relun_op1 = rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), rescale_type, op1_rescale_in,
+ rewriter.getI64IntegerAttr(rescaled_6), rewriter.getF32FloatAttr(0.0f));
+ auto op3_rescale_op2 = buildRescaleFromInt32(
+ rewriter, op, output_type, op2_relun_op1.getResult(), 1.0f,
+ output_qtype.getZeroPoint());
+
+ output = op3_rescale_op2;
+ } else {
+ auto op1_relun_in = rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), output_type, tfl_relu6_op.x(),
+ rewriter.getI64IntegerAttr(0), rewriter.getF32FloatAttr(6.0f));
+
+ output = op1_relun_in.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+// TODO: Use a utility function for common code in comparison ops.
+LogicalResult ConvertTFLEqualOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_equal_op = cast<TFL::EqualOp>(op);
+
+ auto input_x_type = tfl_equal_op.x().getType().dyn_cast<RankedTensorType>();
+ auto input_y_type = tfl_equal_op.y().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_x_type || !input_y_type || !output_type) return failure();
+
+ bool input_x_is_qtype =
+ input_x_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_y_is_qtype =
+ input_y_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_x_is_qtype != output_is_qtype ||
+ input_y_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLEqualOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype && input_x_is_qtype && input_y_is_qtype) {
+ auto input_x_qtype = input_x_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto input_y_qtype = input_y_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ if (input_x_qtype.getScale() != input_y_qtype.getScale() ||
+ input_x_qtype.getZeroPoint() != input_y_qtype.getZeroPoint()) {
+ return op->emitOpError(
+ "ConvertTFLEqualOp: input_x and input_y scale/zp "
+ "must be the same");
+ }
+
+ auto op1_rescale_x = buildRescaleToInt32(
+ rewriter, op, tfl_equal_op.x(), 1.0f, input_x_qtype.getZeroPoint());
+ auto op2_rescale_y = buildRescaleToInt32(
+ rewriter, op, tfl_equal_op.y(), 1.0f, input_y_qtype.getZeroPoint());
+ auto op3_equal_op1_op2 = rewriter.create<tosa::EqualOp>(
+ op->getLoc(), output_type, op1_rescale_x, op2_rescale_y);
+
+ output = op3_equal_op1_op2.getResult();
+ } else {
+ auto op1_equal_in = rewriter.create<tosa::EqualOp>(
+ op->getLoc(), output_type, tfl_equal_op.x(), tfl_equal_op.y());
+
+ output = op1_equal_in.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+LogicalResult ConvertTFLNotEqualOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_not_equal_op = cast<TFL::NotEqualOp>(op);
+
+ auto input_lhs_type =
+ tfl_not_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type =
+ tfl_not_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_not_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLNotEqualOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
+ auto input_lhs_qtype = input_lhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto input_rhs_qtype = input_rhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
+ input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
+ return op->emitOpError(
+ "ConvertTFLNotEqualOp: input_x and input_y scale/zp "
+ "must be the same");
+ }
+
+ auto op1_rescale_lhs =
+ buildRescaleToInt32(rewriter, op, tfl_not_equal_op.lhs(), 1.0f,
+ input_lhs_qtype.getZeroPoint());
+ auto op2_rescale_rhs =
+ buildRescaleToInt32(rewriter, op, tfl_not_equal_op.rhs(), 1.0f,
+ input_rhs_qtype.getZeroPoint());
+ auto op3_equal_op1_op2 = rewriter.create<tosa::EqualOp>(
+ op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
+ auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
+ op->getLoc(), output_type, op3_equal_op1_op2.getResult());
+
+ output = op4_not_op3.getResult();
+ } else {
+ auto op1_equal_in = rewriter.create<tosa::EqualOp>(
+ op->getLoc(), output_type, tfl_not_equal_op.lhs(),
+ tfl_not_equal_op.rhs());
+ auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
+ op->getLoc(), output_type, op1_equal_in.getResult());
+
+ output = op2_not_op1.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+LogicalResult ConvertTFLGreaterOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_greater_op = cast<TFL::GreaterOp>(op);
+
+ auto input_lhs_type =
+ tfl_greater_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type =
+ tfl_greater_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_greater_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLGreaterOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
+ auto input_lhs_qtype = input_lhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto input_rhs_qtype = input_rhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
+ input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
+ return op->emitOpError(
+ "ConvertTFLGreaterOp: input_x and input_y scale/zp "
+ "must be the same");
+ }
+
+ auto op1_rescale_lhs =
+ buildRescaleToInt32(rewriter, op, tfl_greater_op.lhs(), 1.0f,
+ input_lhs_qtype.getZeroPoint());
+ auto op2_rescale_rhs =
+ buildRescaleToInt32(rewriter, op, tfl_greater_op.rhs(), 1.0f,
+ input_rhs_qtype.getZeroPoint());
+ auto op3_greater_op1_op2 = rewriter.create<tosa::GreaterOp>(
+ op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
+
+ output = op3_greater_op1_op2.getResult();
+ } else {
+ auto op1_greater_in = rewriter.create<tosa::GreaterOp>(
+ op->getLoc(), output_type, tfl_greater_op.lhs(), tfl_greater_op.rhs());
+
+ output = op1_greater_in.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+LogicalResult ConvertTFLGreaterEqualOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_greater_equal_op = cast<TFL::GreaterEqualOp>(op);
+
+ auto input_lhs_type =
+ tfl_greater_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type =
+ tfl_greater_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_greater_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLGreaterEqualOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
+ auto input_lhs_qtype = input_lhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto input_rhs_qtype = input_rhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
+ input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
+ return op->emitOpError(
+ "ConvertTFLGreaterEqualOp: input_x and input_y scale/zp "
+ "must be the same");
+ }
+
+ auto op1_rescale_lhs =
+ buildRescaleToInt32(rewriter, op, tfl_greater_equal_op.lhs(), 1.0f,
+ input_lhs_qtype.getZeroPoint());
+ auto op2_rescale_rhs =
+ buildRescaleToInt32(rewriter, op, tfl_greater_equal_op.rhs(), 1.0f,
+ input_rhs_qtype.getZeroPoint());
+ auto op3_greater_equal_op1_op2 = rewriter.create<tosa::GreaterEqualOp>(
+ op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
+
+ output = op3_greater_equal_op1_op2.getResult();
+ } else {
+ auto op1_greater_equal_in = rewriter.create<tosa::GreaterEqualOp>(
+ op->getLoc(), output_type, tfl_greater_equal_op.lhs(),
+ tfl_greater_equal_op.rhs());
+
+ output = op1_greater_equal_in.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+// TODO: Use a utility function for common code in elementwise binary ops.
+LogicalResult ConvertTFLAddOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_add_op = cast<TFL::AddOp>(op);
+
+ auto input_lhs_type = tfl_add_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type = tfl_add_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_add_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLAddOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
+ auto rescale_type =
+ RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
+ auto input_lhs_qtype = input_lhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto input_rhs_qtype = input_rhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto output_qtype = output_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ // Following quantization described in tensorflow/lite/kernels/add.cc
+ // In details it does:
+ // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
+ // 2. Extra left shift to input to increase precision
+ // Where input_shift = 20 if input is 8-bit
+ // input_shift = 15 if input is 16-bit
+ // TODO: support 16-bit
+ double in_lhs_scale = input_lhs_qtype.getScale();
+ double in_rhs_scale = input_rhs_qtype.getScale();
+ double output_scale = output_qtype.getScale();
+ double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
+
+ const int32_t SHIFT_8_BIT = 20;
+ int32_t input_shift = SHIFT_8_BIT;
+
+ double lhs_rescale_scale =
+ static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
+ double rhs_rescale_scale =
+ static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
+ double output_rescale_scale =
+ max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
+
+ auto op1_rescale_lhs =
+ buildRescaleToInt32(rewriter, op, tfl_add_op.lhs(), lhs_rescale_scale,
+ input_lhs_qtype.getZeroPoint());
+ auto op2_rescale_rhs =
+ buildRescaleToInt32(rewriter, op, tfl_add_op.rhs(), rhs_rescale_scale,
+ input_rhs_qtype.getZeroPoint());
+ auto op3_add_op1_op2 = rewriter.create<tosa::AddOp>(
+ op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
+ auto op4_rescale_op3 = buildRescaleFromInt32(
+ rewriter, op, output_type, op3_add_op1_op2.getResult(),
+ output_rescale_scale, output_qtype.getZeroPoint());
+ output = op4_rescale_op3;
+ } else {
+ auto op1_add_in = rewriter.create<tosa::AddOp>(
+ op->getLoc(), output_type, tfl_add_op.lhs(), tfl_add_op.rhs());
+
+ output = op1_add_in.getResult();
+ }
+
+ auto fused_activation_fn = tfl_add_op.fused_activation_functionAttr();
+
+ if (fused_activation_fn) {
+ auto fused_activation_op =
+ convertFusedActivation(rewriter, op, output, fused_activation_fn);
+
+ if (fused_activation_op) {
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
+ }
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+LogicalResult ConvertTFLSubOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_sub_op = cast<TFL::SubOp>(op);
+
+ auto input_lhs_type = tfl_sub_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type = tfl_sub_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_sub_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLSubOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
+ auto rescale_type =
+ RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
+ auto input_lhs_qtype = input_lhs_type.getElementType()
+ .cast<mlir::quant::UniformQuantizedType>();
+ auto input_rhs_qtype = input_rhs_type.getElementType()
+ .cast<mlir::quant::UniformQuantizedType>();
+ auto output_qtype =
+ output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
+
+ // Following quantization described in tensorflow/lite/kernels/add.cc
+ // In details it does:
+ // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
+ // 2. Extra left shift to input to increase precision
+ // Where input_shift = 20 if input is 8-bit
+ // input_shift = 15 if input is 16-bit
+ // TODO: support 16-bit
+ double in_lhs_scale = input_lhs_qtype.getScale();
+ double in_rhs_scale = input_rhs_qtype.getScale();
+ double output_scale = output_qtype.getScale();
+ double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
+
+ const int32_t SHIFT_8_BIT = 20;
+ int32_t input_shift = SHIFT_8_BIT;
+
+ double lhs_rescale_scale =
+ static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
+ double rhs_rescale_scale =
+ static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
+ double output_rescale_scale =
+ max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
+
+ auto op1_rescale_lhs =
+ buildRescaleToInt32(rewriter, op, tfl_sub_op.lhs(), lhs_rescale_scale,
+ input_lhs_qtype.getZeroPoint());
+ auto op2_rescale_rhs =
+ buildRescaleToInt32(rewriter, op, tfl_sub_op.rhs(), rhs_rescale_scale,
+ input_rhs_qtype.getZeroPoint());
+ auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
+ op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
+ auto op4_rescale_op3 = buildRescaleFromInt32(
+ rewriter, op, output_type, op3_sub_op1_op2.getResult(),
+ output_rescale_scale, output_qtype.getZeroPoint());
+ output = op4_rescale_op3;
+ } else {
+ auto op1_sub_in = rewriter.create<tosa::SubOp>(
+ op->getLoc(), output_type, tfl_sub_op.lhs(), tfl_sub_op.rhs());
+
+ output = op1_sub_in.getResult();
+ }
+
+ auto fused_activation_fn = tfl_sub_op.fused_activation_functionAttr();
+
+ if (fused_activation_fn) {
+ auto fused_activation_op =
+ convertFusedActivation(rewriter, op, output, fused_activation_fn);
+
+ if (fused_activation_op) {
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
+ }
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+LogicalResult ConvertTFLMulOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_mul_op = cast<TFL::MulOp>(op);
+
+ auto lowered_op = convertMultiplyOp(rewriter, op, tfl_mul_op.getResult(),
+ tfl_mul_op.lhs(), tfl_mul_op.rhs());
+
+ if (!lowered_op) {
+ return failure();
+ }
+
+ auto fused_activation_fn = tfl_mul_op.fused_activation_functionAttr();
+
+ if (fused_activation_fn) {
+ auto fused_activation_op = convertFusedActivation(
+ rewriter, op, lowered_op->getResult(0), fused_activation_fn);
+
+ if (fused_activation_op) {
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
+ }
+ }
+
+ rewriter.replaceOp(op, {lowered_op->getResult(0)});
+ return success();
+}
+
+LogicalResult ConvertTFLSquareOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_square_op = cast<TFL::SquareOp>(op);
+
+ auto lowered_op = convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
+ tfl_square_op.x(), tfl_square_op.x());
+
+ if (!lowered_op) {
+ return failure();
+ }
+
+ rewriter.replaceOp(op, {lowered_op->getResult(0)});
+ return success();
+}
+
+LogicalResult ConvertTFLSquaredDifferenceOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_squared_op = cast<TFL::SquaredDifferenceOp>(op);
+
+ auto lowered_op =
+ convertSquaredDifferenceOp(rewriter, op, tfl_squared_op.getResult(),
+ tfl_squared_op.lhs(), tfl_squared_op.rhs());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLRoundOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_round_op = cast<TFL::RoundOp>(op);
+
+ auto input_type = tfl_round_op.x().getType().dyn_cast<RankedTensorType>();
+ if (!input_type) {
+ return op->emitOpError("Round: input not ranked tensor type");
+ }
+
+ if (input_type.getElementType().isa<FloatType>()) {
+ auto lowered_op = convertRoundOp(rewriter, op, tfl_round_op.getResult(),
+ tfl_round_op.x());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+
+ } else {
+ // Round on int is nonsensical. Instead, replace uses of result with the
+ // input.
+ tfl_round_op.replaceAllUsesWith(tfl_round_op.x());
+ return success();
+ }
+}
+
+LogicalResult ConvertTFLDivOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_div_op = cast<TFL::DivOp>(op);
+
+ auto output_type =
+ tfl_div_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto fused_activation_fn = tfl_div_op.fused_activation_functionAttr();
+
+ auto reciprocal_op = rewriter.create<tosa::ReciprocalOp>(
+ op->getLoc(), output_type, tfl_div_op.rhs());
+ auto mul_op =
+ rewriter.create<tosa::MulOp>(op->getLoc(), output_type, tfl_div_op.lhs(),
+ reciprocal_op.getResult(), 0);
+
+ if (fused_activation_fn) {
+ auto fused_activation_op = convertFusedActivation(
+ rewriter, op, mul_op.getResult(), fused_activation_fn);
+
+ if (fused_activation_op) {
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
+ }
+ }
+
+ rewriter.replaceOp(op, {mul_op.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFLMaximumOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_max_op = cast<TFL::MaximumOp>(op);
+
+ auto input_lhs_type = tfl_max_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type = tfl_max_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
+
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLMaximumOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype) {
+ auto rescale_type =
+ RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
+
+ auto op1_rescale_lhs =
+ buildRescaleToInt32(rewriter, op, tfl_max_op.lhs(), 1.0f, 0);
+ auto op2_rescale_rhs =
+ buildRescaleToInt32(rewriter, op, tfl_max_op.rhs(), 1.0f, 0);
+ auto op3_max_op1_op2 = rewriter.create<tosa::MaximumOp>(
+ op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
+ auto op4_rescale_op3 = buildRescaleFromInt32(
+ rewriter, op, output_type, op3_max_op1_op2.getResult(), 1.0f, 0);
+
+ output = op4_rescale_op3;
+ } else {
+ auto op1_max_in = rewriter.create<tosa::MaximumOp>(
+ op->getLoc(), output_type, tfl_max_op.lhs(), tfl_max_op.rhs());
+
+ output = op1_max_in.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+
+ return success();
+}
+
+LogicalResult ConvertTFLMinimumOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_min_op = cast<TFL::MinimumOp>(op);
+
+ auto input_lhs_type = tfl_min_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type = tfl_min_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLMinimumOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype) {
+ auto rescale_type =
+ RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
+
+ auto op1_rescale_lhs =
+ buildRescaleToInt32(rewriter, op, tfl_min_op.lhs(), 1.0f, 0);
+ auto op2_rescale_rhs =
+ buildRescaleToInt32(rewriter, op, tfl_min_op.rhs(), 1.0f, 0);
+ auto op3_min_op1_op2 = rewriter.create<tosa::MinimumOp>(
+ op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
+ auto op4_rescale_op3 = buildRescaleFromInt32(
+ rewriter, op, output_type, op3_min_op1_op2.getResult(), 1.0f, 0);
+
+ output = op4_rescale_op3;
+ } else {
+ auto op1_min_in = rewriter.create<tosa::MinimumOp>(
+ op->getLoc(), output_type, tfl_min_op.lhs(), tfl_min_op.rhs());
+
+ output = op1_min_in.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+
+ return success();
+}
+
+LogicalResult ConvertTFLFloorDivOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_floordiv_op = cast<TFL::FloorDivOp>(op);
+
+ auto lowered_op =
+ convertFloorDivOp(rewriter, op, tfl_floordiv_op.getResult(),
+ tfl_floordiv_op.lhs(), tfl_floordiv_op.rhs());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLFloorModOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_floormod_op = cast<TFL::FloorModOp>(op);
+
+ auto lowered_op =
+ convertFloorModOp(rewriter, op, tfl_floormod_op.getResult(),
+ tfl_floormod_op.lhs(), tfl_floormod_op.rhs());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLAddNOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_addn_op = cast<TFL::AddNOp>(op);
+
+ auto output_type =
+ tfl_addn_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ SmallVector<Value, 4> inputs(tfl_addn_op.inputs());
+
+ assert(inputs.size() >= 2);
+
+ auto newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type,
+ inputs[0], inputs[1]);
+ for (int i = 2; i < inputs.size(); i++) {
+ newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type, inputs[i],
+ newOp.getResult());
+ }
+
+ rewriter.replaceOp(op, {newOp.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_avgpool_op = cast<TFL::AveragePool2DOp>(op);
+
+ auto input_type =
+ tfl_avgpool_op.input().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ // Kernels and strides are dimensionally ordered
+ SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
+ ArrayAttr kernel_size;
+ ArrayAttr stride;
+ ArrayAttr pad;
+ {
+ int64_t kernel_h = tfl_avgpool_op.filter_height();
+ int64_t kernel_w = tfl_avgpool_op.filter_width();
+ kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
+ // i64array is formatted as NHWC now
+ i64array[1] = kernel_h;
+ i64array[2] = kernel_w;
+ }
+ {
+ int64_t stride_h = tfl_avgpool_op.stride_h();
+ int64_t stride_w = tfl_avgpool_op.stride_w();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tfl_avgpool_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ // Pooling has no non-unit dilation
+ ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
+
+ auto filter_type = RankedTensorType::get(
+ llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
+
+ // TFLite doesn't support explicit padding
+ if (!getPaddingValuesFromPadType(
+ tf_pad,
+ tensorflow::FORMAT_NHWC, // TFLite only supports this
+ 1, // tensorflow::FORMAT_OHWI,
+ input_type, filter_type, stride, dilation, rewriter, pad))
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
+ op, output_type, tfl_avgpool_op.input(), kernel_size, stride, pad);
+ return success();
+}
+
+LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_maxpool_op = cast<TFL::MaxPool2DOp>(op);
+
+ auto input_type =
+ tfl_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ // Kernels and strides are dimensionally ordered
+ SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
+ ArrayAttr kernel_size;
+ ArrayAttr stride;
+ ArrayAttr pad;
+ {
+ int64_t kernel_h = tfl_maxpool_op.filter_height();
+ int64_t kernel_w = tfl_maxpool_op.filter_width();
+ kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
+ // i64array is formatted as NHWC now
+ i64array[1] = kernel_h;
+ i64array[2] = kernel_w;
+ }
+ {
+ int64_t stride_h = tfl_maxpool_op.stride_h();
+ int64_t stride_w = tfl_maxpool_op.stride_w();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tfl_maxpool_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ // Pooling has no non-unit dilation
+ ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
+
+ auto filter_type = RankedTensorType::get(
+ llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
+
+ // TFLite doesn't support explicit padding
+ if (!getPaddingValuesFromPadType(
+ tf_pad,
+ tensorflow::FORMAT_NHWC, // TFLite only supports this
+ 1, // tensorflow::FORMAT_OHWI,
+ input_type, filter_type, stride, dilation, rewriter, pad))
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
+ op, output_type, tfl_maxpool_op.input(), kernel_size, stride, pad);
+ return success();
+}
+
+LogicalResult ConvertTFLConv2DOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_conv2d_op = cast<TFL::Conv2DOp>(op);
+
+ auto input_type =
+ tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
+ auto filter_type =
+ tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type) return failure();
+ if (!output_type) return failure();
+ if (!filter_type) return failure();
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::QuantizedType>();
+ bool filter_is_qtype =
+ filter_type.getElementType().isa<mlir::quant::QuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::QuantizedType>();
+
+ if ((input_is_qtype != filter_is_qtype) ||
+ (input_is_qtype != output_is_qtype)) {
+ return op->emitOpError(
+ "ConvertTFLConv2DOp: input/filter/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ ArrayAttr pad;
+ ArrayAttr stride;
+ ArrayAttr dilation;
+ {
+ int64_t stride_h = tfl_conv2d_op.stride_h();
+ int64_t stride_w = tfl_conv2d_op.stride_w();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ {
+ int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
+ int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
+ dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ // TFLite doesn't support explicit padding
+ if (!getPaddingValuesFromPadType(
+ tf_pad,
+ tensorflow::FORMAT_NHWC, // TFLite only supports this
+ 1, // tensorflow::FORMAT_OHWI,
+ input_type, filter_type, stride, dilation, rewriter, pad))
+ return failure();
+ }
+
+ Value unquantized_bias =
+ getUnquantizedBias(rewriter, op, tfl_conv2d_op.bias());
+
+ auto a1_conv2d_op = rewriter.create<tosa::Conv2DOp>(
+ op->getLoc(), output_type, tfl_conv2d_op.input(), tfl_conv2d_op.filter(),
+ unquantized_bias, pad, stride, dilation);
+
+ Value conv2d_output;
+ if (input_is_qtype) {
+ conv2d_output =
+ buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
+ input_type, filter_type, output_type);
+ } else {
+ conv2d_output = a1_conv2d_op.getResult();
+ }
+
+ auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
+
+ if (fused_activation_fn) {
+ auto fused_activation_op = convertFusedActivation(
+ rewriter, op, conv2d_output, fused_activation_fn);
+
+ if (fused_activation_op) {
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
+ }
+ }
+
+ rewriter.replaceOp(op, {conv2d_output});
+
+ return success();
+}
+
+LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_conv_op = cast<TFL::TransposeConvOp>(op);
+
+ auto input_type = tfl_conv_op.input().getType().dyn_cast<RankedTensorType>();
+ auto filter_type =
+ tfl_conv_op.weights().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type) return failure();
+ if (!output_type) return failure();
+ if (!filter_type) return failure();
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::QuantizedType>();
+ bool filter_is_qtype =
+ filter_type.getElementType().isa<mlir::quant::QuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::QuantizedType>();
+
+ if ((input_is_qtype != filter_is_qtype) ||
+ (input_is_qtype != output_is_qtype)) {
+ return op->emitOpError(
+ "ConvertTFLConv2DOp: input/filter/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ ArrayAttr stride;
+ ArrayAttr dilation;
+ ArrayAttr outpad;
+ ArrayAttr output_shape;
+ {
+ int64_t stride_h = tfl_conv_op.stride_h();
+ int64_t stride_w = tfl_conv_op.stride_w();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+
+ // tfl.transpose_conv doesn't support dilations
+ dilation = rewriter.getI64ArrayAttr({1, 1});
+
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tfl_conv_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ if (!getTransposeConv2dPaddingValues(
+ tf_pad,
+ tensorflow::FORMAT_NHWC, // TFLite only supports this
+ 1, // tensorflow::FORMAT_OHWI,
+ input_type, filter_type, output_type, stride, dilation, rewriter,
+ outpad))
+ return failure();
+ }
+ {
+ ElementsAttr output_shape_elems;
+ // Match from input_size tensor first
+ if (matchPattern(tfl_conv_op.output_shape(),
+ m_Constant(&output_shape_elems))) {
+ llvm::SmallVector<int64_t, 4> shape_vec;
+ for (int i = 0; i < output_shape_elems.getNumElements(); i++)
+ shape_vec.push_back(
+ output_shape_elems.getValue<IntegerAttr>(i).getInt());
+ output_shape = rewriter.getI64ArrayAttr(shape_vec);
+ } else {
+ // Use output tensor's shape otherwise
+ output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
+ }
+ }
+
+ Value zero_bias;
+ if (input_is_qtype) {
+ uint32_t input_bits = input_type.getElementType()
+ .dyn_cast<mlir::quant::QuantizedType>()
+ .getStorageTypeIntegralWidth();
+ uint32_t weight_bits = filter_type.getElementType()
+ .dyn_cast<mlir::quant::QuantizedType>()
+ .getStorageTypeIntegralWidth();
+
+ if (input_bits == 16 && weight_bits == 8) {
+ SmallVector<int64_t, 8> zero_bias_vec(output_type.getShape()[3], 0);
+ zero_bias = get1DConstTensorInt48(rewriter, op, zero_bias_vec);
+ } else {
+ SmallVector<int32_t, 8> zero_bias_vec(output_type.getShape()[3], 0);
+ zero_bias =
+ get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, zero_bias_vec);
+ }
+ } else {
+ SmallVector<float, 8> zero_bias_vec(output_type.getShape()[3], 0.0f);
+ zero_bias =
+ get1DConstTensor<tosa::ConstOp, float>(rewriter, op, zero_bias_vec);
+ }
+
+ auto a1_conv2d_op = rewriter.create<tosa::TransposeConv2DOp>(
+ op->getLoc(), output_type, tfl_conv_op.input(), tfl_conv_op.weights(),
+ zero_bias, outpad, stride, dilation, output_shape);
+
+ Value conv2d_output;
+ if (input_is_qtype) {
+ conv2d_output =
+ buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
+ input_type, filter_type, output_type);
+ } else {
+ conv2d_output = a1_conv2d_op.getResult();
+ }
+
+ rewriter.replaceOp(op, {conv2d_output});
+
+ return success();
+}
+
+LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_conv2d_op = cast<TFL::DepthwiseConv2DOp>(op);
+
+ auto input_type =
+ tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
+ auto filter_type =
+ tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type) return failure();
+ if (!output_type) return failure();
+ if (!filter_type) return failure();
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::QuantizedType>();
+ bool filter_is_qtype =
+ filter_type.getElementType().isa<mlir::quant::QuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::QuantizedType>();
+
+ if ((input_is_qtype != filter_is_qtype) ||
+ (input_is_qtype != output_is_qtype)) {
+ return op->emitOpError(
+ "ConvertTFLConv2DOp: input/filter/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ auto filter_shape = filter_type.getShape();
+ // Operator depthwiseConv2D
+ // TFLite orders the depthwiseConv2D filter in IHWO, while TOSA orders
+ // filter in HWIO
+ //
+ // The lowering reorders the filter.
+ //
+ // a1_transpose = tosa.transpose(filter, {1, 2, 3, 0}) // HWIO
+ // a2_reshape = tosa.reshape(filter, H, W, depth_multiplier, I /
+ // depth_multiplier)
+ // a3_transpose_conv2d = tosa.transpose_conv2d(input, a2_reshape, padding,
+ // stride, dilation)
+
+ ArrayAttr pad;
+ ArrayAttr stride;
+ ArrayAttr dilation;
+ auto depth_multiplier = tfl_conv2d_op.depth_multiplierAttr();
+
+ {
+ int64_t stride_h = tfl_conv2d_op.stride_h();
+ int64_t stride_w = tfl_conv2d_op.stride_w();
+ stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
+ }
+ {
+ int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
+ int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
+ dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
+ }
+ {
+ tensorflow::Padding tf_pad;
+ if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
+ return failure();
+
+ if (!getPaddingValuesFromPadType(
+ tf_pad,
+ tensorflow::FORMAT_NHWC, // TFLite only supports this
+ 1, // tensorflow::FORMAT_OHWI,
+ input_type, filter_type, stride, dilation, rewriter, pad))
+ return failure();
+ }
+
+ llvm::SmallVector<int64_t, 4> a1_transpose_dims;
+ a1_transpose_dims.push_back(filter_shape[1]);
+ a1_transpose_dims.push_back(filter_shape[2]);
+ a1_transpose_dims.push_back(filter_shape[3]);
+ a1_transpose_dims.push_back(filter_shape[0]);
+
+ llvm::SmallVector<int64_t, 4> a2_reshape_dims;
+ a2_reshape_dims.push_back(a1_transpose_dims[0]);
+ a2_reshape_dims.push_back(a1_transpose_dims[1]);
+ a2_reshape_dims.push_back(a1_transpose_dims[2] / depth_multiplier.getInt());
+ a2_reshape_dims.push_back(depth_multiplier.getInt());
+
+ auto a1_filter_transpose_perms =
+ get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {1, 2, 3, 0});
+ auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
+ filter_type.getElementType()),
+ tfl_conv2d_op.filter(), a1_filter_transpose_perms);
+
+ auto a2_filter_reshape_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(),
+ RankedTensorType::get(ArrayRef<int64_t>(a2_reshape_dims),
+ filter_type.getElementType()),
+ a1_filter_transpose_op.getResult(),
+ rewriter.getI64ArrayAttr(a2_reshape_dims));
+
+ Value unquantized_bias =
+ getUnquantizedBias(rewriter, op, tfl_conv2d_op.bias());
+
+ auto a3_depthwise_conv2d_op = rewriter.create<tosa::DepthwiseConv2DOp>(
+ op->getLoc(), output_type, tfl_conv2d_op.input(),
+ a2_filter_reshape_op.getResult(), unquantized_bias, pad, stride,
+ dilation);
+
+ Value conv2d_output;
+ if (input_is_qtype) {
+ conv2d_output = buildRescaleOpConvOutput(
+ rewriter, op, a3_depthwise_conv2d_op.getResult(), input_type,
+ filter_type, output_type);
+ } else {
+ conv2d_output = a3_depthwise_conv2d_op.getResult();
+ }
+
+ auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
+
+ if (fused_activation_fn) {
+ auto fused_activation_op = convertFusedActivation(
+ rewriter, op, conv2d_output, fused_activation_fn);
+
+ if (fused_activation_op) {
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
+ }
+ }
+
+ rewriter.replaceOp(op, {conv2d_output});
+
+ return success();
+}
+
+LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_fc_op = cast<TFL::FullyConnectedOp>(op);
+
+ auto output_type =
+ tfl_fc_op.getResult(0).getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto input_type = tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
+ auto filter_type = tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
+ auto bias_type = tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
+ if (!input_type || !filter_type) return failure();
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::QuantizedType>();
+ bool filter_is_qtype =
+ filter_type.getElementType().isa<mlir::quant::QuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::QuantizedType>();
+
+ if ((input_is_qtype != filter_is_qtype) ||
+ (input_is_qtype != output_is_qtype)) {
+ return op->emitOpError(
+ "ConvertTFLFullyConnectedOp: input/filter/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value input_val = tfl_fc_op.input();
+
+ // tfl.fully_connected() can takes various dimension tensor as input
+ // need to reshape it to rank 2 tensor, which tosa.fully_connected only
+ // supports if input tensor is rank 4. It's not always reshaping to (dim[0] *
+ // dim[1], dim[2] * dim[3]).
+
+ // In some networks it's reshaping to (dim[0], dim[1] * dim[2] * dim[3]) so a
+ // more general way to determine the reshape's shape is by looking at filter's
+ // shape[1].
+ if (input_type.getRank() != 2) {
+ int64_t num_elems = filter_type.getShape()[1];
+ int64_t num_batch = input_type.getNumElements() / num_elems;
+ SmallVector<int64_t, 2> shape_vals({num_batch, num_elems});
+
+ auto reshape_type = RankedTensorType::get(ArrayRef<int64_t>(shape_vals),
+ input_type.getElementType());
+ auto reshape_op = rewriter.create<tosa::ReshapeOp>(
+ op->getLoc(), reshape_type, tfl_fc_op.input(),
+ rewriter.getI64ArrayAttr(shape_vals));
+
+ input_val = reshape_op.getResult();
+ }
+
+ Value bias_val;
+ if (!bias_type) {
+ // For some matmuls, the bias may actually be a "UnitType" which has no
+ // value. TOSA requires bias to be an array of output_channel_count values,
+ // so create a constant of the appropriate number and type of zeros.
+ SmallVector<int64_t, 1> bias_shape({filter_type.getShape()[0]});
+ auto bias_type = RankedTensorType::get(ArrayRef<int64_t>(bias_shape),
+ input_type.getElementType());
+
+ DenseElementsAttr bias_attr;
+ if (input_type.getElementType().isa<FloatType>()) {
+ SmallVector<float, 2> bias_arr(bias_shape[0]);
+
+ for (int i = 0; i < bias_shape[0]; i++) {
+ bias_arr[i] = 0.0;
+ }
+ // TODO: implicit cast suggest instead of makeArrayRef but triggers
+ // build error.
+ bias_attr = DenseElementsAttr::get(bias_type,
+ llvm::makeArrayRef<float>(bias_arr));
+ } else {
+ SmallVector<int32_t, 2> bias_arr(bias_shape[0]);
+
+ for (int i = 0; i < bias_shape[0]; i++) {
+ bias_arr[i] = 0;
+ }
+ bias_attr = DenseElementsAttr::get(bias_type,
+ llvm::makeArrayRef<int32_t>(bias_arr));
+ }
+ auto bias_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type, bias_attr);
+ bias_val = bias_op.getResult();
+ } else {
+ bias_val = getUnquantizedBias(rewriter, op, tfl_fc_op.bias());
+ }
+
+ auto fc_op = rewriter.create<tosa::FullyConnectedOp>(
+ op->getLoc(), output_type, input_val, tfl_fc_op.filter(), bias_val);
+
+ Value fc_output;
+ if (input_is_qtype) {
+ fc_output = buildRescaleOpConvOutput(rewriter, op, fc_op.getResult(),
+ input_type, filter_type, output_type);
+ } else {
+ fc_output = fc_op.getResult();
+ }
+
+ auto fused_activation_fn = tfl_fc_op.fused_activation_functionAttr();
+
+ if (fused_activation_fn) {
+ auto fused_activation_op =
+ convertFusedActivation(rewriter, op, fc_output, fused_activation_fn);
+
+ if (fused_activation_op) {
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, fused_activation_op);
+ }
+ }
+
+ rewriter.replaceOp(op, {fc_output});
+
+ return success();
+}
+
+LogicalResult ConvertTFLConcatenationOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_concat_op = cast<TFL::ConcatenationOp>(op);
+
+ SmallVector<Value, 8> values(tfl_concat_op.values());
+
+ IntegerAttr axis_attr;
+ {
+ auto tmpAttr = tfl_concat_op.axisAttr();
+ if (!tmpAttr) {
+ tmpAttr = rewriter.getI64IntegerAttr(0);
+ }
+ axis_attr = tmpAttr;
+ }
+ int32_t axis = axis_attr.getInt();
+
+ auto lowered_op =
+ convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
+
+ auto output_type =
+ tfl_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ SmallVector<int64_t, 8> shape_vals;
+ for (int i = 0; i < output_type.getShape().size(); i++) {
+ shape_vals.push_back(output_type.getShape()[i]);
+ }
+ ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals);
+
+ rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+ op, output_type, tfl_reshape_op.input(), shape_attr);
+ return success();
+}
+
+LogicalResult ConvertTFLRankOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_rank_op = cast<TFL::RankOp>(op);
+
+ auto input_type = tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return failure();
+
+ int32_t rank = input_type.getRank();
+
+ auto rank_type = RankedTensorType::get({1}, rewriter.getIntegerType(32));
+ auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
+ auto rank_const =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), rank_type, rank_attr);
+
+ rewriter.replaceOp(op, {rank_const.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFLShapeOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_shape_op = cast<TFL::ShapeOp>(op);
+
+ auto output_type =
+ tfl_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto input_type = tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
+ if (!input_type) return failure();
+
+ auto input_shape = input_type.getShape();
+
+ SmallVector<int32_t, 8> shape_arr;
+ for (int i = 0; i < input_shape.size(); i++) {
+ shape_arr.emplace_back(input_shape[i]);
+ }
+
+ auto shape_type = RankedTensorType::get(
+ {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
+ auto shape_attr = DenseElementsAttr::get(
+ shape_type, llvm::makeArrayRef<int32_t>(shape_arr));
+ auto shape_const =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), shape_type, shape_attr);
+
+ rewriter.replaceOp(op, {shape_const.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFLExpandDimsOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_expanddims_op = cast<TFL::ExpandDimsOp>(op);
+
+ auto lowered_op =
+ convertExpandDimsOp(rewriter, op, tfl_expanddims_op.getResult(),
+ tfl_expanddims_op.input(), tfl_expanddims_op.dim());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLSqueezeOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_squeeze_op = cast<TFL::SqueezeOp>(op);
+
+ // Copy squeeze_dims into int32_t array
+ auto squeeze_dims_attr = tfl_squeeze_op.squeeze_dimsAttr();
+ SmallVector<int32_t, 8> squeeze_dims;
+ for (auto& squeeze_dim : squeeze_dims_attr) {
+ squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
+ }
+
+ auto lowered_op = convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
+ tfl_squeeze_op.input(), squeeze_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLFillOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_fill_op = cast<TFL::FillOp>(op);
+
+ auto output_type =
+ tfl_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ ElementsAttr dims_elems;
+ if (!matchPattern(tfl_fill_op.dims(), m_Constant(&dims_elems)))
+ return failure();
+ SmallVector<int64_t, 4> dims_vals;
+ uint32_t total_size = 1;
+ for (int i = 0; i < dims_elems.getNumElements(); i++) {
+ dims_vals.push_back(dims_elems.getValue<IntegerAttr>(i).getInt());
+ total_size *= dims_vals[i];
+ }
+
+ ElementsAttr value_elem;
+ if (!matchPattern(tfl_fill_op.input(), m_Constant(&value_elem)))
+ return failure();
+
+ auto fill_type = RankedTensorType::get(ArrayRef<int64_t>(dims_vals),
+ value_elem.getType().getElementType());
+ DenseElementsAttr fill_attr;
+
+ // Convert to a compatible zero type.
+ if (value_elem.getType().getElementType().isa<FloatType>()) {
+ llvm::SmallVector<float, 4> fill_arr(
+ total_size,
+ value_elem.getValue<FloatAttr>(0).getValue().convertToFloat());
+ fill_attr =
+ DenseElementsAttr::get(fill_type, llvm::makeArrayRef<float>(fill_arr));
+ } else {
+ llvm::SmallVector<int32_t, 4> fill_arr(
+ total_size,
+ value_elem.getValue<IntegerAttr>(0).getValue().getLimitedValue());
+ fill_attr = DenseElementsAttr::get(fill_type,
+ llvm::makeArrayRef<int32_t>(fill_arr));
+ }
+ auto fill_const_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), fill_type, fill_attr);
+ rewriter.replaceOp(op, {fill_const_op.getResult()});
+
+ return success();
+}
+
+LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_any_op = cast<TFL::ReduceAnyOp>(op);
+
+ auto output_type =
+ tfl_any_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tfl_any_op.reduction_indices(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tfl_any_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceAnyOp(
+ rewriter, op, output_type, tfl_any_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_max_op = cast<TFL::ReduceMaxOp>(op);
+
+ auto output_type =
+ tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tfl_max_op.axes(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tfl_max_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceMaxOp(
+ rewriter, op, output_type, tfl_max_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLReduceMinOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_min_op = cast<TFL::ReduceMinOp>(op);
+
+ auto output_type =
+ tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tfl_min_op.axes(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tfl_min_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceMinOp(
+ rewriter, op, output_type, tfl_min_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLReduceProdOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_prod_op = cast<TFL::ReduceProdOp>(op);
+
+ auto output_type =
+ tfl_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tfl_prod_op.axes(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tfl_prod_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceProdOp(
+ rewriter, op, output_type, tfl_prod_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLMeanOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_mean_op = cast<TFL::MeanOp>(op);
+
+ auto output_type =
+ tfl_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tfl_mean_op.axis(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tfl_mean_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceMeanOp(
+ rewriter, op, output_type, tfl_mean_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLSumOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_sum_op = cast<TFL::SumOp>(op);
+
+ auto output_type =
+ tfl_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ ElementsAttr axes_elems;
+ if (!matchPattern(tfl_sum_op.axes(), m_Constant(&axes_elems)))
+ return failure();
+
+ bool keep_dims = false;
+ auto keep_dims_attr = tfl_sum_op.keep_dimsAttr();
+ if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
+
+ auto lowered_op = convertReduceSumOp(
+ rewriter, op, output_type, tfl_sum_op.input(), axes_elems, keep_dims);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLEluOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_elu_op = cast<TFL::EluOp>(op);
+
+ auto lowered_op =
+ convertEluOp(rewriter, op, tfl_elu_op.getResult(), tfl_elu_op.x());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
+
+ auto lowered_op = convertSoftmaxOp(rewriter, op, tfl_softmax_op.getResult(),
+ tfl_softmax_op.input());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_logsoftmax_op = cast<TFL::LogSoftmaxOp>(op);
+
+ auto lowered_op = convertLogSoftmaxOp(
+ rewriter, op, tfl_logsoftmax_op.getResult(), tfl_logsoftmax_op.input());
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLSliceOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_slice_op = cast<TFL::SliceOp>(op);
+
+ auto output_type =
+ tfl_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ ElementsAttr begin_elems, size_elems;
+
+ SmallVector<int64_t, 4> begin_vals, size_vals;
+
+ if (!matchPattern(tfl_slice_op.begin(), m_Constant(&begin_elems)) ||
+ !matchPattern(tfl_slice_op.size(), m_Constant(&size_elems))) {
+ return failure();
+ }
+
+ for (int i = 0; i < begin_elems.getNumElements(); i++)
+ begin_vals.push_back(begin_elems.getValue<IntegerAttr>(i).getInt());
+
+ for (int i = 0; i < size_elems.getNumElements(); i++)
+ size_vals.push_back(size_elems.getValue<IntegerAttr>(i).getInt());
+
+ ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
+ ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
+
+ rewriter.replaceOpWithNewOp<tosa::SliceOp>(op, output_type,
+ tfl_slice_op.input(), begin, size);
+ return success();
+}
+
+LogicalResult ConvertTFLTileOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_tile_op = cast<TFL::TileOp>(op);
+
+ auto output_type =
+ tfl_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ ElementsAttr multiples_elems;
+ if (!matchPattern(tfl_tile_op.multiples(), m_Constant(&multiples_elems)))
+ return failure();
+ SmallVector<int64_t, 4> multiples_vals;
+ for (int i = 0; i < multiples_elems.getNumElements(); i++)
+ multiples_vals.push_back(multiples_elems.getValue<IntegerAttr>(i).getInt());
+
+ ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
+ rewriter.replaceOpWithNewOp<tosa::TileOp>(
+ op, output_type, tfl_tile_op.input(), multiples_attr);
+
+ return success();
+}
+
+LogicalResult ConvertTFLTransposeOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_transpose_op = cast<TFL::TransposeOp>(op);
+
+ auto output_type =
+ tfl_transpose_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
+ op, output_type, tfl_transpose_op.input(), tfl_transpose_op.perm());
+
+ return success();
+}
+
+LogicalResult ConvertTFLPackOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_pack_op = cast<TFL::PackOp>(op);
+
+ SmallVector<Value, 8> inputs(tfl_pack_op.values());
+ assert(inputs.size() >= 2);
+
+ IntegerAttr axis_attr;
+ {
+ auto tmpAttr = tfl_pack_op.axisAttr();
+ if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
+ axis_attr = tmpAttr;
+ }
+ int32_t axis_i32 = axis_attr.getInt();
+
+ auto lowered_op =
+ convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_unpack_op = cast<TFL::UnpackOp>(op);
+
+ IntegerAttr axis_attr;
+ {
+ auto tmpAttr = tfl_unpack_op.axisAttr();
+ if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
+ axis_attr = tmpAttr;
+ }
+ int32_t axis_i32 = axis_attr.getInt();
+
+ auto lowered_op =
+ convertUnpackOp(rewriter, op, tfl_unpack_op.input(), axis_i32);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+// Splits in num_split parts along split_dim
+LogicalResult ConvertTFLSplitOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_split_op = cast<TFL::SplitOp>(op);
+
+ // Get the number of splits
+ int32_t num_split = -1;
+ auto numSplitAttr = tfl_split_op.num_splitsAttr();
+ if (numSplitAttr) {
+ num_split = numSplitAttr.getInt();
+ } else {
+ return failure();
+ }
+
+ // Get the axis
+ ElementsAttr axisAttrElems;
+ if (!matchPattern(tfl_split_op.split_dim(), m_Constant(&axisAttrElems))) {
+ return op->emitOpError("Cannot read split_dim elems");
+ }
+
+ // The axis/split_dim parameter is stored as a 0D tensor instead of
+ // an integer attribute in TFLite MLIR.
+ int32_t axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
+
+ auto lowered_op = convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
+ tfl_split_op.value(), num_split, axis);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+// Splits in num_split parts along split_dim
+LogicalResult ConvertTFLSplitVOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_splitv_op = cast<TFL::SplitVOp>(op);
+
+ // Get the size_splits array
+ SmallVector<int32_t, 4> size_split;
+ ElementsAttr size_split_elems;
+ if (!matchPattern(tfl_splitv_op.size_splits(),
+ m_Constant(&size_split_elems))) {
+ return failure();
+ }
+
+ for (int i = 0; i < size_split_elems.getNumElements(); i++) {
+ size_split.push_back(size_split_elems.getValue<IntegerAttr>(i).getInt());
+ }
+
+ // Get the axis
+ ElementsAttr axisAttrElems;
+ if (!matchPattern(tfl_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
+ return op->emitOpError("Cannot read split_dim elems");
+ }
+
+ // The axis/split_dim parameter is stored as a 0D tensor instead of
+ // an integer attribute in TFLite MLIR.
+ int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
+
+ auto lowered_op = convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
+ tfl_splitv_op.value(), size_split, axis);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLLessOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_less_op = cast<TFL::LessOp>(op);
+
+ auto input_lhs_type =
+ tfl_less_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type =
+ tfl_less_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_less_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLLessOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype) {
+ auto input_lhs_qtype = input_lhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto input_rhs_qtype = input_rhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
+ input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
+ return op->emitOpError(
+ "ConvertTFLLessOp: input_x and input_y scale/zp "
+ "must be the same");
+ }
+
+ auto op1_rescale_lhs = buildRescaleToInt32(
+ rewriter, op, tfl_less_op.lhs(), 1.0f, input_lhs_qtype.getZeroPoint());
+ auto op2_rescale_rhs = buildRescaleToInt32(
+ rewriter, op, tfl_less_op.rhs(), 1.0f, input_rhs_qtype.getZeroPoint());
+ auto op3_greater_equal_op1_op2 = rewriter.create<tosa::GreaterEqualOp>(
+ op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
+ auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
+ op->getLoc(), output_type, op3_greater_equal_op1_op2.getResult());
+
+ output = op4_not_op3.getResult();
+ } else {
+ auto op1_greater_equal_in = rewriter.create<tosa::GreaterEqualOp>(
+ op->getLoc(), output_type, tfl_less_op.lhs(), tfl_less_op.rhs());
+ auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
+ op->getLoc(), output_type, op1_greater_equal_in.getResult());
+
+ output = op2_not_op1.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+LogicalResult ConvertTFLLessEqualOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_less_equal_op = cast<TFL::LessEqualOp>(op);
+
+ auto input_lhs_type =
+ tfl_less_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto input_rhs_type =
+ tfl_less_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_less_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
+
+ bool input_lhs_is_qtype =
+ input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool input_rhs_is_qtype =
+ input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_is_qtype != output_is_qtype ||
+ input_rhs_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLLessEqualOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ Value output;
+ if (output_is_qtype) {
+ auto input_lhs_qtype = input_lhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto input_rhs_qtype = input_rhs_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
+ input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
+ return op->emitOpError(
+ "ConvertTFLLessEqualOp: input_x and input_y scale/zp "
+ "must be the same");
+ }
+
+ auto op1_rescale_lhs =
+ buildRescaleToInt32(rewriter, op, tfl_less_equal_op.lhs(), 1.0f,
+ input_lhs_qtype.getZeroPoint());
+ auto op2_rescale_rhs =
+ buildRescaleToInt32(rewriter, op, tfl_less_equal_op.rhs(), 1.0f,
+ input_rhs_qtype.getZeroPoint());
+ auto op3_greater_op1_op2 = rewriter.create<tosa::GreaterOp>(
+ op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
+ auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
+ op->getLoc(), output_type, op3_greater_op1_op2.getResult());
+
+ output = op4_not_op3.getResult();
+ } else {
+ auto op1_greater_in = rewriter.create<tosa::GreaterOp>(
+ op->getLoc(), output_type, tfl_less_equal_op.lhs(),
+ tfl_less_equal_op.rhs());
+ auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
+ op->getLoc(), output_type, op1_greater_in.getResult());
+
+ output = op2_not_op1.getResult();
+ }
+
+ rewriter.replaceOp(op, {output});
+ return success();
+}
+
+LogicalResult ConvertTFLPadOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_pad_op = cast<TFL::PadOp>(op);
+
+ auto output_type =
+ tfl_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto pad_op = rewriter.create<tosa::PadOp>(
+ op->getLoc(), output_type, tfl_pad_op.input(), tfl_pad_op.padding());
+
+ rewriter.replaceOp(op, {pad_op.getResult()});
+ return success();
+}
+
+LogicalResult ConvertTFLResizeBilinearOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_resize_op = cast<TFL::ResizeBilinearOp>(op);
+
+ auto output_type =
+ tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto lowered_op = convertResizeOp(
+ rewriter, op, output_type, tfl_resize_op.input(), StringRef("BILINEAR"));
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_resize_op = cast<TFL::ResizeNearestNeighborOp>(op);
+
+ auto output_type =
+ tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto lowered_op = convertResizeOp(
+ rewriter, op, output_type, tfl_resize_op.input(), StringRef("NEAREST"));
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLSelectOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_sel_op = cast<TFL::SelectOp>(op);
+
+ auto lowered_op =
+ convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
+ tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLSelectV2Op::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_sel_op = cast<TFL::SelectV2Op>(op);
+
+ auto lowered_op =
+ convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
+ tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_s2b_op = cast<TFL::SpaceToBatchNdOp>(op);
+ auto lowered_op = convertSpaceToBatchNDOp(
+ rewriter, op, tfl_s2b_op.getResult(), tfl_s2b_op.input(),
+ tfl_s2b_op.block_shape(), tfl_s2b_op.paddings());
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_b2s_op = cast<TFL::BatchToSpaceNdOp>(op);
+
+ auto lowered_op = convertBatchToSpaceNDOp(
+ rewriter, op, tfl_b2s_op.getResult(), tfl_b2s_op.input(),
+ tfl_b2s_op.block_shape(), tfl_b2s_op.indices());
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_s2d_op = cast<TFL::SpaceToDepthOp>(op);
+
+ auto block_size_attr = tfl_s2d_op.block_sizeAttr();
+ auto lowered_op = convertSpaceToDepthOp(rewriter, op, tfl_s2d_op.getResult(),
+ tfl_s2d_op.input(), block_size_attr,
+ rewriter.getStringAttr("NHWC"));
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_d2s_op = cast<TFL::DepthToSpaceOp>(op);
+
+ auto block_size_attr = tfl_d2s_op.block_sizeAttr();
+ auto lowered_op = convertDepthToSpaceOp(rewriter, op, tfl_d2s_op.getResult(),
+ tfl_d2s_op.input(), block_size_attr,
+ rewriter.getStringAttr("NHWC"));
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_ss_op = cast<TFL::StridedSliceOp>(op);
+
+ auto lowered_op = convertStridedSliceOp(
+ rewriter, op, tfl_ss_op.getResult(), tfl_ss_op.input(), tfl_ss_op.begin(),
+ tfl_ss_op.end(), tfl_ss_op.strides(), tfl_ss_op.begin_maskAttr().getInt(),
+ tfl_ss_op.end_maskAttr().getInt(), tfl_ss_op.ellipsis_maskAttr().getInt(),
+ tfl_ss_op.new_axis_maskAttr().getInt(),
+ tfl_ss_op.shrink_axis_maskAttr().getInt());
+ TOSA_REPLACE_LOWERED_OP_LIST(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_zeroslike_op = cast<TFL::ZerosLikeOp>(op);
+
+ auto lowered_op = convertZerosLikeOp(
+ rewriter, op, tfl_zeroslike_op.getResult(), tfl_zeroslike_op.input());
+ TOSA_REPLACE_LOWERED_OP_LIST(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLHardSwishOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_hardswish_op = cast<TFL::HardSwishOp>(op);
+ auto output_type =
+ tfl_hardswish_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto input_type =
+ tfl_hardswish_op.input().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!input_type) return failure();
+
+ auto input_shape = input_type.getShape();
+
+ // TFL hardswish: f(x) -> (x * relu6(x+3))/6
+
+ // TODO: support 16-bit hardswish
+ if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
+ output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
+ // TFLite reference:
+ // tensorflow/lite/kernels/internal/reference/reference_ops.h note
+ // there's a potential rounding issue in TFLite reference
+ mlir::quant::UniformQuantizedType in_quant_type =
+ input_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ mlir::quant::UniformQuantizedType out_quant_type =
+ output_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+
+ auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
+ true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
+ -32768, 32767);
+ auto bool_type = RankedTensorType::get(input_shape, rewriter.getI1Type());
+ auto int16_type = RankedTensorType::get(input_shape, int16_element_qtype);
+ auto int32_type = RankedTensorType::get(input_shape, rewriter.getI32Type());
+
+ // Table's real input range [-4.0, 4.0].
+ // Use TABLE op to get relu6(x+3) / 6
+ const double input_sample_grain = 1.0 / 64.0;
+ auto hardswish_func = [input_sample_grain](int32_t x) -> int32_t {
+ double v = static_cast<double>(x) * input_sample_grain;
+ double w = v + 3.0;
+ w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w;
+ v = v * w / 6.0;
+ return std::lround(32768.0 * v);
+ };
+
+ auto table_const = getTosa1DConstTensorTable(rewriter, op, hardswish_func);
+
+ // Rescale input to 9.7
+ auto op1_rescale_in =
+ buildRescale(rewriter, op, int16_type, tfl_hardswish_op.input(),
+ (in_quant_type.getScale() * 128.0) / input_sample_grain,
+ in_quant_type.getZeroPoint(), 0);
+
+ // Table op. output 0.23
+ auto op2_table_op1 = rewriter.create<tosa::TableOp>(
+ op->getLoc(), int32_type, op1_rescale_in, table_const);
+
+ // scale table output back to quantized space
+ auto op3_rescale_op2 =
+ buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
+ 1.0 / (128.0 * 32768.0 * out_quant_type.getScale()), 0,
+ out_quant_type.getZeroPoint());
+
+ auto op4_rescale_in = buildRescale(rewriter, op, int32_type,
+ tfl_hardswish_op.input(), 1.0, 0, 0);
+
+ // Get 3.0 in quantized space
+ int32_t quantized_3 =
+ static_cast<int32_t>(std::ceil(3.0 / in_quant_type.getScale())) +
+ in_quant_type.getZeroPoint();
+
+ auto op5_ge_op4 = rewriter.create<tosa::GreaterEqualOp>(
+ op->getLoc(), bool_type, op4_rescale_in,
+ getTosaConstTensorSingleI32(rewriter, op, quantized_3));
+
+ auto op6_select_op5_op4_op3 = rewriter.create<tosa::SelectOp>(
+ op->getLoc(), output_type, op5_ge_op4, tfl_hardswish_op.input(),
+ op3_rescale_op2);
+
+ rewriter.replaceOp(op, {op6_select_op5_op4_op3});
+
+ return success();
+
+ } else {
+ // op1 = constop(3)
+ // op2 = add(x, op1)
+ // op3 = reluN(op2, 6)
+ // op4 = mul(x, op3)
+ // op5 = reciprocal(6)
+ // op6 = mul (op4, op5)
+
+ auto op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
+
+ auto op2_add_x_op1 = rewriter.create<tosa::AddOp>(
+ op->getLoc(), output_type, tfl_hardswish_op.input(), op1_value);
+
+ auto op3_relu_op2_6 = rewriter.create<tosa::ReluNOp>(
+ op->getLoc(), output_type, op2_add_x_op1.getResult(),
+ rewriter.getI64IntegerAttr(0), rewriter.getF32FloatAttr(6.0));
+
+ auto op4_mul_x_op3 = rewriter.create<tosa::MulOp>(
+ op->getLoc(), output_type, tfl_hardswish_op.input(),
+ op3_relu_op2_6.getResult(), 0);
+
+ auto op5_reciprocal_6 = rewriter.create<tosa::ReciprocalOp>(
+ op->getLoc(), output_type,
+ getTosaConstTensorSingleF32(rewriter, op, 6.0));
+
+ auto op6_mul_op4_op5 = rewriter.create<tosa::MulOp>(
+ op->getLoc(), output_type, op4_mul_x_op3.getResult(),
+ op5_reciprocal_6.getResult(), 0);
+
+ rewriter.replaceOp(op, {op6_mul_op4_op5.getResult()});
+
+ return success();
+ }
+}
+
+LogicalResult ConvertTFLLogisticOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_logistic_op = cast<TFL::LogisticOp>(op);
+
+ auto output_type =
+ tfl_logistic_op.getResult().getType().dyn_cast<RankedTensorType>();
+ auto input_type = tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
+ if (!input_type || !output_type) return failure();
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLLogisticOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ if (input_is_qtype) {
+ auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
+ true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
+ -32768, 32767);
+ auto int16_type =
+ RankedTensorType::get(output_type.getShape(), int16_element_qtype);
+ auto int32_type = RankedTensorType::get(output_type.getShape(),
+ rewriter.getIntegerType(32));
+ mlir::quant::UniformQuantizedType input_qtype =
+ input_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ mlir::quant::UniformQuantizedType output_qtype =
+ output_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ const double input_sample_grain = 1.0 / 16.0;
+ auto sigmoid_func = [input_sample_grain](int32_t x) -> int32_t {
+ // Input range [-16.0, 16.0], output range [0.0, 1.0]
+ double v = static_cast<double>(x) * input_sample_grain;
+ v = 1.0 / (1.0 + std::exp(-v));
+
+ return std::lround(32768.0 * v);
+ };
+
+ auto table_const = getTosa1DConstTensorTable(rewriter, op, sigmoid_func);
+
+ // Rescale input to 9.7 precision.
+ auto op1_rescale_in =
+ buildRescale(rewriter, op, int16_type, tfl_logistic_op.x(),
+ (input_qtype.getScale() * 128.0) / input_sample_grain,
+ input_qtype.getZeroPoint(), 0);
+
+ auto op2_table_op1 = rewriter.create<tosa::TableOp>(
+ op->getLoc(), int32_type, op1_rescale_in, table_const);
+
+ double output_rescale_scale =
+ 1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
+
+ auto op3_rescale_op2 =
+ buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
+ output_rescale_scale, 0, output_qtype.getZeroPoint());
+
+ rewriter.replaceOp(op, {op3_rescale_op2});
+ } else {
+ rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(op, output_type,
+ tfl_logistic_op.x());
+ }
+
+ return success();
+}
+
+LogicalResult ConvertTFLTanhOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_tanh_op = cast<TFL::TanhOp>(op);
+ auto output_type =
+ tfl_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
+ auto input_type = tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
+ if (!input_type || !output_type) return failure();
+
+ bool input_is_qtype =
+ input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+ bool output_is_qtype =
+ output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
+
+ if (input_is_qtype != output_is_qtype) {
+ return op->emitOpError(
+ "ConvertTFLTanhOp: input/output tensor should "
+ "be all quantized or all floating-point.");
+ }
+
+ if (input_is_qtype) {
+ auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
+ true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
+ -32768, 32767);
+ auto int16_type =
+ RankedTensorType::get(output_type.getShape(), int16_element_qtype);
+ auto int32_type = RankedTensorType::get(output_type.getShape(),
+ rewriter.getIntegerType(32));
+ mlir::quant::UniformQuantizedType input_qtype =
+ input_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ mlir::quant::UniformQuantizedType output_qtype =
+ output_type.getElementType()
+ .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+ const double input_sample_grain = 1.0 / 32.0;
+ auto tanh_func = [input_sample_grain](int32_t x) -> int32_t {
+ // Input range [-16.0, 16.0], output range [0.0, 1.0]
+ double v = static_cast<double>(x) * input_sample_grain;
+ v = std::exp(-2.0 * v);
+ v = (1.0 - v) / (1.0 + v);
+
+ return std::lround(32768.0 * v);
+ };
+
+ auto table_const = getTosa1DConstTensorTable(rewriter, op, tanh_func);
+
+ // Rescale input to 9.7 precision.
+ auto op1_rescale_in =
+ buildRescale(rewriter, op, int16_type, tfl_tanh_op.input(),
+ (input_qtype.getScale() * 128.0) / input_sample_grain,
+ input_qtype.getZeroPoint(), 0);
+
+ auto op2_table_op1 = rewriter.create<tosa::TableOp>(
+ op->getLoc(), int32_type, op1_rescale_in, table_const);
+
+ double output_rescale_scale =
+ 1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
+
+ auto op3_rescale_op2 =
+ buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
+ output_rescale_scale, 0, output_qtype.getZeroPoint());
+
+ rewriter.replaceOp(op, {op3_rescale_op2});
+ } else {
+ rewriter.replaceOpWithNewOp<tosa::TanhOp>(op, output_type,
+ tfl_tanh_op.input());
+ }
+
+ return success();
+}
+
+LogicalResult ConvertTFLPReluOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_prelu_op = cast<TFL::PReluOp>(op);
+ auto output_type =
+ tfl_prelu_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ // TODO: add lowering with MUL + SELECT + RESCALE
+
+ return failure();
+}
+
+LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_leakyrelu_op = cast<TFL::LeakyReluOp>(op);
+ auto output_type =
+ tfl_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ // TODO: add lowering with MUL + SELECT + RESCALE
+
+ return failure();
+}
+
+LogicalResult ConvertTFLNegOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_neg_op = cast<TFL::NegOp>(op);
+ auto output_type =
+ tfl_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::NegateOp>(op, output_type, tfl_neg_op.x());
+
+ return success();
+}
+
+LogicalResult ConvertTFLYieldOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ rewriter.replaceOpWithNewOp<tosa::YieldOp>(op, op->getResultTypes(),
+ op->getOperands());
+
+ return success();
+}
+
+LogicalResult ConvertTFLCustomOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_custom_op = cast<TFL::CustomOp>(op);
+ rewriter.replaceOpWithNewOp<tosa::CustomOp>(
+ op, op->getResultTypes(), tfl_custom_op.custom_code(), op->getOperands());
+
+ return success();
+}
+
+LogicalResult ConvertTFLReverseV2Op::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_reverse_op = cast<TFL::ReverseV2Op>(op);
+
+ auto input_type =
+ tfl_reverse_op.input().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!input_type || !output_type) return failure();
+
+ ElementsAttr axis_elems;
+ if (!matchPattern(tfl_reverse_op.axis(), m_Constant(&axis_elems)))
+ return failure();
+
+ auto input_rank = input_type.getShape().size();
+ Value val = tfl_reverse_op.input();
+ if (axis_elems.getNumElements() == 0) {
+ auto identity_op =
+ rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
+ val = identity_op.getResult();
+ } else {
+ for (int i = 0; i < axis_elems.getNumElements(); i++) {
+ int64_t axis_val = axis_elems.getValue<IntegerAttr>(i).getInt();
+ if (axis_val < 0) axis_val += input_rank;
+ auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
+ auto reverse_op = rewriter.create<tosa::ReverseOp>(
+ op->getLoc(), output_type, val, axis_attr);
+
+ val = reverse_op.getResult();
+ }
+ }
+
+ rewriter.replaceOp(op, {val});
+
+ return success();
+}
+
+LogicalResult ConvertTFLQuantizeOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_quantize_op = cast<TFL::QuantizeOp>(op);
+
+ auto input_type =
+ tfl_quantize_op.input().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ tfl_quantize_op.getResult().getType().dyn_cast<RankedTensorType>();
+
+ if (!input_type || !output_type) return failure();
+
+ auto qtype =
+ tfl_quantize_op.qtypeAttr().getValue().dyn_cast<RankedTensorType>();
+ if (!qtype) return failure();
+
+ auto element_type =
+ qtype.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ if (!element_type) return failure();
+
+ auto input_element_type =
+ input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ // If input is already a quantized type, this is basically a RESCALE (or
+ // tensorflow::ops::Requantize)
+ if (input_element_type) {
+ double rescale_scale =
+ input_element_type.getScale() / element_type.getScale();
+ auto rescale_op = buildRescale(
+ rewriter, op, output_type, tfl_quantize_op.input(), rescale_scale,
+ input_element_type.getZeroPoint(), element_type.getZeroPoint());
+
+ rewriter.replaceOp(op, {rescale_op});
+ return success();
+ } else {
+ double scale = 1 / element_type.getScale();
+ int64_t zp = element_type.getZeroPoint();
+ int64_t num_bits = element_type.getStorageTypeIntegralWidth();
+ zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
+
+ auto lowered_op = convertQuantizeOp(rewriter, op, output_type,
+ tfl_quantize_op.input(), scale, zp);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+ }
+}
+
+LogicalResult ConvertTFLDequantizeOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_dequantize_op = cast<TFL::DequantizeOp>(op);
+
+ auto output_type =
+ tfl_dequantize_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ auto qtype = tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
+ if (!qtype) return failure();
+
+ auto element_type =
+ qtype.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ if (!element_type) return failure();
+
+ double scale = element_type.getScale();
+ int64_t zp = element_type.getZeroPoint();
+ int64_t num_bits = element_type.getStorageTypeIntegralWidth();
+ zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
+
+ auto lowered_op = convertDequantizeOp(rewriter, op, output_type,
+ tfl_dequantize_op.input(), scale, zp);
+
+ TOSA_REPLACE_LOWERED_OP(rewriter, op, lowered_op);
+}
+
+LogicalResult ConvertTFLQConstOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto tfl_qconst_op = cast<TFL::QConstOp>(op);
+
+ auto output_type =
+ tfl_qconst_op.getResult().getType().dyn_cast<RankedTensorType>();
+ // Not a ranked tensor output
+ if (!output_type) return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type,
+ tfl_qconst_op.valueAttr());
+
+ return success();
+}
+
+void LegalizeTFL::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto* ctx = &getContext();
+ auto func = getFunction();
+
+ // Add the generated patterns to the list.
+ populateWithGenerated(ctx, patterns);
+
+#define DEF_PATTERN_INSERT(PAT) patterns.insert<Convert##PAT##Op>(ctx);
+
+ DEF_PATTERN_INSERT(TFLRelu);
+ DEF_PATTERN_INSERT(TFLRelu6);
+ DEF_PATTERN_INSERT(TFLEqual);
+ DEF_PATTERN_INSERT(TFLNotEqual);
+ DEF_PATTERN_INSERT(TFLGreater);
+ DEF_PATTERN_INSERT(TFLGreaterEqual);
+ DEF_PATTERN_INSERT(TFLAdd);
+ DEF_PATTERN_INSERT(TFLSub);
+ DEF_PATTERN_INSERT(TFLMul);
+ DEF_PATTERN_INSERT(TFLSquare);
+ DEF_PATTERN_INSERT(TFLDiv);
+ DEF_PATTERN_INSERT(TFLMaximum);
+ DEF_PATTERN_INSERT(TFLMinimum);
+ DEF_PATTERN_INSERT(TFLFloorMod);
+ DEF_PATTERN_INSERT(TFLFloorDiv);
+ DEF_PATTERN_INSERT(TFLAddN);
+ DEF_PATTERN_INSERT(TFLAveragePool2D);
+ DEF_PATTERN_INSERT(TFLMaxPool2D);
+ DEF_PATTERN_INSERT(TFLConcatenation);
+ DEF_PATTERN_INSERT(TFLReshape);
+ DEF_PATTERN_INSERT(TFLRank);
+ DEF_PATTERN_INSERT(TFLShape);
+ DEF_PATTERN_INSERT(TFLExpandDims);
+ DEF_PATTERN_INSERT(TFLSqueeze);
+ DEF_PATTERN_INSERT(TFLFill);
+ DEF_PATTERN_INSERT(TFLElu);
+ DEF_PATTERN_INSERT(TFLSoftmax);
+ DEF_PATTERN_INSERT(TFLLogSoftmax);
+ DEF_PATTERN_INSERT(TFLReduceAny);
+ DEF_PATTERN_INSERT(TFLReduceMax);
+ DEF_PATTERN_INSERT(TFLReduceMin);
+ DEF_PATTERN_INSERT(TFLMean);
+ DEF_PATTERN_INSERT(TFLReduceProd);
+ DEF_PATTERN_INSERT(TFLSum);
+ DEF_PATTERN_INSERT(TFLConv2D);
+ DEF_PATTERN_INSERT(TFLTransposeConv);
+ DEF_PATTERN_INSERT(TFLDepthwiseConv2D);
+ DEF_PATTERN_INSERT(TFLFullyConnected);
+ DEF_PATTERN_INSERT(TFLSplit);
+ DEF_PATTERN_INSERT(TFLSplitV);
+ DEF_PATTERN_INSERT(TFLPack);
+ DEF_PATTERN_INSERT(TFLUnpack);
+ DEF_PATTERN_INSERT(TFLTranspose);
+ DEF_PATTERN_INSERT(TFLTile);
+ DEF_PATTERN_INSERT(TFLSlice);
+ DEF_PATTERN_INSERT(TFLStridedSlice);
+ DEF_PATTERN_INSERT(TFLZerosLike);
+ DEF_PATTERN_INSERT(TFLHardSwish);
+ DEF_PATTERN_INSERT(TFLLess);
+ DEF_PATTERN_INSERT(TFLLessEqual);
+ DEF_PATTERN_INSERT(TFLPad);
+ DEF_PATTERN_INSERT(TFLResizeBilinear);
+ DEF_PATTERN_INSERT(TFLResizeNearestNeighbor);
+ DEF_PATTERN_INSERT(TFLSelect);
+ DEF_PATTERN_INSERT(TFLSelectV2);
+ DEF_PATTERN_INSERT(TFLSpaceToBatchNd);
+ DEF_PATTERN_INSERT(TFLBatchToSpaceNd);
+ DEF_PATTERN_INSERT(TFLSpaceToDepth);
+ DEF_PATTERN_INSERT(TFLDepthToSpace);
+ DEF_PATTERN_INSERT(TFLLogistic);
+ DEF_PATTERN_INSERT(TFLTanh);
+ DEF_PATTERN_INSERT(TFLPRelu);
+ DEF_PATTERN_INSERT(TFLLeakyRelu);
+ DEF_PATTERN_INSERT(TFLNeg);
+ DEF_PATTERN_INSERT(TFLYield);
+ DEF_PATTERN_INSERT(TFLCustom);
+ DEF_PATTERN_INSERT(TFLReverseV2);
+ DEF_PATTERN_INSERT(TFLQuantize);
+ DEF_PATTERN_INSERT(TFLDequantize);
+ DEF_PATTERN_INSERT(TFLQConst);
+ applyPatternsAndFoldGreedily(func, std::move(patterns));
+}
+} // namespace
+
+// Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass() {
+ return std::make_unique<LegalizeTFL>();
+}
+
+static PassRegistration<LegalizeTFL> pass(
+ PASS_NAME, "Legalize from TensorFlow Lite to TOSA dialect");
+} // namespace tosa
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
new file mode 100644
index 0000000..5bae8ec
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
@@ -0,0 +1,433 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
+
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
+
+// Implements legalization and post-legalization optimization helper functions
+
+namespace mlir {
+
+namespace tosa {
+
+// Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
+Value buildRescale(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_val, double scale,
+ int64_t input_zp, int64_t output_zp, bool double_round) {
+ int32_t multiplier;
+ int32_t shift;
+
+ // We currently only support 32-bit quantized multiplier.
+ computeMultiplierAndShift(scale, multiplier, shift, 32);
+
+ auto rescale_op = rewriter.create<tosa::RescaleOp>(
+ op->getLoc(), output_type, input_val,
+ rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
+ rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
+ rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
+ rewriter.getBoolAttr(true), rewriter.getBoolAttr(double_round),
+ rewriter.getBoolAttr(false));
+
+ return rescale_op.getResult();
+}
+
+// Creates TOSA rescale op with int32 output
+Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op,
+ Value input_val, double input_scale,
+ int64_t input_zp) {
+ // Output is always int32 type
+ auto input_type = input_val.getType().dyn_cast<mlir::RankedTensorType>();
+ assert(input_type);
+ auto output_type =
+ RankedTensorType::get(input_type.getShape(), rewriter.getI32Type());
+
+ return buildRescale(rewriter, op, output_type, input_val, input_scale,
+ input_zp, 0, false);
+}
+
+// Creates TOSA rescale op with int32 input
+Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_val,
+ double output_scale, int64_t output_zp) {
+ // Input should be int32 type
+ auto input_type = input_val.getType().dyn_cast<mlir::RankedTensorType>();
+ (void)input_type;
+ assert(input_type && input_type.getElementType().isInteger(32) &&
+ "expected rescale input element type to be i32");
+
+ // Potentially check input_shape == output_shape here
+ return buildRescale(rewriter, op, output_type, input_val, output_scale, 0,
+ output_zp, true);
+}
+
+// Creates a TOSA rescale op based on conv2d parameters.
+Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op,
+ Value conv_val, RankedTensorType input_type,
+ RankedTensorType weight_type,
+ RankedTensorType output_type) {
+ auto input_qtype =
+ input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto output_qtype = output_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ double input_scale = input_qtype.getScale();
+
+ int64_t output_zp = output_qtype.getZeroPoint();
+ double output_scale = output_qtype.getScale();
+
+ if (auto weight_per_tensor_qtype =
+ weight_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>()) {
+ // Per-tensor quantization
+ double weight_scale = weight_per_tensor_qtype.getScale();
+
+ int32_t multiplier;
+ int32_t shift;
+
+ double op_tensor_scale = (input_scale * weight_scale) / output_scale;
+
+ // We currently only support 32-bit quantized multiplier.
+ computeMultiplierAndShift(op_tensor_scale, multiplier, shift, 32);
+
+ auto rescale_op = rewriter.create<tosa::RescaleOp>(
+ op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0),
+ rewriter.getI32IntegerAttr(output_zp),
+ rewriter.getI32ArrayAttr({multiplier}),
+ rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(true),
+ rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
+
+ return rescale_op.getResult();
+
+ } else if (auto weight_per_channel_qtype =
+ weight_type.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
+ // Per-channel quantization
+ auto output_last_axis = output_type.getShape().size() - 1;
+ uint32_t output_channels = output_type.getShape()[output_last_axis];
+
+ llvm::SmallVector<int32_t, 4> multiplier_arr;
+ llvm::SmallVector<int32_t, 4> shift_arr;
+
+ llvm::SmallVector<double, 4> weight_scale_arr(
+ weight_per_channel_qtype.getScales().begin(),
+ weight_per_channel_qtype.getScales().end());
+
+ int64_t output_zp = output_qtype.getZeroPoint();
+ double output_scale = output_qtype.getScale();
+
+ for (uint32_t oc = 0; oc < output_channels; oc++) {
+ double weight_scale = weight_scale_arr[oc];
+
+ int32_t multiplier;
+ int32_t shift;
+
+ double op_channel_scale = (input_scale * weight_scale) / output_scale;
+
+ // We currently only support 32-bit quantized multiplier.
+ computeMultiplierAndShift(op_channel_scale, multiplier, shift, 32);
+
+ multiplier_arr.push_back(multiplier);
+ shift_arr.push_back(shift);
+ }
+
+ auto rescale_op = rewriter.create<tosa::RescaleOp>(
+ op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0),
+ rewriter.getI32IntegerAttr(output_zp),
+ rewriter.getI32ArrayAttr(multiplier_arr),
+ rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(true),
+ rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
+
+ return rescale_op.getResult();
+
+ } else {
+ op->emitOpError("buildConvRescaleOp: unknown weight quantized type");
+ return nullptr;
+ }
+}
+
+// Create a 513 entry TOSA constant tensor suitable for the Table operator based
+// on the values from an int32_t func(int32_t) lambda function.
+Value getTosa1DConstTensorTable(PatternRewriter& rewriter, Operation* op,
+ std::function<int32_t(int32_t)> func) {
+ llvm::SmallVector<int16_t, 4> table_vec;
+
+ for (int32_t i = -256; i <= 256; i++) {
+ int32_t value = func(i);
+ // Table entry is int16_t; clamp to expressible range.
+ table_vec.push_back(
+ static_cast<int16_t>(std::min(std::max(value, -32768), 32767)));
+ }
+
+ auto element_qtype =
+ UniformQuantizedType::get(true, rewriter.getIntegerType(16),
+ rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
+ auto const_type = RankedTensorType::get({513}, element_qtype);
+ auto storage_type =
+ RankedTensorType::get({513}, element_qtype.getStorageType());
+ auto const_attr = DenseElementsAttr::get(
+ storage_type, llvm::makeArrayRef<int16_t>(table_vec));
+
+ auto const_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
+ return const_op.getResult();
+}
+
+// Create a 32-bit float constant operator from a float
+Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
+ float val) {
+ auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
+ auto const_attr = DenseElementsAttr::get(const_type, val);
+
+ auto const_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
+ return const_op.getResult();
+}
+
+// Create a 32-bit integer constant operator from an int
+Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op,
+ int32_t val) {
+ auto const_type = RankedTensorType::get({}, rewriter.getIntegerType(32));
+ auto const_attr = DenseElementsAttr::get(const_type, val);
+
+ auto const_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
+ return const_op.getResult();
+}
+
+// Create a vector from a 32-bit value tensor. Returns the size of
+// the new vector or -1 on error.
+int getVectorFromValue32(Value val, llvm::SmallVector<int32_t, 4>& vec) {
+ int i = 0;
+
+ ElementsAttr elems;
+
+ if (!matchPattern(val, m_Constant(&elems))) return -1;
+
+ for (auto idx : elems.getValues<IntegerAttr>()) {
+ vec.push_back(idx.getInt());
+ i++;
+ }
+
+ return i;
+}
+
+// Calculates the TOSA padding values based on TF operators padded with
+// SAME/VALID.
+//
+// This could pass tensorflow::FilterTensorFormat and do
+// GetFilterTensorSpatialDimIndex but the current TF core libs do not support
+// FORMAT_OHWI parsing by that function in core/util/tensor_format.h
+bool getPaddingValuesFromPadType(
+ tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
+ uint32_t first_filter_spatial_dim, RankedTensorType input_type,
+ RankedTensorType filter_type, ArrayAttr strides, ArrayAttr dilations,
+ PatternRewriter& rewriter, ArrayAttr& explicit_padding) {
+ assert(tf_pad != tensorflow::Padding::EXPLICIT);
+
+ // Storing the numeric padding values is useful for TOSA codegen, as opposed
+ // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
+ SmallVector<int64_t, 4> computed_paddings;
+
+ int64_t pad_before, pad_after;
+ for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
+ int64_t ifm_dim = GetTensorSpatialDimIndex(
+ 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
+ int64_t filter_dim = first_filter_spatial_dim + i;
+
+ int64_t dim_dilation = dilations[i].template cast<IntegerAttr>().getInt();
+ int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
+
+ tensorflow::int64 op_size, pad_before_tf,
+ pad_after_tf; // Complains if using int64_T
+ tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
+ input_type.getDimSize(ifm_dim), filter_type.getDimSize(filter_dim),
+ dim_dilation, dim_stride, tf_pad, &op_size, &pad_before_tf,
+ &pad_after_tf);
+ if (!status.ok()) return false;
+
+ pad_before = pad_before_tf;
+ pad_after = pad_after_tf;
+ computed_paddings.push_back(pad_before);
+ computed_paddings.push_back(pad_after);
+ }
+
+ explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
+ return true;
+}
+
+// Calculates the TOSA padding values for explicit-padded TF operators.
+//
+// This function only handles the TF padding array explicit_padding, which is
+// only present in certain TF ops. All others encode padding using the string
+// SAME/VALID, which is interpreted using the getPaddingValuesFromPadString
+// function below.
+
+// The explicit padding array in TF holds 2 pad values for every
+// dimension, even those that are not the 2 spatial ones. Just extract the
+// 2x pad values for the XY dims.
+ArrayAttr getPaddingValuesFromExplicitPadAttr(
+ ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf,
+ PatternRewriter& rewriter) {
+ SmallVector<int64_t, 4> computed_paddings;
+
+ int64_t pad_before, pad_after;
+ for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
+ int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf,
+ i); // 4D tensor, NHWC/NCHW format
+
+ pad_before = explicit_pad[dim * 2].template cast<IntegerAttr>().getInt();
+ pad_after = explicit_pad[dim * 2 + 1].template cast<IntegerAttr>().getInt();
+ computed_paddings.push_back(pad_before);
+ computed_paddings.push_back(pad_after);
+ }
+
+ return rewriter.getI64ArrayAttr(computed_paddings);
+}
+
+// Calculates the TOSA padding values for transposeConv2d
+bool getTransposeConv2dPaddingValues(
+ tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
+ uint32_t first_filter_spatial_dim, RankedTensorType input_type,
+ RankedTensorType filter_type, RankedTensorType output_type,
+ ArrayAttr strides, ArrayAttr dilations, PatternRewriter& rewriter,
+ ArrayAttr& explicit_padding) {
+ assert(tf_pad != tensorflow::Padding::EXPLICIT);
+
+ // Storing the numeric padding values is useful for TOSA codegen, as opposed
+ // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
+
+ SmallVector<int64_t, 2> computed_paddings;
+
+ int64_t pad_before, pad_after;
+ for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
+ int64_t ifm_dim = GetTensorSpatialDimIndex(
+ 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
+ int64_t ofm_dim = GetTensorSpatialDimIndex(
+ 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
+ int64_t filter_dim = first_filter_spatial_dim + i;
+
+ int64_t ifm_size = input_type.getDimSize(ifm_dim);
+ int64_t filter_size = filter_type.getDimSize(filter_dim);
+ int64_t ofm_size = output_type.getDimSize(ofm_dim);
+ int64_t dim_dilation = dilations[i].template cast<IntegerAttr>().getInt();
+ int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
+
+ int effective_filter_size = (filter_size - 1) * dim_dilation + 1;
+ int total_padding =
+ ((ifm_size - 1) * dim_stride + effective_filter_size - ofm_size);
+ total_padding = total_padding > 0 ? total_padding : 0;
+
+ pad_before = total_padding / 2;
+ pad_after = total_padding - pad_before;
+
+ computed_paddings.push_back(pad_before);
+ }
+
+ explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
+ return true;
+}
+
+// Templated function to create a constant op in a given dialect and with a
+// given type. Specializations below.
+
+// T0: target dialect constant op
+// T1: native c++ integer type
+template <typename T0, typename T1>
+Value get1DConstTensor(PatternRewriter& rewriter, Operation* op,
+ SmallVector<T1, 8> arr) {
+ auto const_type =
+ RankedTensorType::get({static_cast<int32_t>(arr.size())},
+ rewriter.getIntegerType(sizeof(T1) * 8));
+ auto const_attr =
+ DenseElementsAttr::get(const_type, llvm::makeArrayRef<T1>(arr));
+
+ auto const_op = rewriter.create<T0>(op->getLoc(), const_type, const_attr);
+ return const_op.getResult();
+}
+
+// Specialization for Const ops
+template <>
+Value get1DConstTensor<tosa::ConstOp, float>(PatternRewriter& rewriter,
+ Operation* op,
+ SmallVector<float, 8> arr) {
+ auto const_type = RankedTensorType::get({static_cast<int32_t>(arr.size())},
+ rewriter.getF32Type());
+ auto const_attr =
+ DenseElementsAttr::get(const_type, llvm::makeArrayRef<float>(arr));
+
+ auto const_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
+ return const_op.getResult();
+}
+
+template Value get1DConstTensor<tosa::ConstOp, int32_t>(
+ PatternRewriter&, Operation*, SmallVector<int32_t, 8> arr);
+template Value get1DConstTensor<tosa::ConstOp, int64_t>(
+ PatternRewriter&, Operation*, SmallVector<int64_t, 8> arr);
+template Value get1DConstTensor<TFL::ConstOp, int32_t>(
+ PatternRewriter&, Operation*, SmallVector<int32_t, 8> arr);
+template Value get1DConstTensor<TFL::ConstOp, int64_t>(
+ PatternRewriter&, Operation*, SmallVector<int64_t, 8> arr);
+
+// Same as get1DConstTensor, but int48 is not native c++ type, needs additional
+// interface
+Value get1DConstTensorInt48(PatternRewriter& rewriter, Operation* op,
+ SmallVector<int64_t, 8> arr) {
+ auto const_type = RankedTensorType::get({static_cast<int32_t>(arr.size())},
+ rewriter.getIntegerType(48));
+ auto const_attr =
+ DenseElementsAttr::get(const_type, llvm::makeArrayRef<int64_t>(arr));
+
+ auto const_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
+ return const_op.getResult();
+}
+
+// Strip off quantization information for bias tensor and return a unquantized
+// bias
+Value getUnquantizedBias(PatternRewriter& rewriter, Operation* op,
+ Value input) {
+ auto input_type = input.getType().dyn_cast<mlir::RankedTensorType>();
+ assert(input_type);
+ auto input_element_type = input_type.getElementType();
+ auto input_element_qtype =
+ input_element_type.dyn_cast<mlir::quant::QuantizedType>();
+
+ if (input_element_qtype) {
+ auto output_type = RankedTensorType::get(
+ input_type.getShape(),
+ rewriter.getIntegerType(
+ input_element_qtype.getStorageTypeIntegralWidth()));
+
+ auto input_defining_op = dyn_cast<TFL::QConstOp>(input.getDefiningOp());
+ auto dense_attr = input_defining_op.value().dyn_cast<DenseElementsAttr>();
+
+ if (dense_attr) {
+ auto const_op =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), output_type, dense_attr);
+ return const_op.getResult();
+ } else {
+ return input;
+ }
+
+ } else {
+ return input;
+ }
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h
new file mode 100644
index 0000000..450f227
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h
@@ -0,0 +1,132 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H
+#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H
+
+#include <climits>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <numeric>
+
+#include "mlir/Dialect/Quant/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/core/framework/kernel_shape_util.h"
+#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace mlir {
+namespace tosa {
+
+// Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
+Value buildRescale(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_val, double scale,
+ int64_t input_zp, int64_t output_zp,
+ bool double_round = false);
+
+// Creates TOSA rescale op with int32 output
+Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op,
+ Value input_val, double input_scale,
+ int64_t input_zp);
+
+// Creates TOSA rescale op with int32 input
+Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op,
+ RankedTensorType output_type, Value input_val,
+ double output_scale, int64_t output_zp);
+
+// Creates a TOSA rescale op based on conv2d parameters.
+Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op,
+ Value conv_val, RankedTensorType input_type,
+ RankedTensorType weight_type,
+ RankedTensorType output_type);
+
+// Create a 513 entry TOSA constant tensor suitable for the Table operator based
+// on the values from an int32_t func(int32_t) lambda function.
+Value getTosa1DConstTensorTable(PatternRewriter& rewriter, Operation* op,
+ std::function<int32_t(int32_t)> func);
+
+// Create a 32-bit float constant operator from a float
+Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
+ float val);
+
+// Create a 32-bit integer constant operator from an int
+Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op,
+ int32_t val);
+
+// Create a vector from a 32-bit value tensor. Returns vector size on success
+// or -1 on error.
+int getVectorFromValue32(Value val, SmallVector<int32_t, 4>& vec);
+
+// Calculates the TOSA padding values based on TF operators padded with
+// SAME/VALID.
+bool getPaddingValuesFromPadType(
+ tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
+ uint32_t first_filter_spatial_dim, RankedTensorType input_type,
+ RankedTensorType filter_type, ArrayAttr strides, ArrayAttr dilations,
+ PatternRewriter& rewriter, ArrayAttr& explicit_pad);
+
+// Calculates the TOSA padding values for explicit-padded TF operators.
+ArrayAttr getPaddingValuesFromExplicitPadAttr(
+ ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf,
+ PatternRewriter& rewriter);
+
+// Calculates the TOSA padding values for transposeConv2d
+bool getTransposeConv2dPaddingValues(
+ tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
+ uint32_t first_filter_spatial_dim, RankedTensorType input_type,
+ RankedTensorType filter_type, RankedTensorType output_type,
+ ArrayAttr strides, ArrayAttr dilations, PatternRewriter& rewriter,
+ ArrayAttr& explicit_pad);
+
+// Templated function to create a constant op in a given dialect and with a
+// given type. Specializations below.
+
+// T0: target dialect constant op
+// T1: native c++ integer type
+template <typename T0, typename T1>
+Value get1DConstTensor(PatternRewriter& rewriter, Operation* op,
+ SmallVector<T1, 8> arr);
+
+// Same as get1DConstTensor, but int48 is not native c++ type, needs additional
+// interface
+Value get1DConstTensorInt48(PatternRewriter& rewriter, Operation* op,
+ SmallVector<int64_t, 8> arr);
+
+// Strip off quantization information for bias tensor and return a unquantized
+// bias
+Value getUnquantizedBias(PatternRewriter& rewriter, Operation* op, Value input);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H
diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h
new file mode 100644
index 0000000..f944908
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h
@@ -0,0 +1,42 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H
+#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H
+
+#include <memory>
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace tosa {
+
+struct TOSALegalizationPipelineOptions
+ : public PassPipelineOptions<TOSALegalizationPipelineOptions> {};
+
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass();
+std::unique_ptr<OperationPass<FuncOp>> createFuseBiasTFPass();
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass();
+std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass();
+
+#define GEN_PASS_REGISTRATION
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc"
+
+} // namespace tosa
+
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H
diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.td b/tensorflow/compiler/mlir/tosa/transforms/passes.td
new file mode 100644
index 0000000..ee87829
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/passes.td
@@ -0,0 +1,36 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+include "mlir/Pass/PassBase.td"
+
+def TosaLegalizeTFPass : Pass<"tosa-legalize-tf", "FuncOp"> {
+ let summary = "Legalize from TensorFlow to TOSA";
+ let constructor = "createLegalizeTFPass()";
+}
+
+def TosaLegalizeTFLPass : Pass<"tosa-legalize-tfl", "FuncOp"> {
+ let summary = "Legalize from TensorFlow Lite to TOSA";
+ let constructor = "createLegalizeTFLPass()";
+}
+
+def TosaFusebiasTFPass : Pass<"tosa-fuse-bias-tf", "FuncOp"> {
+ let summary = "Fuse tf.Op + tf.BiasAdd and legalized to TOSA";
+ let constructor = "createFuseBiasTFPass()";
+}
+
+def TosaConvertTFLUint8Pass : Pass<"tosa-convert-tfl-uint8", "FuncOp"> {
+ let summary = "Convert uint8 graph to int8 graph";
+ let constructor = "createConvertTFLUint8Pass()";
+}
diff --git a/tensorflow/compiler/mlir/tosa/transforms/register_passes.h b/tensorflow/compiler/mlir/tosa/transforms/register_passes.h
new file mode 100644
index 0000000..7d13205
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/register_passes.h
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H
+#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
+
+namespace mlir {
+namespace tosa {
+
+inline void registerAllTosaPasses() {
+ registerLegalizeTosaPasses();
+ registerTosaOptPasses();
+}
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.td
new file mode 100644
index 0000000..ef25fe2
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.td
@@ -0,0 +1,48 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TensorFlow legalization patterns
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
+include "mlir/Dialect/Tosa/IR/TosaOps.td"
+include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
+
+// Nullary ops patterns.
+
+def : Pat<(TF_ConstOp ElementsAttr : $value), (Tosa_ConstOp $value)>;
+
+// Unary ops patterns.
+
+def : Pat<(TF_IdentityOp $value), (replaceWithValue $value)>;
+def : Pat<(TF_AbsOp $arg), (Tosa_AbsOp $arg)>;
+def : Pat<(TF_CeilOp $arg), (Tosa_CeilOp $arg)>;
+def : Pat<(TF_FloorOp $arg), (Tosa_FloorOp $arg)>;
+def : Pat<(TF_ExpOp $arg), (Tosa_ExpOp $arg)>;
+def : Pat<(TF_LogOp $arg), (Tosa_LogOp $arg)>;
+def : Pat<(TF_ReciprocalOp $arg), (Tosa_ReciprocalOp $arg)>;
+def : Pat<(TF_RsqrtOp $arg), (Tosa_RsqrtOp $arg)>;
+def : Pat<(TF_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>;
+def : Pat<(TF_InvertOp $arg1), (Tosa_BitwiseNotOp $arg1)>;
+def : Pat<(TF_CastOp $in, BoolAttr : $truncate), (Tosa_CastOp $in)>;
+
+// Binary ops patterns.
+
+def : Pat<(TF_BitwiseOrOp $arg1, $arg2), (Tosa_BitwiseOrOp $arg1, $arg2)>;
+def : Pat<(TF_BitwiseXorOp $arg1, $arg2), (Tosa_BitwiseXorOp $arg1, $arg2)>;
+def : Pat<(TF_BitwiseAndOp $arg1, $arg2), (Tosa_BitwiseAndOp $arg1, $arg2)>;
+def : Pat<(TF_LogicalAndOp $l, $r), (Tosa_LogicalAndOp $l, $r)>;
+def : Pat<(TF_LogicalOrOp $l, $r), (Tosa_LogicalOrOp $l, $r)>;
+def : Pat<(TF_PowOp $l, $r), (Tosa_PowOp $l, $r)>;
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
new file mode 100644
index 0000000..4314591
--- /dev/null
+++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
@@ -0,0 +1,54 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TFLite legalization patterns
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
+include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
+include "mlir/Dialect/Tosa/IR/TosaOps.td"
+
+// Nullary ops patterns.
+def : Pat<(ConstantOp $value), (Tosa_ConstOp $value)>;
+
+// Unary ops patterns.
+def : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>;
+def : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>;
+def : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>;
+def : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>;
+def : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>;
+def : Pat<(TFL_RsqrtOp $arg), (Tosa_RsqrtOp $arg)>;
+def : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>;
+def : Pat<(TFL_SqrtOp $arg), (Tosa_ReciprocalOp(Tosa_RsqrtOp $arg))>;
+def : Pat<(TFL_CastOp $in), (Tosa_CastOp $in)>;
+
+//===----------------------------------------------------------------------===//
+// Binary ops patterns.
+//===----------------------------------------------------------------------===//
+
+def : Pat<(TFL_LogicalAndOp $l, $r), (Tosa_LogicalAndOp $l, $r)>;
+def : Pat<(TFL_LogicalOrOp $l, $r), (Tosa_LogicalOrOp $l, $r)>;
+def : Pat<(TFL_PowOp $l, $r), (Tosa_PowOp $l, $r)>;
+
+//===----------------------------------------------------------------------===//
+// Ternary ops patterns.
+//===----------------------------------------------------------------------===//
+
+def : Pat<(TFL_GatherOp $params,
+ $indices,
+ $axis),
+ (Tosa_GatherOp $params,
+ $indices,
+ $axis)>;