Update legalizations for TOSA v0.22 Part 2

Update Concat Legalization to support variadic
Fix negative axis issue for concatenate legalization
- add rank(input) if axis is negative
- cleanup on axis check
- run clang-format
Fix numerical issue for 8-bit sigmoid/tanh
- softmax numerical behavior is improved but still not bit exact yet since we don't know what TFLite reference we should match against
Support 16-bit TOSA legalization for Add and Conv2D
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Support more 16 bits legalization
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Fix quantized resize_bilinear legalization
- tosa.resize doesn't need input zp shifted, so shouldn't shift output zp as well
Implement bit exact 8-bit tfl.softmax lowering
Updated TF and TFL legalization tests
Rewrite PackOp to use variadic Concat

Change-Id: Ia3827d3f6d6d43fe2b8d6f2c5e9da7b5d3d3edc8
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
index da58400..9351e73 100644
--- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
+++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
@@ -119,6 +119,16 @@
 
 // -----
 
+// CHECK-LABEL: test_concat
+// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64}
+func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
+  %2 = "tf.Const"()  {value = dense<0> : tensor<i32>}  : () -> tensor<i32>
+  %3 = "tf.ConcatV2"(%arg0, %arg1, %2)   : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<i32>) -> tensor<26x21x3xf32>
+  return %3 : tensor<26x21x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: test_bitwise_and
 // CHECK: %[[VAR0:.*]] = "tosa.bitwise_and"(%arg0, %arg1)
 func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
@@ -507,6 +517,26 @@
 
 // -----
 
