[NNAPI] Handle binary ops combining NHWC+NCHW in some cases (#48812)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48812

This came up in a squeeze-and-excitation model.  Starting with an NHWC
tensor T, we perform a mean operation across H and W, giving an NxC
tensor, which (after some fully connected layers) is reshaped to
NxCx1x1, then multiplied with T.  To handle this, we detect the specific
case of a binary op with one NHWC input and one contiguous input with
H,W == 1,1 and allow the op to be applied (after transposing the
contiguous input).

Test Plan: Unit test.

Reviewed By: axitkhurana

Differential Revision: D25317939

Pulled By: dreiss

fbshipit-source-id: b4c17ab3b874d1a7defa04664010ba82115f1c20
diff --git a/test/test_nnapi.py b/test/test_nnapi.py
index 2c22bdf..86d1133 100644
--- a/test/test_nnapi.py
+++ b/test/test_nnapi.py
@@ -331,6 +331,18 @@
         inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
         self.check(mod, inp)
 
+    def test_seblock_mul(self):
+        class MulModel(torch.nn.Module):
+            def forward(self, lhs, rhs):
+                return lhs * rhs
+
+        self.check(
+            MulModel(),
+            [
+                nhwc(torch.randn(2, 3, 4, 4)),
+                torch.randn(1, 3, 1, 1),
+            ])
+
     def test_multi_output(self):
         class MultiModel(torch.nn.Module):
             def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py
index 87604ee..a5d6211 100644
--- a/torch/backends/_nnapi/serializer.py
+++ b/torch/backends/_nnapi/serializer.py
@@ -317,6 +317,9 @@
         if config is None:
             config = {}
 
+    # Add a tensor operand corresponding to a JIT Value.
+    # Returns the NNAPI operand ID.  Can be looked up later with
+    # get_tensor_operand_by_jitval.
     def add_tensor_operand(self, jitval, oper):
         assert isinstance(oper, Operand)
         if jitval in self.jitval_operand_map:
@@ -327,6 +330,15 @@
         self.jitval_operand_map[jitval] = operand_id
         return operand_id
 
+    # Add a tensor operand that does not correspond to a JIT Value.
+    # Useful for cases where multiple NNAPI operands are required
+    # to implement one JIT IR node.  Returns the NNAPI operand ID.
+    def add_anonymous_tensor_operand(self, oper):
+        assert isinstance(oper, Operand)
+        operand_id = len(self.operands)
+        self.operands.append(oper)
+        return operand_id
+
     @staticmethod
     def torch_tensor_to_operand(tensor, dim_order):
         dtype = str(tensor.dtype).replace("torch.", "")
@@ -452,6 +464,39 @@
                 f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'")
         return record
 
+    def transpose_to_nhwc(self, in_id, oper):
+        if oper.shape[2:] != (1, 1):
+            raise Exception("Automatic transpose only supported for H,W == 1,1")
+
+        out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
+
+        outputs = [None] * 1
+        outputs[0] = self.add_anonymous_tensor_operand(out_oper)
+
+        self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
+
+        return outputs[0], out_oper
+
+    # Transpose inputs as necessary to allow broadcasting.
+    def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
+        if in0_oper.dim_order == in1_oper.dim_order:
+            return in0_id, in0_oper, in1_id, in1_oper
+
+        # Assume NHWC is preferred if there is a mismatch.
+        orders = (in0_oper.dim_order, in1_oper.dim_order)
+        if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
+            return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
+        if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
+            return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
+
+        raise Exception(
+            "Automatic transpose not supported for dim_orders: %r, %r" %
+            (in0_oper.dim_order, in1_oper.dim_order))
+
     def get_size_arg(self, jitval):
         ctype, value = self.get_constant_value(jitval)
         if ctype.kind() == "ListType":
@@ -930,7 +975,8 @@
         in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
 
         assert in0_oper.op_type == in1_oper.op_type
-        assert in0_oper.dim_order == in1_oper.dim_order
+        in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
+            in0_id, in0_oper, in1_id, in1_oper)
         # NOTE: PyTorch and NNAPI have the same broadcast semantics.
         out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
         out_oper = in0_oper._replace(shape=out_shape)