|  | import numpy as np | 
|  | import torch | 
|  | from torch import nn | 
|  | from torch.testing._internal.common_utils import TestCase, run_tests | 
|  |  | 
|  |  | 
|  | class StaticRuntime: | 
|  | def __init__(self, scripted): | 
|  | # this is an nn.Module | 
|  | if hasattr(scripted, "_c"): | 
|  | self.static_runtime = torch._C._jit_to_static_runtime(scripted._c) | 
|  | else: | 
|  | self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph) | 
|  |  | 
|  | def __call__(self, *args, **kwargs): | 
|  | if not kwargs: | 
|  | return self.static_runtime.run(args) | 
|  | else: | 
|  | return self.static_runtime.run(args, kwargs) | 
|  |  | 
|  | def benchmark(self, args, kwargs, warmup_runs, main_runs): | 
|  | self.static_runtime.benchmark(args, kwargs, warmup_runs, main_runs) | 
|  |  | 
|  | def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs): | 
|  | return self.static_runtime.benchmark_individual_ops( | 
|  | args, kwargs, warmup_runs, main_runs | 
|  | ) | 
|  |  | 
|  |  | 
|  | def linear_shim(input, weight, bias=None): | 
|  | # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor | 
|  | output = input.matmul(weight.t()) | 
|  | if bias is not None: | 
|  | output += bias | 
|  | ret = output | 
|  | return ret | 
|  |  | 
|  |  | 
|  | torch.nn.functional.linear = linear_shim | 
|  |  | 
|  |  | 
|  | class MultiHeadAttentionLayer(nn.Module): | 
|  | def __init__(self, hid_dim, n_heads, dropout, device): | 
|  | super().__init__() | 
|  | assert hid_dim % n_heads == 0 | 
|  | self.hid_dim = hid_dim | 
|  | self.n_heads = n_heads | 
|  | self.head_dim = hid_dim // n_heads | 
|  | self.fc_q = nn.Linear(hid_dim, hid_dim) | 
|  | self.fc_k = nn.Linear(hid_dim, hid_dim) | 
|  | self.fc_v = nn.Linear(hid_dim, hid_dim) | 
|  | self.fc_o = nn.Linear(hid_dim, hid_dim) | 
|  | # self.dropout = nn.Dropout(dropout) | 
|  | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) | 
|  |  | 
|  | def forward(self, query, key, value, mask): | 
|  | batch_size = query.shape[0] | 
|  | Q = self.fc_q(query) | 
|  | K = self.fc_k(key) | 
|  | V = self.fc_v(value) | 
|  | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | 
|  | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | 
|  | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | 
|  | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale | 
|  | # energy = energy.masked_fill(mask == 0, -1e10) | 
|  | attention = torch.softmax(energy, dim=-1) | 
|  | # x = torch.matmul(self.dropout(attention), V) | 
|  | x = torch.matmul(attention, V) | 
|  | x = x.permute(0, 2, 1, 3).contiguous() | 
|  | x = x.view(batch_size, -1, self.hid_dim) | 
|  | x = self.fc_o(x) | 
|  | return x, attention | 
|  |  | 
|  |  | 
|  | # Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py | 
|  | def create_mlp(ln, sigmoid_layer): | 
|  | layers = nn.ModuleList() | 
|  | for i in range(0, len(ln) - 1): | 
|  | n = ln[i] | 
|  | m = ln[i + 1] | 
|  |  | 
|  | LL = nn.Linear(int(n), int(m), bias=True) | 
|  |  | 
|  | mean = 0.0  # std_dev = np.sqrt(variance) | 
|  | std_dev = np.sqrt(2 / (m + n))  # np.sqrt(1 / m) # np.sqrt(1 / n) | 
|  | W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) | 
|  | std_dev = np.sqrt(1 / m)  # np.sqrt(2 / (m + 1)) | 
|  | bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) | 
|  | LL.weight.data = torch.tensor(W, requires_grad=True) | 
|  | LL.bias.data = torch.tensor(bt, requires_grad=True) | 
|  | layers.append(LL) | 
|  |  | 
|  | if i == sigmoid_layer: | 
|  | layers.append(nn.Sigmoid()) | 
|  | else: | 
|  | layers.append(nn.ReLU()) | 
|  |  | 
|  | with torch.no_grad(): | 
|  | s = torch.jit.script(torch.nn.Sequential(*layers)) | 
|  | s.eval() | 
|  | return s | 
|  |  | 
|  |  | 
|  | def trivial_graph(a, b, c): | 
|  | s = torch.tensor([[3, 3], [3, 3]]) | 
|  | return a + b * c + s | 
|  |  | 
|  |  | 
|  | class TestStaticRuntime(TestCase): | 
|  | def test_multihead_attention_layer(self): | 
|  | HID_DIM = 256 | 
|  | QUERY_LEN = 8 | 
|  | BATCH_SIZE = 128 | 
|  | LAYERS = 3 | 
|  | HEADS = 8 | 
|  | DROPOUT = 0.1 | 
|  | device = torch.device("cpu") | 
|  | attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) | 
|  | with torch.no_grad(): | 
|  | src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) | 
|  | src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) | 
|  |  | 
|  | attention.eval() | 
|  | attention = torch.jit.script(attention) | 
|  | attention.eval() | 
|  | o_ref = attention(src, src, src, src_mask) | 
|  |  | 
|  | attention_a = StaticRuntime(attention) | 
|  | o_test = attention_a(src, src, src, src_mask) | 
|  | o_test_kw = attention_a(src, src, value=src, mask=src_mask) | 
|  | for a, b in zip(o_ref, o_test): | 
|  | torch.testing.assert_allclose(a, b) | 
|  | for a, b in zip(o_ref, o_test_kw): | 
|  | torch.testing.assert_allclose(a, b) | 
|  |  | 
|  | def test_multihead_attention_layer_benchmark(self): | 
|  | HID_DIM = 256 | 
|  | QUERY_LEN = 8 | 
|  | BATCH_SIZE = 128 | 
|  | LAYERS = 3 | 
|  | HEADS = 8 | 
|  | DROPOUT = 0.1 | 
|  | device = torch.device("cpu") | 
|  | attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) | 
|  | with torch.no_grad(): | 
|  | src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) | 
|  | src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) | 
|  |  | 
|  | attention.eval() | 
|  | attention = torch.jit.script(attention) | 
|  | attention_a = StaticRuntime(attention) | 
|  |  | 
|  | attention_a.benchmark([src, src, src, src_mask], {}, 10, 10) | 
|  | metrics = attention_a.benchmark_individual_ops( | 
|  | [src, src, src, src_mask], {}, 10, 10 | 
|  | ) | 
|  |  | 
|  | def test_mlp(self): | 
|  | # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh | 
|  | ln_bot = [512, 512, 64] | 
|  | sigmoid_bot = -1 | 
|  | ln_top = [100, 1024, 1024, 1024, 1] | 
|  | sigmoid_top = 3 | 
|  | bot_l = create_mlp(ln_bot, sigmoid_bot) | 
|  | bot_l_acc = StaticRuntime(bot_l) | 
|  | top_l = create_mlp(ln_top, sigmoid_top) | 
|  | top_l_acc = StaticRuntime(top_l) | 
|  | with torch.no_grad(): | 
|  | bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512]) | 
|  | top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100]) | 
|  | ref_bot = bot_l(bot_inp) | 
|  | acc_bot = bot_l_acc(bot_inp)[0] | 
|  | torch.testing.assert_allclose(acc_bot, ref_bot) | 
|  | ref_top = top_l(top_inp) | 
|  | acc_top = top_l_acc(top_inp)[0] | 
|  | torch.testing.assert_allclose(acc_top, ref_top) | 
|  | for _ in range(5): | 
|  | with torch.no_grad(): | 
|  | bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512]) | 
|  | top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100]) | 
|  | ref_bot = bot_l(bot_inp) | 
|  | acc_bot = bot_l_acc(bot_inp)[0] | 
|  | torch.testing.assert_allclose(acc_bot, ref_bot) | 
|  | ref_top = top_l(top_inp) | 
|  | acc_top = top_l_acc(top_inp)[0] | 
|  | torch.testing.assert_allclose(acc_top, ref_top) | 
|  |  | 
|  | def test_trivial_graph(self): | 
|  | s = torch.full((2, 2), 2) | 
|  | tg = torch.jit.script(trivial_graph) | 
|  | o_ref = tg(s, s, s) | 
|  | tg_a = StaticRuntime(tg) | 
|  | o_test = tg_a(s, s, s)[0] | 
|  | torch.testing.assert_allclose(o_ref, o_test) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | run_tests() |