Jane Xu | 6259601 | 2021-10-29 12:40:39 -0700 | [diff] [blame] | 1 | # Owner(s): ["module: unknown"] |
| 2 | |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 3 | import unittest |
| 4 | from typing import Dict, Optional |
| 5 | |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 6 | import numpy as np |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 7 | import torch |
| 8 | from torch import nn |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 9 | from torch.testing._internal.common_utils import TestCase, run_tests |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 10 | from typing import List |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 11 | |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 12 | class StaticModule: |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 13 | def __init__(self, scripted): |
| 14 | # this is an nn.Module |
| 15 | if hasattr(scripted, "_c"): |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 16 | self.static_module = torch._C._jit_to_static_module(scripted._c) |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 17 | else: |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 18 | self.static_module = torch._C._jit_to_static_module(scripted.graph) |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 19 | |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 20 | def __call__(self, *args, **kwargs): |
Ansha Yu | 4635f57 | 2021-11-18 01:01:46 -0800 | [diff] [blame] | 21 | return self.static_module(*args, **kwargs) |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 22 | |
| 23 | def benchmark(self, args, kwargs, warmup_runs, main_runs): |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 24 | self.static_module.benchmark(args, kwargs, warmup_runs, main_runs) |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 25 | |
Akshay Parashar | fefdad6 | 2022-07-05 23:40:53 +0000 | [diff] [blame] | 26 | def runAsync(self, args, kwargs): |
| 27 | return self.static_module.runAsync(args, kwargs) |
| 28 | |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 29 | def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs): |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 30 | return self.static_module.benchmark_individual_ops( |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 31 | args, kwargs, warmup_runs, main_runs |
| 32 | ) |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 33 | |
Bram Wasti | a475613 | 2020-09-14 12:33:02 -0700 | [diff] [blame] | 34 | |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 35 | def linear_shim( |
| 36 | input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None |
| 37 | ) -> torch.Tensor: |
Bram Wasti | 523b2ce | 2020-08-13 20:16:57 -0700 | [diff] [blame] | 38 | output = input.matmul(weight.t()) |
| 39 | if bias is not None: |
| 40 | output += bias |
| 41 | ret = output |
| 42 | return ret |
Bram Wasti | a475613 | 2020-09-14 12:33:02 -0700 | [diff] [blame] | 43 | |
| 44 | |
Bram Wasti | 523b2ce | 2020-08-13 20:16:57 -0700 | [diff] [blame] | 45 | torch.nn.functional.linear = linear_shim |
| 46 | |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 47 | |
| 48 | class MultiHeadAttentionLayer(nn.Module): |
| 49 | def __init__(self, hid_dim, n_heads, dropout, device): |
| 50 | super().__init__() |
| 51 | assert hid_dim % n_heads == 0 |
| 52 | self.hid_dim = hid_dim |
| 53 | self.n_heads = n_heads |
| 54 | self.head_dim = hid_dim // n_heads |
| 55 | self.fc_q = nn.Linear(hid_dim, hid_dim) |
| 56 | self.fc_k = nn.Linear(hid_dim, hid_dim) |
| 57 | self.fc_v = nn.Linear(hid_dim, hid_dim) |
| 58 | self.fc_o = nn.Linear(hid_dim, hid_dim) |
| 59 | # self.dropout = nn.Dropout(dropout) |
| 60 | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) |
| 61 | |
| 62 | def forward(self, query, key, value, mask): |
| 63 | batch_size = query.shape[0] |
| 64 | Q = self.fc_q(query) |
| 65 | K = self.fc_k(key) |
| 66 | V = self.fc_v(value) |
| 67 | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
| 68 | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
| 69 | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
| 70 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale |
| 71 | # energy = energy.masked_fill(mask == 0, -1e10) |
| 72 | attention = torch.softmax(energy, dim=-1) |
| 73 | # x = torch.matmul(self.dropout(attention), V) |
| 74 | x = torch.matmul(attention, V) |
| 75 | x = x.permute(0, 2, 1, 3).contiguous() |
| 76 | x = x.view(batch_size, -1, self.hid_dim) |
| 77 | x = self.fc_o(x) |
| 78 | return x, attention |
| 79 | |
| 80 | |
| 81 | # Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py |
| 82 | def create_mlp(ln, sigmoid_layer): |
| 83 | layers = nn.ModuleList() |
| 84 | for i in range(0, len(ln) - 1): |
| 85 | n = ln[i] |
| 86 | m = ln[i + 1] |
| 87 | |
| 88 | LL = nn.Linear(int(n), int(m), bias=True) |
| 89 | |
| 90 | mean = 0.0 # std_dev = np.sqrt(variance) |
| 91 | std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) |
| 92 | W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) |
| 93 | std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) |
| 94 | bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) |
| 95 | LL.weight.data = torch.tensor(W, requires_grad=True) |
| 96 | LL.bias.data = torch.tensor(bt, requires_grad=True) |
| 97 | layers.append(LL) |
| 98 | |
| 99 | if i == sigmoid_layer: |
| 100 | layers.append(nn.Sigmoid()) |
| 101 | else: |
| 102 | layers.append(nn.ReLU()) |
| 103 | |
| 104 | with torch.no_grad(): |
| 105 | s = torch.jit.script(torch.nn.Sequential(*layers)) |
| 106 | s.eval() |
| 107 | return s |
| 108 | |
| 109 | |
| 110 | def trivial_graph(a, b, c): |
| 111 | s = torch.tensor([[3, 3], [3, 3]]) |
| 112 | return a + b * c + s |
| 113 | |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 114 | def elementwise_square_addition(input1, input2): |
| 115 | return input1 * input1 + input2 * input2 |
| 116 | |
| 117 | def fork_wait_graph1(input1, input2): |
| 118 | fut = torch.jit.fork(elementwise_square_addition, input1, input2) |
| 119 | return torch.jit.wait(fut) |
| 120 | |
| 121 | def fork_wait_graph2(input1, input2): |
| 122 | fut = torch.jit.fork(loop_graph, input1, input2, 5) |
| 123 | return torch.jit.wait(fut) |
| 124 | |
Akshay Parashar | 49368d9 | 2022-06-20 16:32:17 +0000 | [diff] [blame] | 125 | """ |
| 126 | graph with multiple fork/wait operations |
| 127 | :param input: torch.tensor input to forked subgraph |
| 128 | :param iters: number of future/wait pairs to be created |
| 129 | """ |
| 130 | def fork_wait_graph3(input, iters: int): |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 131 | futures : List[torch.jit.Future[torch.Tensor]] = [] |
Akshay Parashar | 49368d9 | 2022-06-20 16:32:17 +0000 | [diff] [blame] | 132 | for _ in range(iters): |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 133 | futures.append(torch.jit.fork(torch.neg, input)) |
| 134 | results = [] |
| 135 | for future in futures: |
| 136 | results.append(torch.jit.wait(future)) |
| 137 | return torch.sum(torch.stack(results)) |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 138 | |
Akshay Parashar | 49368d9 | 2022-06-20 16:32:17 +0000 | [diff] [blame] | 139 | """ |
| 140 | graph with multi-level fork/wait operations |
| 141 | :param input: torch.tensor input to forked subgraph |
| 142 | :param num_forks: number of top level forks |
| 143 | :param num_child_forks: number of child forks per parent fork |
| 144 | """ |
| 145 | def fork_wait_graph4(input, num_forks: int, num_child_forks: int): |
| 146 | futures : List[torch.jit.Future[torch.Tensor]] = [] |
| 147 | for _ in range(num_forks): |
| 148 | futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks)) |
| 149 | results = [] |
| 150 | for future in futures: |
| 151 | results.append(torch.jit.wait(future)) |
| 152 | return torch.sum(torch.stack(results)) |
| 153 | |
Akshay Parashar | 65a3792 | 2022-06-11 03:11:49 +0000 | [diff] [blame] | 154 | def add_tensor(input1, input2): |
| 155 | return input1 + input2 |
| 156 | |
| 157 | def fork_wait_graph_exception(input1, input2): |
| 158 | fut = torch.jit.fork(add_tensor, input1, input2) |
| 159 | return torch.jit.wait(fut) |
| 160 | |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 161 | def loop_graph(a, b, iters: int): |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 162 | c = a + b * 2 |
| 163 | for i in range(iters): |
| 164 | c = c + b |
| 165 | c *= 2 |
| 166 | c -= a |
| 167 | return c |
| 168 | |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 169 | |
| 170 | def output_graph(a, b, c, iters: int): |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 171 | s = torch.tensor([[3, 3], [3, 3]]) |
| 172 | k = a + b * c + s |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 173 | d: Dict[int, torch.Tensor] = {} |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 174 | for i in range(iters): |
| 175 | d[i] = k + i |
| 176 | return d |
Bram Wasti | a475613 | 2020-09-14 12:33:02 -0700 | [diff] [blame] | 177 | |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 178 | |
| 179 | class SubModule(nn.Module): |
| 180 | def __init__(self): |
| 181 | super(SubModule, self).__init__() |
| 182 | self.a = 11 |
| 183 | self.b = 2 |
| 184 | |
| 185 | def forward(self, x): |
| 186 | return self.a + self.b + x |
| 187 | |
| 188 | |
| 189 | class SubModule2(nn.Module): |
| 190 | def __init__(self): |
| 191 | super(SubModule2, self).__init__() |
| 192 | self.a = 12 |
| 193 | self.b = 2 |
| 194 | |
| 195 | def forward(self, x): |
| 196 | self.b = 30 |
| 197 | return self.a + self.b + x |
| 198 | |
| 199 | |
| 200 | class TestModule(nn.Module): |
| 201 | def __init__(self): |
| 202 | super(TestModule, self).__init__() |
| 203 | self.sub1 = SubModule() |
| 204 | self.sub2 = SubModule2() |
| 205 | self.a = 3 |
| 206 | self.b = 4 |
| 207 | |
| 208 | def forward(self, x): |
| 209 | self.b = 20 |
| 210 | return self.sub1(x) + self.a + self.b + self.sub2(x) |
| 211 | |
| 212 | |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 213 | class TestStaticModule(TestCase): |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 214 | |
| 215 | """ |
| 216 | Test Case: To test simple fork/wait operation in a graph |
| 217 | fork is called on simple addition operation on input tensors |
| 218 | """ |
| 219 | def test_fork_wait_1(self): |
| 220 | inp1 = torch.ones(5, 5) |
| 221 | inp2 = torch.randn(5, 5) |
| 222 | torch_graph = torch.jit.script(fork_wait_graph1) |
| 223 | output_ref = torch_graph(inp1, inp2) |
| 224 | static_runtime_module = StaticModule(torch_graph) |
| 225 | output_test = static_runtime_module(inp1, inp2) |
| 226 | torch.testing.assert_close(output_test, output_ref) |
| 227 | |
| 228 | """ |
Akshay Parashar | fefdad6 | 2022-07-05 23:40:53 +0000 | [diff] [blame] | 229 | Test Case: To test simple fork/wait operation with |
| 230 | StaticRuntime runAsync API returning future |
| 231 | """ |
| 232 | def test_fork_wait_1_async(self): |
| 233 | inp1 = torch.ones(5, 5) |
| 234 | inp2 = torch.randn(5, 5) |
| 235 | torch_graph = torch.jit.script(fork_wait_graph1) |
| 236 | output_ref = torch_graph(inp1, inp2) |
| 237 | static_runtime_module = StaticModule(torch_graph) |
| 238 | output_test = static_runtime_module.runAsync((inp1, inp2), {}) |
| 239 | output_test.wait() |
| 240 | torch.testing.assert_close(output_test.value(), output_ref) |
| 241 | |
| 242 | """ |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 243 | Test Case: To test fork/wait operation in a graph on |
| 244 | a loop subgraph performing mix of operations |
| 245 | """ |
| 246 | def test_fork_wait_2(self): |
| 247 | inp1 = torch.randn(5, 5) |
| 248 | inp2 = torch.randn(5, 5) |
| 249 | torch_graph = torch.jit.script(fork_wait_graph2) |
| 250 | output_ref = torch_graph(inp1, inp2) |
| 251 | static_runtime_module = StaticModule(torch_graph) |
| 252 | output_test = static_runtime_module(inp1, inp2) |
| 253 | torch.testing.assert_close(output_test, output_ref) |
| 254 | |
| 255 | """ |
Akshay Parashar | fefdad6 | 2022-07-05 23:40:53 +0000 | [diff] [blame] | 256 | Test Case: To test fork/wait operation on a loop |
| 257 | subgraph with StaticRuntime runAsync API returning future |
| 258 | """ |
| 259 | def test_fork_wait_2_async(self): |
| 260 | inp1 = torch.randn(5, 5) |
| 261 | inp2 = torch.randn(5, 5) |
| 262 | torch_graph = torch.jit.script(fork_wait_graph2) |
| 263 | output_ref = torch_graph(inp1, inp2) |
| 264 | static_runtime_module = StaticModule(torch_graph) |
| 265 | output_test = static_runtime_module.runAsync((inp1, inp2), {}) |
| 266 | output_test.wait() |
| 267 | torch.testing.assert_close(output_test.value(), output_ref) |
| 268 | |
| 269 | """ |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 270 | Test Case: To test fork/wait operation in a graph on |
| 271 | having multiple fork/wait operations |
| 272 | """ |
| 273 | def test_fork_wait_3(self): |
| 274 | input = torch.ones(3, 3) |
Akshay Parashar | 49368d9 | 2022-06-20 16:32:17 +0000 | [diff] [blame] | 275 | num_forks = 10 |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 276 | torch_graph = torch.jit.script(fork_wait_graph3) |
Akshay Parashar | 49368d9 | 2022-06-20 16:32:17 +0000 | [diff] [blame] | 277 | output_ref = torch_graph(input, num_forks) |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 278 | static_runtime_module = StaticModule(torch_graph) |
Akshay Parashar | 49368d9 | 2022-06-20 16:32:17 +0000 | [diff] [blame] | 279 | output_test = static_runtime_module(input, num_forks) |
| 280 | torch.testing.assert_close(output_test, output_ref) |
Akshay Parashar | fefdad6 | 2022-07-05 23:40:53 +0000 | [diff] [blame] | 281 | |
| 282 | """ |
| 283 | Test Case: To test fork/wait operation in a graph with |
| 284 | multiple fork/wait operations on runAsync API returning future |
| 285 | """ |
| 286 | def test_fork_wait_3_async(self): |
| 287 | input = torch.ones(3, 3) |
| 288 | num_forks = 10 |
| 289 | torch_graph = torch.jit.script(fork_wait_graph3) |
| 290 | output_ref = torch_graph(input, num_forks) |
| 291 | static_runtime_module = StaticModule(torch_graph) |
| 292 | output_test = static_runtime_module.runAsync((input, num_forks), {}) |
| 293 | output_test.wait() |
| 294 | torch.testing.assert_close(output_test.value(), output_ref) |
| 295 | |
Akshay Parashar | 49368d9 | 2022-06-20 16:32:17 +0000 | [diff] [blame] | 296 | """ |
| 297 | Test Case: To test fork/wait operation in a graph on |
| 298 | multiple nested fork/wait operations |
| 299 | """ |
| 300 | def test_fork_wait_4(self): |
| 301 | input = torch.ones(3, 3) |
| 302 | num_forks = 10 |
| 303 | num_child_forks = 10 |
| 304 | torch_graph = torch.jit.script(fork_wait_graph4) |
| 305 | static_runtime_module = StaticModule(torch_graph) |
| 306 | output_ref = torch_graph(input, num_forks, num_child_forks) |
| 307 | output_test = static_runtime_module(input, num_forks, num_child_forks) |
Akshay Parashar | 720cb50 | 2022-06-03 23:39:04 +0000 | [diff] [blame] | 308 | torch.testing.assert_close(output_test, output_ref) |
| 309 | |
Akshay Parashar | 65a3792 | 2022-06-11 03:11:49 +0000 | [diff] [blame] | 310 | """ |
Akshay Parashar | fefdad6 | 2022-07-05 23:40:53 +0000 | [diff] [blame] | 311 | Test Case: To test fork/wait operation in a graph with multiple |
| 312 | nested fork/wait operations on runAsync API returning future |
| 313 | """ |
| 314 | def test_fork_wait_4_async(self): |
| 315 | input = torch.ones(3, 3) |
| 316 | num_forks = 10 |
| 317 | num_child_forks = 10 |
| 318 | torch_graph = torch.jit.script(fork_wait_graph4) |
| 319 | static_runtime_module = StaticModule(torch_graph) |
| 320 | output_ref = torch_graph(input, num_forks, num_child_forks) |
| 321 | output_test = static_runtime_module.runAsync( |
| 322 | (input, num_forks, num_child_forks), {}) |
| 323 | output_test.wait() |
| 324 | torch.testing.assert_close(output_test.value(), output_ref) |
| 325 | |
| 326 | """ |
Akshay Parashar | 65a3792 | 2022-06-11 03:11:49 +0000 | [diff] [blame] | 327 | Test Case: To test exception handling in fork/wait |
| 328 | operation. Add.Tensor op is called for tensors with |
| 329 | non-matching dims on the forked subgraph and the |
| 330 | exception raised by subgraph is set on future returned |
| 331 | by prim::fork to parent graph. Returned exception is |
| 332 | checked for substring expected_error_msg as declared below |
| 333 | """ |
| 334 | def test_fork_wait_exception(self): |
| 335 | # incompatible tensors for add due to shape mismatch |
| 336 | input1 = torch.randn(4, 7) |
| 337 | input2 = torch.randn(4, 5) |
| 338 | torch_graph = torch.jit.script(fork_wait_graph_exception) |
| 339 | try: |
| 340 | static_runtime_module = StaticModule(torch_graph) |
| 341 | output_test = static_runtime_module(input1, input2) |
| 342 | except Exception as error: |
| 343 | expected_error_msg = ( |
| 344 | "The size of tensor a (7) must match the size " |
| 345 | "of tensor b (5) at non-singleton dimension 1" |
| 346 | ) |
| 347 | # test fails if error does not contain expected substr |
| 348 | if str(error).find(expected_error_msg) == -1: |
| 349 | raise RuntimeError( |
| 350 | "Tried execution of add.Tensors with incompatible shape. " |
| 351 | "Exception raised by forked runtime execution does " |
| 352 | f"not contain expected substring: \"{expected_error_msg}\"" |
| 353 | ) from error |
| 354 | |
Akshay Parashar | fefdad6 | 2022-07-05 23:40:53 +0000 | [diff] [blame] | 355 | """ |
| 356 | Test Case: To test exception handling in fork/wait |
| 357 | operation with runAsync API. Add.Tensor op is called for |
| 358 | tensors with non-matching dims on the forked subgraph |
| 359 | and the exception raised by subgraph is set on future returned |
| 360 | by prim::fork to parent graph. Returned exception is |
| 361 | checked for substring expected_error_msg as declared below |
| 362 | """ |
| 363 | def test_fork_wait_exception_async(self): |
| 364 | # incompatible tensors for add due to shape mismatch |
| 365 | input1 = torch.randn(4, 7) |
| 366 | input2 = torch.randn(4, 5) |
| 367 | torch_graph = torch.jit.script(fork_wait_graph_exception) |
| 368 | try: |
| 369 | static_runtime_module = StaticModule(torch_graph) |
| 370 | output_test = static_runtime_module.runAsync( |
| 371 | (input1, input2), {}) |
| 372 | except Exception as error: |
| 373 | expected_error_msg = ( |
| 374 | "The size of tensor a (7) must match the size " |
| 375 | "of tensor b (5) at non-singleton dimension 1" |
| 376 | ) |
| 377 | # test fails if error does not contain expected substr |
| 378 | if str(error).find(expected_error_msg) == -1: |
| 379 | raise RuntimeError( |
| 380 | "Tried execution of add.Tensors with incompatible shape. " |
| 381 | "Exception raised by forked runtime execution does " |
| 382 | f"not contain expected substring: \"{expected_error_msg}\"" |
| 383 | ) from error |
| 384 | |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 385 | def test_multihead_attention_layer(self): |
| 386 | HID_DIM = 256 |
| 387 | QUERY_LEN = 8 |
| 388 | BATCH_SIZE = 128 |
| 389 | LAYERS = 3 |
| 390 | HEADS = 8 |
| 391 | DROPOUT = 0.1 |
| 392 | device = torch.device("cpu") |
| 393 | attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) |
Bram Wasti | d1a1161 | 2020-09-25 11:01:10 -0700 | [diff] [blame] | 394 | with torch.no_grad(): |
| 395 | src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 396 | src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) |
| 397 | |
| 398 | attention.eval() |
| 399 | attention = torch.jit.script(attention) |
| 400 | attention.eval() |
| 401 | o_ref = attention(src, src, src, src_mask) |
| 402 | |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 403 | attention_a = StaticModule(attention) |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 404 | o_test = attention_a(src, src, src, src_mask) |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 405 | o_test_kw = attention_a(src, src, value=src, mask=src_mask) |
Hao Lu | 996f444 | 2020-11-03 23:42:24 -0800 | [diff] [blame] | 406 | |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 407 | for a, b in zip(o_ref, o_test): |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 408 | torch.testing.assert_close(a, b) |
Hao Lu | 996f444 | 2020-11-03 23:42:24 -0800 | [diff] [blame] | 409 | |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 410 | for a, b in zip(o_ref, o_test_kw): |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 411 | torch.testing.assert_close(a, b) |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 412 | |
| 413 | def test_multihead_attention_layer_benchmark(self): |
| 414 | HID_DIM = 256 |
| 415 | QUERY_LEN = 8 |
| 416 | BATCH_SIZE = 128 |
| 417 | LAYERS = 3 |
| 418 | HEADS = 8 |
| 419 | DROPOUT = 0.1 |
| 420 | device = torch.device("cpu") |
| 421 | attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) |
| 422 | with torch.no_grad(): |
| 423 | src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) |
| 424 | src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) |
| 425 | |
| 426 | attention.eval() |
| 427 | attention = torch.jit.script(attention) |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 428 | attention_a = StaticModule(attention) |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 429 | |
Hao Lu | 996f444 | 2020-11-03 23:42:24 -0800 | [diff] [blame] | 430 | attention_a.benchmark([src, src, src, src_mask], {}, 2, 2) |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 431 | metrics = attention_a.benchmark_individual_ops( |
Hao Lu | 996f444 | 2020-11-03 23:42:24 -0800 | [diff] [blame] | 432 | [src, src, src, src_mask], {}, 2, 2 |
Hao Lu | e8d8de3 | 2020-10-06 20:52:29 -0700 | [diff] [blame] | 433 | ) |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 434 | |
| 435 | def test_mlp(self): |
| 436 | # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh |
| 437 | ln_bot = [512, 512, 64] |
| 438 | sigmoid_bot = -1 |
| 439 | ln_top = [100, 1024, 1024, 1024, 1] |
| 440 | sigmoid_top = 3 |
| 441 | bot_l = create_mlp(ln_bot, sigmoid_bot) |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 442 | bot_l_acc = StaticModule(bot_l) |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 443 | top_l = create_mlp(ln_top, sigmoid_top) |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 444 | top_l_acc = StaticModule(top_l) |
Bram Wasti | d1a1161 | 2020-09-25 11:01:10 -0700 | [diff] [blame] | 445 | with torch.no_grad(): |
| 446 | bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) |
| 447 | top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 448 | ref_bot = bot_l(bot_inp) |
Mike Iovine | a0495b3 | 2021-10-25 08:16:14 -0700 | [diff] [blame] | 449 | acc_bot = bot_l_acc(bot_inp) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 450 | torch.testing.assert_close(acc_bot, ref_bot) |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 451 | ref_top = top_l(top_inp) |
Mike Iovine | a0495b3 | 2021-10-25 08:16:14 -0700 | [diff] [blame] | 452 | acc_top = top_l_acc(top_inp) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 453 | torch.testing.assert_close(acc_top, ref_top) |
Bram Wasti | a475613 | 2020-09-14 12:33:02 -0700 | [diff] [blame] | 454 | for _ in range(5): |
Bram Wasti | d1a1161 | 2020-09-25 11:01:10 -0700 | [diff] [blame] | 455 | with torch.no_grad(): |
| 456 | bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) |
| 457 | top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) |
Bram Wasti | a475613 | 2020-09-14 12:33:02 -0700 | [diff] [blame] | 458 | ref_bot = bot_l(bot_inp) |
Mike Iovine | a0495b3 | 2021-10-25 08:16:14 -0700 | [diff] [blame] | 459 | acc_bot = bot_l_acc(bot_inp) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 460 | torch.testing.assert_close(acc_bot, ref_bot) |
Bram Wasti | a475613 | 2020-09-14 12:33:02 -0700 | [diff] [blame] | 461 | ref_top = top_l(top_inp) |
Mike Iovine | a0495b3 | 2021-10-25 08:16:14 -0700 | [diff] [blame] | 462 | acc_top = top_l_acc(top_inp) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 463 | torch.testing.assert_close(acc_top, ref_top) |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 464 | |
Bram Wasti | 87b356d | 2020-09-28 12:53:59 -0700 | [diff] [blame] | 465 | def test_trivial_graph(self): |
| 466 | s = torch.full((2, 2), 2) |
| 467 | tg = torch.jit.script(trivial_graph) |
| 468 | o_ref = tg(s, s, s) |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 469 | tg_a = StaticModule(tg) |
Mike Iovine | a0495b3 | 2021-10-25 08:16:14 -0700 | [diff] [blame] | 470 | o_test = tg_a(s, s, s) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 471 | torch.testing.assert_close(o_ref, o_test) |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 472 | |
Katy Voor | fe7d1d7 | 2020-11-13 22:04:06 -0800 | [diff] [blame] | 473 | def test_leaky_relu(self): |
| 474 | s = torch.randn(5, 5) |
| 475 | tg = torch.jit.script(nn.LeakyReLU(0.1)) |
| 476 | o_ref = tg(s) |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 477 | tg_a = StaticModule(tg) |
Mike Iovine | a0495b3 | 2021-10-25 08:16:14 -0700 | [diff] [blame] | 478 | o_test = tg_a(s) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 479 | torch.testing.assert_close(o_ref, o_test) |
Bram Wasti | a475613 | 2020-09-14 12:33:02 -0700 | [diff] [blame] | 480 | |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 481 | def test_attr(self): |
| 482 | """ |
| 483 | TorchScript IR of TestModule() after freezing: |
| 484 | graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule, |
| 485 | %x.1 : Tensor): |
| 486 | %18 : int = prim::Constant[value=30]() |
| 487 | %30 : int = prim::Constant[value=13]() |
| 488 | %3 : int = prim::Constant[value=20]() |
| 489 | %2 : int = prim::Constant[value=1]() |
| 490 | %self.sub2.a : int = prim::Constant[value=12]() |
| 491 | %self.a : int = prim::Constant[value=3]() |
| 492 | = prim::SetAttr[name="b"](%self, %3) |
| 493 | %17 : Tensor = aten::add(%x.1, %30, %2) |
| 494 | %7 : Tensor = aten::add(%17, %self.a, %2) |
| 495 | %b.1 : int = prim::GetAttr[name="b"](%self) |
| 496 | %9 : Tensor = aten::add(%7, %b.1, %2) |
| 497 | %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self) |
| 498 | = prim::SetAttr[name="b"](%sub2, %18) |
| 499 | %b : int = prim::GetAttr[name="b"](%sub2) |
| 500 | %22 : int = aten::add(%self.sub2.a, %b) |
| 501 | %23 : Tensor = aten::add(%x.1, %22, %2) |
| 502 | %12 : Tensor = aten::add(%9, %23, %2) |
| 503 | return (%12) |
| 504 | """ |
| 505 | # test prim::SetAttr and prim::GetAttr impl in Static Runtime |
| 506 | m = TestModule() |
| 507 | |
| 508 | m.eval() |
| 509 | input = torch.randn(2, 2) |
| 510 | output_s = m.forward(input) |
| 511 | |
| 512 | ms = torch.jit.script(m) |
| 513 | sm = StaticModule(ms) |
Mike Iovine | a0495b3 | 2021-10-25 08:16:14 -0700 | [diff] [blame] | 514 | output_sm = sm(input) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 515 | torch.testing.assert_close(output_s, output_sm) |
Hao Lu | ccd0977 | 2021-07-10 14:04:48 -0700 | [diff] [blame] | 516 | sm.benchmark([input], {}, 2, 2) |
| 517 | sm.benchmark_individual_ops([input], {}, 2, 2) |
| 518 | sm.benchmark([], {"x": input}, 2, 2) |
| 519 | sm.benchmark_individual_ops([], {"x": input}, 2, 2) |
| 520 | |
Hao Lu | 0521e42 | 2021-04-05 20:50:39 -0700 | [diff] [blame] | 521 | @unittest.skip("Temporarily disabled") |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 522 | def test_fusion_trivial_graph(self): |
| 523 | s = torch.full((2, 2), 2) |
| 524 | tg = torch.jit.script(trivial_graph) |
| 525 | o_ref = tg(s, s, s) |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 526 | torch._C._fuse_to_static_module(tg.graph) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 527 | assert "StaticSubgraph" in str(tg.graph) |
| 528 | o_test = tg(s, s, s) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 529 | torch.testing.assert_close(o_ref, o_test) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 530 | |
Hao Lu | 0521e42 | 2021-04-05 20:50:39 -0700 | [diff] [blame] | 531 | @unittest.skip("Temporarily disabled") |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 532 | def test_fusion_multihead_attention_layer(self): |
| 533 | HID_DIM = 256 |
| 534 | QUERY_LEN = 8 |
| 535 | BATCH_SIZE = 128 |
| 536 | LAYERS = 3 |
| 537 | HEADS = 8 |
| 538 | DROPOUT = 0.1 |
| 539 | device = torch.device("cpu") |
| 540 | attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) |
| 541 | with torch.no_grad(): |
| 542 | src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) |
| 543 | src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) |
| 544 | |
| 545 | attention.eval() |
| 546 | attention = torch.jit.script(attention) |
| 547 | attention.eval() |
| 548 | o_ref = attention(src, src, src, src_mask) |
| 549 | |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 550 | torch._C._fuse_to_static_module(attention._c) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 551 | o_test = attention(src, src, src, src_mask) |
| 552 | |
| 553 | for a, b in zip(o_ref, o_test): |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 554 | torch.testing.assert_close(a, b) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 555 | |
Hao Lu | 0521e42 | 2021-04-05 20:50:39 -0700 | [diff] [blame] | 556 | @unittest.skip("Temporarily disabled") |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 557 | def test_fusion_loop(self): |
| 558 | a = torch.randn(5, 5) |
| 559 | b = torch.randn(5, 5) |
| 560 | c = 4 |
| 561 | lg = torch.jit.script(loop_graph) |
| 562 | o_ref = lg(a, b, c) |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 563 | torch._C._fuse_to_static_module(lg.graph) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 564 | assert "StaticSubgraph" in str(lg.graph) |
| 565 | o_test = lg(a, b, c) |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 566 | torch.testing.assert_close(o_ref, o_test) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 567 | |
Hao Lu | 0521e42 | 2021-04-05 20:50:39 -0700 | [diff] [blame] | 568 | @unittest.skip("Temporarily disabled") |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 569 | def test_fusion_outputs(self): |
| 570 | a = torch.randn(2, 2) |
| 571 | b = torch.randn(2, 2) |
| 572 | c = 4 |
| 573 | og = torch.jit.script(output_graph) |
| 574 | o_ref = og(a, b, b, c) |
Bram Wasti | 56f8379 | 2021-03-05 10:12:17 -0800 | [diff] [blame] | 575 | torch._C._fuse_to_static_module(og.graph) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 576 | assert "StaticSubgraph" in str(og.graph) |
| 577 | o_test = og(a, b, b, c) |
| 578 | for i in o_ref.keys(): |
Philip Meier | 9920358 | 2021-08-19 12:45:32 -0700 | [diff] [blame] | 579 | torch.testing.assert_close(o_ref[i], o_test[i]) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 580 | |
Mike Iovine | 0bb3158 | 2022-02-03 04:13:51 -0800 | [diff] [blame] | 581 | def test_create_object(self): |
Nikita Shulga | b268012 | 2022-02-03 07:13:09 -0800 | [diff] [blame] | 582 | class Foo: # noqa: B903 |
Mike Iovine | 0bb3158 | 2022-02-03 04:13:51 -0800 | [diff] [blame] | 583 | def __init__(self, x: torch.Tensor) -> None: |
| 584 | self.x = x |
| 585 | |
| 586 | class Mod(torch.nn.Module): |
| 587 | def __init__(self) -> None: |
| 588 | super().__init__() |
| 589 | |
| 590 | def forward(self, y: torch.Tensor) -> torch.Tensor: |
| 591 | foo = Foo(y) |
| 592 | return y * foo.x |
| 593 | |
| 594 | mod = torch.jit.script(Mod()).eval() |
| 595 | y = torch.randn((1, )) |
| 596 | expected = mod(y) |
| 597 | |
| 598 | static_mod = StaticModule(torch.jit.freeze(mod)) |
| 599 | actual = static_mod(y) |
| 600 | |
| 601 | self.assertEqual(expected, actual) |
Bram Wasti | f4226b5 | 2020-12-10 14:01:36 -0800 | [diff] [blame] | 602 | |
Bram Wasti | ada8404 | 2020-08-12 13:02:29 -0700 | [diff] [blame] | 603 | if __name__ == "__main__": |
Hao Lu | 8538a79 | 2020-08-28 23:17:17 -0700 | [diff] [blame] | 604 | run_tests() |