blob: 4c7bc4aa628c02e1ded048c9a7e432ede48aa1bc [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import unittest
import torch
import torch._C
torch.ops.load_library("//caffe2:xnnpack_backend")
class TestXNNPackBackend(unittest.TestCase):
def test_xnnpack_constant_data(self):
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self._constant = torch.ones(4, 4, 4)
def forward(self, x):
return x + self._constant
scripted_module = torch.jit.script(Module())
lowered_module = torch._C._jit_to_backend(
"xnnpack",
scripted_module,
{
"forward": {
"inputs" : [torch.randn(4, 4, 4)],
"outputs": [torch.randn(4, 4, 4)]
}
}
)
for i in range(0, 20):
sample_input = torch.randn(4, 4, 4)
actual_output = scripted_module(sample_input)
expected_output = lowered_module(sample_input)
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
def test_xnnpack_lowering(self):
class Module(torch.nn.Module):
def forward(self, x):
return x + x
scripted_module = torch.jit.script(Module())
faulty_compile_spec = {
"backward": {
"inputs" : [torch.zeros(1)],
"outputs": [torch.zeros(1)],
}
}
error_msg = (
"method_compile_spec does not contain the \"forward\" key."
)
with self.assertRaisesRegex(
RuntimeError,
error_msg,
):
_ = torch._C._jit_to_backend(
"xnnpack",
scripted_module,
faulty_compile_spec,
)
mismatch_compile_spec = {
"forward" : {
"inputs" : [torch.zeros(1), torch.zeros(1)],
"outputs" : [torch.zeros(1)]
}
}
error_msg = ("method_compile_spec inputs do not match expected number of forward inputs")
with self.assertRaisesRegex(
RuntimeError,
error_msg,
):
_ = torch._C._jit_to_backend(
"xnnpack",
scripted_module,
mismatch_compile_spec
)
lowered = torch._C._jit_to_backend(
"xnnpack",
scripted_module,
{
"forward": {
"inputs" : [torch.zeros(1)],
"outputs": [torch.zeros(1)],
}
}
)
lowered(torch.zeros(1))
def test_xnnpack_backend_add(self):
class AddModule(torch.nn.Module):
def forward(self, x, y):
z = x + y
z = z + x
z = z + x
return z
add_module = AddModule()
sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3))
sample_output = torch.zeros(1, 512, 512, 3)
add_module = torch.jit.script(add_module)
expected_output = add_module(sample_inputs[0], sample_inputs[1])
lowered_add_module = torch._C._jit_to_backend(
"xnnpack",
add_module,
{
"forward": {
"inputs" : [sample_inputs[0].clone(), sample_inputs[1].clone()],
"outputs": [sample_output]
}
}
)
actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
def test_xnnpack_broadcasting(self):
class AddModule(torch.nn.Module):
def forward(self, x, y):
return x + y
add_module = AddModule()
sample_inputs = (torch.rand(5, 1, 4, 1), torch.rand(3, 1, 1))
sample_output = torch.zeros(5, 3, 4, 1)
add_module = torch.jit.script(add_module)
expected_output = add_module(sample_inputs[0], sample_inputs[1])
lowered_add_module = torch._C._jit_to_backend(
"xnnpack",
add_module,
{
"forward": {
"inputs" : [sample_inputs[0], sample_inputs[1]],
"outputs": [sample_output]
}
}
)
actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
def test_xnnpack_unsupported(self):
class AddSpliceModule(torch.nn.Module):
def forward(self, x, y):
z = x + y[:, :, 1, :]
return z
sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3))
sample_output = torch.zeros(1, 512, 512, 3)
error_msg = (
"the module contains the following unsupported ops:\n"
"aten::select\n"
"aten::slice\n"
)
add_module = torch.jit.script(AddSpliceModule())
with self.assertRaisesRegex(
RuntimeError,
error_msg,
):
_ = torch._C._jit_to_backend(
"xnnpack",
add_module,
{
"forward": {
"inputs" : [sample_inputs[0], sample_inputs[1]],
"outputs": [sample_output]
}
}
)