blob: 3515671781ea9b504a6dc97f5f782107350f20b4 [file] [log] [blame]
import torch
import numpy as np
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
import unittest
LLVM_ENABLED = torch._C._llvm_enabled()
class kernel_arena_scope(object):
def __enter__(self):
self.scope = torch._C._te.KernelScope()
def __exit__(self, typ, val, traceback):
self.scope = None
class TestTensorExprPyBind(JitTestCase):
def test_simple_sum(self):
with kernel_arena_scope():
dtype = torch._C._te.Dtype.Float
N = 32
dN = torch._C._te.ExprHandle.int(N)
A = torch._C._te.Placeholder('A', dtype, [dN])
B = torch._C._te.Placeholder('B', dtype, [dN])
def compute(i):
return A.load([i]) + B.load([i])
C = torch._C._te.Compute('C', [torch._C._te.DimArg(dN, 'i')], compute)
loopnest = torch._C._te.LoopNest([C])
loopnest.prepare_for_codegen()
stmt = torch._C._te.simplify(loopnest.root_stmt())
cg = torch._C._te.construct_codegen('ir_eval', stmt, [torch._C._te.BufferArg(x) for x in [A, B, C]])
tA = torch.rand(N) * 5
tB = torch.rand(N) * 6
tC = torch.empty(N)
cg.call([tA, tB, tC])
torch.testing.assert_allclose(tA + tB, tC)
def test_external_calls(self):
with kernel_arena_scope():
dtype = torch._C._te.Dtype.Float
ZERO = torch._C._te.ExprHandle.int(0)
ONE = torch._C._te.ExprHandle.int(1)
FOUR = torch._C._te.ExprHandle.int(4)
A = torch._C._te.BufHandle('A', [ONE, FOUR], dtype)
B = torch._C._te.BufHandle('B', [FOUR, ONE], dtype)
C = torch._C._te.BufHandle('C', [ONE, ONE], dtype)
s = torch._C._te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
loopnest = torch._C._te.LoopNest(s, [C])
loopnest.prepare_for_codegen()
codegen = torch._C._te.construct_codegen('ir_eval', s, [torch._C._te.BufferArg(x) for x in [A, B, C]])
tA = torch.ones(1, 4)
tB = torch.ones(4, 1)
tC = torch.empty(1, 1)
codegen.call([tA, tB, tC])
torch.testing.assert_allclose(torch.matmul(tA, tB), tC)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_tensor_inputs(self):
def f(a, b, c):
return a + b + c
device, size = 'cpu', (4, 4)
x = torch.rand(size, device=device)
y = torch.rand(size, device=device)
z = torch.rand(size, device=device)
graph_str = """
graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
%b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%6 : int = prim::Constant[value=1]()
%7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6)
%3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6)
return (%3)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_scalar_inputs(self):
def f(a, b, c):
return a + b + c
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
y = torch.tensor(0.6, dtype=torch.float, device='cpu')
z = torch.tensor(0.7, dtype=torch.float, device='cpu')
graph_str = """
graph(%a.1 : Float(requires_grad=0, device=cpu),
%b.1 : Float(requires_grad=0, device=cpu),
%c.1 : Float(requires_grad=0, device=cpu)):
%3 : int = prim::Constant[value=1]()
%6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3)
%9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3)
return (%9)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_shape_prop(self):
device, size = 'cpu', (4, 4)
x = torch.rand(size, device=device)
y = torch.rand(size, device=device)
graph_str = """
graph(%a : Tensor, %b : Tensor):
%c : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a, %b)
return (%c)
"""
graph = torch._C.parse_ir(graph_str)
exception_thrown = False
try:
kernel = torch._C._te.TensorExprKernel(graph)
except RuntimeError:
# Graph doesn't have shape info for inputs => compilation should
# fail
exception_thrown = True
pass
assert exception_thrown
# Inject shape info and try compiling again
example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
torch._C._te.annotate_input_shapes(graph, example_inputs)
# TODO: once we have shape propagation as well we should erase type
# info for %c from the input IR and run shape propagation here - it
# should be able to reconstruct that info
# Now compilation should pass
kernel = torch._C._te.TensorExprKernel(graph)
res = kernel.run((x, y))
correct = torch.mul(x, y)
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
@unittest.skip("Does not work until shape propagation is implemented")
def test_kernel_shape_prop_module(self):
class TestModule(torch.nn.Module):
def forward(self, x, y):
return x * x + y
graph = torch.jit.script(TestModule()).graph
# Try compiling the graph as-is. It should fail because it doesn't have
# shape info.
exception_thrown = False
try:
kernel = torch._C._te.TensorExprKernel(graph)
except RuntimeError:
exception_thrown = True
pass
assert exception_thrown
# Try injecting shape info for graph inputs
example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
exception_thrown = False
try:
torch._C._te.annotate_input_shapes(graph, example_inputs)
except RuntimeError:
# Graph has a 'self' argument for which we can't set shapes
exception_thrown = True
pass
assert exception_thrown
# Remove 'self' argument and try annotating shapes one more time
graph = torch._C._te.remove_unused_self_argument(graph)
# Inject shape info and try compiling again
torch._C._te.annotate_input_shapes(graph, example_inputs)
# TODO: once we have shape propagation as well we should erase type
# info for %c from the input IR and run shape propagation here - it
# should be able to reconstruct that info
# Now compilation should pass
kernel = torch._C._te.TensorExprKernel(graph)
device, size = 'cpu', (4, 4)
x = torch.rand(size, device=device)
y = torch.rand(size, device=device)
res = kernel.run((x, y))
correct = TestModule().forward(x, y)
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_t(self):
def f(a):
return a.t()
device, size = 'cpu', (3, 4)
x = torch.rand(size, device=device)
graph_str = """
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1)
return (%3)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_transpose(self):
def f(a):
return a.transpose(-1, -2)
device, size = 'cpu', (3, 4)
x = torch.rand(size, device=device)
graph_str = """
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%2 : int = prim::Constant[value=-1]()
%3 : int = prim::Constant[value=-2]()
%4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3)
return (%4)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_permute(self):
def f(a):
return a.permute([2, 1, 0])
device, size = 'cpu', (3, 4, 5)
x = torch.rand(size, device=device)
graph_str = """
graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
%1 : int = prim::Constant[value=2]()
%2 : int = prim::Constant[value=1]()
%3 : int = prim::Constant[value=0]()
%4 : int[] = prim::ListConstruct(%1, %2, %3)
%5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4)
return (%5)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_expand(self):
def f(a):
return a.expand((2, 3, 4))
device = 'cpu'
x = torch.rand((1, 3, 1), device=device)
graph_str = """
graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
%1 : int = prim::Constant[value=2]()
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=4]()
%4 : int[] = prim::ListConstruct(%1, %2, %3)
%5 : bool = prim::Constant[value=0]()
%6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5)
return (%6)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
if __name__ == '__main__':
run_tests()