blob: b3087eee18e06171a326aa2045483b3dce12f995 [file] [log] [blame]
Jane Xu62596012021-10-29 12:40:39 -07001# Owner(s): ["module: unknown"]
2
Hao Luccd09772021-07-10 14:04:48 -07003import unittest
4from typing import Dict, Optional
5
Hao Lue8d8de32020-10-06 20:52:29 -07006import numpy as np
Bram Wastiada84042020-08-12 13:02:29 -07007import torch
8from torch import nn
Hao Lu8538a792020-08-28 23:17:17 -07009from torch.testing._internal.common_utils import TestCase, run_tests
Akshay Parashar720cb502022-06-03 23:39:04 +000010from typing import List
Bram Wastiada84042020-08-12 13:02:29 -070011
Bram Wasti56f83792021-03-05 10:12:17 -080012class StaticModule:
Bram Wastiada84042020-08-12 13:02:29 -070013 def __init__(self, scripted):
14 # this is an nn.Module
15 if hasattr(scripted, "_c"):
Bram Wasti56f83792021-03-05 10:12:17 -080016 self.static_module = torch._C._jit_to_static_module(scripted._c)
Bram Wastiada84042020-08-12 13:02:29 -070017 else:
Bram Wasti56f83792021-03-05 10:12:17 -080018 self.static_module = torch._C._jit_to_static_module(scripted.graph)
Bram Wastiada84042020-08-12 13:02:29 -070019
Hao Lue8d8de32020-10-06 20:52:29 -070020 def __call__(self, *args, **kwargs):
Ansha Yu4635f572021-11-18 01:01:46 -080021 return self.static_module(*args, **kwargs)
Hao Lue8d8de32020-10-06 20:52:29 -070022
23 def benchmark(self, args, kwargs, warmup_runs, main_runs):
Bram Wasti56f83792021-03-05 10:12:17 -080024 self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
Hao Lue8d8de32020-10-06 20:52:29 -070025
Akshay Parasharfefdad62022-07-05 23:40:53 +000026 def runAsync(self, args, kwargs):
27 return self.static_module.runAsync(args, kwargs)
28
Hao Lue8d8de32020-10-06 20:52:29 -070029 def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
Bram Wasti56f83792021-03-05 10:12:17 -080030 return self.static_module.benchmark_individual_ops(
Hao Lue8d8de32020-10-06 20:52:29 -070031 args, kwargs, warmup_runs, main_runs
32 )
Bram Wastiada84042020-08-12 13:02:29 -070033
Bram Wastia4756132020-09-14 12:33:02 -070034
Hao Luccd09772021-07-10 14:04:48 -070035def linear_shim(
36 input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
37) -> torch.Tensor:
Bram Wasti523b2ce2020-08-13 20:16:57 -070038 output = input.matmul(weight.t())
39 if bias is not None:
40 output += bias
41 ret = output
42 return ret
Bram Wastia4756132020-09-14 12:33:02 -070043
44
Bram Wasti523b2ce2020-08-13 20:16:57 -070045torch.nn.functional.linear = linear_shim
46
Bram Wastiada84042020-08-12 13:02:29 -070047
48class MultiHeadAttentionLayer(nn.Module):
49 def __init__(self, hid_dim, n_heads, dropout, device):
50 super().__init__()
51 assert hid_dim % n_heads == 0
52 self.hid_dim = hid_dim
53 self.n_heads = n_heads
54 self.head_dim = hid_dim // n_heads
55 self.fc_q = nn.Linear(hid_dim, hid_dim)
56 self.fc_k = nn.Linear(hid_dim, hid_dim)
57 self.fc_v = nn.Linear(hid_dim, hid_dim)
58 self.fc_o = nn.Linear(hid_dim, hid_dim)
59 # self.dropout = nn.Dropout(dropout)
60 self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
61
62 def forward(self, query, key, value, mask):
63 batch_size = query.shape[0]
64 Q = self.fc_q(query)
65 K = self.fc_k(key)
66 V = self.fc_v(value)
67 Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
68 K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
69 V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
70 energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
71 # energy = energy.masked_fill(mask == 0, -1e10)
72 attention = torch.softmax(energy, dim=-1)
73 # x = torch.matmul(self.dropout(attention), V)
74 x = torch.matmul(attention, V)
75 x = x.permute(0, 2, 1, 3).contiguous()
76 x = x.view(batch_size, -1, self.hid_dim)
77 x = self.fc_o(x)
78 return x, attention
79
80
81# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
82def create_mlp(ln, sigmoid_layer):
83 layers = nn.ModuleList()
84 for i in range(0, len(ln) - 1):
85 n = ln[i]
86 m = ln[i + 1]
87
88 LL = nn.Linear(int(n), int(m), bias=True)
89
90 mean = 0.0 # std_dev = np.sqrt(variance)
91 std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
92 W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
93 std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
94 bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
95 LL.weight.data = torch.tensor(W, requires_grad=True)
96 LL.bias.data = torch.tensor(bt, requires_grad=True)
97 layers.append(LL)
98
99 if i == sigmoid_layer:
100 layers.append(nn.Sigmoid())
101 else:
102 layers.append(nn.ReLU())
103
104 with torch.no_grad():
105 s = torch.jit.script(torch.nn.Sequential(*layers))
106 s.eval()
107 return s
108
109
110def trivial_graph(a, b, c):
111 s = torch.tensor([[3, 3], [3, 3]])
112 return a + b * c + s
113
Akshay Parashar720cb502022-06-03 23:39:04 +0000114def elementwise_square_addition(input1, input2):
115 return input1 * input1 + input2 * input2
116
117def fork_wait_graph1(input1, input2):
118 fut = torch.jit.fork(elementwise_square_addition, input1, input2)
119 return torch.jit.wait(fut)
120
121def fork_wait_graph2(input1, input2):
122 fut = torch.jit.fork(loop_graph, input1, input2, 5)
123 return torch.jit.wait(fut)
124
Akshay Parashar49368d92022-06-20 16:32:17 +0000125"""
126 graph with multiple fork/wait operations
127 :param input: torch.tensor input to forked subgraph
128 :param iters: number of future/wait pairs to be created
129"""
130def fork_wait_graph3(input, iters: int):
Akshay Parashar720cb502022-06-03 23:39:04 +0000131 futures : List[torch.jit.Future[torch.Tensor]] = []
Akshay Parashar49368d92022-06-20 16:32:17 +0000132 for _ in range(iters):
Akshay Parashar720cb502022-06-03 23:39:04 +0000133 futures.append(torch.jit.fork(torch.neg, input))
134 results = []
135 for future in futures:
136 results.append(torch.jit.wait(future))
137 return torch.sum(torch.stack(results))
Hao Luccd09772021-07-10 14:04:48 -0700138
Akshay Parashar49368d92022-06-20 16:32:17 +0000139"""
140 graph with multi-level fork/wait operations
141 :param input: torch.tensor input to forked subgraph
142 :param num_forks: number of top level forks
143 :param num_child_forks: number of child forks per parent fork
144"""
145def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
146 futures : List[torch.jit.Future[torch.Tensor]] = []
147 for _ in range(num_forks):
148 futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
149 results = []
150 for future in futures:
151 results.append(torch.jit.wait(future))
152 return torch.sum(torch.stack(results))
153
Akshay Parashar65a37922022-06-11 03:11:49 +0000154def add_tensor(input1, input2):
155 return input1 + input2
156
157def fork_wait_graph_exception(input1, input2):
158 fut = torch.jit.fork(add_tensor, input1, input2)
159 return torch.jit.wait(fut)
160
Hao Luccd09772021-07-10 14:04:48 -0700161def loop_graph(a, b, iters: int):
Bram Wastif4226b52020-12-10 14:01:36 -0800162 c = a + b * 2
163 for i in range(iters):
164 c = c + b
165 c *= 2
166 c -= a
167 return c
168
Hao Luccd09772021-07-10 14:04:48 -0700169
170def output_graph(a, b, c, iters: int):
Bram Wastif4226b52020-12-10 14:01:36 -0800171 s = torch.tensor([[3, 3], [3, 3]])
172 k = a + b * c + s
Hao Luccd09772021-07-10 14:04:48 -0700173 d: Dict[int, torch.Tensor] = {}
Bram Wastif4226b52020-12-10 14:01:36 -0800174 for i in range(iters):
175 d[i] = k + i
176 return d
Bram Wastia4756132020-09-14 12:33:02 -0700177
Hao Luccd09772021-07-10 14:04:48 -0700178
179class SubModule(nn.Module):
180 def __init__(self):
181 super(SubModule, self).__init__()
182 self.a = 11
183 self.b = 2
184
185 def forward(self, x):
186 return self.a + self.b + x
187
188
189class SubModule2(nn.Module):
190 def __init__(self):
191 super(SubModule2, self).__init__()
192 self.a = 12
193 self.b = 2
194
195 def forward(self, x):
196 self.b = 30
197 return self.a + self.b + x
198
199
200class TestModule(nn.Module):
201 def __init__(self):
202 super(TestModule, self).__init__()
203 self.sub1 = SubModule()
204 self.sub2 = SubModule2()
205 self.a = 3
206 self.b = 4
207
208 def forward(self, x):
209 self.b = 20
210 return self.sub1(x) + self.a + self.b + self.sub2(x)
211
212
Bram Wasti56f83792021-03-05 10:12:17 -0800213class TestStaticModule(TestCase):
Akshay Parashar720cb502022-06-03 23:39:04 +0000214
215 """
216 Test Case: To test simple fork/wait operation in a graph
217 fork is called on simple addition operation on input tensors
218 """
219 def test_fork_wait_1(self):
220 inp1 = torch.ones(5, 5)
221 inp2 = torch.randn(5, 5)
222 torch_graph = torch.jit.script(fork_wait_graph1)
223 output_ref = torch_graph(inp1, inp2)
224 static_runtime_module = StaticModule(torch_graph)
225 output_test = static_runtime_module(inp1, inp2)
226 torch.testing.assert_close(output_test, output_ref)
227
228 """
Akshay Parasharfefdad62022-07-05 23:40:53 +0000229 Test Case: To test simple fork/wait operation with
230 StaticRuntime runAsync API returning future
231 """
232 def test_fork_wait_1_async(self):
233 inp1 = torch.ones(5, 5)
234 inp2 = torch.randn(5, 5)
235 torch_graph = torch.jit.script(fork_wait_graph1)
236 output_ref = torch_graph(inp1, inp2)
237 static_runtime_module = StaticModule(torch_graph)
238 output_test = static_runtime_module.runAsync((inp1, inp2), {})
239 output_test.wait()
240 torch.testing.assert_close(output_test.value(), output_ref)
241
242 """
Akshay Parashar720cb502022-06-03 23:39:04 +0000243 Test Case: To test fork/wait operation in a graph on
244 a loop subgraph performing mix of operations
245 """
246 def test_fork_wait_2(self):
247 inp1 = torch.randn(5, 5)
248 inp2 = torch.randn(5, 5)
249 torch_graph = torch.jit.script(fork_wait_graph2)
250 output_ref = torch_graph(inp1, inp2)
251 static_runtime_module = StaticModule(torch_graph)
252 output_test = static_runtime_module(inp1, inp2)
253 torch.testing.assert_close(output_test, output_ref)
254
255 """
Akshay Parasharfefdad62022-07-05 23:40:53 +0000256 Test Case: To test fork/wait operation on a loop
257 subgraph with StaticRuntime runAsync API returning future
258 """
259 def test_fork_wait_2_async(self):
260 inp1 = torch.randn(5, 5)
261 inp2 = torch.randn(5, 5)
262 torch_graph = torch.jit.script(fork_wait_graph2)
263 output_ref = torch_graph(inp1, inp2)
264 static_runtime_module = StaticModule(torch_graph)
265 output_test = static_runtime_module.runAsync((inp1, inp2), {})
266 output_test.wait()
267 torch.testing.assert_close(output_test.value(), output_ref)
268
269 """
Akshay Parashar720cb502022-06-03 23:39:04 +0000270 Test Case: To test fork/wait operation in a graph on
271 having multiple fork/wait operations
272 """
273 def test_fork_wait_3(self):
274 input = torch.ones(3, 3)
Akshay Parashar49368d92022-06-20 16:32:17 +0000275 num_forks = 10
Akshay Parashar720cb502022-06-03 23:39:04 +0000276 torch_graph = torch.jit.script(fork_wait_graph3)
Akshay Parashar49368d92022-06-20 16:32:17 +0000277 output_ref = torch_graph(input, num_forks)
Akshay Parashar720cb502022-06-03 23:39:04 +0000278 static_runtime_module = StaticModule(torch_graph)
Akshay Parashar49368d92022-06-20 16:32:17 +0000279 output_test = static_runtime_module(input, num_forks)
280 torch.testing.assert_close(output_test, output_ref)
Akshay Parasharfefdad62022-07-05 23:40:53 +0000281
282 """
283 Test Case: To test fork/wait operation in a graph with
284 multiple fork/wait operations on runAsync API returning future
285 """
286 def test_fork_wait_3_async(self):
287 input = torch.ones(3, 3)
288 num_forks = 10
289 torch_graph = torch.jit.script(fork_wait_graph3)
290 output_ref = torch_graph(input, num_forks)
291 static_runtime_module = StaticModule(torch_graph)
292 output_test = static_runtime_module.runAsync((input, num_forks), {})
293 output_test.wait()
294 torch.testing.assert_close(output_test.value(), output_ref)
295
Akshay Parashar49368d92022-06-20 16:32:17 +0000296 """
297 Test Case: To test fork/wait operation in a graph on
298 multiple nested fork/wait operations
299 """
300 def test_fork_wait_4(self):
301 input = torch.ones(3, 3)
302 num_forks = 10
303 num_child_forks = 10
304 torch_graph = torch.jit.script(fork_wait_graph4)
305 static_runtime_module = StaticModule(torch_graph)
306 output_ref = torch_graph(input, num_forks, num_child_forks)
307 output_test = static_runtime_module(input, num_forks, num_child_forks)
Akshay Parashar720cb502022-06-03 23:39:04 +0000308 torch.testing.assert_close(output_test, output_ref)
309
Akshay Parashar65a37922022-06-11 03:11:49 +0000310 """
Akshay Parasharfefdad62022-07-05 23:40:53 +0000311 Test Case: To test fork/wait operation in a graph with multiple
312 nested fork/wait operations on runAsync API returning future
313 """
314 def test_fork_wait_4_async(self):
315 input = torch.ones(3, 3)
316 num_forks = 10
317 num_child_forks = 10
318 torch_graph = torch.jit.script(fork_wait_graph4)
319 static_runtime_module = StaticModule(torch_graph)
320 output_ref = torch_graph(input, num_forks, num_child_forks)
321 output_test = static_runtime_module.runAsync(
322 (input, num_forks, num_child_forks), {})
323 output_test.wait()
324 torch.testing.assert_close(output_test.value(), output_ref)
325
326 """
Akshay Parashar65a37922022-06-11 03:11:49 +0000327 Test Case: To test exception handling in fork/wait
328 operation. Add.Tensor op is called for tensors with
329 non-matching dims on the forked subgraph and the
330 exception raised by subgraph is set on future returned
331 by prim::fork to parent graph. Returned exception is
332 checked for substring expected_error_msg as declared below
333 """
334 def test_fork_wait_exception(self):
335 # incompatible tensors for add due to shape mismatch
336 input1 = torch.randn(4, 7)
337 input2 = torch.randn(4, 5)
338 torch_graph = torch.jit.script(fork_wait_graph_exception)
339 try:
340 static_runtime_module = StaticModule(torch_graph)
341 output_test = static_runtime_module(input1, input2)
342 except Exception as error:
343 expected_error_msg = (
344 "The size of tensor a (7) must match the size "
345 "of tensor b (5) at non-singleton dimension 1"
346 )
347 # test fails if error does not contain expected substr
348 if str(error).find(expected_error_msg) == -1:
349 raise RuntimeError(
350 "Tried execution of add.Tensors with incompatible shape. "
351 "Exception raised by forked runtime execution does "
352 f"not contain expected substring: \"{expected_error_msg}\""
353 ) from error
354
Akshay Parasharfefdad62022-07-05 23:40:53 +0000355 """
356 Test Case: To test exception handling in fork/wait
357 operation with runAsync API. Add.Tensor op is called for
358 tensors with non-matching dims on the forked subgraph
359 and the exception raised by subgraph is set on future returned
360 by prim::fork to parent graph. Returned exception is
361 checked for substring expected_error_msg as declared below
362 """
363 def test_fork_wait_exception_async(self):
364 # incompatible tensors for add due to shape mismatch
365 input1 = torch.randn(4, 7)
366 input2 = torch.randn(4, 5)
367 torch_graph = torch.jit.script(fork_wait_graph_exception)
368 try:
369 static_runtime_module = StaticModule(torch_graph)
370 output_test = static_runtime_module.runAsync(
371 (input1, input2), {})
372 except Exception as error:
373 expected_error_msg = (
374 "The size of tensor a (7) must match the size "
375 "of tensor b (5) at non-singleton dimension 1"
376 )
377 # test fails if error does not contain expected substr
378 if str(error).find(expected_error_msg) == -1:
379 raise RuntimeError(
380 "Tried execution of add.Tensors with incompatible shape. "
381 "Exception raised by forked runtime execution does "
382 f"not contain expected substring: \"{expected_error_msg}\""
383 ) from error
384
Hao Lu8538a792020-08-28 23:17:17 -0700385 def test_multihead_attention_layer(self):
386 HID_DIM = 256
387 QUERY_LEN = 8
388 BATCH_SIZE = 128
389 LAYERS = 3
390 HEADS = 8
391 DROPOUT = 0.1
392 device = torch.device("cpu")
393 attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
Bram Wastid1a11612020-09-25 11:01:10 -0700394 with torch.no_grad():
395 src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
Hao Lu8538a792020-08-28 23:17:17 -0700396 src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
397
398 attention.eval()
399 attention = torch.jit.script(attention)
400 attention.eval()
401 o_ref = attention(src, src, src, src_mask)
402
Bram Wasti56f83792021-03-05 10:12:17 -0800403 attention_a = StaticModule(attention)
Hao Lu8538a792020-08-28 23:17:17 -0700404 o_test = attention_a(src, src, src, src_mask)
Hao Lue8d8de32020-10-06 20:52:29 -0700405 o_test_kw = attention_a(src, src, value=src, mask=src_mask)
Hao Lu996f4442020-11-03 23:42:24 -0800406
Hao Lu8538a792020-08-28 23:17:17 -0700407 for a, b in zip(o_ref, o_test):
Philip Meier99203582021-08-19 12:45:32 -0700408 torch.testing.assert_close(a, b)
Hao Lu996f4442020-11-03 23:42:24 -0800409
Hao Lue8d8de32020-10-06 20:52:29 -0700410 for a, b in zip(o_ref, o_test_kw):
Philip Meier99203582021-08-19 12:45:32 -0700411 torch.testing.assert_close(a, b)
Hao Lue8d8de32020-10-06 20:52:29 -0700412
413 def test_multihead_attention_layer_benchmark(self):
414 HID_DIM = 256
415 QUERY_LEN = 8
416 BATCH_SIZE = 128
417 LAYERS = 3
418 HEADS = 8
419 DROPOUT = 0.1
420 device = torch.device("cpu")
421 attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
422 with torch.no_grad():
423 src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
424 src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
425
426 attention.eval()
427 attention = torch.jit.script(attention)
Bram Wasti56f83792021-03-05 10:12:17 -0800428 attention_a = StaticModule(attention)
Hao Lue8d8de32020-10-06 20:52:29 -0700429
Hao Lu996f4442020-11-03 23:42:24 -0800430 attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
Hao Lue8d8de32020-10-06 20:52:29 -0700431 metrics = attention_a.benchmark_individual_ops(
Hao Lu996f4442020-11-03 23:42:24 -0800432 [src, src, src, src_mask], {}, 2, 2
Hao Lue8d8de32020-10-06 20:52:29 -0700433 )
Hao Lu8538a792020-08-28 23:17:17 -0700434
435 def test_mlp(self):
436 # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
437 ln_bot = [512, 512, 64]
438 sigmoid_bot = -1
439 ln_top = [100, 1024, 1024, 1024, 1]
440 sigmoid_top = 3
441 bot_l = create_mlp(ln_bot, sigmoid_bot)
Bram Wasti56f83792021-03-05 10:12:17 -0800442 bot_l_acc = StaticModule(bot_l)
Hao Lu8538a792020-08-28 23:17:17 -0700443 top_l = create_mlp(ln_top, sigmoid_top)
Bram Wasti56f83792021-03-05 10:12:17 -0800444 top_l_acc = StaticModule(top_l)
Bram Wastid1a11612020-09-25 11:01:10 -0700445 with torch.no_grad():
446 bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
447 top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
Hao Lu8538a792020-08-28 23:17:17 -0700448 ref_bot = bot_l(bot_inp)
Mike Iovinea0495b32021-10-25 08:16:14 -0700449 acc_bot = bot_l_acc(bot_inp)
Philip Meier99203582021-08-19 12:45:32 -0700450 torch.testing.assert_close(acc_bot, ref_bot)
Hao Lu8538a792020-08-28 23:17:17 -0700451 ref_top = top_l(top_inp)
Mike Iovinea0495b32021-10-25 08:16:14 -0700452 acc_top = top_l_acc(top_inp)
Philip Meier99203582021-08-19 12:45:32 -0700453 torch.testing.assert_close(acc_top, ref_top)
Bram Wastia4756132020-09-14 12:33:02 -0700454 for _ in range(5):
Bram Wastid1a11612020-09-25 11:01:10 -0700455 with torch.no_grad():
456 bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
457 top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
Bram Wastia4756132020-09-14 12:33:02 -0700458 ref_bot = bot_l(bot_inp)
Mike Iovinea0495b32021-10-25 08:16:14 -0700459 acc_bot = bot_l_acc(bot_inp)
Philip Meier99203582021-08-19 12:45:32 -0700460 torch.testing.assert_close(acc_bot, ref_bot)
Bram Wastia4756132020-09-14 12:33:02 -0700461 ref_top = top_l(top_inp)
Mike Iovinea0495b32021-10-25 08:16:14 -0700462 acc_top = top_l_acc(top_inp)
Philip Meier99203582021-08-19 12:45:32 -0700463 torch.testing.assert_close(acc_top, ref_top)
Hao Lu8538a792020-08-28 23:17:17 -0700464
Bram Wasti87b356d2020-09-28 12:53:59 -0700465 def test_trivial_graph(self):
466 s = torch.full((2, 2), 2)
467 tg = torch.jit.script(trivial_graph)
468 o_ref = tg(s, s, s)
Bram Wasti56f83792021-03-05 10:12:17 -0800469 tg_a = StaticModule(tg)
Mike Iovinea0495b32021-10-25 08:16:14 -0700470 o_test = tg_a(s, s, s)
Philip Meier99203582021-08-19 12:45:32 -0700471 torch.testing.assert_close(o_ref, o_test)
Bram Wastiada84042020-08-12 13:02:29 -0700472
Katy Voorfe7d1d72020-11-13 22:04:06 -0800473 def test_leaky_relu(self):
474 s = torch.randn(5, 5)
475 tg = torch.jit.script(nn.LeakyReLU(0.1))
476 o_ref = tg(s)
Bram Wasti56f83792021-03-05 10:12:17 -0800477 tg_a = StaticModule(tg)
Mike Iovinea0495b32021-10-25 08:16:14 -0700478 o_test = tg_a(s)
Philip Meier99203582021-08-19 12:45:32 -0700479 torch.testing.assert_close(o_ref, o_test)
Bram Wastia4756132020-09-14 12:33:02 -0700480
Hao Luccd09772021-07-10 14:04:48 -0700481 def test_attr(self):
482 """
483 TorchScript IR of TestModule() after freezing:
484 graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
485 %x.1 : Tensor):
486 %18 : int = prim::Constant[value=30]()
487 %30 : int = prim::Constant[value=13]()
488 %3 : int = prim::Constant[value=20]()
489 %2 : int = prim::Constant[value=1]()
490 %self.sub2.a : int = prim::Constant[value=12]()
491 %self.a : int = prim::Constant[value=3]()
492 = prim::SetAttr[name="b"](%self, %3)
493 %17 : Tensor = aten::add(%x.1, %30, %2)
494 %7 : Tensor = aten::add(%17, %self.a, %2)
495 %b.1 : int = prim::GetAttr[name="b"](%self)
496 %9 : Tensor = aten::add(%7, %b.1, %2)
497 %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
498 = prim::SetAttr[name="b"](%sub2, %18)
499 %b : int = prim::GetAttr[name="b"](%sub2)
500 %22 : int = aten::add(%self.sub2.a, %b)
501 %23 : Tensor = aten::add(%x.1, %22, %2)
502 %12 : Tensor = aten::add(%9, %23, %2)
503 return (%12)
504 """
505 # test prim::SetAttr and prim::GetAttr impl in Static Runtime
506 m = TestModule()
507
508 m.eval()
509 input = torch.randn(2, 2)
510 output_s = m.forward(input)
511
512 ms = torch.jit.script(m)
513 sm = StaticModule(ms)
Mike Iovinea0495b32021-10-25 08:16:14 -0700514 output_sm = sm(input)
Philip Meier99203582021-08-19 12:45:32 -0700515 torch.testing.assert_close(output_s, output_sm)
Hao Luccd09772021-07-10 14:04:48 -0700516 sm.benchmark([input], {}, 2, 2)
517 sm.benchmark_individual_ops([input], {}, 2, 2)
518 sm.benchmark([], {"x": input}, 2, 2)
519 sm.benchmark_individual_ops([], {"x": input}, 2, 2)
520
Hao Lu0521e422021-04-05 20:50:39 -0700521 @unittest.skip("Temporarily disabled")
Bram Wastif4226b52020-12-10 14:01:36 -0800522 def test_fusion_trivial_graph(self):
523 s = torch.full((2, 2), 2)
524 tg = torch.jit.script(trivial_graph)
525 o_ref = tg(s, s, s)
Bram Wasti56f83792021-03-05 10:12:17 -0800526 torch._C._fuse_to_static_module(tg.graph)
Bram Wastif4226b52020-12-10 14:01:36 -0800527 assert "StaticSubgraph" in str(tg.graph)
528 o_test = tg(s, s, s)
Philip Meier99203582021-08-19 12:45:32 -0700529 torch.testing.assert_close(o_ref, o_test)
Bram Wastif4226b52020-12-10 14:01:36 -0800530
Hao Lu0521e422021-04-05 20:50:39 -0700531 @unittest.skip("Temporarily disabled")
Bram Wastif4226b52020-12-10 14:01:36 -0800532 def test_fusion_multihead_attention_layer(self):
533 HID_DIM = 256
534 QUERY_LEN = 8
535 BATCH_SIZE = 128
536 LAYERS = 3
537 HEADS = 8
538 DROPOUT = 0.1
539 device = torch.device("cpu")
540 attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
541 with torch.no_grad():
542 src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
543 src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
544
545 attention.eval()
546 attention = torch.jit.script(attention)
547 attention.eval()
548 o_ref = attention(src, src, src, src_mask)
549
Bram Wasti56f83792021-03-05 10:12:17 -0800550 torch._C._fuse_to_static_module(attention._c)
Bram Wastif4226b52020-12-10 14:01:36 -0800551 o_test = attention(src, src, src, src_mask)
552
553 for a, b in zip(o_ref, o_test):
Philip Meier99203582021-08-19 12:45:32 -0700554 torch.testing.assert_close(a, b)
Bram Wastif4226b52020-12-10 14:01:36 -0800555
Hao Lu0521e422021-04-05 20:50:39 -0700556 @unittest.skip("Temporarily disabled")
Bram Wastif4226b52020-12-10 14:01:36 -0800557 def test_fusion_loop(self):
558 a = torch.randn(5, 5)
559 b = torch.randn(5, 5)
560 c = 4
561 lg = torch.jit.script(loop_graph)
562 o_ref = lg(a, b, c)
Bram Wasti56f83792021-03-05 10:12:17 -0800563 torch._C._fuse_to_static_module(lg.graph)
Bram Wastif4226b52020-12-10 14:01:36 -0800564 assert "StaticSubgraph" in str(lg.graph)
565 o_test = lg(a, b, c)
Philip Meier99203582021-08-19 12:45:32 -0700566 torch.testing.assert_close(o_ref, o_test)
Bram Wastif4226b52020-12-10 14:01:36 -0800567
Hao Lu0521e422021-04-05 20:50:39 -0700568 @unittest.skip("Temporarily disabled")
Bram Wastif4226b52020-12-10 14:01:36 -0800569 def test_fusion_outputs(self):
570 a = torch.randn(2, 2)
571 b = torch.randn(2, 2)
572 c = 4
573 og = torch.jit.script(output_graph)
574 o_ref = og(a, b, b, c)
Bram Wasti56f83792021-03-05 10:12:17 -0800575 torch._C._fuse_to_static_module(og.graph)
Bram Wastif4226b52020-12-10 14:01:36 -0800576 assert "StaticSubgraph" in str(og.graph)
577 o_test = og(a, b, b, c)
578 for i in o_ref.keys():
Philip Meier99203582021-08-19 12:45:32 -0700579 torch.testing.assert_close(o_ref[i], o_test[i])
Bram Wastif4226b52020-12-10 14:01:36 -0800580
Mike Iovine0bb31582022-02-03 04:13:51 -0800581 def test_create_object(self):
Nikita Shulgab2680122022-02-03 07:13:09 -0800582 class Foo: # noqa: B903
Mike Iovine0bb31582022-02-03 04:13:51 -0800583 def __init__(self, x: torch.Tensor) -> None:
584 self.x = x
585
586 class Mod(torch.nn.Module):
587 def __init__(self) -> None:
588 super().__init__()
589
590 def forward(self, y: torch.Tensor) -> torch.Tensor:
591 foo = Foo(y)
592 return y * foo.x
593
594 mod = torch.jit.script(Mod()).eval()
595 y = torch.randn((1, ))
596 expected = mod(y)
597
598 static_mod = StaticModule(torch.jit.freeze(mod))
599 actual = static_mod(y)
600
601 self.assertEqual(expected, actual)
Bram Wastif4226b52020-12-10 14:01:36 -0800602
Bram Wastiada84042020-08-12 13:02:29 -0700603if __name__ == "__main__":
Hao Lu8538a792020-08-28 23:17:17 -0700604 run_tests()