David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
Jane Xu | a4a6d05 | 2021-11-05 10:51:35 -0700 | [diff] [blame] | 2 | # Owner(s): ["oncall: mobile"] |
Jane Xu | 6259601 | 2021-10-29 12:40:39 -0700 | [diff] [blame] | 3 | |
Zsolt Dollenstein | b004307 | 2021-08-12 10:56:55 -0700 | [diff] [blame] | 4 | import os |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 5 | import ctypes |
Zsolt Dollenstein | b004307 | 2021-08-12 10:56:55 -0700 | [diff] [blame] | 6 | import torch |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 7 | from typing import Tuple |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 8 | from torch.backends._nnapi.prepare import convert_model_to_nnapi |
| 9 | from torch.testing._internal.common_utils import TestCase, run_tests |
| 10 | |
| 11 | |
| 12 | def qpt(t, scale, zero_point, dtype=torch.quint8): |
| 13 | t = torch.tensor(t) |
| 14 | return torch.quantize_per_tensor(t, scale, zero_point, dtype) |
| 15 | |
| 16 | |
| 17 | def nhwc(t): |
| 18 | t = t.clone().contiguous(memory_format=torch.channels_last) |
| 19 | t.nnapi_nhwc = True |
| 20 | return t |
| 21 | |
| 22 | |
| 23 | class TestNNAPI(TestCase): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 24 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 25 | def setUp(self): |
| 26 | # Avoid saturation in fbgemm |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 27 | torch.backends.quantized.engine = 'qnnpack' |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 28 | |
| 29 | libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH") |
| 30 | if libneuralnetworks_path: |
| 31 | ctypes.cdll.LoadLibrary(libneuralnetworks_path) |
| 32 | print("Will attempt to run NNAPI models.") |
| 33 | self.can_run_nnapi = True |
| 34 | else: |
| 35 | self.can_run_nnapi = False |
| 36 | |
Amy He | 046272f | 2021-07-23 16:56:36 -0700 | [diff] [blame] | 37 | # Created for easy override by subclasses (eg TestNnapiBackend) |
| 38 | def call_lowering_to_nnapi(self, traced_module, args): |
| 39 | return convert_model_to_nnapi(traced_module, args) |
| 40 | |
| 41 | # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend) |
| 42 | def set_can_run_nnapi(self, can_run): |
| 43 | self.can_run_nnapi = can_run |
| 44 | |
David Reiss | da7a27b | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 45 | def check( |
| 46 | self, |
| 47 | module, |
| 48 | arg_or_args, |
| 49 | *, |
| 50 | trace_args=None, |
| 51 | convert_args=None, |
| 52 | atol_rtol=None, |
| 53 | limit=None, |
Akshit Khurana | 2d58f3f | 2021-08-20 21:08:59 -0700 | [diff] [blame] | 54 | expected_memory_format=None |
David Reiss | da7a27b | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 55 | ): |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 56 | with torch.no_grad(): |
| 57 | if isinstance(arg_or_args, torch.Tensor): |
| 58 | args = [arg_or_args] |
| 59 | else: |
| 60 | args = arg_or_args |
| 61 | module.eval() |
David Reiss | da7a27b | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 62 | traced = torch.jit.trace(module, trace_args or args) |
Amy He | 046272f | 2021-07-23 16:56:36 -0700 | [diff] [blame] | 63 | nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 64 | if not self.can_run_nnapi: |
| 65 | # Only test that the model was converted successfully. |
| 66 | return |
| 67 | eager_output = module(*args) |
David Reiss | 1f1d261 | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 68 | nnapi_output = nnapi_module(*args) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 69 | kwargs = {} |
| 70 | if atol_rtol is not None: |
| 71 | kwargs["atol"] = atol_rtol[0] |
| 72 | kwargs["rtol"] = atol_rtol[1] |
| 73 | self.assertEqual(eager_output, nnapi_output, **kwargs) |
| 74 | if limit is not None: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 75 | mismatches = \ |
| 76 | eager_output.int_repr().to(torch.int32) - \ |
| 77 | nnapi_output.int_repr().to(torch.int32) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 78 | if mismatches.count_nonzero() > limit: |
| 79 | # Too many mismatches. Re-run the check with no tolerance |
| 80 | # to get a nice message. |
| 81 | self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0) |
Akshit Khurana | 2d58f3f | 2021-08-20 21:08:59 -0700 | [diff] [blame] | 82 | if expected_memory_format: |
Akshit Khurana | 130549d | 2021-08-23 16:33:07 -0700 | [diff] [blame] | 83 | self.assertTrue(nnapi_output.is_contiguous(memory_format=expected_memory_format)) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 84 | |
| 85 | def float_and_quant_and_nhwc(self, inp_float, scale, zero_point): |
| 86 | torch.manual_seed(29) |
| 87 | inp_quant = qpt(inp_float, 0.03, 128) |
| 88 | return [ |
| 89 | ("float", inp_float), |
| 90 | ("float-nhwc", nhwc(inp_float)), |
| 91 | ("quant", inp_quant), |
| 92 | ("quant-nhwc", nhwc(inp_quant)), |
| 93 | ] |
| 94 | |
| 95 | def test_prelu(self): |
| 96 | arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1) |
| 97 | single_a = torch.nn.PReLU() |
| 98 | self.check(single_a, arg) |
| 99 | multi_a = torch.nn.PReLU(4) |
| 100 | with torch.no_grad(): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 101 | multi_a.weight.copy_(torch.tensor([.1, .2, .3, .4])) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 102 | self.check(multi_a, nhwc(arg)) |
| 103 | |
David Reiss | da7a27b | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 104 | # Test flexible size |
| 105 | self.check( |
| 106 | multi_a, |
| 107 | arg, |
| 108 | trace_args=[torch.zeros(1, 4, 3, 3)], |
| 109 | convert_args=[nhwc(torch.zeros(1, 4, 0, 0))], |
| 110 | ) |
| 111 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 112 | def test_quantize(self): |
| 113 | self.check( |
| 114 | torch.nn.quantized.Quantize(0.25, 2, torch.quint8), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 115 | nhwc(torch.tensor([[[[1.0]], [[2.0]]]]))) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 116 | |
| 117 | def test_dequantize(self): |
| 118 | self.check( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 119 | torch.nn.quantized.DeQuantize(), |
| 120 | nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2))) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 121 | |
David Reiss | b057d27 | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 122 | def test_unsqueeze(self): |
| 123 | class UnsqueezeModule(torch.nn.Module): |
| 124 | def __init__(self, dim): |
| 125 | super().__init__() |
| 126 | self.dim = dim |
| 127 | |
| 128 | def forward(self, arg): |
| 129 | return arg.unsqueeze(self.dim) |
| 130 | |
| 131 | self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2)) |
| 132 | self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2)) |
| 133 | self.check(UnsqueezeModule(0), torch.randn(4, 2, 2)) |
| 134 | self.check(UnsqueezeModule(1), torch.randn(4, 2, 2)) |
| 135 | self.check(UnsqueezeModule(2), torch.randn(4, 2, 2)) |
| 136 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 137 | def test_reshape(self): |
| 138 | class ReshapeModule(torch.nn.Module): |
| 139 | def __init__(self, shape): |
| 140 | super().__init__() |
| 141 | self.shape = shape |
| 142 | |
| 143 | def forward(self, arg): |
| 144 | return arg.reshape(self.shape) |
| 145 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 146 | self.check( |
| 147 | ReshapeModule((2, 4)), |
| 148 | torch.randn(4, 2, 1, 1)) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 149 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 150 | self.check( |
| 151 | ReshapeModule((8, -1)), |
| 152 | nhwc(torch.randn(4, 2, 1, 1))) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 153 | |
| 154 | with self.assertRaisesRegex(Exception, "target size"): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 155 | self.check( |
| 156 | ReshapeModule((2, 4)), |
| 157 | nhwc(torch.randn(4, 2, 1, 1))) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 158 | |
Akshit Khurana | 0be228d | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 159 | def test_flatten(self): |
| 160 | for mod in [ |
| 161 | torch.nn.Flatten(), |
| 162 | torch.nn.Flatten(start_dim=2, end_dim=3), |
| 163 | torch.nn.Flatten(start_dim=2, end_dim=4), |
| 164 | torch.nn.Flatten(start_dim=0, end_dim=-2), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 165 | torch.nn.Flatten(start_dim=0, end_dim=4) |
| 166 | |
Akshit Khurana | 0be228d | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 167 | ]: |
| 168 | self.check(mod, torch.randn(4, 2, 1, 3, 7)) |
| 169 | |
Akshit Khurana | 8e71f48 | 2021-07-26 10:47:48 -0700 | [diff] [blame] | 170 | # flex inputs |
Akshit Khurana | ae65f63 | 2021-07-09 15:08:54 -0700 | [diff] [blame] | 171 | self.check( |
| 172 | torch.nn.Flatten(), |
| 173 | torch.randn(4, 2, 1, 3, 7), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 174 | convert_args=[torch.zeros(0, 2, 1, 3, 7)] |
Akshit Khurana | ae65f63 | 2021-07-09 15:08:54 -0700 | [diff] [blame] | 175 | ) |
Akshit Khurana | 8e71f48 | 2021-07-26 10:47:48 -0700 | [diff] [blame] | 176 | |
| 177 | # channels last |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 178 | self.check( |
| 179 | torch.nn.Flatten(), |
| 180 | nhwc(torch.randn(2, 1, 4, 7)) |
| 181 | ) |
| 182 | self.check( |
| 183 | torch.nn.Flatten(), |
| 184 | nhwc(torch.randn(2, 3, 1, 1)) |
| 185 | ) |
Akshit Khurana | 8e71f48 | 2021-07-26 10:47:48 -0700 | [diff] [blame] | 186 | |
| 187 | # Exceptions |
| 188 | with self.assertRaisesRegex(Exception, "not supported on NHWC"): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 189 | self.check( |
| 190 | torch.nn.Flatten(), |
| 191 | nhwc(torch.randn(1, 3, 4, 4)) |
| 192 | ) |
| 193 | with self.assertRaisesRegex(Exception, "Flattening flexible dims is not supported yet"): |
Akshit Khurana | ae65f63 | 2021-07-09 15:08:54 -0700 | [diff] [blame] | 194 | self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7)) |
| 195 | with self.assertRaisesRegex(Exception, "Only 1 dim"): |
| 196 | self.check( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 197 | torch.nn.Flatten(start_dim=1, end_dim=-2), |
| 198 | torch.randn(0, 2, 1, 3, 0)) |
Akshit Khurana | 0be228d | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 199 | |
Akshit Khurana | cf285d8 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 200 | def test_slice(self): |
| 201 | class SliceModule(torch.nn.Module): |
| 202 | def __init__(self, start, stop, step): |
| 203 | super().__init__() |
| 204 | self.start = start |
| 205 | self.stop = stop |
| 206 | self.step = step |
| 207 | |
| 208 | def forward(self, t): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 209 | return t[1:, self.start:self.stop:self.step, :] |
Akshit Khurana | cf285d8 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 210 | |
| 211 | class SliceModule2(torch.nn.Module): |
| 212 | def forward(self, t): |
| 213 | return t[3:] |
| 214 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 215 | self.check( |
| 216 | SliceModule(1, 5, 2), |
| 217 | torch.randn(4, 6, 2) |
| 218 | ) |
| 219 | self.check( |
| 220 | SliceModule2(), |
| 221 | torch.randn(5) |
| 222 | ) |
Akshit Khurana | cf285d8 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 223 | |
| 224 | # flex inputs |
| 225 | self.check( |
| 226 | SliceModule(1, 5, 2), |
| 227 | torch.randn(4, 6, 2), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 228 | convert_args=[torch.zeros(4, 6, 0)] |
Akshit Khurana | cf285d8 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 229 | ) |
| 230 | with self.assertRaisesRegex(Exception, "slice with flexible shape"): |
| 231 | self.check( |
| 232 | SliceModule(1, 5, 2), |
| 233 | torch.randn(4, 6, 2), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 234 | convert_args=[torch.zeros(0, 0, 0)] |
Akshit Khurana | cf285d8 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 235 | ) |
| 236 | |
David Reiss | b057d27 | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 237 | def test_cat(self): |
| 238 | class CatModule(torch.nn.Module): |
| 239 | def __init__(self, dim): |
| 240 | super().__init__() |
| 241 | self.dim = dim |
| 242 | |
| 243 | def forward(self, t1, t2): |
| 244 | return torch.cat([t1, t2], self.dim) |
| 245 | |
| 246 | self.check( |
| 247 | CatModule(0), |
| 248 | [ |
| 249 | torch.randn(1, 2, 3, 3), |
| 250 | torch.randn(2, 2, 3, 3), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 251 | ]) |
David Reiss | b057d27 | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 252 | |
Akshit Khurana | 76c0f22 | 2021-07-09 14:22:41 -0700 | [diff] [blame] | 253 | self.check( |
| 254 | CatModule(1), |
| 255 | [ |
| 256 | torch.randn(1, 2, 3, 3), |
| 257 | torch.randn(1, 4, 3, 3), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 258 | ]) |
Zsolt Dollenstein | b004307 | 2021-08-12 10:56:55 -0700 | [diff] [blame] | 259 | |
| 260 | self.check( |
| 261 | CatModule(1), |
| 262 | [ |
| 263 | nhwc(torch.randn(1, 2, 3, 3)), |
| 264 | nhwc(torch.randn(1, 4, 3, 3)), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 265 | ]) |
Zsolt Dollenstein | b004307 | 2021-08-12 10:56:55 -0700 | [diff] [blame] | 266 | |
| 267 | self.check( |
| 268 | CatModule(1), |
| 269 | [ |
| 270 | torch.randn(1, 2, 3, 3), |
| 271 | torch.randn(1, 4, 3, 3), |
| 272 | ], |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 273 | convert_args=[ |
| 274 | torch.zeros(0, 0, 0, 0), |
| 275 | torch.zeros(0, 0, 0, 0) |
| 276 | ]) |
Akshit Khurana | 76c0f22 | 2021-07-09 14:22:41 -0700 | [diff] [blame] | 277 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 278 | def test_pointwise_unary(self): |
| 279 | for op in ["relu", "sigmoid"]: |
| 280 | with self.subTest(op): |
| 281 | class UnaryModule(torch.nn.Module): |
| 282 | def forward(self, arg): |
| 283 | if op == "relu": |
| 284 | return torch.nn.functional.relu(arg) |
| 285 | if op == "sigmoid": |
| 286 | return torch.sigmoid(arg) |
| 287 | raise Exception("Bad op") |
| 288 | self.check(UnaryModule(), torch.tensor([-1.0, 1.0])) |
Akshit Khurana | a70297e | 2022-01-07 13:34:33 -0800 | [diff] [blame] | 289 | self.check( |
| 290 | UnaryModule(), |
| 291 | qpt(torch.tensor([-1.0, 1.0]), 1. / 256, 0), |
| 292 | ) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 293 | |
| 294 | def test_pointwise_binary(self): |
Akshit Khurana | b297f65 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 295 | for op in ["add", "sub", "mul", "div"]: |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 296 | with self.subTest(op): |
| 297 | class BinaryModule(torch.nn.Module): |
| 298 | def forward(self, lhs, rhs): |
| 299 | if op == "add": |
| 300 | return lhs + rhs |
| 301 | if op == "sub": |
| 302 | return lhs - rhs |
| 303 | if op == "mul": |
| 304 | return lhs * rhs |
Akshit Khurana | b297f65 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 305 | if op == "div": |
| 306 | return lhs / rhs |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 307 | raise Exception("Bad op") |
| 308 | |
| 309 | self.check( |
| 310 | BinaryModule(), |
| 311 | [ |
| 312 | torch.tensor([1.0, 2.0]), |
| 313 | torch.tensor([3.0, 4.0]), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 314 | ]) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 315 | |
| 316 | self.check( |
| 317 | BinaryModule(), |
| 318 | [ |
| 319 | torch.tensor([[1.0, 2.0]]), |
| 320 | torch.tensor([[3.0, 4.0], [5.0, 6.0]]), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 321 | ]) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 322 | |
| 323 | with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"): |
| 324 | self.check( |
| 325 | BinaryModule(), |
| 326 | [ |
| 327 | torch.tensor([1.0, 2.0]), |
| 328 | torch.tensor([[3.0, 4.0], [5.0, 6.0]]), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 329 | ]) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 330 | |
Akshit Khurana | 2d58f3f | 2021-08-20 21:08:59 -0700 | [diff] [blame] | 331 | def test_pointwise_binary_const(self): |
| 332 | const = torch.randn(1, 4, 6, 6) |
| 333 | |
| 334 | class ArgPlusConst(torch.nn.Module): |
| 335 | def forward(self, arg): |
| 336 | return arg + const |
| 337 | |
| 338 | class ConstPlusArg(torch.nn.Module): |
| 339 | def forward(self, arg): |
| 340 | return const + arg |
| 341 | |
| 342 | arg_contig = torch.randn(2, 4, 6, 6) |
| 343 | arg_nhwc = nhwc(torch.randn(2, 4, 6, 6)) |
| 344 | |
| 345 | for mod_class in [ArgPlusConst, ConstPlusArg]: |
| 346 | for use_nhwc in [False, True]: |
| 347 | with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc): |
| 348 | arg = arg_nhwc if use_nhwc else arg_contig |
| 349 | memory_format = torch.channels_last if use_nhwc else torch.contiguous_format |
| 350 | self.check(mod_class(), arg, |
| 351 | expected_memory_format=memory_format) |
| 352 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 353 | def test_hardtanh(self): |
| 354 | inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0]) |
| 355 | self.check(torch.nn.Hardtanh(), inp) |
| 356 | self.check(torch.nn.Hardtanh(0.0, 6.0), inp) |
| 357 | with self.assertRaisesRegex(Exception, "hardtanh with args"): |
| 358 | self.check(torch.nn.Hardtanh(0.0, 5.0), inp) |
| 359 | |
Akshit Khurana | 14d604a | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 360 | def test_softmax(self): |
| 361 | inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]]) |
| 362 | self.check(torch.nn.Softmax(), inp) |
| 363 | self.check(torch.nn.Softmax(dim=0), inp) |
| 364 | # Test flexible size |
| 365 | self.check( |
| 366 | torch.nn.Softmax(), |
| 367 | inp, |
| 368 | convert_args=[torch.zeros(0, 0)], |
| 369 | ) |
| 370 | |
Akshit Khurana | eab18a9 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 371 | def test_to(self): |
| 372 | class ToCPU(torch.nn.Module): |
| 373 | def __init__(self): |
| 374 | super().__init__() |
| 375 | self.prelu = torch.nn.PReLU() |
| 376 | |
| 377 | def forward(self, x): |
| 378 | y = x.to("cpu") |
| 379 | # add prelu since input operand can't be output |
| 380 | return self.prelu(y) |
| 381 | |
| 382 | arg = torch.randn(1, 2, 3, 3) |
| 383 | self.check(ToCPU(), arg) |
| 384 | # Test flexible size |
| 385 | self.check( |
| 386 | ToCPU(), |
| 387 | arg, |
| 388 | convert_args=[torch.zeros(1, 2, 0, 0)], |
| 389 | ) |
| 390 | |
Akshit Khurana | d263727 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 391 | def test_detach(self): |
| 392 | class DetachModule(torch.nn.Module): |
| 393 | def __init__(self): |
| 394 | super().__init__() |
| 395 | |
| 396 | def forward(self, x): |
| 397 | y = x.detach() |
| 398 | return torch.nn.functional.relu(y) |
| 399 | |
| 400 | self.check(DetachModule(), torch.randn(1, 2, 3, 3)) |
| 401 | self.check( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 402 | DetachModule(), torch.randn(1, 2, 3, 3), |
| 403 | convert_args=[torch.zeros(1, 2, 0, 0)]) |
Akshit Khurana | d263727 | 2021-07-07 12:37:51 -0700 | [diff] [blame] | 404 | |
Ivan Kobzarev | 7b6ddb6 | 2021-07-07 17:48:05 -0700 | [diff] [blame] | 405 | def test_log_softmax(self): |
| 406 | inp = torch.randn(3, 10) |
| 407 | self.check(torch.nn.LogSoftmax(), inp) |
| 408 | self.check(torch.nn.LogSoftmax(0), inp) |
| 409 | |
David Reiss | b057d27 | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 410 | def test_mean(self): |
| 411 | class MeanModule(torch.nn.Module): |
| 412 | def __init__(self, dim, keep=False): |
| 413 | super().__init__() |
| 414 | self.dim = dim |
| 415 | self.keep = keep |
| 416 | |
| 417 | def forward(self, t): |
| 418 | return torch.mean(t, dim=self.dim, keepdim=self.keep) |
| 419 | |
| 420 | self.check(MeanModule(0), torch.randn(2, 3)) |
| 421 | self.check(MeanModule(1), torch.randn(2, 3)) |
| 422 | self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6)) |
| 423 | self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6))) |
| 424 | self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6))) |
| 425 | self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6))) |
| 426 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 427 | def test_max_pool2d(self): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 428 | for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 429 | with self.subTest(name): |
| 430 | self.check(torch.nn.MaxPool2d(2), inp) |
| 431 | self.check(torch.nn.MaxPool2d((3, 4)), inp) |
| 432 | self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp) |
| 433 | |
Akshit Khurana | 369802a | 2021-07-01 14:05:57 -0700 | [diff] [blame] | 434 | def test_avg_pool2d(self): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 435 | for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): |
Akshit Khurana | 369802a | 2021-07-01 14:05:57 -0700 | [diff] [blame] | 436 | with self.subTest(name): |
| 437 | atol_rtol = None |
| 438 | limit = None |
| 439 | convert_dims = (2, 3, 0, 0) |
| 440 | convert_arg = torch.zeros(*convert_dims) |
| 441 | |
| 442 | for model in ( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 443 | torch.nn.AvgPool2d(2), |
| 444 | torch.nn.AvgPool2d((3, 4)), |
| 445 | torch.nn.AvgPool2d((3, 4), (1, 2))): |
Akshit Khurana | 369802a | 2021-07-01 14:05:57 -0700 | [diff] [blame] | 446 | if "quant" in name: |
| 447 | atol_rtol = (1, 0) |
| 448 | limit = model(inp).numel() |
| 449 | convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) |
| 450 | if "nhwc" in name: |
| 451 | convert_arg = nhwc(convert_arg) |
| 452 | |
| 453 | self.check(model, inp, atol_rtol=atol_rtol, limit=limit) |
| 454 | self.check( |
| 455 | model, |
| 456 | inp, |
| 457 | convert_args=[convert_arg], |
| 458 | atol_rtol=atol_rtol, |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 459 | limit=limit |
Akshit Khurana | 369802a | 2021-07-01 14:05:57 -0700 | [diff] [blame] | 460 | ) |
| 461 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 462 | def test_adaptive_avg_pool2d(self): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 463 | for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 464 | with self.subTest(name): |
| 465 | self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp) |
| 466 | with self.assertRaisesRegex(Exception, "with output size"): |
| 467 | self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp) |
| 468 | |
| 469 | def test_upsample_nearest2d(self): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 470 | convert_args = dict(self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128)) |
| 471 | for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 472 | with self.subTest(name): |
| 473 | self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp) |
| 474 | self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp) |
| 475 | self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp) |
| 476 | self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp) |
| 477 | self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp) |
| 478 | self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp) |
| 479 | |
Akshit Khurana | c4bb6a5 | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 480 | self.check( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 481 | torch.nn.UpsamplingNearest2d(size=(24, 32)), inp, |
| 482 | convert_args=[convert_args[name]] |
Akshit Khurana | c4bb6a5 | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 483 | ) |
| 484 | self.check( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 485 | torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp, |
| 486 | convert_args=[convert_args[name]] |
Akshit Khurana | c4bb6a5 | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 487 | ) |
| 488 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 489 | def test_linear(self): |
| 490 | torch.manual_seed(29) |
| 491 | self.check(torch.nn.Linear(16, 32), torch.randn(2, 16)) |
Akshit Khurana | 9e81d3d | 2021-07-09 13:53:07 -0700 | [diff] [blame] | 492 | self.check( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 493 | torch.nn.Linear(16, 32), torch.randn(2, 16), |
| 494 | convert_args=[torch.zeros(0, 16)]) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 495 | |
| 496 | def test_conv2d(self): |
| 497 | cases = [ |
| 498 | # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 499 | ( 4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"), # noqa: E201,E241 |
| 500 | ( 4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"), # noqa: E201,E241 |
| 501 | ( 4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"), # noqa: E201,E241 |
| 502 | ( 8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"), # noqa: E201,E241 |
| 503 | ( 4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"), # noqa: E201,E241 |
| 504 | ( 4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"), # noqa: E201,E241 |
| 505 | ( 8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"), # noqa: E201,E241 |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 506 | ] |
| 507 | |
| 508 | for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: |
| 509 | for case in cases: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 510 | in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name = case |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 511 | with self.subTest("{}-{}".format(kind, name)): |
| 512 | inp = torch.randn(input_dim) |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 513 | model = torch.nn.Conv2d(in_ch, out_ch, kernel, stride, padding, groups=groups, bias=bool(bias)) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 514 | output_size = model(inp).numel() |
| 515 | atol_rtol = None |
| 516 | limit = None |
Akshit Khurana | 9e533a6 | 2021-07-09 09:48:41 -0700 | [diff] [blame] | 517 | convert_dims = (0, in_ch, 0, 0) |
Akshit Khurana | 28cd04e | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 518 | convert_arg = torch.zeros(*convert_dims) |
| 519 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 520 | if "quant" in kind: |
| 521 | model = torch.nn.Sequential(model) |
| 522 | model.eval() |
Vasiliy Kuznetsov | 227e37d | 2021-10-01 06:21:08 -0700 | [diff] [blame] | 523 | model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') |
| 524 | model = torch.ao.quantization.prepare(model) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 525 | model(inp) |
Vasiliy Kuznetsov | 227e37d | 2021-10-01 06:21:08 -0700 | [diff] [blame] | 526 | model = torch.ao.quantization.convert(model) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 527 | inp = qpt(inp, 1.0 / 16, 128) |
| 528 | # I've seen numerical differences between QNNPACK and NNAPI, |
| 529 | # but never more than 1 quantum, and never more than ~1% of |
| 530 | # the output in this test. |
| 531 | atol_rtol = (1, 0) |
| 532 | limit = output_size * 0.03 |
Akshit Khurana | 28cd04e | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 533 | convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) |
| 534 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 535 | if "nhwc" in kind: |
| 536 | inp = nhwc(inp) |
Akshit Khurana | 28cd04e | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 537 | convert_arg = nhwc(convert_arg) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 538 | |
| 539 | self.check(model, inp, atol_rtol=atol_rtol, limit=limit) |
Akshit Khurana | 28cd04e | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 540 | self.check( |
| 541 | model, |
| 542 | inp, |
| 543 | convert_args=[convert_arg], |
| 544 | atol_rtol=atol_rtol, |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 545 | limit=limit |
Akshit Khurana | 28cd04e | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 546 | ) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 547 | |
Akshit Khurana | 8bd3e52e | 2021-07-09 09:24:55 -0700 | [diff] [blame] | 548 | def test_conv2d_transpose(self): |
Akshit Khurana | 1de8976 | 2021-09-24 17:02:31 -0700 | [diff] [blame] | 549 | torch.manual_seed(29) |
Akshit Khurana | 8bd3e52e | 2021-07-09 09:24:55 -0700 | [diff] [blame] | 550 | in_ch, out_ch, kernel = (5, 7, (2, 2)) |
| 551 | input_dim = (4, 5, 3, 3) |
Akshit Khurana | 8bd3e52e | 2021-07-09 09:24:55 -0700 | [diff] [blame] | 552 | convert_dims = input_dim[:2] + (0, 0) |
| 553 | |
| 554 | for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: |
| 555 | with self.subTest(kind): |
Akshit Khurana | 1de8976 | 2021-09-24 17:02:31 -0700 | [diff] [blame] | 556 | inp = torch.randn(input_dim) |
Akshit Khurana | 8bd3e52e | 2021-07-09 09:24:55 -0700 | [diff] [blame] | 557 | model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel) |
| 558 | output_size = model(inp).numel() |
| 559 | atol_rtol = (0.0002, 0) |
| 560 | limit = None |
| 561 | convert_arg = torch.zeros(*convert_dims) |
| 562 | |
| 563 | if "quant" in kind: |
Akshit Khurana | 1de8976 | 2021-09-24 17:02:31 -0700 | [diff] [blame] | 564 | model = torch.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel) |
Vasiliy Kuznetsov | 227e37d | 2021-10-01 06:21:08 -0700 | [diff] [blame] | 565 | model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') |
Akshit Khurana | 8bd3e52e | 2021-07-09 09:24:55 -0700 | [diff] [blame] | 566 | inp = qpt(inp, 1.0 / 16, 128) |
| 567 | # I've seen numerical differences between QNNPACK and NNAPI, |
Akshit Khurana | 1de8976 | 2021-09-24 17:02:31 -0700 | [diff] [blame] | 568 | # but never more than 1 quantum, and never more than ~10% of |
Akshit Khurana | 8bd3e52e | 2021-07-09 09:24:55 -0700 | [diff] [blame] | 569 | # the output in this test. |
| 570 | atol_rtol = (1, 0) |
Akshit Khurana | 1de8976 | 2021-09-24 17:02:31 -0700 | [diff] [blame] | 571 | limit = output_size * 0.1 |
Akshit Khurana | 8bd3e52e | 2021-07-09 09:24:55 -0700 | [diff] [blame] | 572 | convert_arg = qpt(convert_arg, 1.0 / 16, 128) |
| 573 | |
| 574 | if "nhwc" in kind: |
| 575 | inp = nhwc(inp) |
| 576 | convert_arg = nhwc(convert_arg) |
| 577 | |
| 578 | self.check(model, inp, atol_rtol=atol_rtol, limit=limit) |
| 579 | self.check( |
| 580 | model, |
| 581 | inp, |
| 582 | convert_args=[convert_arg], |
| 583 | atol_rtol=atol_rtol, |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 584 | limit=limit |
Akshit Khurana | 8bd3e52e | 2021-07-09 09:24:55 -0700 | [diff] [blame] | 585 | ) |
| 586 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 587 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 588 | def test_qadd(self): |
| 589 | func = torch.nn.quantized.QFunctional() |
| 590 | func.scale = 0.5 |
| 591 | func.zero_point = 120 |
| 592 | |
| 593 | class AddMod(torch.nn.Module): |
| 594 | def forward(self, lhs, rhs): |
| 595 | return func.add(lhs, rhs) |
| 596 | |
| 597 | class AddReluMod(torch.nn.Module): |
| 598 | def forward(self, lhs, rhs): |
| 599 | return func.add_relu(lhs, rhs) |
| 600 | |
Akshit Khurana | ab5eb56 | 2021-09-24 17:02:31 -0700 | [diff] [blame] | 601 | class MulMod(torch.nn.Module): |
| 602 | def forward(self, lhs, rhs): |
| 603 | return func.mul(lhs, rhs) |
| 604 | |
| 605 | for (name, mod) in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]: |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 606 | with self.subTest(name): |
| 607 | self.check( |
| 608 | mod(), |
| 609 | [ |
| 610 | qpt([1.0, 2.0], 0.25, 128), |
| 611 | qpt([3.0, 4.0], 0.25, 128), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 612 | ]) |
Akshit Khurana | 4c609a9 | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 613 | self.check( |
| 614 | mod(), |
| 615 | [ |
| 616 | qpt([[1.0, 2.0]], 0.25, 128), |
| 617 | qpt([[3.0, 4.0]], 0.25, 128), |
| 618 | ], |
| 619 | convert_args=[ |
| 620 | qpt([[1.0, 2.0]], 0.25, 128), |
| 621 | qpt(torch.zeros((1, 2)), 0.25, 128), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 622 | ] |
Akshit Khurana | 4c609a9 | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 623 | ) |
| 624 | self.check( |
| 625 | mod(), |
| 626 | [ |
| 627 | qpt([[1.0, 2.0]], 0.25, 128), |
| 628 | qpt([[3.0, 4.0]], 0.25, 128), |
| 629 | ], |
| 630 | convert_args=[ |
| 631 | qpt(torch.zeros((1, 2)), 0.25, 128), |
| 632 | qpt([[3.0, 4.0]], 0.25, 128), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 633 | ] |
Akshit Khurana | 4c609a9 | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 634 | ) |
| 635 | self.check( |
| 636 | mod(), |
| 637 | [ |
| 638 | qpt([[1.0, 2.0]], 0.25, 128), |
| 639 | qpt([[3.0, 4.0]], 0.25, 128), |
| 640 | ], |
| 641 | convert_args=[ |
| 642 | qpt(torch.zeros((1, 2)), 0.25, 128), |
| 643 | qpt(torch.zeros((1, 2)), 0.25, 128), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 644 | ] |
Akshit Khurana | 4c609a9 | 2021-05-05 13:51:59 -0700 | [diff] [blame] | 645 | ) |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 646 | # NOTE: NNAPI qadd supports broadcast, but PT does not. |
| 647 | |
| 648 | def test_qlinear(self): |
| 649 | torch.manual_seed(29) |
| 650 | weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8) |
| 651 | bias = torch.randn(16) |
| 652 | mod = torch.nn.quantized.Linear(32, 16) |
| 653 | mod.set_weight_bias(weight, bias) |
| 654 | inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8) |
| 655 | self.check(mod, inp) |
| 656 | |
David Reiss | 476c597 | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 657 | def test_seblock_mul(self): |
| 658 | class MulModel(torch.nn.Module): |
| 659 | def forward(self, lhs, rhs): |
| 660 | return lhs * rhs |
| 661 | |
| 662 | self.check( |
| 663 | MulModel(), |
| 664 | [ |
| 665 | nhwc(torch.randn(2, 3, 4, 4)), |
| 666 | torch.randn(1, 3, 1, 1), |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 667 | ]) |
David Reiss | 476c597 | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 668 | |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 669 | def test_multi_output(self): |
| 670 | class MultiModel(torch.nn.Module): |
| 671 | def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]: |
| 672 | the_sum = lhs + rhs |
| 673 | the_diff = lhs - rhs |
| 674 | return the_sum, the_diff |
| 675 | |
| 676 | self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])]) |
| 677 | |
| 678 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 679 | if __name__ == '__main__': |
David Reiss | 3802edd | 2021-04-06 13:40:04 -0700 | [diff] [blame] | 680 | run_tests() |