| from typing import Optional, List |
| |
| import torch |
| from torch.backends._nnapi.serializer import serialize_model |
| |
| class NnapiModule(torch.nn.Module): |
| """Torch Module that wraps an NNAPI Compilation. |
| |
| This module handles preparing the weights, initializing the |
| NNAPI TorchBind object, and adjusting the memory formats |
| of all inputs and outputs. |
| """ |
| |
| comp: Optional[torch.classes._nnapi.Compilation] |
| |
| def __init__( |
| self, |
| ser_model: torch.Tensor, |
| weights: List[torch.Tensor], |
| inp_mem_fmts: List[int], |
| out_mem_fmts: List[int], |
| out_templates: List[torch.Tensor]): |
| super().__init__() |
| self.ser_model = ser_model |
| self.weights = weights |
| self.inp_mem_fmts = inp_mem_fmts |
| self.out_mem_fmts = out_mem_fmts |
| self.out_templates = out_templates |
| self.comp = None |
| |
| @torch.jit.export |
| def init(self): |
| assert self.comp is None |
| self.weights = [w.contiguous() for w in self.weights] |
| comp = torch.classes._nnapi.Compilation() |
| comp.init(self.ser_model, self.weights) |
| self.comp = comp |
| |
| def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]: |
| comp = self.comp |
| assert comp is not None |
| outs = [torch.empty_like(out) for out in self.out_templates] |
| |
| assert len(args) == len(self.inp_mem_fmts) |
| fixed_args = [] |
| for idx in range(len(args)): |
| fmt = self.inp_mem_fmts[idx] |
| # These constants match the values in DimOrder in serializer.py |
| # TODO: See if it's possible to use those directly. |
| if fmt == 0: |
| fixed_args.append(args[idx].contiguous()) |
| elif fmt == 1: |
| fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous()) |
| else: |
| raise Exception("Invalid mem_fmt") |
| comp.run(fixed_args, outs) |
| assert len(outs) == len(self.out_mem_fmts) |
| for idx in range(len(self.out_templates)): |
| fmt = self.out_mem_fmts[idx] |
| # These constants match the values in DimOrder in serializer.py |
| # TODO: See if it's possible to use those directly. |
| if fmt == 0: |
| pass |
| elif fmt == 1: |
| outs[idx] = outs[idx].permute(0, 3, 1, 2) |
| else: |
| raise Exception("Invalid mem_fmt") |
| return outs |
| |
| |
| class NnapiInitWrapper(torch.nn.Module): |
| """Wrapper module to ensure NNAPI init is called.""" |
| def __init__(self, nnapi_module): |
| super().__init__() |
| self.nnapi_module = nnapi_module |
| |
| def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]: |
| return self.nnapi_module(args) |
| |
| @torch.jit.export |
| def __getstate__(self): |
| return self.nnapi_module |
| |
| @torch.jit.export |
| def __setstate__(self, nnapi_module): |
| self.training = False |
| self.nnapi_module = nnapi_module |
| self.nnapi_module.init() |
| |
| |
| class ListWrapper(torch.nn.Module): |
| """NNAPI list-ifying wrapper. |
| |
| NNAPI always expects a list of inputs. This module provides a |
| single-tensor input interface for models that want it. |
| """ |
| def __init__(self, mod): |
| super().__init__() |
| self.mod = mod |
| |
| def forward(self, t: torch.Tensor) -> List[torch.Tensor]: |
| return self.mod([t]) |
| |
| class DelistWrapper(torch.nn.Module): |
| """NNAPI de-list-ifying wrapper. |
| |
| NNAPI always provides a list of outputs. This module provides a |
| single-tensor output interface for models that want it. |
| """ |
| def __init__(self, mod): |
| super().__init__() |
| self.mod = mod |
| |
| def forward(self, ts: List[torch.Tensor]) -> torch.Tensor: |
| outs = self.mod(ts) |
| assert len(outs) == 1 |
| return outs[0] |
| |
| class ListDelistWrapper(torch.nn.Module): |
| """NNAPI list-ifying and de-list-ifying wrapper. |
| |
| NNAPI always expects a list of inputs and provides a list of outputs. |
| This module provides a single-tensor input/output interface |
| for models that want it. |
| """ |
| def __init__(self, mod): |
| super().__init__() |
| self.mod = mod |
| |
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| outs = self.mod([t]) |
| assert len(outs) == 1 |
| return outs[0] |
| |
| |
| def _condensed_zeros_like(t): |
| """Get a small-storage deterministic tensor with the same shape and dtype as t |
| |
| Similar to `torch.zeros(1, dtype=out.dtype).expand(out.shape)`, |
| but this works with quantized dtypes as well. |
| |
| Similar to `torch.empty(1, dtype=out.dtype).expand(out.shape)`, |
| but always returns the same data. |
| """ |
| |
| ret = torch.empty_like(t).flatten()[1].clone().expand(t.shape) |
| assert ret.storage().size() == 1 |
| ret.storage()[0] = 0 |
| return ret |
| |
| |
| def convert_model_to_nnapi(model, inputs): |
| model = torch.jit.freeze(model) |
| |
| if isinstance(inputs, torch.Tensor): |
| inputs = [inputs] |
| list_inputs = True |
| else: |
| list_inputs = False |
| |
| outputs = model(*inputs) |
| |
| if isinstance(outputs, torch.Tensor): |
| outputs = [outputs] |
| delist_outputs = True |
| else: |
| delist_outputs = False |
| |
| ser_model, used_weights, inp_mem_fmts, out_mem_fmts = serialize_model(model, inputs) |
| ser_model_tensor = torch.tensor(list(ser_model), dtype=torch.uint8) |
| |
| out_templates = [_condensed_zeros_like(out) for out in outputs] |
| nnapi_model = NnapiInitWrapper(NnapiModule( |
| ser_model_tensor, |
| used_weights, |
| inp_mem_fmts, |
| out_mem_fmts, |
| out_templates)) |
| |
| if list_inputs and delist_outputs: |
| nnapi_model = ListDelistWrapper(nnapi_model) |
| elif list_inputs: |
| nnapi_model = ListWrapper(nnapi_model) |
| elif delist_outputs: |
| nnapi_model = DelistWrapper(nnapi_model) |
| |
| return torch.jit.script(nnapi_model) |