[NNAPI] Handle binary ops combining NHWC+NCHW in some cases (#48812)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48812
This came up in a squeeze-and-excitation model. Starting with an NHWC
tensor T, we perform a mean operation across H and W, giving an NxC
tensor, which (after some fully connected layers) is reshaped to
NxCx1x1, then multiplied with T. To handle this, we detect the specific
case of a binary op with one NHWC input and one contiguous input with
H,W == 1,1 and allow the op to be applied (after transposing the
contiguous input).
Test Plan: Unit test.
Reviewed By: axitkhurana
Differential Revision: D25317939
Pulled By: dreiss
fbshipit-source-id: b4c17ab3b874d1a7defa04664010ba82115f1c20
diff --git a/test/test_nnapi.py b/test/test_nnapi.py
index 2c22bdf..86d1133 100644
--- a/test/test_nnapi.py
+++ b/test/test_nnapi.py
@@ -331,6 +331,18 @@
inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
self.check(mod, inp)
+ def test_seblock_mul(self):
+ class MulModel(torch.nn.Module):
+ def forward(self, lhs, rhs):
+ return lhs * rhs
+
+ self.check(
+ MulModel(),
+ [
+ nhwc(torch.randn(2, 3, 4, 4)),
+ torch.randn(1, 3, 1, 1),
+ ])
+
def test_multi_output(self):
class MultiModel(torch.nn.Module):
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py
index 87604ee..a5d6211 100644
--- a/torch/backends/_nnapi/serializer.py
+++ b/torch/backends/_nnapi/serializer.py
@@ -317,6 +317,9 @@
if config is None:
config = {}
+ # Add a tensor operand corresponding to a JIT Value.
+ # Returns the NNAPI operand ID. Can be looked up later with
+ # get_tensor_operand_by_jitval.
def add_tensor_operand(self, jitval, oper):
assert isinstance(oper, Operand)
if jitval in self.jitval_operand_map:
@@ -327,6 +330,15 @@
self.jitval_operand_map[jitval] = operand_id
return operand_id
+ # Add a tensor operand that does not correspond to a JIT Value.
+ # Useful for cases where multiple NNAPI operands are required
+ # to implement one JIT IR node. Returns the NNAPI operand ID.
+ def add_anonymous_tensor_operand(self, oper):
+ assert isinstance(oper, Operand)
+ operand_id = len(self.operands)
+ self.operands.append(oper)
+ return operand_id
+
@staticmethod
def torch_tensor_to_operand(tensor, dim_order):
dtype = str(tensor.dtype).replace("torch.", "")
@@ -452,6 +464,39 @@
f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'")
return record
+ def transpose_to_nhwc(self, in_id, oper):
+ if oper.shape[2:] != (1, 1):
+ raise Exception("Automatic transpose only supported for H,W == 1,1")
+
+ out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
+
+ inputs = [None] * 2
+ inputs[0] = in_id
+ inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
+
+ outputs = [None] * 1
+ outputs[0] = self.add_anonymous_tensor_operand(out_oper)
+
+ self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
+
+ return outputs[0], out_oper
+
+ # Transpose inputs as necessary to allow broadcasting.
+ def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
+ if in0_oper.dim_order == in1_oper.dim_order:
+ return in0_id, in0_oper, in1_id, in1_oper
+
+ # Assume NHWC is preferred if there is a mismatch.
+ orders = (in0_oper.dim_order, in1_oper.dim_order)
+ if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
+ return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
+ if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
+ return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
+
+ raise Exception(
+ "Automatic transpose not supported for dim_orders: %r, %r" %
+ (in0_oper.dim_order, in1_oper.dim_order))
+
def get_size_arg(self, jitval):
ctype, value = self.get_constant_value(jitval)
if ctype.kind() == "ListType":
@@ -930,7 +975,8 @@
in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
assert in0_oper.op_type == in1_oper.op_type
- assert in0_oper.dim_order == in1_oper.dim_order
+ in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
+ in0_id, in0_oper, in1_id, in1_oper)
# NOTE: PyTorch and NNAPI have the same broadcast semantics.
out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
out_oper = in0_oper._replace(shape=out_shape)