+// CHECK-LABEL: test_concatv2
+// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64}
+func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<52x21x3xf32> {
+  %2 = "tf.Const"()  {value = dense<0> : tensor<i32>}  : () -> tensor<i32>
+  %3 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %arg3, %2)   : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<i32>) -> tensor<52x21x3xf32>
+  return %3 : tensor<52x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_stack
+// CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64}
+// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [4, 13, 21, 3]}
+func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> {
+  %2 = "tf.Pack"(%arg0, %arg1, %arg2, %arg3)  {axis = 0 : i64}  : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32>
+  return %2 : tensor<4x13x21x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: test_unstack
 // CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [1, 32, 32, 8], start = [0, 0, 0, 0]}
 // CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [32, 32, 8]}
diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
index 694c929..fb24b43 100644
--- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
+++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
@@ -484,6 +484,25 @@
 
 // -----
 
+// CHECK-LABEL: test_concatv2
+// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64}
+func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<52x21x3xf32> {
+  %0 = "tfl.concatenation"(%arg0, %arg1, %arg2, %arg3)  {axis = 0 : i32, fused_activation_function = "NONE"}  : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<52x21x3xf32>
+  return %0 : tensor<52x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_stack
+// CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64}
+// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [4, 13, 21, 3]}
+func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> {
+  %0 = "tfl.pack"(%arg0, %arg1, %arg2, %arg3)  {axis = 0 : i32, values_count = 4 : i32}  : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32>
+  return %0 : tensor<4x13x21x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: test_unstack
 // CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [1, 32, 32, 8], start = [0, 0, 0, 0]}
 // CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [32, 32, 8]}
@@ -796,35 +815,68 @@
 // -----
 
 // CHECK-LABEL: test_softmax_qi8
-// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<513xi16>}
-// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<3> : tensor<i32>}
-// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() {value = dense<34> : tensor<i32>}
-// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() {value = dense<-2147483648> : tensor<i32>}
-// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() {value = dense<16> : tensor<i32>}
-// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<513xi16>}
-// CHECK-DAG: %[[VAR6:.*]] = "tosa.rescale"(%arg0)
-// CHECK-DAG: %[[VAR7:.*]] = "tosa.reduce_max"(%[[VAR6]]) {axis = 2 : i64}
-// CHECK-DAG: %[[VAR8:.*]] = "tosa.sub"(%[[VAR6]], %[[VAR7]])
-// CHECK-DAG: %[[VAR9:.*]] = "tosa.rescale"(%[[VAR8]])
-// CHECK-DAG: %[[VAR10:.*]] = "tosa.table"(%[[VAR9]], %[[VAR0]])
-// CHECK-DAG: %[[VAR11:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = [1, 1, 1]}
-// CHECK-DAG: %[[VAR12:.*]] = "tosa.arithmetic_right_shift"(%[[VAR10]], %[[VAR11]]) {round = true}
-// CHECK-DAG: %[[VAR13:.*]] = "tosa.reduce_sum"(%[[VAR12]]) {axis = 2 : i64}
-// CHECK-DAG: %[[VAR14:.*]] = "tosa.clz"(%[[VAR13]])
-// CHECK-DAG: %[[VAR15:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [1, 1, 1]}
-// CHECK-DAG: %[[VAR16:.*]] = "tosa.sub"(%[[VAR15]], %[[VAR14]])
-// CHECK-DAG: %[[VAR17:.*]] = "tosa.logical_left_shift"(%[[VAR13]], %[[VAR14]])
-// CHECK-DAG: %[[VAR18:.*]] = "tosa.reshape"(%[[VAR3]]) {new_shape = [1, 1, 1]}
-// CHECK-DAG: %[[VAR19:.*]] = "tosa.sub"(%[[VAR17]], %[[VAR18]])
-// CHECK-DAG: %[[VAR20:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = [1, 1, 1]}
-// CHECK-DAG: %[[VAR21:.*]] = "tosa.arithmetic_right_shift"(%[[VAR19]], %[[VAR20]]) {round = true}
-// CHECK-DAG: %[[VAR22:.*]] = "tosa.cast"(%[[VAR21]])
-// CHECK-DAG: %[[VAR23:.*]] = "tosa.table"(%[[VAR22]], %[[VAR5]])
-// CHECK-DAG: %[[VAR24:.*]] = "tosa.rescale"(%[[VAR23]])
-// CHECK-DAG: %[[VAR25:.*]] = "tosa.rescale"(%[[VAR10]])
-// CHECK-DAG: %[[VAR26:.*]] = "tosa.mul"(%[[VAR24]], %[[VAR25]]) {shift = 0 : i32}
-// CHECK-DAG: %[[VAR27:.*]] = "tosa.arithmetic_right_shift"(%[[VAR26]], %[[VAR16]]) {round = true}
-// CHECK: %[[VAR28:.*]] = "tosa.rescale"(%[[VAR27]])
+// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<513xi16>}
+// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<513xi16>}
+// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() {value = dense<9> : tensor<i32>}
+// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() {value = dense<7> : tensor<i32>}
+// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() {value = dense<32768> : tensor<i32>}
+// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() {value = dense<12> : tensor<i32>}
+// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() {value = dense<1> : tensor<i32>}
+// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() {value = dense<4> : tensor<i32>}
+// CHECK-DAG: %[[VAR9:.*]] = "tosa.const"() {value = dense<536870912> : tensor<i32>}
+// CHECK-DAG: %[[VAR10:.*]] = "tosa.const"() {value = dense<1515870810> : tensor<i32>}
+// CHECK-DAG: %[[VAR11:.*]] = "tosa.const"() {value = dense<-1010580540> : tensor<i32>}
+// CHECK-DAG: %[[VAR12:.*]] = "tosa.const"() {value = dense<35> : tensor<i32>}
+// CHECK-DAG: %[[VAR13:.*]] = "tosa.rescale"(%arg0) {double_round = false, input_zp = 0 : i32, multiplier = [1073741824 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [30 : i32]}
+// CHECK-DAG: %[[VAR14:.*]] = "tosa.reduce_max"(%[[VAR13]]) {axis = 2 : i64}
+// CHECK-DAG: %[[VAR15:.*]] = "tosa.sub"(%[[VAR13]], %[[VAR14]])
+// CHECK-DAG: %[[VAR16:.*]] = "tosa.rescale"(%[[VAR15]]) {double_round = false, input_zp = 0 : i32, multiplier = [1073741824 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [23 : i32]}
+// CHECK-DAG: %[[VAR17:.*]] = "tosa.table"(%[[VAR16]], %[[VAR1]])
+// CHECK-DAG: %[[VAR18:.*]] = "tosa.table"(%[[VAR16]], %[[VAR2]])
+// CHECK-DAG: %[[VAR19:.*]] = "tosa.reshape"(%[[VAR3]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR20:.*]] = "tosa.logical_left_shift"(%[[VAR17]], %[[VAR19]])
+// CHECK-DAG: %[[VAR21:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR22:.*]] = "tosa.arithmetic_right_shift"(%[[VAR18]], %[[VAR21]]) {round = true}
+// CHECK-DAG: %[[VAR23:.*]] = "tosa.reshape"(%[[VAR5]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR24:.*]] = "tosa.add"(%[[VAR22]], %[[VAR23]])
+// CHECK-DAG: %[[VAR25:.*]] = "tosa.add"(%[[VAR20]], %[[VAR24]])
+// CHECK-DAG: %[[VAR26:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR27:.*]] = "tosa.arithmetic_right_shift"(%[[VAR25]], %[[VAR26]]) {round = true}
+// CHECK-DAG: %[[VAR28:.*]] = "tosa.reduce_sum"(%[[VAR27]]) {axis = 2 : i64}
+// CHECK-DAG: %[[VAR29:.*]] = "tosa.clz"(%[[VAR28]])
+// CHECK-DAG: %[[VAR30:.*]] = "tosa.reshape"(%[[VAR7]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR31:.*]] = "tosa.sub"(%[[VAR29]], %[[VAR30]])
+// CHECK-DAG: %[[VAR32:.*]] = "tosa.logical_left_shift"(%[[VAR28]], %[[VAR31]])
+// CHECK-DAG: %[[VAR33:.*]] = "tosa.reshape"(%[[VAR11]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR34:.*]] = "tosa.mul"(%[[VAR32]], %[[VAR33]]) {shift = 31 : i32}
+// CHECK-DAG: %[[VAR35:.*]] = "tosa.reshape"(%[[VAR10]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR36:.*]] = "tosa.add"(%[[VAR34]], %[[VAR35]])
+// CHECK-DAG: %[[VAR37:.*]] = "tosa.mul"(%[[VAR36]], %[[VAR32]]) {shift = 31 : i32}
+// CHECK-DAG: %[[VAR38:.*]] = "tosa.reshape"(%[[VAR9]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR39:.*]] = "tosa.sub"(%[[VAR38]], %[[VAR37]])
+// CHECK-DAG: %[[VAR40:.*]] = "tosa.mul"(%[[VAR36]], %[[VAR39]]) {shift = 31 : i32}
+// CHECK-DAG: %[[VAR41:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR42:.*]] = "tosa.mul"(%[[VAR40]], %[[VAR41]]) {shift = 0 : i32}
+// CHECK-DAG: %[[VAR43:.*]] = "tosa.add"(%[[VAR36]], %[[VAR42]])
+// CHECK-DAG: %[[VAR44:.*]] = "tosa.mul"(%[[VAR43]], %[[VAR32]]) {shift = 31 : i32}
+// CHECK-DAG: %[[VAR45:.*]] = "tosa.reshape"(%[[VAR9]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR46:.*]] = "tosa.sub"(%[[VAR45]], %[[VAR44]])
+// CHECK-DAG: %[[VAR47:.*]] = "tosa.mul"(%[[VAR43]], %[[VAR46]]) {shift = 31 : i32}
+// CHECK-DAG: %[[VAR48:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR49:.*]] = "tosa.mul"(%[[VAR47]], %[[VAR48]]) {shift = 0 : i32}
+// CHECK-DAG: %[[VAR50:.*]] = "tosa.add"(%[[VAR43]], %[[VAR49]])
+// CHECK-DAG: %[[VAR51:.*]] = "tosa.mul"(%[[VAR50]], %[[VAR32]]) {shift = 31 : i32}
+// CHECK-DAG: %[[VAR52:.*]] = "tosa.reshape"(%[[VAR9]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR53:.*]] = "tosa.sub"(%[[VAR52]], %[[VAR51]])
+// CHECK-DAG: %[[VAR54:.*]] = "tosa.mul"(%[[VAR50]], %[[VAR53]]) {shift = 31 : i32}
+// CHECK-DAG: %[[VAR55:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR56:.*]] = "tosa.mul"(%[[VAR54]], %[[VAR55]]) {shift = 0 : i32}
+// CHECK-DAG: %[[VAR57:.*]] = "tosa.add"(%[[VAR50]], %[[VAR56]])
+// CHECK-DAG: %[[VAR58:.*]] = "tosa.mul"(%[[VAR25]], %[[VAR57]]) {shift = 30 : i32}
+// CHECK-DAG: %[[VAR59:.*]] = "tosa.reshape"(%[[VAR12]]) {new_shape = [1, 1, 1]}
+// CHECK-DAG: %[[VAR60:.*]] = "tosa.sub"(%[[VAR59]], %[[VAR29]])
+// CHECK-DAG: %[[VAR61:.*]] = "tosa.arithmetic_right_shift"(%[[VAR58]], %[[VAR60]]) {round = true}
+// CHECK: %[[VAR62:.*]] = "tosa.rescale"(%[[VAR61]]) {double_round = false, input_zp = 0 : i32, multiplier = [1073741824 : i32], output_zp = -128 : i32, per_channel = false, scale32 = true, shift = [30 : i32]}
 func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.0156164625659585>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>> {
   %0 = "tfl.softmax"(%arg0)  {beta = 1.000000e+00 : f32}  : (tensor<13x21x3x!quant.uniform<i8:f32, 0.0156164625659585>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 3.906250e-03:-128>>
@@ -894,17 +946,8 @@
 // -----
 
 // CHECK-LABEL: test_resize_bilinear_qi8
-// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0> : tensor<i32>}
-// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<20> : tensor<i32>}
-// CHECK-DAG: %[[VAR2:.*]] = "tosa.resize"(%arg0) {mode = "BILINEAR", offset = [-448, -448], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [640, 640], shift = 10 : i32, stride = [128, 128], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]}
-// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [1, 1, 1, 1]}
-// CHECK-DAG: %[[VAR4:.*]] = "tosa.greater_equal"(%[[VAR2]], %[[VAR3]])
-// CHECK-DAG: %[[VAR5:.*]] = "tosa.abs"(%[[VAR2]])
-// CHECK-DAG: %[[VAR6:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = [1, 1, 1, 1]}
-// CHECK-DAG: %[[VAR7:.*]] = "tosa.arithmetic_right_shift"(%[[VAR5]], %[[VAR6]]) {round = true}
-// CHECK-DAG: %[[VAR8:.*]] = "tosa.negate"(%[[VAR7]])
-// CHECK-DAG: %[[VAR9:.*]] = "tosa.select"(%[[VAR4]], %[[VAR7]], %[[VAR8]])
-// CHECK: %[[VAR10:.*]] = "tosa.cast"(%[[VAR9]])
+// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) {mode = "BILINEAR", offset = [-448, -448], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [640, 640], shift = 10 : i32, stride = [128, 128], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]}
+// CHECK: %[[VAR2:.*]] = "tosa.rescale"(%[[VAR1]]) {double_round = false, input_zp = 0 : i32, multiplier = [1073741824 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [50 : i32]}
 func @test_resize_bilinear_qi8(%arg0: tensor<1x80x80x2x!quant.uniform<i8:f32, 0.42546585202217102>>) -> tensor<1x640x640x2x!quant.uniform<i8:f32, 0.42546585202217102>> {
   %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32>
   %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform<i8:f32, 0.42546585202217102>>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform<i8:f32, 0.42546585202217102>>
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
index 53727b4..c3ccff1 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
@@ -65,6 +65,208 @@
   *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
 }
 
+// Lowers the Pack operator to TOSA.
+llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
+                                    Value result_value,
+                                    SmallVector<Value, 8>& inputs,
+                                    int32_t axis) {
+  //////////////////////////////////////////////////
+  // Operator: output = Pack([values], axis) or output = Stack([values], axis)
+  // Lowering:
+  //
+  // This operator is lowered into a series of pairwise tosa.concat()
+  // operators and a reshape
+  // Depending on the inputs, a tranpose operator is also generated:
+  //
+  // Step 1: concatenate the tensors
+  // a1_concat = tosa.concat(input[0], input[1], axis)
+  // for (i = 2; i < len(input); i++)
+  //   a1_concat = tosa.concat(a1_concat, input[i], axis)
+  //
+  // Step 2: reshape to N+1 dimensions
+  // a2_reshape = tosa.reshape(a1_concat, new_rank)
+  //
+  // Step 3: Transpose if a new dimension is being added:
+  // if (axis == rank(values[0]):
+  //   // perm will be [1, 2, 3, 0]
+  //   a3_transpose = tosa.transpose(a2_reshape, perm)
+
+  // Sanity check 1: make sure all input tensors have the same shape
+  // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
+  // shape[A, B, C]
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
+
+  // Check for ranked tensor type.
+  if (!result_type) {
+    op->emitOpError("PackOp: result type not ranked tensor");
+    return llvm::None;
+  }
+
+  // Valid axis in TF is [-rank(input), rank(input))
+  // Valid axis in TOSA is [0, rank(input))
+  // Plus rank(input) once if axis is negative.
+  RankedTensorType input_type =
+      op->getOperand(0).getType().dyn_cast<RankedTensorType>();
+  if (!input_type) {
+    op->emitOpError("PackOp: input type not ranked tensor");
+    return llvm::None;
+  }
+
+  input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
+  if (!input_type) {
+    op->emitOpError("Input 0 type not ranked tensor.");
+    return llvm::None;
+  }
+  ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
+  int input_tensor_rank = input0_tensor_shape.size();
+
+  for (int i = 1; i < inputs.size(); i++) {
+    input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
+    if (!input_type) {
+      op->emitOpError(llvm::formatv(
+          "reduce axis {} is not in valid range [-rank(input), rank(input))",
+          i));
+      return llvm::None;
+    }
+    ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
+    if (next_tensor_shape.size() != input_tensor_rank) {
+      op->emitOpError("PackOp: input tensor rank mismatch.");
+      return llvm::None;
+    }
+    for (int d = 0; d < input0_tensor_shape.size(); d++) {
+      if (input0_tensor_shape[d] != next_tensor_shape[d]) {
+        op->emitOpError("PackOp: input tensor shape mismatch.");
+        return llvm::None;
+      }
+    }
+  }
+
+  // If input tensors are rank 0, should reshape them to rank 1 size 1 before
+  // performing concat.
+  if (input_tensor_rank == 0) {
+    SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
+    RankedTensorType reshape_rank1_size1_type =
+        RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
+                              result_type.getElementType());
+    ArrayAttr shape_rank1_size1_attr =
+        rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
+    for (int i = 0; i < inputs.size(); i++) {
+      auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
+          op->getLoc(), reshape_rank1_size1_type, inputs[i],
+          shape_rank1_size1_attr);
+      inputs[i] = a0_reshape_op.getResult();
+    }
+  }
+
+  // Sanity check 2: axis can be from [0, rank(input)+1]
+  // Where rank(input)+1 means create a new dimension
+  // Negative values are also allowed up to -(rank(input)+1)
+  // where the axis "wraps around".
+  if (axis < 0) axis += input_tensor_rank;
+  if ((axis < 0) || (axis > (input_tensor_rank + 1))) {
+    op->emitOpError("PackOp: axis out of valid range.");
+    return llvm::None;
+  }
+
+  // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
+  // A, B, C]
+  // 2.a check output is rank(input) + 1
+  SmallVector<int64_t, 8> output_shape_vals(result_type.getShape().begin(),
+                                            result_type.getShape().end());
+  if (output_shape_vals.size() != (input_tensor_rank + 1)) {
+    op->emitOpError("PackOp: output tensor rank mismatch.");
+    return llvm::None;
+  }
+  // 2.b check output rank 0 is N
+  if (output_shape_vals[axis] != inputs.size()) {
+    op->emitOpError("PackOp: output tensor shape mismatch.");
+    return llvm::None;
+  }
+  // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
+  // We can directly concatenate along that axis and perform the reshape.
+  // For example, stack N [A, B, C] input tensor ranks along axis = 1
+  // after concatenation, output will be [A, N * B, C]
+  // and then reshape it into [A, N, B, C]
+  // a special case would be PackOp.axis() equal to rank(input), in which case
+  // we can't directly concatenate along the PackOp.axis(), instead
+  // we concat along axis=0, and reshape into [N, A, B, C]
+  // and then we need an extra transpose to [A, B, C, N].
+  int64_t concat_axis;
+  SmallVector<int32_t, 8> perm;
+  SmallVector<int64_t, 8> reshape_output_shape;
+  if (axis == 0 && input_tensor_rank == 0) {
+    concat_axis = 0;
+  } else if (axis == input_tensor_rank) {
+    concat_axis = 0;
+
+    // A special case when stack axis is equal to input tensor rank:
+    // Output shape is [A, B, C, N]
+    // so reshape output will be [N, A, B, C]
+    // and perm will be [1, 2, 3, 0].
+    reshape_output_shape.push_back(output_shape_vals[axis]);
+    for (int d = 0; d < input_tensor_rank; d++) {
+      perm.push_back(d + 1);
+      reshape_output_shape.push_back(output_shape_vals[d]);
+    }
+    perm.push_back(0);
+  } else {
+    // General case, doesn't need perm vector.
+    concat_axis = axis;
+    reshape_output_shape.assign(output_shape_vals.begin(),
+                                output_shape_vals.end());
+  }
+  IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis);
+  ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape);
+
+  // Concat output shape will depend on concat_axis. E.g. [N * A, B, C]
+  SmallVector<int64_t, 4> concat_output_shape;
+  if (input_tensor_rank == 0) {
+    concat_output_shape.push_back(1);
+  } else {
+    for (int i = 0; i < input_tensor_rank; i++) {
+      concat_output_shape.push_back(input0_tensor_shape[i]);
+    }
+  }
+
+  concat_output_shape[concat_axis] =
+      concat_output_shape[concat_axis] * inputs.size();
+  RankedTensorType concat_type = RankedTensorType::get(
+      ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
+
+  SmallVector<mlir::Value> inputs_0;
+  for (int i = 0; i < inputs.size(); i++) {
+    inputs_0.push_back(inputs[i]);
+  }
+  auto a1_concat_op = rewriter.create<tosa::ConcatOp>(
+      op->getLoc(), concat_type, inputs_0, concat_axis_attr);
+
+  // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
+  // are reshaped beforehand.
+  if (input_tensor_rank == 0) return a1_concat_op.getResult();
+
+  // Reshape [N * A, B, C] to [N, A, B, C].
+  RankedTensorType reshape_output_type = RankedTensorType::get(
+      ArrayRef<int64_t>(reshape_output_shape), result_type.getElementType());
+
+  auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
+      op->getLoc(), reshape_output_type, a1_concat_op.getResult(), shape_attr);
+
+  // If axis is equal to input tensor rank, then we need extra transpose
+  // [N, A, B, C] to [A, B, C, N]
+  if (axis == input_tensor_rank) {
+    Value a3_transpose_perm =
+        get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm);
+
+    return rewriter
+        .create<tosa::TransposeOp>(op->getLoc(), result_type,
+                                   a2_reshape_op.getResult(), a3_transpose_perm)
+        .getResult();
+  }
+
+  return a2_reshape_op.getResult();
+}
+
 // Lowers the Unpack operator to TOSA
 llvm::Optional<ValueRange> convertUnpackOp(PatternRewriter& rewriter,
                                            Operation* op, Value input_value,
@@ -80,8 +282,10 @@
 
   // Negative axis allowed as long as it's within [-input_rank, input_rank).
   if (axis < 0) axis += input_rank;
-
-  assert(axis >= 0 && axis < input_shape.size());
+  if ((axis < 0) || (axis > input_rank)) {
+    op->emitOpError("UnpackOp: axis out of valid range.");
+    return llvm::None;
+  }
 
   // A list of the output types for each slice op
   SmallVector<Type, 4> outs_type_vec;
@@ -283,21 +487,28 @@
                                .cast<mlir::quant::UniformQuantizedType>();
     auto output_qtype =
         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
-    double in_lhs_scale = input_lhs_qtype.getScale();
-    double in_rhs_scale = input_rhs_qtype.getScale();
-    double output_scale = output_qtype.getScale();
+
+    // MLIR store scale as double, but TFLite store scale as float
+    // Downcasting from double to float to match TFLite behavior
+    float in_lhs_scale = input_lhs_qtype.getScale();
+    float in_rhs_scale = input_rhs_qtype.getScale();
+    float output_scale = output_qtype.getScale();
 
     double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
 
+    // 16bits x 16bits -> 32bits
+    // 32bits can be rescaled with 32bits quantize multiplier back to 16bits
+    bool scale32 = true;
+
     Value op1_rescale_lhs = buildRescaleToInt32(
         rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
     Value op2_rescale_rhs = buildRescaleToInt32(
         rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
     auto op3_mul_op1_op2 = rewriter.create<tosa::MulOp>(
         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs, 0);
-    return buildRescaleFromInt32(
-        rewriter, op, output_type, op3_mul_op1_op2.getResult(),
-        output_rescale_scale, output_qtype.getZeroPoint());
+    return buildRescale(rewriter, op, output_type, op3_mul_op1_op2.getResult(),
+                        output_rescale_scale, 0, output_qtype.getZeroPoint(),
+                        true, scale32);
   }
 
   return rewriter
@@ -357,6 +568,80 @@
       .getResult();
 }
 
+// Lowers ConcatV2 to TOSA Concat.
+llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
+                                        Operation* op, Value result_value,
+                                        SmallVector<Value, 8>& values,
+                                        int32_t axis) {
+  // Check all inputs are RankedTensorType
+  for (auto v : values) {
+    if (!v.getType().dyn_cast<RankedTensorType>()) {
+      op->emitOpError("ConcatV2Op: value type not ranked tensor.");
+      return llvm::None;
+    }
+  }
+
+  // Check output is Ranked tensor type
+  if (!result_value.getType().dyn_cast<RankedTensorType>()) {
+    op->emitOpError("ConcatV2Op: output value type not ranked tensor.");
+    return llvm::None;
+  }
+
+  RankedTensorType result_type =
+      result_value.getType().dyn_cast<RankedTensorType>();
+  mlir::quant::UniformQuantizedType result_quant_type =
+      result_type.getElementType()
+          .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+
+  SmallVector<mlir::Value> values_rescaled;
+
+  for (auto v : values) {
+    RankedTensorType operand_type = v.getType().dyn_cast<RankedTensorType>();
+    mlir::quant::UniformQuantizedType operand_quant_type =
+        operand_type.getElementType()
+            .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
+
+    // tfl.concat currently allows different scales for each input tensor, which
+    // TFlite team will fix in:
+    // https://github.com/tensorflow/tensorflow/issues/39658
+    // For backward compatibility, we still need to support this artifact by
+    // scaling inputs to let them have the same scales.
+    if (result_quant_type && operand_quant_type) {
+      double operand_scale = static_cast<double>(operand_quant_type.getScale());
+      int32_t operand_zeropoint = operand_quant_type.getZeroPoint();
+
+      double result_scale = static_cast<double>(result_quant_type.getScale());
+      int32_t result_zeropoint = result_quant_type.getZeroPoint();
+
+      // Rescale input if scale is not equal to output tensor scale.
+      if (operand_scale != result_scale) {
+        RankedTensorType rescale_type =
+            RankedTensorType::get(operand_type.getShape(), result_quant_type);
+        Value rescale_op = buildRescale(
+            rewriter, op, rescale_type, v, operand_scale / result_scale,
+            operand_zeropoint, result_zeropoint, false, true);
+        values_rescaled.push_back(rescale_op);
+      } else
+        values_rescaled.push_back(v);
+    } else
+      values_rescaled.push_back(v);
+  }
+
+  int32_t tensor_rank = result_type.getShape().size();
+
+  if (axis < 0) axis += tensor_rank;
+  if ((axis < 0) || (axis > tensor_rank)) {
+    op->emitOpError("ConcatV2Op: axis out of valid range.");
+    return llvm::None;
+  }
+
+  auto concat_op = rewriter.create<tosa::ConcatOp>(
+      op->getLoc(), result_value.getType(), values_rescaled,
+      rewriter.getI64IntegerAttr(axis));
+
+  return concat_op.getResult();
+}
+
 // Lowers SpaceToBatchND to TOSA.
 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
                                               Operation* op, Value result_value,
