Added CMSIS-NN specialization for int8 depthwise conv op.

Change-Id: Icc8b933363677eca7cc444078ff15721c734ca8f
diff --git a/tensorflow/lite/experimental/micro/kernels/cmsis-nn/depthwise_conv.cc b/tensorflow/lite/experimental/micro/kernels/cmsis-nn/depthwise_conv.cc
index 79b6120..948d672 100644
--- a/tensorflow/lite/experimental/micro/kernels/cmsis-nn/depthwise_conv.cc
+++ b/tensorflow/lite/experimental/micro/kernels/cmsis-nn/depthwise_conv.cc
@@ -20,9 +20,11 @@
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/padding.h"
+#include "tensorflow/lite/experimental/micro/kernels/cmsis-nn/scratch_buffer.h"
 
 namespace tflite {
 namespace ops {
@@ -34,6 +36,7 @@
 constexpr int kFilterTensor = 1;
 constexpr int kBiasTensor = 2;
 constexpr int kOutputTensor = 0;
+constexpr int kMaxChannels = 256;
 
 struct OpData {
   TfLitePaddingValues padding;
@@ -41,6 +44,12 @@
   // be represented as a fixed point multiplier plus a left shift.
   int32_t output_multiplier;
   int output_shift;
+
+  // Per channel output multiplier and shift.
+  // TODO(b/141139247): Allocate these dynamically when possible.
+  int32_t per_channel_output_multiplier[kMaxChannels];
+  int32_t per_channel_output_shift[kMaxChannels];
+
   // The range of the fused activation layer. For example for kNone and
   // uint8_t these would be 0 and 255.
   int32_t output_activation_min;
@@ -50,12 +59,17 @@
 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
                              TfLiteDepthwiseConvParams* params, int width,
                              int height, int filter_width, int filter_height,
-                             int out_width, int out_height,
                              const TfLiteType data_type, OpData* data) {
-  data->padding.height = ComputePadding(params->stride_height, 1, height,
-                                        filter_height, out_height);
-  data->padding.width =
-      ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+  bool has_bias = node->inputs->size == 3;
+  // Check number of inputs/outputs
+  TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
+  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+  int unused_output_height, unused_output_width;
+  data->padding = ComputePaddingHeightWidth(
+      params->stride_height, params->stride_width, 1, 1, height, width,
+      filter_height, filter_width, params->padding, &unused_output_height,
+      &unused_output_width);
 
   // Note that quantized inference requires that all tensors have their
   // parameters set. This is usually done during quantized training.
@@ -66,15 +80,12 @@
         GetOptionalInputTensor(context, node, kBiasTensor);
     TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
 
-    double real_multiplier = 0.0;
-    TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
-        context, input, filter, bias, output, &real_multiplier));
-    int exponent;
-    QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
-    data->output_shift = -exponent;
-    CalculateActivationRangeUint8(params->activation, output,
-                                  &data->output_activation_min,
-                                  &data->output_activation_max);
+    TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
+        context, input, filter, bias, output, params->activation,
+        &data->output_multiplier, &data->output_shift,
+        &data->output_activation_min, &data->output_activation_max,
+        data->per_channel_output_multiplier,
+        reinterpret_cast<int*>(data->per_channel_output_shift)));
   }
   return kTfLiteOk;
 }
@@ -91,10 +102,10 @@
   return kTfLiteOk;
 }
 
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
-               TfLiteDepthwiseConvParams* params, OpData* data,
-               const TfLiteTensor* input, const TfLiteTensor* filter,
-               const TfLiteTensor* bias, TfLiteTensor* output) {
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+                       TfLiteDepthwiseConvParams* params, OpData* data,
+                       const TfLiteTensor* input, const TfLiteTensor* filter,
+                       const TfLiteTensor* bias, TfLiteTensor* output) {
   float output_activation_min, output_activation_max;
   CalculateActivationRange(params->activation, &output_activation_min,
                            &output_activation_max);
@@ -117,12 +128,113 @@
       GetTensorShape(filter), GetTensorData<float>(filter),
       GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
       GetTensorData<float>(output));
