blob: 7cf6a3900ca831770c6e6a62fce789956f0cc90c [file] [log] [blame]
David Reiss3802edd2021-04-06 13:40:04 -07001#!/usr/bin/env python3
Jane Xua4a6d052021-11-05 10:51:35 -07002# Owner(s): ["oncall: mobile"]
Jane Xu62596012021-10-29 12:40:39 -07003
Zsolt Dollensteinb0043072021-08-12 10:56:55 -07004import os
Shen Li10224432021-08-12 11:39:31 -07005import ctypes
Zsolt Dollensteinb0043072021-08-12 10:56:55 -07006import torch
Shen Li10224432021-08-12 11:39:31 -07007from typing import Tuple
David Reiss3802edd2021-04-06 13:40:04 -07008from torch.backends._nnapi.prepare import convert_model_to_nnapi
9from torch.testing._internal.common_utils import TestCase, run_tests
10
11
12def qpt(t, scale, zero_point, dtype=torch.quint8):
13 t = torch.tensor(t)
14 return torch.quantize_per_tensor(t, scale, zero_point, dtype)
15
16
17def nhwc(t):
18 t = t.clone().contiguous(memory_format=torch.channels_last)
19 t.nnapi_nhwc = True
20 return t
21
22
23class TestNNAPI(TestCase):
Shen Li10224432021-08-12 11:39:31 -070024
David Reiss3802edd2021-04-06 13:40:04 -070025 def setUp(self):
26 # Avoid saturation in fbgemm
Shen Li10224432021-08-12 11:39:31 -070027 torch.backends.quantized.engine = 'qnnpack'
David Reiss3802edd2021-04-06 13:40:04 -070028
29 libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH")
30 if libneuralnetworks_path:
31 ctypes.cdll.LoadLibrary(libneuralnetworks_path)
32 print("Will attempt to run NNAPI models.")
33 self.can_run_nnapi = True
34 else:
35 self.can_run_nnapi = False
36
Amy He046272f2021-07-23 16:56:36 -070037 # Created for easy override by subclasses (eg TestNnapiBackend)
38 def call_lowering_to_nnapi(self, traced_module, args):
39 return convert_model_to_nnapi(traced_module, args)
40
41 # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend)
42 def set_can_run_nnapi(self, can_run):
43 self.can_run_nnapi = can_run
44
David Reissda7a27b2021-04-06 13:40:04 -070045 def check(
46 self,
47 module,
48 arg_or_args,
49 *,
50 trace_args=None,
51 convert_args=None,
52 atol_rtol=None,
53 limit=None,
Akshit Khurana2d58f3f2021-08-20 21:08:59 -070054 expected_memory_format=None
David Reissda7a27b2021-04-06 13:40:04 -070055 ):
David Reiss3802edd2021-04-06 13:40:04 -070056 with torch.no_grad():
57 if isinstance(arg_or_args, torch.Tensor):
58 args = [arg_or_args]
59 else:
60 args = arg_or_args
61 module.eval()
David Reissda7a27b2021-04-06 13:40:04 -070062 traced = torch.jit.trace(module, trace_args or args)
Amy He046272f2021-07-23 16:56:36 -070063 nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args)
David Reiss3802edd2021-04-06 13:40:04 -070064 if not self.can_run_nnapi:
65 # Only test that the model was converted successfully.
66 return
67 eager_output = module(*args)
David Reiss1f1d2612021-04-06 13:40:04 -070068 nnapi_output = nnapi_module(*args)
David Reiss3802edd2021-04-06 13:40:04 -070069 kwargs = {}
70 if atol_rtol is not None:
71 kwargs["atol"] = atol_rtol[0]
72 kwargs["rtol"] = atol_rtol[1]
73 self.assertEqual(eager_output, nnapi_output, **kwargs)
74 if limit is not None:
Shen Li10224432021-08-12 11:39:31 -070075 mismatches = \
76 eager_output.int_repr().to(torch.int32) - \
77 nnapi_output.int_repr().to(torch.int32)
David Reiss3802edd2021-04-06 13:40:04 -070078 if mismatches.count_nonzero() > limit:
79 # Too many mismatches. Re-run the check with no tolerance
80 # to get a nice message.
81 self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
Akshit Khurana2d58f3f2021-08-20 21:08:59 -070082 if expected_memory_format:
Akshit Khurana130549d2021-08-23 16:33:07 -070083 self.assertTrue(nnapi_output.is_contiguous(memory_format=expected_memory_format))
David Reiss3802edd2021-04-06 13:40:04 -070084
85 def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
86 torch.manual_seed(29)
87 inp_quant = qpt(inp_float, 0.03, 128)
88 return [
89 ("float", inp_float),
90 ("float-nhwc", nhwc(inp_float)),
91 ("quant", inp_quant),
92 ("quant-nhwc", nhwc(inp_quant)),
93 ]
94
95 def test_prelu(self):
96 arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
97 single_a = torch.nn.PReLU()
98 self.check(single_a, arg)
99 multi_a = torch.nn.PReLU(4)
100 with torch.no_grad():
Shen Li10224432021-08-12 11:39:31 -0700101 multi_a.weight.copy_(torch.tensor([.1, .2, .3, .4]))
David Reiss3802edd2021-04-06 13:40:04 -0700102 self.check(multi_a, nhwc(arg))
103
David Reissda7a27b2021-04-06 13:40:04 -0700104 # Test flexible size
105 self.check(
106 multi_a,
107 arg,
108 trace_args=[torch.zeros(1, 4, 3, 3)],
109 convert_args=[nhwc(torch.zeros(1, 4, 0, 0))],
110 )
111
David Reiss3802edd2021-04-06 13:40:04 -0700112 def test_quantize(self):
113 self.check(
114 torch.nn.quantized.Quantize(0.25, 2, torch.quint8),
Shen Li10224432021-08-12 11:39:31 -0700115 nhwc(torch.tensor([[[[1.0]], [[2.0]]]])))
David Reiss3802edd2021-04-06 13:40:04 -0700116
117 def test_dequantize(self):
118 self.check(
Shen Li10224432021-08-12 11:39:31 -0700119 torch.nn.quantized.DeQuantize(),
120 nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2)))
David Reiss3802edd2021-04-06 13:40:04 -0700121
David Reissb057d272021-04-06 13:40:04 -0700122 def test_unsqueeze(self):
123 class UnsqueezeModule(torch.nn.Module):
124 def __init__(self, dim):
125 super().__init__()
126 self.dim = dim
127
128 def forward(self, arg):
129 return arg.unsqueeze(self.dim)
130
131 self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
132 self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
133 self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
134 self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
135 self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))
136
David Reiss3802edd2021-04-06 13:40:04 -0700137 def test_reshape(self):
138 class ReshapeModule(torch.nn.Module):
139 def __init__(self, shape):
140 super().__init__()
141 self.shape = shape
142
143 def forward(self, arg):
144 return arg.reshape(self.shape)
145
Shen Li10224432021-08-12 11:39:31 -0700146 self.check(
147 ReshapeModule((2, 4)),
148 torch.randn(4, 2, 1, 1))
David Reiss3802edd2021-04-06 13:40:04 -0700149
Shen Li10224432021-08-12 11:39:31 -0700150 self.check(
151 ReshapeModule((8, -1)),
152 nhwc(torch.randn(4, 2, 1, 1)))
David Reiss3802edd2021-04-06 13:40:04 -0700153
154 with self.assertRaisesRegex(Exception, "target size"):
Shen Li10224432021-08-12 11:39:31 -0700155 self.check(
156 ReshapeModule((2, 4)),
157 nhwc(torch.randn(4, 2, 1, 1)))
David Reiss3802edd2021-04-06 13:40:04 -0700158
Akshit Khurana0be228d2021-07-07 12:37:51 -0700159 def test_flatten(self):
160 for mod in [
161 torch.nn.Flatten(),
162 torch.nn.Flatten(start_dim=2, end_dim=3),
163 torch.nn.Flatten(start_dim=2, end_dim=4),
164 torch.nn.Flatten(start_dim=0, end_dim=-2),
Shen Li10224432021-08-12 11:39:31 -0700165 torch.nn.Flatten(start_dim=0, end_dim=4)
166
Akshit Khurana0be228d2021-07-07 12:37:51 -0700167 ]:
168 self.check(mod, torch.randn(4, 2, 1, 3, 7))
169
Akshit Khurana8e71f482021-07-26 10:47:48 -0700170 # flex inputs
Akshit Khuranaae65f632021-07-09 15:08:54 -0700171 self.check(
172 torch.nn.Flatten(),
173 torch.randn(4, 2, 1, 3, 7),
Shen Li10224432021-08-12 11:39:31 -0700174 convert_args=[torch.zeros(0, 2, 1, 3, 7)]
Akshit Khuranaae65f632021-07-09 15:08:54 -0700175 )
Akshit Khurana8e71f482021-07-26 10:47:48 -0700176
177 # channels last
Shen Li10224432021-08-12 11:39:31 -0700178 self.check(
179 torch.nn.Flatten(),
180 nhwc(torch.randn(2, 1, 4, 7))
181 )
182 self.check(
183 torch.nn.Flatten(),
184 nhwc(torch.randn(2, 3, 1, 1))
185 )
Akshit Khurana8e71f482021-07-26 10:47:48 -0700186
187 # Exceptions
188 with self.assertRaisesRegex(Exception, "not supported on NHWC"):
Shen Li10224432021-08-12 11:39:31 -0700189 self.check(
190 torch.nn.Flatten(),
191 nhwc(torch.randn(1, 3, 4, 4))
192 )
193 with self.assertRaisesRegex(Exception, "Flattening flexible dims is not supported yet"):
Akshit Khuranaae65f632021-07-09 15:08:54 -0700194 self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
195 with self.assertRaisesRegex(Exception, "Only 1 dim"):
196 self.check(
Shen Li10224432021-08-12 11:39:31 -0700197 torch.nn.Flatten(start_dim=1, end_dim=-2),
198 torch.randn(0, 2, 1, 3, 0))
Akshit Khurana0be228d2021-07-07 12:37:51 -0700199
Akshit Khuranacf285d82021-07-07 12:37:51 -0700200 def test_slice(self):
201 class SliceModule(torch.nn.Module):
202 def __init__(self, start, stop, step):
203 super().__init__()
204 self.start = start
205 self.stop = stop
206 self.step = step
207
208 def forward(self, t):
Shen Li10224432021-08-12 11:39:31 -0700209 return t[1:, self.start:self.stop:self.step, :]
Akshit Khuranacf285d82021-07-07 12:37:51 -0700210
211 class SliceModule2(torch.nn.Module):
212 def forward(self, t):
213 return t[3:]
214
Shen Li10224432021-08-12 11:39:31 -0700215 self.check(
216 SliceModule(1, 5, 2),
217 torch.randn(4, 6, 2)
218 )
219 self.check(
220 SliceModule2(),
221 torch.randn(5)
222 )
Akshit Khuranacf285d82021-07-07 12:37:51 -0700223
224 # flex inputs
225 self.check(
226 SliceModule(1, 5, 2),
227 torch.randn(4, 6, 2),
Shen Li10224432021-08-12 11:39:31 -0700228 convert_args=[torch.zeros(4, 6, 0)]
Akshit Khuranacf285d82021-07-07 12:37:51 -0700229 )
230 with self.assertRaisesRegex(Exception, "slice with flexible shape"):
231 self.check(
232 SliceModule(1, 5, 2),
233 torch.randn(4, 6, 2),
Shen Li10224432021-08-12 11:39:31 -0700234 convert_args=[torch.zeros(0, 0, 0)]
Akshit Khuranacf285d82021-07-07 12:37:51 -0700235 )
236
David Reissb057d272021-04-06 13:40:04 -0700237 def test_cat(self):
238 class CatModule(torch.nn.Module):
239 def __init__(self, dim):
240 super().__init__()
241 self.dim = dim
242
243 def forward(self, t1, t2):
244 return torch.cat([t1, t2], self.dim)
245
246 self.check(
247 CatModule(0),
248 [
249 torch.randn(1, 2, 3, 3),
250 torch.randn(2, 2, 3, 3),
Shen Li10224432021-08-12 11:39:31 -0700251 ])
David Reissb057d272021-04-06 13:40:04 -0700252
Akshit Khurana76c0f222021-07-09 14:22:41 -0700253 self.check(
254 CatModule(1),
255 [
256 torch.randn(1, 2, 3, 3),
257 torch.randn(1, 4, 3, 3),
Shen Li10224432021-08-12 11:39:31 -0700258 ])
Zsolt Dollensteinb0043072021-08-12 10:56:55 -0700259
260 self.check(
261 CatModule(1),
262 [
263 nhwc(torch.randn(1, 2, 3, 3)),
264 nhwc(torch.randn(1, 4, 3, 3)),
Shen Li10224432021-08-12 11:39:31 -0700265 ])
Zsolt Dollensteinb0043072021-08-12 10:56:55 -0700266
267 self.check(
268 CatModule(1),
269 [
270 torch.randn(1, 2, 3, 3),
271 torch.randn(1, 4, 3, 3),
272 ],
Shen Li10224432021-08-12 11:39:31 -0700273 convert_args=[
274 torch.zeros(0, 0, 0, 0),
275 torch.zeros(0, 0, 0, 0)
276 ])
Akshit Khurana76c0f222021-07-09 14:22:41 -0700277
David Reiss3802edd2021-04-06 13:40:04 -0700278 def test_pointwise_unary(self):
279 for op in ["relu", "sigmoid"]:
280 with self.subTest(op):
281 class UnaryModule(torch.nn.Module):
282 def forward(self, arg):
283 if op == "relu":
284 return torch.nn.functional.relu(arg)
285 if op == "sigmoid":
286 return torch.sigmoid(arg)
287 raise Exception("Bad op")
288 self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
Akshit Khuranaa70297e2022-01-07 13:34:33 -0800289 self.check(
290 UnaryModule(),
291 qpt(torch.tensor([-1.0, 1.0]), 1. / 256, 0),
292 )
David Reiss3802edd2021-04-06 13:40:04 -0700293
294 def test_pointwise_binary(self):
Akshit Khuranab297f652021-07-07 12:37:51 -0700295 for op in ["add", "sub", "mul", "div"]:
David Reiss3802edd2021-04-06 13:40:04 -0700296 with self.subTest(op):
297 class BinaryModule(torch.nn.Module):
298 def forward(self, lhs, rhs):
299 if op == "add":
300 return lhs + rhs
301 if op == "sub":
302 return lhs - rhs
303 if op == "mul":
304 return lhs * rhs
Akshit Khuranab297f652021-07-07 12:37:51 -0700305 if op == "div":
306 return lhs / rhs
David Reiss3802edd2021-04-06 13:40:04 -0700307 raise Exception("Bad op")
308
309 self.check(
310 BinaryModule(),
311 [
312 torch.tensor([1.0, 2.0]),
313 torch.tensor([3.0, 4.0]),
Shen Li10224432021-08-12 11:39:31 -0700314 ])
David Reiss3802edd2021-04-06 13:40:04 -0700315
316 self.check(
317 BinaryModule(),
318 [
319 torch.tensor([[1.0, 2.0]]),
320 torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
Shen Li10224432021-08-12 11:39:31 -0700321 ])
David Reiss3802edd2021-04-06 13:40:04 -0700322
323 with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"):
324 self.check(
325 BinaryModule(),
326 [
327 torch.tensor([1.0, 2.0]),
328 torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
Shen Li10224432021-08-12 11:39:31 -0700329 ])
David Reiss3802edd2021-04-06 13:40:04 -0700330
Akshit Khurana2d58f3f2021-08-20 21:08:59 -0700331 def test_pointwise_binary_const(self):
332 const = torch.randn(1, 4, 6, 6)
333
334 class ArgPlusConst(torch.nn.Module):
335 def forward(self, arg):
336 return arg + const
337
338 class ConstPlusArg(torch.nn.Module):
339 def forward(self, arg):
340 return const + arg
341
342 arg_contig = torch.randn(2, 4, 6, 6)
343 arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))
344
345 for mod_class in [ArgPlusConst, ConstPlusArg]:
346 for use_nhwc in [False, True]:
347 with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
348 arg = arg_nhwc if use_nhwc else arg_contig
349 memory_format = torch.channels_last if use_nhwc else torch.contiguous_format
350 self.check(mod_class(), arg,
351 expected_memory_format=memory_format)
352
David Reiss3802edd2021-04-06 13:40:04 -0700353 def test_hardtanh(self):
354 inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
355 self.check(torch.nn.Hardtanh(), inp)
356 self.check(torch.nn.Hardtanh(0.0, 6.0), inp)
357 with self.assertRaisesRegex(Exception, "hardtanh with args"):
358 self.check(torch.nn.Hardtanh(0.0, 5.0), inp)
359
Akshit Khurana14d604a2021-07-07 12:37:51 -0700360 def test_softmax(self):
361 inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]])
362 self.check(torch.nn.Softmax(), inp)
363 self.check(torch.nn.Softmax(dim=0), inp)
364 # Test flexible size
365 self.check(
366 torch.nn.Softmax(),
367 inp,
368 convert_args=[torch.zeros(0, 0)],
369 )
370
Akshit Khuranaeab18a92021-07-07 12:37:51 -0700371 def test_to(self):
372 class ToCPU(torch.nn.Module):
373 def __init__(self):
374 super().__init__()
375 self.prelu = torch.nn.PReLU()
376
377 def forward(self, x):
378 y = x.to("cpu")
379 # add prelu since input operand can't be output
380 return self.prelu(y)
381
382 arg = torch.randn(1, 2, 3, 3)
383 self.check(ToCPU(), arg)
384 # Test flexible size
385 self.check(
386 ToCPU(),
387 arg,
388 convert_args=[torch.zeros(1, 2, 0, 0)],
389 )
390
Akshit Khuranad2637272021-07-07 12:37:51 -0700391 def test_detach(self):
392 class DetachModule(torch.nn.Module):
393 def __init__(self):
394 super().__init__()
395
396 def forward(self, x):
397 y = x.detach()
398 return torch.nn.functional.relu(y)
399
400 self.check(DetachModule(), torch.randn(1, 2, 3, 3))
401 self.check(
Shen Li10224432021-08-12 11:39:31 -0700402 DetachModule(), torch.randn(1, 2, 3, 3),
403 convert_args=[torch.zeros(1, 2, 0, 0)])
Akshit Khuranad2637272021-07-07 12:37:51 -0700404
Ivan Kobzarev7b6ddb62021-07-07 17:48:05 -0700405 def test_log_softmax(self):
406 inp = torch.randn(3, 10)
407 self.check(torch.nn.LogSoftmax(), inp)
408 self.check(torch.nn.LogSoftmax(0), inp)
409
David Reissb057d272021-04-06 13:40:04 -0700410 def test_mean(self):
411 class MeanModule(torch.nn.Module):
412 def __init__(self, dim, keep=False):
413 super().__init__()
414 self.dim = dim
415 self.keep = keep
416
417 def forward(self, t):
418 return torch.mean(t, dim=self.dim, keepdim=self.keep)
419
420 self.check(MeanModule(0), torch.randn(2, 3))
421 self.check(MeanModule(1), torch.randn(2, 3))
422 self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
423 self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
424 self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
425 self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))
426
David Reiss3802edd2021-04-06 13:40:04 -0700427 def test_max_pool2d(self):
Shen Li10224432021-08-12 11:39:31 -0700428 for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
David Reiss3802edd2021-04-06 13:40:04 -0700429 with self.subTest(name):
430 self.check(torch.nn.MaxPool2d(2), inp)
431 self.check(torch.nn.MaxPool2d((3, 4)), inp)
432 self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)
433
Akshit Khurana369802a2021-07-01 14:05:57 -0700434 def test_avg_pool2d(self):
Shen Li10224432021-08-12 11:39:31 -0700435 for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
Akshit Khurana369802a2021-07-01 14:05:57 -0700436 with self.subTest(name):
437 atol_rtol = None
438 limit = None
439 convert_dims = (2, 3, 0, 0)
440 convert_arg = torch.zeros(*convert_dims)
441
442 for model in (
Shen Li10224432021-08-12 11:39:31 -0700443 torch.nn.AvgPool2d(2),
444 torch.nn.AvgPool2d((3, 4)),
445 torch.nn.AvgPool2d((3, 4), (1, 2))):
Akshit Khurana369802a2021-07-01 14:05:57 -0700446 if "quant" in name:
447 atol_rtol = (1, 0)
448 limit = model(inp).numel()
449 convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
450 if "nhwc" in name:
451 convert_arg = nhwc(convert_arg)
452
453 self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
454 self.check(
455 model,
456 inp,
457 convert_args=[convert_arg],
458 atol_rtol=atol_rtol,
Shen Li10224432021-08-12 11:39:31 -0700459 limit=limit
Akshit Khurana369802a2021-07-01 14:05:57 -0700460 )
461
David Reiss3802edd2021-04-06 13:40:04 -0700462 def test_adaptive_avg_pool2d(self):
Shen Li10224432021-08-12 11:39:31 -0700463 for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
David Reiss3802edd2021-04-06 13:40:04 -0700464 with self.subTest(name):
465 self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp)
466 with self.assertRaisesRegex(Exception, "with output size"):
467 self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp)
468
469 def test_upsample_nearest2d(self):
Shen Li10224432021-08-12 11:39:31 -0700470 convert_args = dict(self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128))
471 for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
David Reiss3802edd2021-04-06 13:40:04 -0700472 with self.subTest(name):
473 self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp)
474 self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp)
475 self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp)
476 self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp)
477 self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp)
478 self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp)
479
Akshit Khuranac4bb6a52021-05-05 13:51:59 -0700480 self.check(
Shen Li10224432021-08-12 11:39:31 -0700481 torch.nn.UpsamplingNearest2d(size=(24, 32)), inp,
482 convert_args=[convert_args[name]]
Akshit Khuranac4bb6a52021-05-05 13:51:59 -0700483 )
484 self.check(
Shen Li10224432021-08-12 11:39:31 -0700485 torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp,
486 convert_args=[convert_args[name]]
Akshit Khuranac4bb6a52021-05-05 13:51:59 -0700487 )
488
David Reiss3802edd2021-04-06 13:40:04 -0700489 def test_linear(self):
490 torch.manual_seed(29)
491 self.check(torch.nn.Linear(16, 32), torch.randn(2, 16))
Akshit Khurana9e81d3d2021-07-09 13:53:07 -0700492 self.check(
Shen Li10224432021-08-12 11:39:31 -0700493 torch.nn.Linear(16, 32), torch.randn(2, 16),
494 convert_args=[torch.zeros(0, 16)])
David Reiss3802edd2021-04-06 13:40:04 -0700495
496 def test_conv2d(self):
497 cases = [
498 # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name
Shen Li10224432021-08-12 11:39:31 -0700499 ( 4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"), # noqa: E201,E241
500 ( 4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"), # noqa: E201,E241
501 ( 4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"), # noqa: E201,E241
502 ( 8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"), # noqa: E201,E241
503 ( 4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"), # noqa: E201,E241
504 ( 4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"), # noqa: E201,E241
505 ( 8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"), # noqa: E201,E241
David Reiss3802edd2021-04-06 13:40:04 -0700506 ]
507
508 for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
509 for case in cases:
Shen Li10224432021-08-12 11:39:31 -0700510 in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name = case
David Reiss3802edd2021-04-06 13:40:04 -0700511 with self.subTest("{}-{}".format(kind, name)):
512 inp = torch.randn(input_dim)
Shen Li10224432021-08-12 11:39:31 -0700513 model = torch.nn.Conv2d(in_ch, out_ch, kernel, stride, padding, groups=groups, bias=bool(bias))
David Reiss3802edd2021-04-06 13:40:04 -0700514 output_size = model(inp).numel()
515 atol_rtol = None
516 limit = None
Akshit Khurana9e533a62021-07-09 09:48:41 -0700517 convert_dims = (0, in_ch, 0, 0)
Akshit Khurana28cd04e2021-05-05 13:51:59 -0700518 convert_arg = torch.zeros(*convert_dims)
519
David Reiss3802edd2021-04-06 13:40:04 -0700520 if "quant" in kind:
521 model = torch.nn.Sequential(model)
522 model.eval()
Vasiliy Kuznetsov227e37d2021-10-01 06:21:08 -0700523 model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
524 model = torch.ao.quantization.prepare(model)
David Reiss3802edd2021-04-06 13:40:04 -0700525 model(inp)
Vasiliy Kuznetsov227e37d2021-10-01 06:21:08 -0700526 model = torch.ao.quantization.convert(model)
David Reiss3802edd2021-04-06 13:40:04 -0700527 inp = qpt(inp, 1.0 / 16, 128)
528 # I've seen numerical differences between QNNPACK and NNAPI,
529 # but never more than 1 quantum, and never more than ~1% of
530 # the output in this test.
531 atol_rtol = (1, 0)
532 limit = output_size * 0.03
Akshit Khurana28cd04e2021-05-05 13:51:59 -0700533 convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
534
David Reiss3802edd2021-04-06 13:40:04 -0700535 if "nhwc" in kind:
536 inp = nhwc(inp)
Akshit Khurana28cd04e2021-05-05 13:51:59 -0700537 convert_arg = nhwc(convert_arg)
David Reiss3802edd2021-04-06 13:40:04 -0700538
539 self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
Akshit Khurana28cd04e2021-05-05 13:51:59 -0700540 self.check(
541 model,
542 inp,
543 convert_args=[convert_arg],
544 atol_rtol=atol_rtol,
Shen Li10224432021-08-12 11:39:31 -0700545 limit=limit
Akshit Khurana28cd04e2021-05-05 13:51:59 -0700546 )
David Reiss3802edd2021-04-06 13:40:04 -0700547
Akshit Khurana8bd3e52e2021-07-09 09:24:55 -0700548 def test_conv2d_transpose(self):
Akshit Khurana1de89762021-09-24 17:02:31 -0700549 torch.manual_seed(29)
Akshit Khurana8bd3e52e2021-07-09 09:24:55 -0700550 in_ch, out_ch, kernel = (5, 7, (2, 2))
551 input_dim = (4, 5, 3, 3)
Akshit Khurana8bd3e52e2021-07-09 09:24:55 -0700552 convert_dims = input_dim[:2] + (0, 0)
553
554 for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
555 with self.subTest(kind):
Akshit Khurana1de89762021-09-24 17:02:31 -0700556 inp = torch.randn(input_dim)
Akshit Khurana8bd3e52e2021-07-09 09:24:55 -0700557 model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel)
558 output_size = model(inp).numel()
559 atol_rtol = (0.0002, 0)
560 limit = None
561 convert_arg = torch.zeros(*convert_dims)
562
563 if "quant" in kind:
Akshit Khurana1de89762021-09-24 17:02:31 -0700564 model = torch.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel)
Vasiliy Kuznetsov227e37d2021-10-01 06:21:08 -0700565 model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
Akshit Khurana8bd3e52e2021-07-09 09:24:55 -0700566 inp = qpt(inp, 1.0 / 16, 128)
567 # I've seen numerical differences between QNNPACK and NNAPI,
Akshit Khurana1de89762021-09-24 17:02:31 -0700568 # but never more than 1 quantum, and never more than ~10% of
Akshit Khurana8bd3e52e2021-07-09 09:24:55 -0700569 # the output in this test.
570 atol_rtol = (1, 0)
Akshit Khurana1de89762021-09-24 17:02:31 -0700571 limit = output_size * 0.1
Akshit Khurana8bd3e52e2021-07-09 09:24:55 -0700572 convert_arg = qpt(convert_arg, 1.0 / 16, 128)
573
574 if "nhwc" in kind:
575 inp = nhwc(inp)
576 convert_arg = nhwc(convert_arg)
577
578 self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
579 self.check(
580 model,
581 inp,
582 convert_args=[convert_arg],
583 atol_rtol=atol_rtol,
Shen Li10224432021-08-12 11:39:31 -0700584 limit=limit
Akshit Khurana8bd3e52e2021-07-09 09:24:55 -0700585 )
586
Shen Li10224432021-08-12 11:39:31 -0700587
David Reiss3802edd2021-04-06 13:40:04 -0700588 def test_qadd(self):
589 func = torch.nn.quantized.QFunctional()
590 func.scale = 0.5
591 func.zero_point = 120
592
593 class AddMod(torch.nn.Module):
594 def forward(self, lhs, rhs):
595 return func.add(lhs, rhs)
596
597 class AddReluMod(torch.nn.Module):
598 def forward(self, lhs, rhs):
599 return func.add_relu(lhs, rhs)
600
Akshit Khuranaab5eb562021-09-24 17:02:31 -0700601 class MulMod(torch.nn.Module):
602 def forward(self, lhs, rhs):
603 return func.mul(lhs, rhs)
604
605 for (name, mod) in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]:
David Reiss3802edd2021-04-06 13:40:04 -0700606 with self.subTest(name):
607 self.check(
608 mod(),
609 [
610 qpt([1.0, 2.0], 0.25, 128),
611 qpt([3.0, 4.0], 0.25, 128),
Shen Li10224432021-08-12 11:39:31 -0700612 ])
Akshit Khurana4c609a92021-05-05 13:51:59 -0700613 self.check(
614 mod(),
615 [
616 qpt([[1.0, 2.0]], 0.25, 128),
617 qpt([[3.0, 4.0]], 0.25, 128),
618 ],
619 convert_args=[
620 qpt([[1.0, 2.0]], 0.25, 128),
621 qpt(torch.zeros((1, 2)), 0.25, 128),
Shen Li10224432021-08-12 11:39:31 -0700622 ]
Akshit Khurana4c609a92021-05-05 13:51:59 -0700623 )
624 self.check(
625 mod(),
626 [
627 qpt([[1.0, 2.0]], 0.25, 128),
628 qpt([[3.0, 4.0]], 0.25, 128),
629 ],
630 convert_args=[
631 qpt(torch.zeros((1, 2)), 0.25, 128),
632 qpt([[3.0, 4.0]], 0.25, 128),
Shen Li10224432021-08-12 11:39:31 -0700633 ]
Akshit Khurana4c609a92021-05-05 13:51:59 -0700634 )
635 self.check(
636 mod(),
637 [
638 qpt([[1.0, 2.0]], 0.25, 128),
639 qpt([[3.0, 4.0]], 0.25, 128),
640 ],
641 convert_args=[
642 qpt(torch.zeros((1, 2)), 0.25, 128),
643 qpt(torch.zeros((1, 2)), 0.25, 128),
Shen Li10224432021-08-12 11:39:31 -0700644 ]
Akshit Khurana4c609a92021-05-05 13:51:59 -0700645 )
David Reiss3802edd2021-04-06 13:40:04 -0700646 # NOTE: NNAPI qadd supports broadcast, but PT does not.
647
648 def test_qlinear(self):
649 torch.manual_seed(29)
650 weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8)
651 bias = torch.randn(16)
652 mod = torch.nn.quantized.Linear(32, 16)
653 mod.set_weight_bias(weight, bias)
654 inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
655 self.check(mod, inp)
656
David Reiss476c5972021-04-06 13:40:04 -0700657 def test_seblock_mul(self):
658 class MulModel(torch.nn.Module):
659 def forward(self, lhs, rhs):
660 return lhs * rhs
661
662 self.check(
663 MulModel(),
664 [
665 nhwc(torch.randn(2, 3, 4, 4)),
666 torch.randn(1, 3, 1, 1),
Shen Li10224432021-08-12 11:39:31 -0700667 ])
David Reiss476c5972021-04-06 13:40:04 -0700668
David Reiss3802edd2021-04-06 13:40:04 -0700669 def test_multi_output(self):
670 class MultiModel(torch.nn.Module):
671 def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
672 the_sum = lhs + rhs
673 the_diff = lhs - rhs
674 return the_sum, the_diff
675
676 self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])])
677
678
Shen Li10224432021-08-12 11:39:31 -0700679if __name__ == '__main__':
David Reiss3802edd2021-04-06 13:40:04 -0700680 run_tests()