@@ -1026,7 +1311,8 @@
 
 // Lowers Softmax to a sequence of TOSA ops.
 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
-                                       Value result_value, Value logits_value) {
+                                       Value result_value, Value logits_value,
+                                       double beta) {
   // softmax = exp(logits) / reduce_sum(exp(logits), -1)
   //
   // or equivalently multiply exp(-max(logits)) to both numerator and
@@ -1077,128 +1363,291 @@
     RankedTensorType int32_rsum_type =
         RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
 
-    // Step 1. get x - max(x)
-    Value op1_rescale_in =
-        buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
-                     in_quant_type.getZeroPoint(), 0);
+    if (in_quant_type.getStorageTypeIntegralWidth() == 8) {
+      // Step 1. get x - max(x)
+      Value op1_rescale_in =
+          buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
+                       in_quant_type.getZeroPoint(), 0, false, true);
 
-    auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
-        op->getLoc(), int32_rsum_type, op1_rescale_in,
-        rewriter.getI64IntegerAttr(input_rank - 1));
+      auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
+          op->getLoc(), int32_rsum_type, op1_rescale_in,
+          rewriter.getI64IntegerAttr(input_rank - 1));
 
-    auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
-        op->getLoc(), int32_logits_type, op1_rescale_in,
-        op2_reducemax_op1.getResult());
+      auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
+          op->getLoc(), int32_logits_type, op1_rescale_in,
+          op2_reducemax_op1.getResult());
 
-    // Table input range from -16.0 to 16.0, input below -16.0 treated as
-    // exp(-16.0), which is 0 in 0.16
-    const double exp_sample_grain = 1.0 / 16.0;
-    auto exp_func = [exp_sample_grain](int32_t x) -> int32_t {
-      double v = static_cast<double>(x) * exp_sample_grain;
-      v = v < 0.0 ? std::exp(v) : 1.0;
-      return std::lround(32768.0 * v);
-    };
+      // Step 2. get exp() result
+      // Implemented with two 8-bit -> 16-bit table lookup
+      // Since table output is allowed to be [-32768, 32767]
+      // And lower 16 bits are unsigned and ranges [0, 65535]
+      // Lower table is generated with offset -32768, and this need to be
+      // recovered before adding with higher 16 bits.
+      auto exp_func = [](double x) -> double { return std::exp(x); };
 
-    Value exp_table_const = getTosa1DConstTensorTable(rewriter, op, exp_func);
+      Value exp_table_const_upper, exp_table_const_lower;
+      getTosaConst32bitTable(rewriter, op, beta * in_quant_type.getScale(), 0,
+                             exp_func, exp_table_const_upper,
+                             exp_table_const_lower);
 
-    // Step 2. rescale input
-    Value op4_rescale_op3 = buildRescale(
-        rewriter, op, int16_logits_type, op3_sub_op1_op2.getResult(),
-        in_quant_type.getScale() * 128.0 / exp_sample_grain, 0, 0);
+      Value op4_rescale_op3 =
+          buildRescale(rewriter, op, int16_logits_type,
+                       op3_sub_op1_op2.getResult(), 128.0, 0, 0, false, true);
 
-    // Step 3. get exp() result
-    // Since we already make sure input x < 0 in step 1,
-    // we can utilize full output 0.16 range.
+      // Input is 9.7, where lower 7 bits are all zeros.
+      // Output is 23 bits, where lower 7 bits should be all zeros as well,
+      // since there's no interpolation here.
+      auto op5_table_op4_upper = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_logits_type, op4_rescale_op3,
+          exp_table_const_upper);
 
-    // Output is 0.23
-    auto op5_table_op4 = rewriter.create<tosa::TableOp>(
-        op->getLoc(), int32_logits_type, op4_rescale_op3, exp_table_const);
+      auto op6_table_op4_lower = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_logits_type, op4_rescale_op3,
+          exp_table_const_lower);
 
-    // Right shift 3 bits. output 0.20
-    auto op6_rshift_op5 = rewriter.create<tosa::ArithmeticRightShiftOp>(
-        op->getLoc(), int32_logits_type, op5_table_op4.getResult(),
-        getTosaConstTensorSingleI32(rewriter, op, 3), true);
+      // To get 16 bits upper/lower value, we need to right shift 7 bits
+      // And then we reconstruct 32-bit value we need (upper << 16) + lower
+      // So effectively we left shift upper with 9 bits
+      auto op7_lshift_op5 = rewriter.create<tosa::LogicalLeftShiftOp>(
+          op->getLoc(), int32_logits_type, op5_table_op4_upper.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 9));
 
-    // Step 4. get sum(exp()). output 12.20
-    auto op7_reducesum_op6 = rewriter.create<tosa::ReduceSumOp>(
-        op->getLoc(), int32_rsum_type, op6_rshift_op5.getResult(),
-        rewriter.getI64IntegerAttr(input_rank - 1));
+      // Right shift 7 bits to get lower 16 bits.
+      auto op8_rshift_op6 = rewriter.create<tosa::ArithmeticRightShiftOp>(
+          op->getLoc(), int32_logits_type, op6_table_op4_lower.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 7), true);
 