+  return kTfLiteOk;
 }
 
-void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
-                   TfLiteDepthwiseConvParams* params, OpData* data,
-                   const TfLiteTensor* input, const TfLiteTensor* filter,
-                   const TfLiteTensor* bias, TfLiteTensor* output) {
+TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
+                             TfLiteDepthwiseConvParams* params, OpData* data,
+                             const TfLiteTensor* input,
+                             const TfLiteTensor* filter,
+                             const TfLiteTensor* bias, TfLiteTensor* output) {
+#if defined(ARM_MATH_DSP) && defined(ARM_MATH_LOOPUNROLL)
+    DepthwiseParams op_params;
+    op_params.padding_type = PaddingType::kSame;
+    op_params.padding_values.width = data->padding.width;
+    op_params.padding_values.height = data->padding.height;
+    op_params.stride_width = params->stride_width;
+    op_params.stride_height = params->stride_height;
+    op_params.dilation_width_factor = params->dilation_width_factor;
+    op_params.dilation_height_factor = params->dilation_height_factor;
+    op_params.depth_multiplier = params->depth_multiplier;
+    op_params.input_offset = -input->params.zero_point;
+    op_params.weights_offset = 0;
+    op_params.output_offset = output->params.zero_point;
+    // TODO(b/130439627): Use calculated value for clamping.
+    op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
+    op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
+    RuntimeShape filter_shape = GetTensorShape(filter);
+    const int filter_height = filter_shape.Dims(1);
+    const int filter_width = filter_shape.Dims(2);
+    RuntimeShape input_shape = GetTensorShape(input);
+    const int input_height = input_shape.Dims(1);
+    const int input_width = input_shape.Dims(2);
+    const int input_depth = input_shape.Dims(3);
+    RuntimeShape output_shape = GetTensorShape(output);
+    const int output_height = output_shape.Dims(1);
+    const int output_width = output_shape.Dims(2);
+    RuntimeShape bias_shape = GetTensorShape(bias);
+
+    if (op_params.depth_multiplier == 1) {
+      int16_t* buf = nullptr;
+      const int32_t buf_size =
+        arm_depthwise_conv_s8_opt_get_buffer_size(input_depth,
+                                                  filter_width,
+                                                  filter_height);
+      TF_LITE_ENSURE_OK(context,
+                        get_cmsis_scratch_buffer(context, &buf, buf_size));
+      TF_LITE_ENSURE_EQ(context,
+                        arm_depthwise_conv_s8_opt(
+                          GetTensorData<int8_t>(input),
+                          input_width, input_height, input_depth,
+                          GetTensorData<int8_t>(filter),
+                          input_depth,
+                          filter_width, filter_height,
+                          op_params.padding_values.width,
+                          op_params.padding_values.height,
+                          op_params.stride_width,
+                          op_params.stride_height,
+                          GetTensorData<int32>(bias),
+                          GetTensorData<int8_t>(output),
+                          data->per_channel_output_shift,
+                          data->per_channel_output_multiplier,
+                          output_width,
+                          output_height,
+                          op_params.output_offset,
+                          op_params.input_offset,
+                          op_params.quantized_activation_min,
+                          op_params.quantized_activation_max,
+                          op_params.dilation_width_factor,
+                          op_params.dilation_height_factor,
+                          buf),
+                        ARM_MATH_SUCCESS);
+    } else {
+      TF_LITE_ENSURE_EQ(context,
+                        arm_depthwise_conv_s8(
+                          GetTensorData<int8_t>(input),
+                          input_width, input_height, input_depth,
+                          GetTensorData<int8_t>(filter),
+                          op_params.depth_multiplier * input_depth,
+                          op_params.depth_multiplier,
+                          filter_width, filter_height,
+                          op_params.padding_values.width,
+                          op_params.padding_values.height,
+                          op_params.stride_width,
+                          op_params.stride_height,
+                          GetTensorData<int32>(bias),
+                          GetTensorData<int8_t>(output),
+                          data->per_channel_output_shift,
+                          data->per_channel_output_multiplier,
+                          output_width,
+                          output_height,
+                          op_params.output_offset,
+                          op_params.input_offset,
+                          op_params.quantized_activation_min,
+                          op_params.quantized_activation_max,
+                          op_params.dilation_width_factor,
+                          op_params.dilation_height_factor,
+                          nullptr),
+                        ARM_MATH_SUCCESS);
+    }
+#else
+  #error ARM_MATH_DSP and ARM_MATH_LOOPUNROLL must be set
+#endif
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+                          TfLiteDepthwiseConvParams* params, OpData* data,
+                          const TfLiteTensor* input, const TfLiteTensor* filter,
+                          const TfLiteTensor* bias, TfLiteTensor* output) {
   const int32_t input_offset = -input->params.zero_point;
   const int32_t filter_offset = -filter->params.zero_point;
   const int32_t output_offset = output->params.zero_point;
@@ -181,6 +293,7 @@
         GetTensorShape(bias), GetTensorData<int32_t>(bias),
         GetTensorShape(output), GetTensorData<uint8_t>(output));
   }