-    // Step 5. calculate reciprocal(sum(exp()))
-    auto op8_clz_op7 = rewriter.create<tosa::ClzOp>(
-        op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult());
+      // Recover lower bits from [-32768, 32767] back to [0, 65535]
+      auto op9_add_op8_32768 = rewriter.create<tosa::AddOp>(
+          op->getLoc(), int32_logits_type, op8_rshift_op6.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 32768));
 
-    // rshift amount of reciprocal(sum(exp()))
-    // 12 from the integer bits of 12.20 accumulator
-    // 30 from output of multiply 0.15 x 0.15
-    // -8 to keep additional 8 bits before output rescaling
-    auto op9_sub_op8 = rewriter.create<tosa::SubOp>(
-        op->getLoc(), int32_rsum_type,
-        getTosaConstTensorSingleI32(rewriter, op, 12 + 30 - 8),
-        op8_clz_op7.getResult());
+      auto op10_add_op7_op9 = rewriter.create<tosa::AddOp>(
+          op->getLoc(), int32_logits_type, op7_lshift_op5.getResult(),
+          op9_add_op8_32768.getResult());
 
-    // Left shift to get  1.31 format
-    auto op10_lshift_op7_op8 = rewriter.create<tosa::LogicalLeftShiftOp>(
-        op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult(),
-        op8_clz_op7.getResult());
+      // Step 3. get sum(exp()). output 12.19
+      auto op11_rshift_op10_12 = rewriter.create<tosa::ArithmeticRightShiftOp>(
+          op->getLoc(), int32_logits_type, op10_add_op7_op9.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 12), true);
 
-    // Subtract (1 << 31) to make 0 <= x <= 1
-    auto op11_sub_op10 = rewriter.create<tosa::SubOp>(
-        op->getLoc(), int32_rsum_type, op10_lshift_op7_op8.getResult(),
-        getTosaConstTensorSingleI32(rewriter, op, (1u << 31)));
+      auto op12_reducesum_op11 = rewriter.create<tosa::ReduceSumOp>(
+          op->getLoc(), int32_rsum_type, op11_rshift_op10_12.getResult(),
+          rewriter.getI64IntegerAttr(input_rank - 1));
 
-    // Right shift 16 bits to get 16 bits index
-    auto op12_rshift_op11 = rewriter.create<tosa::ArithmeticRightShiftOp>(
-        op->getLoc(), int32_rsum_type, op11_sub_op10.getResult(),
-        getTosaConstTensorSingleI32(rewriter, op, 16), true);
+      // Step 4. calculate reciprocal(sum(exp()))
+      // CLZ returns headroom_plus_one
+      auto op13_clz_op12 = rewriter.create<tosa::ClzOp>(
+          op->getLoc(), int32_rsum_type, op12_reducesum_op11.getResult());
 
-    // cast to 16 bits to index TABLE op
-    auto op13_cast_op12 = rewriter.create<tosa::CastOp>(
-        op->getLoc(), int16_rsum_type, op12_rshift_op11.getResult());
+      // minus one to get headroom
+      auto op14_sub_op13 = rewriter.create<tosa::SubOp>(
+          op->getLoc(), int32_rsum_type, op13_clz_op12.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 1));
 
-    // Generate table for 1 / (1 + x), for 0 <= x <= 1
-    const double one_over_one_plus_x_sample_grain = 1.0 / 256.0;
-    auto one_over_one_plus_x_func =
-        [one_over_one_plus_x_sample_grain](int32_t x) -> int32_t {
-      double v = static_cast<double>(x) * one_over_one_plus_x_sample_grain;
-      v = v < 0 ? 1.0 : 1.0 / (1.0 + v);
-      return std::lround(32768.0 * v);
-    };
+      // Left shift to get s1.30 format
+      auto op15_lshift_op12_op14 = rewriter.create<tosa::LogicalLeftShiftOp>(
+          op->getLoc(), int32_rsum_type, op12_reducesum_op11.getResult(),
+          op14_sub_op13.getResult());
 
-    Value one_over_one_plus_x_table_const =
-        getTosa1DConstTensorTable(rewriter, op, one_over_one_plus_x_func);
+      // Step 5. Calculate one_over_one_plus_x() with Newton-Raphson division
+      // with 3 iterations.
+      // Need two magic constants 48/17 and -32/17 from Newton-Raphson algorithm
+      // We need to operator in s2.29 since 48/17 is > 2.0
+      // Reference: gemmlowp/fixedpoint/fixedpoint.h
+      Value half_denominator = op15_lshift_op12_op14.getResult();
+      Value four = getTosaConstTensorSingleI32(rewriter, op, 4);
+      Value F2_one = getTosaConstTensorSingleI32(rewriter, op, (1U << 29));
+      Value constant_48_over_17 =
+          getTosaConstTensorSingleI32(rewriter, op, 1515870810);
+      Value constant_neg_32_over_17 =
+          getTosaConstTensorSingleI32(rewriter, op, -1010580540);
 
-    auto op14_table_op13 = rewriter.create<tosa::TableOp>(
-        op->getLoc(), int32_rsum_type, op13_cast_op12.getResult(),
-        one_over_one_plus_x_table_const);
+      // F2 x = constant_48_over_17 + half_denominator *
+      // constant_neg_32_over_17;
+      auto op16_mul_half_denominator = rewriter.create<tosa::MulOp>(
+          op->getLoc(), int32_rsum_type, half_denominator,
+          constant_neg_32_over_17, 31);
 
-    // Rescale sum(exp(x)) from 0.23 back to 0.16
-    Value op15_rescale_op14 = buildRescale(rewriter, op, int32_rsum_type,
-                                           op14_table_op13, 1.0 / 128.0, 0, 0);
+      auto op17_add_op16 = rewriter.create<tosa::AddOp>(
+          op->getLoc(), int32_rsum_type, op16_mul_half_denominator.getResult(),
+          constant_48_over_17);
 
-    // Rescale exp(x) from 0.23 back to 0.16
-    Value op16_rescale_op5 =
-        buildRescale(rewriter, op, int32_logits_type, op5_table_op4.getResult(),
-                     1.0 / 128.0, 0, 0);
+      // Newton-Raphson 3x iteration
+      Value nr_x = op17_add_op16.getResult();
+      for (int i = 0; i < 3; i++) {
+        // half_denominator_times_x =
+        // SaturatingRoundingDoublingHighMul(half_denominator, x)
+        auto op18_mul_x_half_denominator = rewriter.create<tosa::MulOp>(
+            op->getLoc(), int32_rsum_type, nr_x, half_denominator, 31);
 
-    // Step 6. apply the scales we just get explicitly in i32 space
-    // lhs: 0.16, rhs: 0.16, output: 0.32
-    auto op17_mul_op15_op16 =
-        rewriter.create<tosa::MulOp>(op->getLoc(), int32_logits_type,
-                                     op15_rescale_op14, op16_rescale_op5, 0);
+        // F2 one_minus_half_denominator_times_x = F2::One() -
+        // half_denominator_times_x
+        auto op19_sub_one_op18 = rewriter.create<tosa::SubOp>(
+            op->getLoc(), int32_rsum_type, F2_one,
+            op18_mul_x_half_denominator.getResult());
 
-    // Apply right shift from clz
-    auto op18_rshift_op17_op9 = rewriter.create<tosa::ArithmeticRightShiftOp>(
-        op->getLoc(), int32_logits_type, op17_mul_op15_op16.getResult(),
-        op9_sub_op8.getResult(), true);
+        // SaturatingRoundingDoublingHighMul(x,
+        // one_minus_half_denominator_times_x)
+        auto op20_mul_x_op19 =
+            rewriter.create<tosa::MulOp>(op->getLoc(), int32_rsum_type, nr_x,
+                                         op19_sub_one_op18.getResult(), 31);
 
-    // Step 7. output scaling, extra 1.0 / 256.0 since we keep extra 8 bits
-    // in op9_sub_op8
-    return buildRescale(rewriter, op, output_type,
-                        op18_rshift_op17_op9.getResult(),
-                        1.0 / (out_quant_type.getScale() * 256.0), 0,
-                        out_quant_type.getZeroPoint());
+        // x + Rescale<2>(x * one_minus_half_denominator_times_x)
+        auto op21_mul_op20_four =
+            rewriter.create<tosa::MulOp>(op->getLoc(), int32_rsum_type,
+                                         op20_mul_x_op19.getResult(), four, 0);
 
+        auto op22_add_x_op21 =
+            rewriter.create<tosa::AddOp>(op->getLoc(), int32_rsum_type, nr_x,
+                                         op21_mul_op20_four.getResult());
+
+        nr_x = op22_add_x_op21.getResult();
+      }
+
+      // Step 6. multiply exp(x) with 1 / sum(exp(x))
+      // combined with Rescale<0>(ExactMulByPot<-1>(x))
+      // so shift 30 instead of 31
+      auto op23_mul_op10_x = rewriter.create<tosa::MulOp>(
+          op->getLoc(), int32_logits_type, op10_add_op7_op9.getResult(), nr_x,
+          31 - 1);
+
+      // Right shift amount is
+      // num_bits_over_unit + 31 - (sizeof(OutputT) * 8 =
+      // (12 - headroom_plus_one) + 31 - 8 =
+      // (12 + 31 - 8) - headroom_plus_one
+      auto op24_sub_op13 = rewriter.create<tosa::SubOp>(
+          op->getLoc(), int32_rsum_type,
+          getTosaConstTensorSingleI32(rewriter, op, 12 + 31 - 8),
+          op13_clz_op12.getResult());
+
+      auto op25_rshift_op23_op24 =
+          rewriter.create<tosa::ArithmeticRightShiftOp>(
+              op->getLoc(), int32_logits_type, op23_mul_op10_x.getResult(),
+              op24_sub_op13.getResult(), true);
+
+      return buildRescale(rewriter, op, output_type,
+                          op25_rshift_op23_op24.getResult(), 1.0, 0,
+                          out_quant_type.getZeroPoint(), false, true);
+
+    } else if (in_quant_type.getStorageTypeIntegralWidth() == 16) {
+      // Step 1. get x - max(x)
+      Value op1_rescale_in =
+          buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
+                       in_quant_type.getZeroPoint(), 0, false, true);
+
+      auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
+          op->getLoc(), int32_rsum_type, op1_rescale_in,
+          rewriter.getI64IntegerAttr(input_rank - 1));
+
+      // output range is [-65535, 0]
+      auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
+          op->getLoc(), int32_logits_type, op1_rescale_in,
+          op2_reducemax_op1.getResult());
+
+      auto exp_func = [](double x) -> double { return std::exp(x); };
+
+      // Follow TFLite reference: tensorflow/lite/kernels/activations.cc
+      Value exp_table_const =
+          getTosaConst16bitTable(rewriter, op, exp_func, -10.0, 0);
+
+      double input_diff_scale = in_quant_type.getScale() / (10.0 / 65535.0);
+
+      // Step 2. rescale input from [-65535, 0] to [-32768, 32767] for LUT input
+      Value op4_rescale_op3 = buildRescale(
+          rewriter, op, int16_logits_type, op3_sub_op1_op2.getResult(),
+          input_diff_scale, 0, 32767, true, true);
+
+      // Step 3. get exp() result
+      // Output is 15.7.
+      // In 8-bit case, no interpolation here, since input should be right on
+      // table entry.
+      auto op5_table_op4 = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_logits_type, op4_rescale_op3, exp_table_const);
+
+      // Right shift 7 bits. output 15. Shouldn't lose any precision since last
+      // 7 bits should be all 0.
+      auto op6_rshift_op5 = rewriter.create<tosa::ArithmeticRightShiftOp>(
+          op->getLoc(), int32_logits_type, op5_table_op4.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 7), true);
+
+      // Step 4. get sum(exp()). output 16.15
+      auto op7_reducesum_op6 = rewriter.create<tosa::ReduceSumOp>(
+          op->getLoc(), int32_rsum_type, op6_rshift_op5.getResult(),
+          rewriter.getI64IntegerAttr(input_rank - 1));
+
+      // Step 5. calculate reciprocal(sum(exp()))
+      // CLZ returns 32 - first non zero bit
+      auto op8_clz_op7 = rewriter.create<tosa::ClzOp>(
+          op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult());
+
+      auto op9_sub_op8 = rewriter.create<tosa::SubOp>(
+          op->getLoc(), int32_rsum_type, op8_clz_op7.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 1));
+
+      // Left shift to get  1.30 format
+      auto op10_lshift_op7_op9 = rewriter.create<tosa::LogicalLeftShiftOp>(
+          op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult(),
+          op9_sub_op8.getResult());
+
+      // Subtract (1 << 30) to make 0 <= x <= 1 under 0.30 format
+      auto op11_sub_op10 = rewriter.create<tosa::SubOp>(
+          op->getLoc(), int32_rsum_type, op10_lshift_op7_op9.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, (1u << 30)));
+
+      // Right shift 14 bits to get output range [0, 65535]
+      auto op12_rshift_op11 = rewriter.create<tosa::ArithmeticRightShiftOp>(
+          op->getLoc(), int32_rsum_type, op11_sub_op10.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 14), true);
+
+      // Remap input to [-32768, 32767] for LUT input
+      auto op13_rescale_op12 = buildRescale(rewriter, op, int16_rsum_type,
+                                            op12_rshift_op11.getResult(), 1.0,
+                                            32768, 0, false, true);
+
+      // Generate table for 1 / (1 + x), for 0 <= x <= 1
+      auto one_over_one_plus_x_func = [](double x) -> double {
+        return 1.0 / (1.0 + x);
+      };
+
+      Value one_over_one_plus_x_table_const = getTosaConst16bitTable(
+          rewriter, op, one_over_one_plus_x_func, 0.0, 1.0);
+
+      // Get (1 / sum(exp(x))) result as 23 bits (including sign bit)
+      auto op14_table_op13 = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_rsum_type, op13_rescale_op12,
+          one_over_one_plus_x_table_const);
+
+      // Right shift 7 bits back to 0.15
+      auto op15_rshift_op14 = rewriter.create<tosa::ArithmeticRightShiftOp>(
+          op->getLoc(), int32_rsum_type, op14_table_op13.getResult(),
+          getTosaConstTensorSingleI32(rewriter, op, 7), true);
+
+      // Step 6. multiply exp(max-x) with 1 / sum(exp(max-x))
+      // lhs: 0.15, rhs: 0.15, output: 0.30
+      auto op16_mul_op15_op6 = rewriter.create<tosa::MulOp>(
+          op->getLoc(), int32_logits_type, op15_rshift_op14, op6_rshift_op5, 0);
+
+      auto op17_sub_op8 = rewriter.create<tosa::SubOp>(
+          op->getLoc(), int32_rsum_type,
+          getTosaConstTensorSingleI32(rewriter, op, 31),
+          op8_clz_op7.getResult());
+
+      // Apply the clz back, we get 0.15 output
+      // [0, 32767] corresponding to [0.0, 1.0]
+      auto op18_rshift_op16_op17 =
+          rewriter.create<tosa::ArithmeticRightShiftOp>(
+              op->getLoc(), int32_logits_type, op16_mul_op15_op6.getResult(),
+              op17_sub_op8.getResult(), true);
+
+      return buildRescale(rewriter, op, output_type,
+                          op18_rshift_op16_op17.getResult(),
+                          (1.0 / out_quant_type.getScale()) * (1.0 / 32768.0),
+                          0, out_quant_type.getZeroPoint(), false, true);
+    } else {
+      op->emitOpError("Softmax: unknown quantization bitwidth");
+      return llvm::None;
+    }
   } else {
     SmallVector<int64_t, 4> rsum_shape_v(input_type.getShape().begin(),
                                          input_type.getShape().end());
@@ -1982,7 +2431,7 @@
       RankedTensorType output_rescale_type = RankedTensorType::get(
           llvm::makeArrayRef<int64_t>(shape_vec), output_type.getElementType());
       val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
-                         0, output_zp, false);
+                         0, output_zp, false, true);
     }
 
     // Optionally squeeze out the reduced axes.
@@ -2314,12 +2763,16 @@
       auto input_element_qtype =
           input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
 
+      bool scale32;
+
       // TOSA RESIZE: 16 bit input -> 48 bit output, or 8 bit input -> 32 bit
       // output.
       if (input_element_qtype.getStorageTypeIntegralWidth() == 16) {
+        scale32 = false;
         output_acc_type = RankedTensorType::get(output_type.getShape(),
                                                 rewriter.getIntegerType(48));
       } else if (input_element_qtype.getStorageTypeIntegralWidth() == 8) {
+        scale32 = true;
         output_acc_type = RankedTensorType::get(output_type.getShape(),
                                                 rewriter.getI32Type());
       } else {
@@ -2335,6 +2788,7 @@
           offset, shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
           rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
 
+#ifdef RESIZE_BILINEAR_LOWER_SYMMETRIC_ROUNDING
       // TFLite resize_bilinear always assume input and output tensors have same
       // scale That means we only need to arithmetic right shift with (2 *
       // shift)
@@ -2367,6 +2821,13 @@
                                                    select_op.getResult());
 
       return cast_op.getResult();
+#else
+      // This should be the expected lowering, but is +-1 within compared to
+      // TFLite reference.
+      return buildRescale(rewriter, op, output_type, resize_op.getResult(),
+                          1.0 / (1 << 20), 0, 0, false, scale32);
+#endif
+
     } else if (mode == "NEAREST_NEIGHBOR") {
       auto resize_op = rewriter.create<tosa::ResizeOp>(
           op->getLoc(), output_type, input_value, output_size, stride, offset,
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h
index ce122f8..a3cffff 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h
@@ -31,6 +31,12 @@
 namespace mlir {
 namespace tosa {
 
+// Lowers the Pack operator to TOSA.
+llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
+                                    Value result_value,
+                                    SmallVector<Value, 8>& inputs,
+                                    int32_t axis);
+
 // Lowers the Unpack operator to TOSA.
 llvm::Optional<ValueRange> convertUnpackOp(PatternRewriter& rewriter,
                                            Operation* op, Value input_value,
@@ -63,6 +69,12 @@
 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
                                      Value result, Value input);
 
+// Lowers ConcatV2 to TOSA.
+llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
+                                        Operation* op, Value result_value,
+                                        SmallVector<Value, 8>& values,
+                                        int32_t axis);
+
 // Lowers SpaceToBatchND to TOSA.
 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
                                               Operation* op, Value result_value,
@@ -93,7 +105,8 @@
 
 // Lowers Softmax to a sequence of TOSA ops.
 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
-                                       Value result_value, Value logits_value);
+                                       Value result_value, Value logits_value,
+                                       double beta);
 
 // Lowers LogSoftmax to a sequence of TOSA ops.
 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
index 00cba1e..2b401de 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
@@ -81,6 +81,7 @@
 DECL_CONVERT_OP(ArgMax);
 DECL_CONVERT_OP(AvgPool);
 DECL_CONVERT_OP(MaxPool);
+DECL_CONVERT_OP(ConcatV2);
 DECL_CONVERT_OP(Reshape);
 DECL_CONVERT_OP(Rank);
 DECL_CONVERT_OP(Shape);
@@ -105,6 +106,7 @@
 DECL_CONVERT_OP(BiasAdd);
 DECL_CONVERT_OP(Split);
 DECL_CONVERT_OP(SplitV);
+DECL_CONVERT_OP(Pack);
 DECL_CONVERT_OP(Unpack);
 DECL_CONVERT_OP(Transpose);
 DECL_CONVERT_OP(Tile);
@@ -633,6 +635,27 @@
   return success();
 }
 
+LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
+    Operation* op, PatternRewriter& rewriter) const {
+  auto tf_concatv2_op = cast<TF::ConcatV2Op>(op);
+  SmallVector<Value, 8> values(tf_concatv2_op.values());
+
+  ElementsAttr axis_elems;
+  if (!matchPattern(tf_concatv2_op.axis(), m_Constant(&axis_elems)))
+    return failure();
+
+  int32_t axis = axis_elems.getValue<IntegerAttr>({}).getInt();
+
+  llvm::Optional<Value> result =
+      convertConcatV2Op(rewriter, op, tf_concatv2_op.getResult(), values, axis);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
+}
+
 LogicalResult ConvertTFReshapeOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tf_reshape_op = cast<TF::ReshapeOp>(op);
@@ -1210,7 +1233,7 @@
   auto tf_softmax_op = cast<TF::SoftmaxOp>(op);
 
   llvm::Optional<Value> result = convertSoftmaxOp(
-      rewriter, op, tf_softmax_op.getResult(), tf_softmax_op.logits());
+      rewriter, op, tf_softmax_op.getResult(), tf_softmax_op.logits(), 1.0);
 
   if (!result) return failure();
 
@@ -1469,6 +1492,32 @@
   return success();
 }
 
+LogicalResult ConvertTFPackOp::matchAndRewrite(
+    Operation* op, PatternRewriter& rewriter) const {
+  auto tf_pack_op = cast<TF::PackOp>(op);
+
+  SmallVector<Value, 8> inputs(tf_pack_op.values());
+
+  assert(inputs.size() >= 2);
+
+  IntegerAttr axis_attr;
+  {
+    auto tmpAttr = tf_pack_op.axisAttr();
+    if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
+    axis_attr = tmpAttr;
+  }
+  int32_t axis_i32 = axis_attr.getInt();
+
+  llvm::Optional<Value> result =
+      convertPackOp(rewriter, op, tf_pack_op.getResult(), inputs, axis_i32);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
+}
+
 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tf_unpack_op = cast<TF::UnpackOp>(op);
@@ -2139,6 +2188,7 @@
   patterns.insert<ConvertTFArgMaxOp>(ctx);
   patterns.insert<ConvertTFAvgPoolOp>(ctx);
   patterns.insert<ConvertTFMaxPoolOp>(ctx);
+  patterns.insert<ConvertTFConcatV2Op>(ctx);
   patterns.insert<ConvertTFReshapeOp>(ctx);
   patterns.insert<ConvertTFRankOp>(ctx);
   patterns.insert<ConvertTFShapeOp>(ctx);