+  return kTfLiteOk;
 }
 
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -198,28 +311,42 @@
   int height = SizeOfDimension(input, 1);
   int filter_width = SizeOfDimension(filter, 2);
   int filter_height = SizeOfDimension(filter, 1);
-  int out_width = ComputeOutSize(params->padding, width, filter_width,
-                                 params->stride_width);
-  int out_height = ComputeOutSize(params->padding, height, filter_height,
-                                  params->stride_height);
-  OpData local_data_object;
-  OpData* data = &local_data_object;
+
+  OpData data;
+
+  if (input->type != kTfLiteFloat32) {
+    TF_LITE_ENSURE_EQ(context, filter->quantization.type,
+                      kTfLiteAffineQuantization);
+
+    const auto* affine_quantization =
+        reinterpret_cast<TfLiteAffineQuantization*>(
+            filter->quantization.params);
+    TF_LITE_ENSURE(context, affine_quantization);
+    TF_LITE_ENSURE(context, affine_quantization->scale);
+  }
+
   TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
-                                        filter_width, filter_height, out_width,
-                                        out_height, data_type, data));
+                                        filter_width, filter_height, data_type,
+                                        &data));
 
   // TODO(aselle): Consider whether float conv and quantized conv should be
   // separate ops to avoid dispatch overhead here.
   switch (input->type) {  // Already know in/out types are same.
     case kTfLiteFloat32:
-      EvalFloat(context, node, params, data, input, filter, bias, output);
+      return EvalFloat(context, node, params, &data, input, filter, bias,
+                       output);
+      break;
+    case kTfLiteInt8:
+      return EvalQuantizedPerChannel(context, node, params, &data, input,
+                                     filter, bias, output);
       break;
     case kTfLiteUInt8:
-      EvalQuantized(context, node, params, data, input, filter, bias, output);
+      return EvalQuantized(context, node, params, &data, input, filter, bias,
+                           output);
       break;
     default:
-      context->ReportError(context, "Type %d not currently supported.",
-                           input->type);
+      context->ReportError(context, "Type %s (%d) not supported.",
+                           TfLiteTypeGetName(input->type), input->type);
       return kTfLiteError;
   }
   return kTfLiteOk;