@@ -2163,6 +2213,7 @@
   patterns.insert<ConvertTFBiasAddOp>(ctx);
   patterns.insert<ConvertTFSplitOp>(ctx);
   patterns.insert<ConvertTFSplitVOp>(ctx);
+  patterns.insert<ConvertTFPackOp>(ctx);
   patterns.insert<ConvertTFUnpackOp>(ctx);
   patterns.insert<ConvertTFTransposeOp>(ctx);
   patterns.insert<ConvertTFTileOp>(ctx);
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
index 8a7d3d7..1b94b7e 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
@@ -87,6 +87,7 @@
 DECL_CONVERT_OP(AddN);
 DECL_CONVERT_OP(AveragePool2D);
 DECL_CONVERT_OP(MaxPool2D);
+DECL_CONVERT_OP(Concatenation);
 DECL_CONVERT_OP(Reshape);
 DECL_CONVERT_OP(Rank);
 DECL_CONVERT_OP(Shape);
@@ -108,6 +109,7 @@
 DECL_CONVERT_OP(FullyConnected);
 DECL_CONVERT_OP(Split);
 DECL_CONVERT_OP(SplitV);
+DECL_CONVERT_OP(Pack);
 DECL_CONVERT_OP(Unpack);
 DECL_CONVERT_OP(Transpose);
 DECL_CONVERT_OP(Tile);
@@ -142,6 +144,16 @@
 DECL_CONVERT_OP(OneHot);
 #undef DECL_CONVERT_OP
 
+// Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
+// bits. Need to do a customized truncate here instead of tablegen to handle
+// attribute with negative value.
+struct ConvertConstantOp : public RewritePattern {
+  explicit ConvertConstantOp(MLIRContext* context)
+      : RewritePattern(ConstantOp::getOperationName(), 1, context) {}
+  LogicalResult matchAndRewrite(Operation* op,
+                                PatternRewriter& rewriter) const override;
+};
+
 LogicalResult ConvertTFLReluOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_relu_op = cast<TFL::ReluOp>(op);
@@ -569,14 +581,17 @@
     // 2. Extra left shift to input to increase precision
     // Where input_shift = 20 if input is 8-bit
     // input_shift = 15 if input is 16-bit
-    // TODO: support 16-bit
     double in_lhs_scale = input_lhs_qtype.getScale();
     double in_rhs_scale = input_rhs_qtype.getScale();
     double output_scale = output_qtype.getScale();
     double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
 
     const int32_t SHIFT_8_BIT = 20;
-    int32_t input_shift = SHIFT_8_BIT;
+    const int32_t SHIFT_16_BIT = 15;
+
+    int32_t input_shift = (output_qtype.getStorageTypeIntegralWidth() == 16)
+                              ? SHIFT_16_BIT
+                              : SHIFT_8_BIT;
 
     double lhs_rescale_scale =
         static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
@@ -1271,8 +1286,10 @@
                                .getStorageTypeIntegralWidth();
 
     if (input_bits == 16 && weight_bits == 8) {
-      SmallVector<int64_t, 8> zero_bias_vec(output_type.getShape()[3], 0);
-      zero_bias = get1DConstTensorInt48(rewriter, op, zero_bias_vec);
+      std::vector<APInt> zero_bias_vec(output_type.getShape()[3],
+                                       APInt(48, 0, true));
+      ArrayRef<APInt> zero_bias_ref = llvm::makeArrayRef<APInt>(zero_bias_vec);
+      zero_bias = get1DConstTensorInt48(rewriter, op, zero_bias_ref);
     } else {
       SmallVector<int32_t, 8> zero_bias_vec(output_type.getShape()[3], 0);
       zero_bias =
@@ -1552,6 +1569,31 @@
   return success();
 }
 
+LogicalResult ConvertTFLConcatenationOp::matchAndRewrite(
+    Operation* op, PatternRewriter& rewriter) const {
+  auto tfl_concat_op = cast<TFL::ConcatenationOp>(op);
+
+  SmallVector<Value, 8> values(tfl_concat_op.values());
+
+  IntegerAttr axis_attr;
+  {
+    auto tmpAttr = tfl_concat_op.axisAttr();
+    if (!tmpAttr) {
+      tmpAttr = rewriter.getI64IntegerAttr(0);
+    }
+    axis_attr = tmpAttr;
+  }
+  int32_t axis = axis_attr.getInt();
+
+  llvm::Optional<Value> result =
+      convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+  return success();
+}
+
 LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
@@ -1885,7 +1927,8 @@
   auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
 
   llvm::Optional<Value> result = convertSoftmaxOp(
-      rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input());
+      rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input(),
+      tfl_softmax_op.betaAttr().getValueAsDouble());
 
   if (!result) return failure();
 
@@ -1978,6 +2021,31 @@
   return success();
 }
 
+LogicalResult ConvertTFLPackOp::matchAndRewrite(
+    Operation* op, PatternRewriter& rewriter) const {
+  auto tfl_pack_op = cast<TFL::PackOp>(op);
+
+  SmallVector<Value, 8> inputs(tfl_pack_op.values());
+  assert(inputs.size() >= 2);
+
+  IntegerAttr axis_attr;
+  {
+    auto tmpAttr = tfl_pack_op.axisAttr();
+    if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
+    axis_attr = tmpAttr;
+  }
+  int32_t axis_i32 = axis_attr.getInt();
+
+  llvm::Optional<Value> result =
+      convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
+
+  if (!result) return failure();
+
+  rewriter.replaceOp(op, {result.getValue()});
+
+  return success();
+}
+
 LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_unpack_op = cast<TFL::UnpackOp>(op);
@@ -2404,12 +2472,9 @@
 
   // TFL hardswish: f(x) -> (x * relu6(x+3))/6
 
-  // TODO: support 16-bit hardswish
   if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
       output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
-    // TFLite reference:
-    // tensorflow/lite/kernels/internal/reference/reference_ops.h note
-    // there's a potential rounding issue in TFLite reference
+    // TODO: match TFLite reference numerical behavior
     mlir::quant::UniformQuantizedType in_quant_type =
         input_type.getElementType()
             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
@@ -2428,54 +2493,51 @@
     RankedTensorType int32_type =
         RankedTensorType::get(input_shape, rewriter.getI32Type());
 
-    // Table's real input range [-4.0, 4.0].
-    // Use TABLE op to get relu6(x+3) / 6
-    const double input_sample_grain = 1.0 / 64.0;
-    auto hardswish_func = [input_sample_grain](int32_t x) -> int32_t {
-      double v = static_cast<double>(x) * input_sample_grain;
+    auto hardswish_func = [](double v) -> double {
       double w = v + 3.0;
       w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w;
-      v = v * w / 6.0;
-      return std::lround(32768.0 * v);
+      return v * w / 6.0;
     };
 
-    Value table_const = getTosa1DConstTensorTable(rewriter, op, hardswish_func);
+    if (in_quant_type.getStorageTypeIntegralWidth() == 8) {
+      Value table_const = getTosaConst8bitTable(
+          rewriter, op, in_quant_type.getScale(), in_quant_type.getZeroPoint(),
+          out_quant_type.getScale(), out_quant_type.getZeroPoint(),
+          hardswish_func);
 
-    // Rescale input to 9.7
-    Value op1_rescale_in =
-        buildRescale(rewriter, op, int16_type, tfl_hardswish_op.input(),
-                     (in_quant_type.getScale() * 128.0) / input_sample_grain,
-                     in_quant_type.getZeroPoint(), 0);
+      // Rescale input to 9.7 precision.
+      // No real rescaled other than left shift 7 bits
+      Value op1_rescale_in =
+          buildRescale(rewriter, op, int16_type, tfl_hardswish_op.input(),
+                       128.0, 0, 0, false, true);
 
-    // Table op. output 0.23
-    auto op2_table_op1 = rewriter.create<tosa::TableOp>(
-        op->getLoc(), int32_type, op1_rescale_in, table_const);
+      auto op2_table_op1 = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_type, op1_rescale_in, table_const);
 
-    // scale table output back to quantized space
-    Value op3_rescale_op2 =
-        buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
-                     1.0 / (128.0 * 32768.0 * out_quant_type.getScale()), 0,
-                     out_quant_type.getZeroPoint());
+      Value op3_rescale_op2 =
+          buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
+                       1.0 / 128.0, 0, 0, false, true);
 
-    Value op4_rescale_in = buildRescale(rewriter, op, int32_type,
-                                        tfl_hardswish_op.input(), 1.0, 0, 0);
+      rewriter.replaceOp(op, {op3_rescale_op2});
 
-    // Get 3.0 in quantized space
-    int32_t quantized_3 =
-        static_cast<int32_t>(std::ceil(3.0 / in_quant_type.getScale())) +
-        in_quant_type.getZeroPoint();
+    } else {  // int16
+      // Table valid input ranges [-256, 256], valid int16 ranges [-32768,
+      // 32767] To map [-256, 256] to [-32768, 32767], an extra 128.0 factor is
+      // passed with input scale
+      Value table_const = getTosaConst8bitTable(
+          rewriter, op, in_quant_type.getScale() * 128.0,
+          in_quant_type.getZeroPoint(), out_quant_type.getScale(),
+          out_quant_type.getZeroPoint(), hardswish_func);
 
-    auto op5_ge_op4 = rewriter.create<tosa::GreaterEqualOp>(
-        op->getLoc(), bool_type, op4_rescale_in,
-        getTosaConstTensorSingleI32(rewriter, op, quantized_3));
+      auto op1_table_in = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_type, tfl_hardswish_op.input(), table_const);
 
-    auto op6_select_op5_op4_op3 = rewriter.create<tosa::SelectOp>(
-        op->getLoc(), output_type, op5_ge_op4, tfl_hardswish_op.input(),
-        op3_rescale_op2);
+      Value op2_rescale_op1 =
+          buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
+                       1.0 / 128.0, 0, 0, false, true);
 
-    rewriter.replaceOp(op, {op6_select_op5_op4_op3});
-
-    return success();
+      rewriter.replaceOp(op, {op2_rescale_op1});
+    }
 
   } else {
     // op1 = constop(3)
@@ -2507,9 +2569,9 @@
         op5_reciprocal_6.getResult(), 0);
 
     rewriter.replaceOp(op, {op6_mul_op4_op5.getResult()});
-
-    return success();
   }
+
+  return success();
 }
 
 LogicalResult ConvertTFLLogisticOp::matchAndRewrite(
@@ -2548,34 +2610,52 @@
     mlir::quant::UniformQuantizedType output_qtype =
         output_type.getElementType()
             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
-    const double input_sample_grain = 1.0 / 16.0;
-    auto sigmoid_func = [input_sample_grain](int32_t x) -> int32_t {
-      // Input range [-16.0, 16.0], output range [0.0, 1.0]
-      double v = static_cast<double>(x) * input_sample_grain;
-      v = 1.0 / (1.0 + std::exp(-v));
 
-      return std::lround(32768.0 * v);
+    auto sigmoid_func = [](double x) -> double {
+      return 1.0 / (1.0 + std::exp(-x));
     };
 
-    Value table_const = getTosa1DConstTensorTable(rewriter, op, sigmoid_func);
+    if (input_qtype.getStorageTypeIntegralWidth() == 8) {
+      // Generate table with 16 bit entry, where in input/output's scale and zp
+      // are baked into the table generation. In 8-bit case, only 8-bit LSB out
+      // of a 16 bit entry is used. Reference:
+      // tensorflow/lite/kernels/activations.cc
+      Value table_const = getTosaConst8bitTable(
+          rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
+          output_qtype.getScale(), output_qtype.getZeroPoint(), sigmoid_func);
 
-    // Rescale input to 9.7 precision.
-    Value op1_rescale_in =
-        buildRescale(rewriter, op, int16_type, tfl_logistic_op.x(),
-                     (input_qtype.getScale() * 128.0) / input_sample_grain,
-                     input_qtype.getZeroPoint(), 0);
+      // Rescale input to 9.7 precision.
+      // No real rescaled other than left shift 7 bits
+      Value op1_rescale_in =
+          buildRescale(rewriter, op, int16_type, tfl_logistic_op.x(), 128.0, 0,
+                       0, false, true);
 
-    auto op2_table_op1 = rewriter.create<tosa::TableOp>(
-        op->getLoc(), int32_type, op1_rescale_in, table_const);
+      auto op2_table_op1 = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_type, op1_rescale_in, table_const);
 
-    double output_rescale_scale =
-        1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
+      Value op3_rescale_op2 =
+          buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
+                       1.0 / 128.0, 0, 0, false, true);
 
-    Value op3_rescale_op2 =
-        buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
-                     output_rescale_scale, 0, output_qtype.getZeroPoint());
+      rewriter.replaceOp(op, {op3_rescale_op2});
+    } else {  // int16
+      // Table valid input ranges [-256, 256], valid int16 ranges [-32768,
+      // 32767] To map [-256, 256] to [-32768, 32767], an extra 128.0 factor is
+      // passed with input scale
+      Value table_const = getTosaConst8bitTable(
+          rewriter, op, input_qtype.getScale() * 128.0,
+          input_qtype.getZeroPoint(), output_qtype.getScale(),
+          output_qtype.getZeroPoint(), sigmoid_func);
 
-    rewriter.replaceOp(op, {op3_rescale_op2});
+      auto op1_table_in = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_type, tfl_logistic_op.x(), table_const);
+
+      Value op2_rescale_op1 =
+          buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
+                       1.0 / 128.0, 0, 0, false, true);
+
+      rewriter.replaceOp(op, {op2_rescale_op1});
+    }
   } else {
     rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(op, output_type,
                                                  tfl_logistic_op.x());
@@ -2619,35 +2699,54 @@
     mlir::quant::UniformQuantizedType output_qtype =
         output_type.getElementType()
             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
-    const double input_sample_grain = 1.0 / 32.0;
-    auto tanh_func = [input_sample_grain](int32_t x) -> int32_t {
-      // Input range [-16.0, 16.0], output range [0.0, 1.0]
-      double v = static_cast<double>(x) * input_sample_grain;
-      v = std::exp(-2.0 * v);
-      v = (1.0 - v) / (1.0 + v);
 
-      return std::lround(32768.0 * v);
+    auto tanh_func = [](double x) -> double {
+      x = std::exp(-2.0 * x);
+      return (1.0 - x) / (1.0 + x);
     };
 
-    Value table_const = getTosa1DConstTensorTable(rewriter, op, tanh_func);
+    if (input_qtype.getStorageTypeIntegralWidth() == 8) {
+      // Generate table with 16 bit entry, where in input/output's scale and zp
+      // are baked into the table generation. In 8-bit case, only 8-bit LSB out
+      // of a 16 bit entry is used. Reference:
+      // tensorflow/lite/kernels/activations.cc
+      Value table_const = getTosaConst8bitTable(
+          rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
+          output_qtype.getScale(), output_qtype.getZeroPoint(), tanh_func);
 
-    // Rescale input to 9.7 precision.
-    Value op1_rescale_in =
-        buildRescale(rewriter, op, int16_type, tfl_tanh_op.input(),
-                     (input_qtype.getScale() * 128.0) / input_sample_grain,
-                     input_qtype.getZeroPoint(), 0);
+      // Rescale input to 9.7 precision.
+      // No real rescaled other than left shift 7 bits
+      Value op1_rescale_in =
+          buildRescale(rewriter, op, int16_type, tfl_tanh_op.input(), 128.0, 0,
+                       0, false, true);
 
-    auto op2_table_op1 = rewriter.create<tosa::TableOp>(
-        op->getLoc(), int32_type, op1_rescale_in, table_const);
+      auto op2_table_op1 = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_type, op1_rescale_in, table_const);
 
-    double output_rescale_scale =
-        1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
+      Value op3_rescale_op2 =
+          buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
+                       1.0 / 128.0, 0, 0, false, true);
 
-    Value op3_rescale_op2 =
-        buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
-                     output_rescale_scale, 0, output_qtype.getZeroPoint());
+      rewriter.replaceOp(op, {op3_rescale_op2});
+    } else {  // int16
+      // Table valid input ranges [-256, 256], valid int16 ranges [-32768,
+      // 32767] To map [-256, 256] to [-32768, 32767], an extra 128.0 factor is
+      // passed with input scale
+      Value table_const = getTosaConst8bitTable(
+          rewriter, op, input_qtype.getScale() * 128.0,
+          input_qtype.getZeroPoint(), output_qtype.getScale(),
+          output_qtype.getZeroPoint(), tanh_func);
 
-    rewriter.replaceOp(op, {op3_rescale_op2});
+      auto op1_table_in = rewriter.create<tosa::TableOp>(
+          op->getLoc(), int32_type, tfl_tanh_op.input(), table_const);
+
+      Value op2_rescale_op1 =
+          buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
+                       1.0 / 128.0, 0, 0, false, true);
+
+      rewriter.replaceOp(op, {op2_rescale_op1});
+    }
+
   } else {
     rewriter.replaceOpWithNewOp<tosa::TanhOp>(op, output_type,
                                               tfl_tanh_op.input());
@@ -2741,11 +2840,11 @@
 
     Value op3_rescale_alpha_in = buildRescale(
         rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_alpha,
-        input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true);
+        input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
 
     Value op4_rescale_identity_in = buildRescale(
         rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_identity,
-        input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true);
+        input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
 
     rewriter.replaceOpWithNewOp<tosa::SelectOp>(
         op, output_type, op2_ge, op4_rescale_identity_in, op3_rescale_alpha_in);
@@ -2864,9 +2963,10 @@
   if (input_element_type) {
     double rescale_scale =
         input_element_type.getScale() / element_type.getScale();
-    Value rescale_op = buildRescale(
-        rewriter, op, output_type, tfl_quantize_op.input(), rescale_scale,
-        input_element_type.getZeroPoint(), element_type.getZeroPoint(), true);
+    Value rescale_op =
+        buildRescale(rewriter, op, output_type, tfl_quantize_op.input(),
+                     rescale_scale, input_element_type.getZeroPoint(),
+                     element_type.getZeroPoint(), true, true);
 
     rewriter.replaceOp(op, {rescale_op});
     return success();
@@ -2934,6 +3034,33 @@
   return success();
 }
 
+LogicalResult ConvertConstantOp::matchAndRewrite(
+    Operation* op, PatternRewriter& rewriter) const {
+  auto tfl_const_op = cast<ConstantOp>(op);
+
+  RankedTensorType output_type =
+      tfl_const_op.getResult().getType().dyn_cast<RankedTensorType>();
+  // Not a ranked tensor output
+  if (!output_type) return failure();
+
+  ElementsAttr attr = tfl_const_op.valueAttr().dyn_cast<ElementsAttr>();
+
+  // TOSA only support up to 48-bits
+  // If source is higher than that, it's not representabble.
+  // For data type like 64 bits, we need to truncate them into 48 bits.
+  if (output_type.getElementType().isInteger(64)) {
+    Type new_element_type = rewriter.getIntegerType(48);
+    output_type =
+        RankedTensorType::get(output_type.getShape(), new_element_type);
+    attr = attr.mapValues(new_element_type,
+                          [](const APInt& x) -> APInt { return x.trunc(48); });
+  }
+
+  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, attr);
+
+  return success();
+}
+
 LogicalResult ConvertTFLGatherOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_gather_op = cast<TFL::GatherOp>(op);
@@ -3018,6 +3145,7 @@
   DEF_PATTERN_INSERT(TFLAddN);
   DEF_PATTERN_INSERT(TFLAveragePool2D);
   DEF_PATTERN_INSERT(TFLMaxPool2D);
+  DEF_PATTERN_INSERT(TFLConcatenation);
   DEF_PATTERN_INSERT(TFLReshape);
   DEF_PATTERN_INSERT(TFLRank);
   DEF_PATTERN_INSERT(TFLShape);
@@ -3039,6 +3167,7 @@
   DEF_PATTERN_INSERT(TFLFullyConnected);
   DEF_PATTERN_INSERT(TFLSplit);
   DEF_PATTERN_INSERT(TFLSplitV);
+  DEF_PATTERN_INSERT(TFLPack);
   DEF_PATTERN_INSERT(TFLUnpack);
   DEF_PATTERN_INSERT(TFLTranspose);
   DEF_PATTERN_INSERT(TFLTile);
@@ -3068,6 +3197,7 @@
   DEF_PATTERN_INSERT(TFLQuantize);
   DEF_PATTERN_INSERT(TFLDequantize);
   DEF_PATTERN_INSERT(TFLQConst);
+  DEF_PATTERN_INSERT(Constant);
   DEF_PATTERN_INSERT(TFLGather);
   DEF_PATTERN_INSERT(TFLGatherNd);
   DEF_PATTERN_INSERT(TFLOneHot);
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
index 24017eb..f037bff 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
@@ -28,19 +28,21 @@
 // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
 Value buildRescale(PatternRewriter& rewriter, Operation* op,
                    RankedTensorType output_type, Value input_val, double scale,
-                   int64_t input_zp, int64_t output_zp, bool double_round) {
+                   int64_t input_zp, int64_t output_zp, bool double_round,
+                   bool scale32) {
   int32_t multiplier;
   int32_t shift;
 
-  // We currently only support 32-bit quantized multiplier.
-  computeMultiplierAndShift(scale, multiplier, shift, 32);
+  int32_t scale_width = scale32 ? 32 : 16;
+
+  computeMultiplierAndShift(scale, multiplier, shift, scale_width);
 
   auto rescale_op = rewriter.create<tosa::RescaleOp>(
       op->getLoc(), output_type, input_val,
       rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
       rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
       rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
-      rewriter.getBoolAttr(true), rewriter.getBoolAttr(double_round),
+      rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
       rewriter.getBoolAttr(false));
 
   return rescale_op.getResult();
@@ -57,7 +59,7 @@
       RankedTensorType::get(input_type.getShape(), rewriter.getI32Type());
 
   return buildRescale(rewriter, op, output_type, input_val, input_scale,
-                      input_zp, 0, false);
+                      input_zp, 0, false, true);
 }
 
 // Creates TOSA rescale op with int32 input