diff --git a/tensorflow/lite/experimental/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/experimental/micro/kernels/cmsis-nn/fully_connected.cc
index 5c8b371..d8dda97 100644
--- a/tensorflow/lite/experimental/micro/kernels/cmsis-nn/fully_connected.cc
+++ b/tensorflow/lite/experimental/micro/kernels/cmsis-nn/fully_connected.cc
@@ -89,7 +89,6 @@
                                const TfLiteTensor* input,
                                const TfLiteTensor* filter,
                                const TfLiteTensor* bias, TfLiteTensor* output) {
-  TfLiteStatus status = kTfLiteOk;
   RuntimeShape output_shape = GetTensorShape(output);
   const int batches = output_shape.Dims(0);
   const int output_depth = output_shape.Dims(1);
@@ -100,18 +99,23 @@
 #if defined(ARM_MATH_DSP) && defined(ARM_MATH_LOOPUNROLL)
   const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(accum_depth);
   int16_t* buf = nullptr;
-  status = get_cmsis_scratch_buffer(context, &buf, buf_size);
-  arm_fully_connected_s8(
-      GetTensorData<int8_t>(input), GetTensorData<int8_t>(filter), accum_depth,
-      output_depth, batches, -input->params.zero_point,
-      -filter->params.zero_point, data->output_multiplier, -data->output_shift,
-      output->params.zero_point, GetTensorData<int32_t>(bias),
-      GetTensorData<int8_t>(output), data->output_activation_min,
-      data->output_activation_max, buf);
+  TF_LITE_ENSURE_OK(context,
+                    get_cmsis_scratch_buffer(context, &buf, buf_size));
+  TF_LITE_ENSURE_EQ(context,
+                    arm_fully_connected_s8(
+                      GetTensorData<int8_t>(input),
+                      GetTensorData<int8_t>(filter),
+                      accum_depth, output_depth, batches,
+                      -input->params.zero_point, -filter->params.zero_point,
+                      data->output_multiplier, -data->output_shift,
+                      output->params.zero_point, GetTensorData<int32_t>(bias),
+                      GetTensorData<int8_t>(output), data->output_activation_min,
+                      data->output_activation_max, buf),
+                    ARM_MATH_SUCCESS);
 #else
 #error ARM_MATH_DSP and ARM_MATH_LOOPUNROLL must be set
 #endif
-  return status;
+  return kTfLiteOk;
 }
 
 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
diff --git a/tensorflow/lite/experimental/micro/kernels/cmsis-nn/scratch_buffer.cc b/tensorflow/lite/experimental/micro/kernels/cmsis-nn/scratch_buffer.cc
index b41420e..a9bc82d 100644
--- a/tensorflow/lite/experimental/micro/kernels/cmsis-nn/scratch_buffer.cc
+++ b/tensorflow/lite/experimental/micro/kernels/cmsis-nn/scratch_buffer.cc
@@ -19,17 +19,18 @@
 // implemented.
 
 // This buffer is used by CMSIS-NN optimized operator implementations.
-// SCRATCH_BUFFER_BYTES bytes is chosenn empirically. It needs to be large
+// SCRATCH_BUFFER_BYTES bytes is chosen empirically. It needs to be large
 // enough to hold the biggest buffer needed by all CMSIS-NN operators in the
 // network.
-#define SCRATCH_BUFFER_BYTES 6000
+#define SCRATCH_BUFFER_BYTES 13000
 
-__attribute__((aligned(
-    4))) static int16_t cmsis_scratch_buffer[SCRATCH_BUFFER_BYTES / 2] = {0};
+__attribute__((aligned(4))) static int16_t
+  cmsis_scratch_buffer[SCRATCH_BUFFER_BYTES/2] = {0};
 
 TfLiteStatus get_cmsis_scratch_buffer(TfLiteContext* context, int16_t** buf,
-                                      int32_t buf_size) {
-  TF_LITE_ENSURE(context, buf_size <= SCRATCH_BUFFER_BYTES / 2);
+                                      int32_t buf_size_bytes)
+{
+  TF_LITE_ENSURE(context, buf_size_bytes <= SCRATCH_BUFFER_BYTES);
   *buf = cmsis_scratch_buffer;
   return kTfLiteOk;
 }
\ No newline at end of file