@@ -72,7 +74,7 @@
 
   // Potentially check input_shape == output_shape here
   return buildRescale(rewriter, op, output_type, input_val, output_scale, 0,
-                      output_zp, true);
+                      output_zp, true, true);
 }
 
 // Creates a TOSA rescale op based on conv2d parameters.
@@ -90,6 +92,9 @@
   int64_t output_zp = output_qtype.getZeroPoint();
   double output_scale = output_qtype.getScale();
 
+  bool scale32 = isScale32(output_qtype);
+  int32_t scale_width = scale32 ? 32 : 16;
+
   if (auto weight_per_tensor_qtype =
           weight_type.getElementType()
               .dyn_cast<mlir::quant::UniformQuantizedType>()) {
@@ -101,14 +106,13 @@
 
     double op_tensor_scale = (input_scale * weight_scale) / output_scale;
 
-    // We currently only support 32-bit quantized multiplier.
-    computeMultiplierAndShift(op_tensor_scale, multiplier, shift, 32);
+    computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width);
 
     auto rescale_op = rewriter.create<tosa::RescaleOp>(
         op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0),
         rewriter.getI32IntegerAttr(output_zp),
         rewriter.getI32ArrayAttr({multiplier}),
-        rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(true),
+        rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
         rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
 
     return rescale_op.getResult();
@@ -138,8 +142,8 @@
 
       double op_channel_scale = (input_scale * weight_scale) / output_scale;
 
-      // We currently only support 32-bit quantized multiplier.
-      computeMultiplierAndShift(op_channel_scale, multiplier, shift, 32);
+      computeMultiplierAndShift(op_channel_scale, multiplier, shift,
+                                scale_width);
 
       multiplier_arr.push_back(multiplier);
       shift_arr.push_back(shift);
@@ -149,7 +153,7 @@
         op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0),
         rewriter.getI32IntegerAttr(output_zp),
         rewriter.getI32ArrayAttr(multiplier_arr),
-        rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(true),
+        rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
         rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
 
     return rescale_op.getResult();
@@ -160,17 +164,22 @@
   }
 }
 
-// Create a 513 entry TOSA constant tensor suitable for the Table operator based
-// on the values from an int32_t func(int32_t) lambda function.
-Value getTosa1DConstTensorTable(PatternRewriter& rewriter, Operation* op,
-                                std::function<int32_t(int32_t)> func) {
+// Create a 8-bit TOSA TABLE constant tensor
+// Follow PopulateLookupTable() tensorflow/lite/kernels/activations.cc
+Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op,
+                            double input_scale, int32_t input_zp,
+                            double output_scale, int32_t output_zp,
+                            std::function<double(double)> func) {
   llvm::SmallVector<int16_t, 4> table_vec;
 
+  // TODO: rewrite this with table[256]
   for (int32_t i = -256; i <= 256; i++) {
-    int32_t value = func(i);
-    // Table entry is int16_t; clamp to expressible range.
+    double dequantized = input_scale * (i - input_zp);
+    double transformed = func(dequantized);
+    int32_t rescaled = std::llround(transformed / output_scale);
+    int32_t quantized = static_cast<int32_t>(rescaled + output_zp);
     table_vec.push_back(
-        static_cast<int16_t>(std::min(std::max(value, -32768), 32767)));
+        static_cast<int16_t>(std::min(std::max(quantized, -32768), 32767)));
   }
 
   auto element_qtype =
@@ -187,6 +196,102 @@
   return const_op.getResult();
 }
 
+// Create a 16-bit TOSA TABLE constant tensor
+// Only used for 16-bit softmax now
+// Follow gen_lut() tensorflow/lite/kernels/internal/common.h
+Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op,
+                             std::function<double(double)> func, double min,
+                             double max) {
+  llvm::SmallVector<int16_t, 4> table_vec;
+
+  double step = (max - min) / 512.0f;
+  double half_step = step / 2.0f;
+  for (int32_t i = 0; i <= 512; i++) {
+    int32_t sample_val = std::llround(func(min + (i * step)) * 32768.0);
+    double midpoint_interp_val =
+        std::round(((func(min + (i + 1) * step) * 32768.0) +
+                    std::round(func(min + (i * step)) * 32768.0)) /
+                   2.0);
+    double midpoint_val =
+        std::round(func(min + (i * step) + half_step) * 32768.0);
+    double midpoint_err = midpoint_interp_val - midpoint_val;
+    int32_t bias = std::llround(midpoint_err / 2.0);
+
+    table_vec.push_back(static_cast<int16_t>(
+        std::min(std::max(sample_val - bias, -32768), 32767)));
+  }
+
+  int32_t max_val = std::llround(func(max) * 32768.0);
+  table_vec.push_back(
+      static_cast<int16_t>(std::min(std::max(max_val, -32768), 32767)));
+
+  auto element_qtype =
+      UniformQuantizedType::get(true, rewriter.getIntegerType(16),
+                                rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
+  auto const_type = RankedTensorType::get({513}, element_qtype);
+  auto storage_type =
+      RankedTensorType::get({513}, element_qtype.getStorageType());
+  auto const_attr = DenseElementsAttr::get(
+      storage_type, llvm::makeArrayRef<int16_t>(table_vec));
+
+  auto const_op =
+      rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
+  return const_op.getResult();
+}
+
+// Create a 32-bit TOSA TABLE constant tensor
+// Output is restricted to [-1.0, 1.0] as s0.31 format
+void getTosaConst32bitTable(PatternRewriter& rewriter, Operation* op,
+                            double input_scale, int32_t input_zp,
+                            std::function<double(double)> func,
+                            Value& upper_const, Value& lower_const) {
+  std::array<int16_t, 513> upper_table_array, lower_table_array;
+
+  double output_inv_scale = static_cast<double>(1L << 31);
+
+  for (int32_t i = -256; i <= 256; i++) {
+    double dequantized = input_scale * (i - input_zp);
+    double transformed = func(dequantized);
+    double truncated = std::min(std::max(transformed, -1.0), 1.0);
+    int64_t rescaled =
+        static_cast<int64_t>(std::round(truncated * output_inv_scale));
+
+    // 2^31 is not representable in int32_t, so store as 2^31 - 1 instead
+    if (rescaled == static_cast<int64_t>(1L << 31)) {
+      rescaled = static_cast<int64_t>(1L << 31) - 1;
+    }
+
+    int32_t upper = (rescaled >> 16) & 0xFFFF;
+    // TABLE output is signed 16 bits with range [-32768, 32767]
+    // Lower 16 bits are unsigned and ranges [0, 65536]
+    // Need to adjust value with offset 0x8000 in table generation
+    // Legalization should add this back before recovering 32-bit value
+    int32_t lower = (rescaled & 0xFFFF) - 0x8000;
+
+    upper_table_array[i + 256] = upper;
+    lower_table_array[i + 256] = lower;
+  }
+
+  auto element_qtype =
+      UniformQuantizedType::get(true, rewriter.getIntegerType(16),
+                                rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
+  auto const_type = RankedTensorType::get({513}, element_qtype);
+  auto storage_type =
+      RankedTensorType::get({513}, element_qtype.getStorageType());
+
+  auto upper_const_attr = DenseElementsAttr::get(
+      storage_type, llvm::makeArrayRef<int16_t>(upper_table_array));
+  auto lower_const_attr = DenseElementsAttr::get(
+      storage_type, llvm::makeArrayRef<int16_t>(lower_table_array));
+
+  upper_const =
+      rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, upper_const_attr)
+          .getResult();
+  lower_const =
+      rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, lower_const_attr)
+          .getResult();
+}
+
 // Create a 32-bit float constant operator from a float
 Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
                                   float val) {
@@ -387,11 +492,10 @@
 // Same as get1DConstTensor, but int48 is not native c++ type, needs additional
 // interface
 Value get1DConstTensorInt48(PatternRewriter& rewriter, Operation* op,
-                            SmallVector<int64_t, 8> arr) {
+                            ArrayRef<APInt>& arr) {
   auto const_type = RankedTensorType::get({static_cast<int32_t>(arr.size())},
                                           rewriter.getIntegerType(48));
-  auto const_attr =
-      DenseElementsAttr::get(const_type, llvm::makeArrayRef<int64_t>(arr));
+  auto const_attr = DenseElementsAttr::get(const_type, arr);
 
   auto const_op =
       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
@@ -435,5 +539,13 @@
   return input;
 }
 
+// Check if scale32 mode is used for given output_element_type
+bool isScale32(mlir::quant::UniformQuantizedType output_element_type) {
+  if (output_element_type.getStorageTypeIntegralWidth() == 8)
+    return true;
+  else
+    return false;
+}
+
 }  // namespace tosa
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h
index f18e573..61121ac 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h
@@ -22,6 +22,7 @@
 #include <iterator>
 #include <numeric>
 
+#include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
@@ -37,8 +38,8 @@
 // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
 Value buildRescale(PatternRewriter& rewriter, Operation* op,
                    RankedTensorType output_type, Value input_val, double scale,
-                   int64_t input_zp, int64_t output_zp,
-                   bool double_round = false);
+                   int64_t input_zp, int64_t output_zp, bool double_round,
+                   bool scale32);
 
 // Creates TOSA rescale op with int32 output
 Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op,
@@ -56,10 +57,23 @@
                                RankedTensorType weight_type,
                                RankedTensorType output_type);
 
-// Create a 513 entry TOSA constant tensor suitable for the Table operator based
-// on the values from an int32_t func(int32_t) lambda function.
-Value getTosa1DConstTensorTable(PatternRewriter& rewriter, Operation* op,
-                                std::function<int32_t(int32_t)> func);
+// Create a 8-bit TOSA TABLE constant tensor
+Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op,
+                            double input_scale, int32_t input_zp,
+                            double output_scale, int32_t output_zp,
+                            std::function<double(double)> func);
+
+// Create a 16-bit TOSA TABLE constant tensor
+Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op,
+                             std::function<double(double)> func, double min,
+                             double max);
+
+// Create a 32-bit TOSA TABLE constant tensor
+// Output is restricted to [-1.0, 1.0] as s0.31 format
+void getTosaConst32bitTable(PatternRewriter& rewriter, Operation* op,
+                            double input_scale, int32_t input_zp,
+                            std::function<double(double)> func,
+                            Value& upper_const, Value& lower_const);
 
 // Create a 32-bit float constant operator from a float
 Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
@@ -106,12 +120,15 @@
 // Same as get1DConstTensor, but int48 is not native c++ type, needs additional
 // interface
 Value get1DConstTensorInt48(PatternRewriter& rewriter, Operation* op,
-                            SmallVector<int64_t, 8> arr);
+                            ArrayRef<APInt>& arr);
 
 // Strip off quantization information for bias tensor and return a unquantized
 // bias
 Value getUnquantizedBias(PatternRewriter& rewriter, Operation* op, Value input);
 
+// Check if scale32 mode is used for given output_element_type
+bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
+
 }  // namespace tosa
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
index 42a2b7a..ec856d1 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
@@ -20,9 +20,6 @@
 include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
 include "mlir/Dialect/Tosa/IR/TosaOps.td"
 
-// Nullary ops patterns.
-def : Pat<(ConstantOp ElementsAttr:$value), (Tosa_ConstOp $value)>;
-
 // Unary ops patterns.
 def : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>;
 def : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>;