blob: 1313923c03f722935d990ff47061e2840e8a10fc [file] [log] [blame]
# Owner(s): ["oncall: fx"]
import unittest
from typing import Callable, List
import numpy as np
import torch
import torch.fx.experimental.fx_acc.acc_normalizer as acc_normalizer
import torch.fx.experimental.fx_acc.acc_ops as acc_ops
import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer
import torch.fx.experimental.fx_acc.acc_utils as acc_utils
import torch.nn as nn
import torchvision
from parameterized import parameterized, param
torch.manual_seed(0)
class AccTracerTest(unittest.TestCase):
def _make_model_unit_test(
self,
model,
*args,
input_shape=None,
enable_allclose=False,
**kwargs,
):
"""
Test that the model can be traced correctly and is producing correct
result.
"""
if input_shape is None:
input_shape = [1, 3, 224, 224]
input = torch.randn(input_shape)
traced = acc_tracer.trace(model, [input])
if enable_allclose:
torch.testing.assert_allclose(model(input), traced(input))
else:
self.assertTrue(torch.equal(model(input), traced(input)))
def _make_acc_op_function_test(
self,
acc_op: Callable,
torch_op,
*args,
input_shape=(2, 3),
validate_same_kwargs=True,
enable_allclose=False,
**kwargs,
):
"""
Test that acc_op is traced somewhat.
"""
class TestModule(torch.nn.Module):
def __init__(self, torch_op, args, kwargs):
super().__init__()
self._torch_op = torch_op
self._args = args
self._kwargs = kwargs
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self._torch_op(a, *self._args, **self._kwargs)
m = TestModule(torch_op, args, kwargs)
a = torch.randn(*input_shape)
traced = acc_tracer.trace(m, [a])
ph_a = acc_op_node = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_op)
self.assertEqual(node.kwargs["input"], ph_a)
if validate_same_kwargs:
for key, value in kwargs.items():
self.assertEqual(node.kwargs[key], value)
acc_op_node = node
elif node.op == "output":
if acc_op is None:
# If we expect no new acc_op after graph building
# and found we have only output in traced graph
continue
self.assertEqual(acc_op_node, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
ref_outputs = m(a)
outputs = traced(a)
if isinstance(ref_outputs, torch.Tensor):
ref_outputs = [ref_outputs]
outputs = [outputs]
for ref_output, output in zip(ref_outputs, outputs):
if enable_allclose:
torch.testing.assert_allclose(
torch.nan_to_num(ref_output), torch.nan_to_num(output)
)
else:
self.assertTrue(
torch.equal(torch.nan_to_num(ref_output), torch.nan_to_num(output))
)
def test_sum(self):
def torch_sum(x, *args, **kwargs):
return x.sum(*args, **kwargs)
self._make_acc_op_function_test(acc_ops.sum, torch_sum)
self._make_acc_op_function_test(acc_ops.sum, torch_sum, dim=(1,), keepdim=True)
def test_max(self):
def torch_max(x, *args, **kwargs):
return x.max(*args, **kwargs)
self._make_acc_op_function_test(acc_ops.max_full_reduce, torch_max)
self._make_acc_op_function_test(
acc_ops.max_dim_reduce, torch_max, dim=1, keepdim=True
)
self._make_acc_op_function_test(
acc_ops.max_dim_reduce, torch_max, input_shape=(1, 4), dim=1, keepdim=True
)
self._make_acc_op_function_test(
acc_ops.max_dim_reduce, torch_max, input_shape=(3, 4, 3), dim=2
)
@parameterized.expand(
[
param("max_maximum", orig_op=torch.max, expected_op=acc_ops.maximum),
param(
"maximum_maximum", orig_op=torch.maximum, expected_op=acc_ops.maximum
),
param("min_minimum", orig_op=torch.min, expected_op=acc_ops.minimum),
param(
"minimum_minimum", orig_op=torch.minimum, expected_op=acc_ops.minimum
),
]
)
def test_maximum_minimum(self, _: str, orig_op, expected_op):
class TestModule(torch.nn.Module):
def __init__(self, orig_op):
super().__init__()
self.orig_op = orig_op
def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
return self.orig_op(input, other)
m = TestModule(orig_op)
input, other = torch.randn(2, 2), torch.randn(2, 2)
traced = acc_tracer.trace(m, [input, other])
ph_in = ph_oth = mxm = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "other":
ph_oth = node
else:
self.assertTrue(str(node.target) == "input")
ph_in = node
elif node.op == "call_function":
if node.target == expected_op:
self.assertEqual(node.kwargs["input"], ph_in)
self.assertEqual(node.kwargs["other"], ph_oth)
mxm = node
elif node.op == "output":
self.assertEqual(mxm, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input, other), traced(input, other)))
def test_conv(self):
"""
Test that a conv is traced as expected.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(8, 7, 3, stride=2)
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self.conv(a)
m = TestModule()
input = torch.randn(3, 8, 10, 10)
traced = acc_tracer.trace(m, [input])
ph = weight_attr = bias_attr = conv = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "get_attr" and node.target == "conv.weight":
weight_attr = node
elif node.op == "get_attr" and node.target == "conv.bias":
bias_attr = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.conv2d)
self.assertEqual(node.kwargs["input"], ph)
self.assertEqual(node.kwargs["weight"], weight_attr)
self.assertEqual(node.kwargs["bias"], bias_attr)
self.assertEqual(node.kwargs["stride"], (2, 2))
self.assertEqual(node.kwargs["padding"], (0, 0))
self.assertEqual(node.kwargs["dilation"], (1, 1))
self.assertEqual(node.kwargs["groups"], 1)
conv = node
elif node.op == "output":
self.assertEqual(conv, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input), traced(input)))
def test_quantized_conv2d(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.quantized.Conv2d(3, 3, 1)
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self.conv(a)
m = TestModule()
input = torch.quantize_per_tensor(
torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8
)
traced = acc_tracer.trace(m, [input])
print(traced.graph)
ph = weight_attr = bias_attr = conv = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "get_attr" and node.target == "conv_weight":
weight_attr = node
elif node.op == "get_attr" and node.target == "conv_bias":
bias_attr = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.quantized_conv2d)
self.assertEqual(node.kwargs["input"], ph)
self.assertEqual(node.kwargs["weight"], weight_attr)
self.assertEqual(node.kwargs["bias"], bias_attr)
conv = node
elif node.op == "output":
self.assertEqual(conv, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input), traced(input)))
def test_quantized_convrelu2d(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.intrinsic.quantized.ConvReLU2d(3, 3, 1)
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self.conv(a)
m = TestModule()
input = torch.quantize_per_tensor(
torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8
)
traced = acc_tracer.trace(m, [input])
ph = weight_attr = bias_attr = conv = relu = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "get_attr" and node.target == "conv_weight":
weight_attr = node
elif node.op == "get_attr" and node.target == "conv_bias":
bias_attr = node
elif node.op == "call_function" and node.target == acc_ops.quantized_conv2d:
self.assertEqual(node.target, acc_ops.quantized_conv2d)
self.assertEqual(node.kwargs["input"], ph)
self.assertEqual(node.kwargs["weight"], weight_attr)
self.assertEqual(node.kwargs["bias"], bias_attr)
conv = node
elif node.op == "call_function" and node.target == acc_ops.relu:
self.assertEqual(node.target, acc_ops.relu)
self.assertEqual(node.kwargs["input"], conv)
relu = node
elif node.op == "output":
self.assertEqual(relu, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input), traced(input)))
def test_embedding_bag(self):
"""
Test that an embedding_bag is traced as expected.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.eb = nn.EmbeddingBag(10, 3, mode="sum", include_last_offset=True)
def forward(self, inp: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
return self.eb(inp, offsets)
m = TestModule()
inp = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.LongTensor([0, 4])
traced = acc_tracer.trace(m, [inp, offsets])
inp_node = offsets_node = weight_attr = eb_node = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "inp":
inp_node = node
elif str(node.target) == "offsets":
offsets_node = node
else:
self.fail(f"Unexpected placeholder {node.target}.")
continue
elif node.op == "get_attr" and node.target == "eb.weight":
weight_attr = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.embedding_bag)
# Note: Normalization called from acc_tracer means we use all kwargs.
self.assertEqual(node.kwargs["input"], inp_node)
self.assertEqual(node.kwargs["offsets"], offsets_node)
self.assertEqual(node.kwargs["weight"], weight_attr)
self.assertEqual(node.kwargs["mode"], "sum")
self.assertEqual(node.kwargs["include_last_offset"], True)
# The rest of these were unspecified, so verify they fell back
# to their respective default values thanks to normalization.
self.assertEqual(node.kwargs["max_norm"], None)
self.assertEqual(node.kwargs["norm_type"], 2.0)
self.assertEqual(node.kwargs["scale_grad_by_freq"], False)
self.assertEqual(node.kwargs["sparse"], False)
self.assertEqual(node.kwargs["per_sample_weights"], None)
eb_node = node
elif node.op == "output":
self.assertEqual(eb_node, node.args[0])
self.assertTrue(torch.equal(m(inp, offsets), traced(inp, offsets)))
def test_embedding_bag_byte_and_4bit_rowwise_offsets(self):
"""
Test that 4 bit quantized embedding_bag is traced as expected.
"""
class TestModule(nn.Module):
def __init__(
self,
op,
q_weights,
per_index_weights,
):
super().__init__()
self.emb = op
self.q_weights = q_weights
self.per_index_weights = per_index_weights
def forward(
self,
indices,
offsets,
):
return self.emb(
self.q_weights,
indices,
offsets,
mode=0,
per_sample_weights=self.per_index_weights,
include_last_offset=True,
)
def run_embedding_bag_test(is_4bit, use_weights):
# generate random indices, offsets, and weights.
num_embeddings = 16
embedding_dim = 32
num_lengths = 10
weights = torch.from_numpy(
(np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(
np.float32
)
)
q_weights = (
torch.ops.quantized.embedding_bag_4bit_prepack(weights)
if is_4bit
else torch.ops.quantized.embedding_bag_byte_prepack(weights)
)
np_lengths = np.random.randint(0, num_lengths, size=10).astype(np.int32)
num_lengths = np.sum(np_lengths)
indices = torch.from_numpy(
np.random.randint(low=0, high=num_embeddings, size=num_lengths)
).int()
lengths = torch.from_numpy(np_lengths)
offsets = torch.cat([torch.zeros([1]), torch.cumsum(lengths, 0)]).int()
weights = torch.randint(low=0, high=4, size=indices.size())
per_sample_weights = weights.to(torch.float32)
indices = indices.to(torch.int32)
offsets = offsets.to(torch.int32)
inputs = [
indices,
offsets,
]
op = (
torch.ops.quantized.embedding_bag_4bit_rowwise_offsets
if is_4bit
else torch.ops.quantized.embedding_bag_byte_rowwise_offsets
)
m = TestModule(
op,
q_weights,
per_sample_weights,
)
traced = acc_tracer.trace(m, inputs)
print(traced.graph)
expected_target = (
acc_ops.embedding_bag_4bit_rowwise_offsets
if is_4bit
else acc_ops.embedding_bag_byte_rowwise_offsets
)
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "indices":
inp_node = node
elif str(node.target) == "offsets":
offsets_node = node
else:
self.fail(f"Unexpected placeholder {node.target}.")
continue
elif node.op == "get_attr" and node.target == "q_weights":
weight_attr = node
elif node.op == "call_function":
self.assertEqual(node.target, expected_target)
# Note: Normalization called from acc_tracer means we use all kwargs.
self.assertEqual(node.kwargs["indices"], inp_node)
self.assertEqual(node.kwargs["offsets"], offsets_node)
self.assertEqual(node.kwargs["weight"], weight_attr)
self.assertEqual(node.kwargs["mode"], 0)
self.assertEqual(node.kwargs["include_last_offset"], True)
# The rest of these were unspecified, so verify they fell back
# to their respective default values thanks to normalization.
eb_node = node
elif node.op == "output":
self.assertEqual(eb_node, node.args[0])
self.assertTrue(torch.equal(m(indices, offsets), traced(indices, offsets)))
# test 8-bit
run_embedding_bag_test(is_4bit=False, use_weights=True)
# test 4-bit
run_embedding_bag_test(is_4bit=True, use_weights=True)
def test_quantized_batch_norm2d(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.bn = nn.quantized.BatchNorm2d(3)
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self.bn(a)
m = TestModule()
m.eval()
input = torch.quantize_per_tensor(
torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8
)
traced = acc_tracer.trace(m, [input])
ph = weight_attr = bias_attr = bn_mean = bn_var = bn = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "get_attr" and node.target == "bn.weight":
weight_attr = node
elif node.op == "get_attr" and node.target == "bn.bias":
bias_attr = node
elif node.op == "get_attr" and node.target == "bn.running_mean":
bn_mean = node
elif node.op == "get_attr" and node.target == "bn.running_var":
bn_var = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.quantized_batch_norm2d)
self.assertEqual(node.kwargs["input"], ph)
self.assertEqual(node.kwargs["weight"], weight_attr)
self.assertEqual(node.kwargs["bias"], bias_attr)
self.assertEqual(node.kwargs["running_mean"], bn_mean)
self.assertEqual(node.kwargs["running_var"], bn_var)
bn = node
elif node.op == "output":
self.assertEqual(bn, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input), traced(input)))
def test_linear(self):
"""
Test that a linear is traced as expected, i.e. to the functional level and with
kwarg normalization. Also verify that symbolic shape inference worked as part of
the acc_tracer.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 5, bias=True)
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self.linear(a)
m = TestModule()
test_input = torch.randn(1, 3)
traced = acc_tracer.trace(m, test_input)
ph = weight_attr = bias_attr = linear = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "get_attr" and node.target == "linear.weight":
weight_attr = node
elif node.op == "get_attr" and node.target == "linear.bias":
bias_attr = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.linear)
self.assertEqual(node.kwargs["input"], ph)
self.assertEqual(node.kwargs["weight"], weight_attr)
self.assertEqual(node.kwargs["bias"], bias_attr)
linear = node
elif node.op == "output":
self.assertEqual(linear, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(test_input), traced(test_input)))
def test_quantized_linear(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.quantized.Linear(3, 5)
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self.linear(a)
m = TestModule()
input = torch.quantize_per_tensor(
torch.randn(2, 3), scale=0.01, zero_point=3, dtype=torch.quint8
)
traced = acc_tracer.trace(m, [input])
ph = weight_attr = bias_attr = linear = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "get_attr" and node.target == "linear_weight":
weight_attr = node
elif node.op == "get_attr" and node.target == "linear_bias":
bias_attr = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.quantized_linear)
self.assertEqual(node.kwargs["input"], ph)
self.assertEqual(node.kwargs["weight"], weight_attr)
self.assertEqual(node.kwargs["bias"], bias_attr)
linear = node
elif node.op == "output":
self.assertEqual(linear, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input), traced(input)))
@parameterized.expand(
[
param("remove_exceptions_false", remove_exceptions=False),
param("remove_exceptions_true", remove_exceptions=True),
]
)
def test_batch_norm(self, _, remove_exceptions):
"""
Test that a batch norm is traced as expected, i.e. to the functional level
and with kwarg normalization. Note that we also expect to see a
ConditionalExceptionWrapper in the graph that the AST rewriter converted
from `if x: raise y`.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm2d(2)
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self.bn(a)
m = TestModule()
input = torch.randn(2, 2, 1, 1)
# Note: Explicitly not removing exceptions so that we can check they
# were found and exist below.
traced = acc_tracer.trace(
m,
[input],
remove_exceptions=remove_exceptions,
)
ph = exception_wrapper = weight = bias = mean = var = bn = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "get_attr" and node.target == "bn.weight":
weight = node
elif node.op == "get_attr" and node.target == "bn.bias":
bias = node
elif node.op == "get_attr" and node.target == "bn.running_mean":
mean = node
elif node.op == "get_attr" and node.target == "bn.running_var":
var = node
elif node.op == "call_function" and node.target == acc_ops.batch_norm:
# Note: Normalization called from acc_tracer means we use
# all kwargs.
self.assertEqual(node.kwargs["input"], ph)
self.assertEqual(node.kwargs["weight"], weight)
self.assertEqual(node.kwargs["bias"], bias)
self.assertEqual(node.kwargs["running_mean"], mean)
self.assertEqual(node.kwargs["running_var"], var)
bn = node
elif (
node.op == "call_module"
and node.target == "bn._conditional_exception_wrapper_ValueError"
):
exception_wrapper = node
elif node.op == "output":
self.assertEqual(bn, node.args[0])
self.assertTrue(remove_exceptions or exception_wrapper is not None)
self.assertTrue(torch.equal(m(input), traced(input)))
def test_remove_asserts(self):
"""
Test that a Module with asserts has the asserts automatically removed, as
well as calls to a class method that should be dead.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def _test_method(self, a):
return a
def forward(self, a: torch.Tensor) -> torch.Tensor:
assert torch.equal(self._test_method(a), a)
return a
m = TestModule()
input = torch.randn(10)
traced = acc_tracer.trace(m, [input], ast_rewriter_allow_list={TestModule})
# Check we have no call_functions. If remove asserts didn't work
# correctly we would see a call to torch._assert, _test_method, and
# torch.equal.
for node in traced.graph.nodes:
self.assertFalse(node.op == "call_function")
self.assertTrue(torch.equal(m(input), traced(input)))
def test_sequential(self):
"""
Test that the tracer works for torch.nn.Sequential.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(nn.Sigmoid(), nn.ReLU())
def forward(self, a: torch.Tensor) -> torch.Tensor:
return self.model(a)
m = TestModule()
input = torch.randn(10)
traced = acc_tracer.trace(m, [input])
for node in traced.graph.nodes:
if node.op == "call_function":
is_sigmoid = node.target == acc_ops.sigmoid
is_relu = node.target == acc_ops.relu
self.assertTrue(is_sigmoid or is_relu)
else:
self.assertTrue(node.op == "placeholder" or node.op == "output")
self.assertTrue(torch.equal(m(input), traced(input)))
def test_unsqueeze(self):
"""
Test that torch.unsqueeze is traced correctly.
"""
self._make_acc_op_function_test(
acc_ops.unsqueeze,
torch.unsqueeze,
validate_same_kwargs=False,
dim=1,
)
def test_stack(self):
"""
Test that torch.stack is traced correctly.
"""
class TestModule(torch.nn.Module):
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return torch.stack((a, b), dim=1)
a, b = torch.randn(4, 5, 6), torch.randn(4, 5, 6)
mod = TestModule()
traced = acc_tracer.trace(mod, [a, b])
self.assertTrue(torch.equal(mod(a, b), traced(a, b)))
ph_a = ph_b = unsqueeze_a = unsqueeze_b = cat_node = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
else:
self.assertTrue(str(node.target) == "b")
ph_b = node
elif node.op == "call_function":
if node.target == acc_ops.unsqueeze:
if node.kwargs["input"] is ph_a:
unsqueeze_a = node
else:
self.assertEqual(node.kwargs["input"], ph_b)
unsqueeze_b = node
else:
self.assertEqual(node.target, acc_ops.cat)
self.assertEqual(node.kwargs["tensors"], [unsqueeze_a, unsqueeze_b])
cat_node = node
elif node.op == "output":
self.assertEqual(cat_node, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
def test_no_raise(self):
"""
self that we can trace `if x: raise y(msg)` when the raise isn't executed.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
if torch.equal(a, b):
raise AssertionError("a equaled b!")
return a
m = TestModule()
in_a, in_b = torch.randn(5), torch.randn(5)
traced = acc_tracer.trace(
m,
[in_a, in_b],
remove_exceptions=False,
use_acc_normalization=False,
ast_rewriter_allow_list={TestModule},
)
# Verify the structure of the graph, including the existence of the
# exception_wrapper.
ph_a = exception_wrapper = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
else:
self.assertTrue(str(node.target) == "b")
elif node.op == "call_module":
self.assertEqual(
node.target, "_conditional_exception_wrapper_AssertionError"
)
exception_wrapper = node
elif node.op == "output":
self.assertEqual(ph_a, node.args[0])
self.assertTrue(exception_wrapper is not None)
self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_b)))
def test_yes_raise(self):
"""
Test that we can trace `if x: raise y(msg)` when the raise is executed.
"""
err_str = "a equaled b!"
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.err_str = err_str
def forward(self, a, b):
if torch.equal(a, b):
raise RuntimeError(self.err_str)
return a
m = TestModule()
# Note: We must use different inputs here in order for shape_prop to work, as
# otherwise the exception is thrown (as expected/checked below).
in_a, in_b = torch.randn(5), torch.randn(5)
traced = acc_tracer.trace(
m,
[in_a, in_b],
remove_exceptions=False,
ast_rewriter_allow_list={TestModule},
)
# Verify the structure of the graph, including the existence of the
# exception_wrapper.
ph_a = exception_wrapper = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
else:
self.assertTrue(str(node.target) == "b")
elif node.op == "call_module":
self.assertEqual(
node.target, "_conditional_exception_wrapper_RuntimeError"
)
exception_wrapper = node
elif node.op == "output":
self.assertEqual(ph_a, node.args[0])
self.assertTrue(exception_wrapper is not None)
def test(mod):
try:
# Note: Use the same input here to ensure the exception is thrown.
mod(in_a, in_a)
self.fail("Shouldn't get here because exception should be thrown.")
except RuntimeError as e:
self.assertEqual(err_str, str(e))
test(m)
test(traced)
def test_remove_raise(self):
"""
Test that we can trace `if x: raise y(msg)` and then remove the exception_wrapper.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
if torch.equal(a, b):
raise AssertionError("a equaled b!")
return a
m = TestModule()
in_a, in_b = torch.randn(5), torch.randn(5)
traced = acc_tracer.trace(
m,
[in_a, in_b],
remove_exceptions=True,
ast_rewriter_allow_list={TestModule},
)
# Verify the structure of the graph, including the existence of the
# exception_wrapper.
ph_a = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
else:
self.assertTrue(str(node.target) == "b")
elif node.op == "output":
self.assertEqual(ph_a, node.args[0])
else:
# Should not encounter any call_modules, e.g. to the
# exception_wrapper.
self.assertFalse(node.op == "call_module")
# Note: Using input in_a twice for the tracer version, which would
# trigger the raise if it was still there.
self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_a)))
def test_raise_no_message(self):
"""
Test that we can trace `if x: raise y` when `y` has no message.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
if torch.equal(a, b):
raise AssertionError
return a
m = TestModule()
in_a, in_b = torch.randn(5), torch.randn(5)
traced = acc_tracer.trace(
m,
[in_a, in_b],
remove_exceptions=False,
use_acc_normalization=False,
ast_rewriter_allow_list={TestModule},
)
# Verify the structure of the graph, including the existence of the
# exception_wrapper.
ph_a = exception_wrapper = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
else:
self.assertTrue(str(node.target) == "b")
elif node.op == "call_module":
self.assertEqual(
node.target, "_conditional_exception_wrapper_AssertionError"
)
exception_wrapper = node
elif node.op == "output":
self.assertEqual(ph_a, node.args[0])
self.assertTrue(exception_wrapper is not None)
self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_b)))
def test_quantized_add(self):
"""
Test that a quantized_add and acc_ops.quantize_per_tensor are traced as expected,
verifying the acc_out_tys are set as expected.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.q_input = torch.nn.quantized.Quantize(
scale=1.0 / 128, zero_point=5, dtype=torch.quint8
)
self.q_other = torch.nn.quantized.Quantize(
scale=1.0 / 128, zero_point=10, dtype=torch.quint8
)
def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.add(
self.q_input(input),
self.q_other(other),
scale=0.05,
zero_point=1,
)
m = TestModule()
input, other = torch.randn(2, 3, 4), torch.randn(2, 3, 4)
traced = acc_tracer.trace(m, [input, other])
input_ph = other_ph = q_input = q_other = q_add = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "input":
input_ph = node
else:
self.assertTrue(str(node.target) == "other")
other_ph = node
elif (
node.op == "call_function"
and node.target == acc_ops.quantize_per_tensor
):
qparams = {
"scale": 1.0 / 128,
"zero_point": 5,
}
expected_md = acc_utils.build_raw_tensor_meta(
dtype=torch.quint8,
qparams=qparams,
)
if node.kwargs["input"] == input_ph:
q_input = node
else:
self.assertTrue(node.kwargs["input"] == other_ph)
q_other = node
qparams_copy = qparams.copy()
qparams_copy["zero_point"] = 10
expected_md = expected_md._replace(qparams=qparams_copy)
self.assertEqual(node.kwargs["acc_out_ty"], expected_md)
elif node.op == "call_function" and node.target == acc_ops.quantized_add:
self.assertEqual(node.kwargs["input"], q_input)
self.assertEqual(node.kwargs["other"], q_other)
qparams = {
"scale": 0.05,
"zero_point": 1,
}
expected_md = acc_utils.build_raw_tensor_meta(qparams=qparams)
self.assertEqual(node.kwargs["acc_out_ty"], expected_md)
q_add = node
elif node.op == "output":
self.assertEqual(q_add, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input, other), traced(input, other)))
def test_quantized_mul(self):
"""
Test that a quantized_mul and acc_ops.quantize_per_tensor are traced as expected,
verifying the acc_out_tys are set as expected.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.q_input = torch.nn.quantized.Quantize(
scale=1.0 / 128, zero_point=5, dtype=torch.quint8
)
self.q_other = torch.nn.quantized.Quantize(
scale=1.0 / 128, zero_point=10, dtype=torch.quint8
)
def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.mul(
self.q_input(input),
self.q_other(other),
scale=0.05,
zero_point=1,
)
m = TestModule()
input, other = torch.randn(2, 3, 4), torch.randn(2, 3, 4)
traced = acc_tracer.trace(m, [input, other])
input_ph = other_ph = q_input = q_other = q_add = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "input":
input_ph = node
else:
self.assertTrue(str(node.target) == "other")
other_ph = node
elif (
node.op == "call_function"
and node.target == acc_ops.quantize_per_tensor
):
qparams = {
"scale": 1.0 / 128,
"zero_point": 5,
}
expected_md = acc_utils.build_raw_tensor_meta(
dtype=torch.quint8,
qparams=qparams,
)
if node.kwargs["input"] == input_ph:
q_input = node
else:
self.assertTrue(node.kwargs["input"] == other_ph)
q_other = node
qparams_copy = qparams.copy()
qparams_copy["zero_point"] = 10
expected_md = expected_md._replace(qparams=qparams_copy)
self.assertEqual(node.kwargs["acc_out_ty"], expected_md)
elif node.op == "call_function" and node.target == acc_ops.quantized_mul:
self.assertEqual(node.kwargs["input"], q_input)
self.assertEqual(node.kwargs["other"], q_other)
qparams = {
"scale": 0.05,
"zero_point": 1,
}
expected_md = acc_utils.build_raw_tensor_meta(qparams=qparams)
self.assertEqual(node.kwargs["acc_out_ty"], expected_md)
q_add = node
elif node.op == "output":
self.assertEqual(q_add, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input, other), traced(input, other)))
def test_cat(self):
"""
Test that torch.cat is traced correctly.
"""
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return torch.cat([a, a, b], 0)
m = TestModule()
a, b = torch.randn(2, 2), torch.randn(2, 2)
traced = acc_tracer.trace(m, (a, b))
ph_a = ph_b = cat = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
else:
self.assertTrue(str(node.target) == "b")
ph_b = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.cat)
self.assertEqual(node.kwargs["tensors"][0], ph_a)
self.assertEqual(node.kwargs["tensors"][1], ph_a)
self.assertEqual(node.kwargs["tensors"][2], ph_b)
self.assertEqual(node.kwargs["dim"], 0)
cat = node
elif node.op == "output":
self.assertEqual(cat, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(a, b), traced(a, b)))
def test_square(self):
"""
Test that torch.square is traced correctly.
"""
self._make_acc_op_function_test(acc_ops.mul, torch.square)
def test_reshape(self):
"""
Test that torch.reshape is traced correctly.
"""
self._make_acc_op_function_test(acc_ops.reshape, torch.reshape, (1, -1))
# arg = (1, -1)
self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.reshape(1, -1))
# arg = ((1, -1))
self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.reshape((1, -1)))
def test_transpose(self):
"""
Test that torch.transpose is traced correctly.
"""
self._make_acc_op_function_test(
acc_ops.permute, lambda x: torch.transpose(x, 1, 0)
)
def test_permute(self):
"""
Test that torch.permute is traced correctly.
"""
def torch_permute(a, *dim):
return a.permute(*dim)
self._make_acc_op_function_test(acc_ops.permute, torch_permute, 1, 0)
def test_min_full_reduce(self):
"""
Test that test_min_full_reduce is traced correctly.
"""
self._make_acc_op_function_test(acc_ops.min_full_reduce, torch.min)
def test_matmul(self):
"""
Test that torch.matmul is traced correctly.
"""
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return torch.matmul(a, b)
m = TestModule()
a, b = torch.randn(2, 2), torch.randn(2, 2)
traced = acc_tracer.trace(m, [a, b])
ph_a = ph_b = matmul = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
else:
self.assertTrue(str(node.target) == "b")
ph_b = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.matmul)
self.assertEqual(node.kwargs["input"], ph_a)
self.assertEqual(node.kwargs["other"], ph_b)
matmul = node
elif node.op == "output":
self.assertEqual(matmul, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(a, b), traced(a, b)))
def test_bmm(self):
self._make_acc_op_function_test(
acc_ops.matmul, lambda x: torch.bmm(x, x), input_shape=(2, 4, 4)
)
def test_tile(self):
return self._make_acc_op_function_test(
acc_ops.tile, lambda x: torch.tile(x, (2, 1, 2)), input_shape=(1, 2)
)
def test_dropout(self):
self._make_acc_op_function_test(
None,
lambda x: nn.functional.dropout(x, training=False),
input_shape=(1, 2, 3),
)
def test_hardsigmoid(self):
self._make_acc_op_function_test(
acc_ops.hardsigmoid,
lambda x: nn.functional.hardsigmoid(x),
input_shape=(3, 4, 5),
)
def test_hardtanh(self):
self._make_acc_op_function_test(
acc_ops.hardtanh,
lambda x: nn.functional.hardtanh(x),
input_shape=(3, 4, 5),
)
def test_hardswish(self):
class TestModule(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = nn.functional.hardswish(x)
return y
m = TestModule()
x = torch.randn(3, 4, 5)
traced = acc_tracer.trace(m, x)
ph_x = hardsigmoid_y = res_y = None
for node in traced.graph.nodes:
if node.op == "placeholder":
ph_x = node
elif node.op == "call_function" and node.target == acc_ops.hardsigmoid:
hardsigmoid_y = node
self.assertEqual(node.kwargs["input"], ph_x)
elif node.op == "call_function" and node.target == acc_ops.mul:
res_y = node
self.assertEqual(node.kwargs["input"], hardsigmoid_y)
self.assertEqual(node.kwargs["other"], ph_x)
elif node.op == "output":
self.assertEqual(node.args[0], res_y)
else:
self.fail(f"Unexpected node: {node.format_node()}")
ref = m(x)
res = traced(x)
torch.testing.assert_allclose(ref, res)
def test_add_with_alpha(self):
"""
Test that normalization works for torch add with alpha, which requires special
normalization handling.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
a1 = torch.add(a, b)
a2 = torch.add(a, b, alpha=1.0)
a3 = torch.add(a, b, alpha=0.5)
return a1, a2, a3
m = TestModule()
input_a = torch.randn(2, 3)
input_b = torch.randn(2, 3)
traced = acc_tracer.trace(m, [input_a, input_b])
ph_a = ph_b = add_1 = add_2 = add_3 = mul = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
elif str(node.target) == "b":
ph_b = node
else:
self.fail(f"Unexpected placeholder {node.target}.")
elif node.op == "call_function" and node.target == acc_ops.mul:
mul = node
self.assertEqual(node.kwargs["input"], ph_b)
self.assertEqual(node.kwargs["other"], 0.5)
elif node.op == "call_function" and node.target == acc_ops.add:
if add_1 is None:
add_1 = node
self.assertEqual(node.kwargs["input"], ph_a)
self.assertEqual(node.kwargs["other"], ph_b)
elif add_2 is None:
add_2 = node
self.assertEqual(node.kwargs["input"], ph_a)
self.assertEqual(node.kwargs["other"], ph_b)
elif add_3 is None:
add_3 = node
self.assertEqual(node.kwargs["input"], ph_a)
self.assertEqual(node.kwargs["other"], mul)
else:
self.fail(f"Unexpected add: {node.format_node()}")
elif node.op == "output":
self.assertEqual(node.args[0][0], add_1)
self.assertEqual(node.args[0][1], add_2)
self.assertEqual(node.args[0][2], add_3)
else:
self.fail(f"Unexpected node: {node.format_node()}")
ref = m(input_a, input_b)
res = traced(input_a, input_b)
self.assertTrue(torch.equal(ref[0], res[0]))
self.assertTrue(torch.equal(ref[1], res[1]))
self.assertTrue(torch.equal(ref[2], res[2]))
def test_leaf_module_list(self):
"""
Test leaf_module_list is working properly.
"""
class LeafModule(nn.Module):
def forward(self, x):
return x
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.mod = LeafModule()
def forward(self, x):
return self.mod(x)
x = torch.randn(1, 1)
mod = TestModule()
acc_mod = acc_tracer.trace(
mod,
[x],
leaf_module_list={LeafModule},
)
ph = leaf_module = None
for node in acc_mod.graph.nodes:
if node.op == "placeholder":
ph = node
elif node.op == "call_module":
leaf_module = node
self.assertEqual(leaf_module.target, "mod")
self.assertEqual(leaf_module.args[0], ph)
elif node.op == "output":
self.assertEqual(node.args[0], leaf_module)
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(mod(x), acc_mod(x)))
def test_sign(self):
self._make_acc_op_function_test(acc_ops.sign, torch.sign)
def test_relu(self):
self._make_acc_op_function_test(acc_ops.relu, torch.relu)
def test_sigmoid(self):
self._make_acc_op_function_test(acc_ops.sigmoid, torch.sigmoid)
def test_sin(self):
self._make_acc_op_function_test(acc_ops.sin, torch.sin)
def test_cos(self):
self._make_acc_op_function_test(acc_ops.cos, torch.cos)
def test_tan(self):
self._make_acc_op_function_test(acc_ops.tan, torch.tan)
def test_sinh(self):
self._make_acc_op_function_test(acc_ops.sinh, torch.sinh)
def test_cosh(self):
self._make_acc_op_function_test(acc_ops.cosh, torch.cosh)
def test_tanh(self):
self._make_acc_op_function_test(acc_ops.tanh, torch.tanh)
def test_asin(self):
self._make_acc_op_function_test(acc_ops.asin, torch.asin)
def test_acos(self):
self._make_acc_op_function_test(acc_ops.acos, torch.acos)
def test_atan(self):
self._make_acc_op_function_test(acc_ops.atan, torch.atan)
def test_exp(self):
self._make_acc_op_function_test(acc_ops.exp, torch.exp)
def test_log(self):
self._make_acc_op_function_test(acc_ops.log, torch.log)
def test_sqrt(self):
self._make_acc_op_function_test(acc_ops.sqrt, torch.sqrt)
def test_reciprocal(self):
self._make_acc_op_function_test(acc_ops.reciprocal, torch.reciprocal)
def test_abs(self):
self._make_acc_op_function_test(acc_ops.abs, torch.abs)
def test_neg(self):
self._make_acc_op_function_test(acc_ops.neg, torch.neg)
def test_floor(self):
self._make_acc_op_function_test(acc_ops.floor, torch.floor)
def test_ceil(self):
self._make_acc_op_function_test(acc_ops.ceil, torch.ceil)
def test_softmax(self):
self._make_acc_op_function_test(acc_ops.softmax, torch.nn.functional.softmax)
def test_tensor_squeeze(self):
self._make_acc_op_function_test(acc_ops.squeeze, lambda x: x.squeeze())
def test_torch_squeeze(self):
self._make_acc_op_function_test(acc_ops.squeeze, lambda x: torch.squeeze(x))
def test_operator_mul(self):
self._make_acc_op_function_test(acc_ops.mul, lambda x: x * 7)
def test_torch_mul(self):
self._make_acc_op_function_test(acc_ops.mul, lambda x: torch.mul(x, 7))
def test_view(self):
"""
Test that Tensor.view is traced correctly.
"""
self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.view(1, -1))
def test_narrow(self):
"""
Test that torch.narrow is traced correctly.
"""
return self._make_acc_op_function_test(
acc_ops.slice_tensor,
torch.narrow,
validate_same_kwargs=False,
dim=1,
start=1,
length=2,
)
def test_pow(self):
self._make_acc_op_function_test(acc_ops.pow, torch.pow, exponent=2)
def test_size(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a):
idx = a.size(1)
return a.shape[idx]
m = TestModule()
a = torch.randn(2, 1, 4)
traced = acc_tracer.trace(m, [a])
ph_a = size_1 = size_2 = getitem_1 = getitem_2 = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertTrue(node.target == "a")
ph_a = node
elif node.op == "call_function" and node.target == acc_ops.size:
if size_1:
size_2 = node
self.assertTrue(size_2.kwargs["input"] is ph_a)
else:
size_1 = node
self.assertTrue(size_1.kwargs["input"] is ph_a)
elif node.op == "call_function" and node.target == acc_ops.getitem:
if getitem_1:
getitem_2 = node
self.assertTrue(getitem_2.kwargs["idx"] == getitem_1)
self.assertTrue(getitem_2.kwargs["input"] == size_2)
else:
getitem_1 = node
self.assertTrue(getitem_1.kwargs["idx"] == 1)
self.assertTrue(getitem_1.kwargs["input"] == size_1)
elif node.op == "output":
self.assertEqual(node.args[0], getitem_2)
else:
self.fail(f"Unexpected node: {node.format_node()}")
ref = m(a)
res = traced(a)
self.assertEqual(ref, res)
def test_flatten(self):
"""
Test that torch.flatten is traced correctly.
"""
self._make_acc_op_function_test(
acc_ops.flatten, torch.flatten, start_dim=1, end_dim=1
)
self._make_acc_op_function_test(acc_ops.flatten, lambda x: x.flatten())
def test_topk_multi_output(self):
"""
Test that torch.topk multi outputs work.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a: torch.Tensor) -> torch.Tensor:
return torch.topk(a, 3)[1]
m = TestModule()
input_a = torch.randn(10)
traced = acc_tracer.trace(m, [input_a])
ph_a = topk = getitem = None
for node in traced.graph.nodes:
if node.op == "placeholder" and str(node.target) == "a":
ph_a = node
elif node.op == "call_function" and node.target == acc_ops.topk:
topk = node
self.assertEqual(node.kwargs["input"], ph_a)
self.assertEqual(node.kwargs["k"], 3)
elif node.op == "call_function" and node.target == acc_ops.getitem:
getitem = node
self.assertEqual(node.kwargs["input"], topk)
self.assertEqual(node.kwargs["idx"], 1)
elif node.op == "output":
self.assertEqual(node.args[0], getitem)
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input_a), traced(input_a)))
def test_addmm_with_alpha_beta(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
return torch.addmm(input, a, b, alpha=1.2, beta=1.1)
m = TestModule()
input, a, b = torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)
traced = acc_tracer.trace(m, [input, a, b])
ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
elif str(node.target) == "b":
ph_b = node
else:
self.assertTrue(str(node.target) == "input")
ph_in = node
elif node.op == "call_function":
if node.target == acc_ops.matmul:
self.assertEqual(node.kwargs["input"], ph_a)
self.assertEqual(node.kwargs["other"], ph_b)
mm = node
elif node.target == acc_ops.add:
self.assertEqual(node.kwargs["input"], mm_mul)
self.assertEqual(node.kwargs["other"], add_mul)
add = node
elif mm_mul:
self.assertEqual(node.kwargs["input"], ph_in)
self.assertEqual(node.kwargs["other"], 1.1)
add_mul = node
else:
self.assertEqual(node.kwargs["input"], mm)
self.assertEqual(node.kwargs["other"], 1.2)
mm_mul = node
elif node.op == "output":
self.assertEqual(add, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
torch.testing.assert_allclose(m(input, a, b), traced(input, a, b))
def test_log1p(self):
class TestModule(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.log1p(input)
m = TestModule().eval()
input = torch.tensor([[1.2, 0.3, -0.4]])
traced = acc_tracer.trace(m, [input])
ph_in = add = log = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertTrue(str(node.target) == "input")
ph_in = node
elif node.op == "call_function":
if node.target == acc_ops.add:
self.assertEqual(node.kwargs["input"], ph_in)
self.assertEqual(node.kwargs["other"], 1)
add = node
else:
self.assertEqual(node.target, acc_ops.log)
self.assertEqual(node.kwargs["input"], add)
log = node
elif node.op == "output":
self.assertEqual(log, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
torch.testing.assert_allclose(m(input), traced(input))
def test_addmm(self):
class TestModule(torch.nn.Module):
def forward(
self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
return torch.addmm(input, a, b)
m = TestModule()
input, a, b = torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)
traced = acc_tracer.trace(m, [input, a, b])
ph_in = ph_a = ph_b = mm = add = None
for node in traced.graph.nodes:
if node.op == "placeholder":
if str(node.target) == "a":
ph_a = node
elif str(node.target) == "b":
ph_b = node
else:
self.assertTrue(str(node.target) == "input")
ph_in = node
elif node.op == "call_function":
if node.target == acc_ops.matmul:
self.assertEqual(node.kwargs["input"], ph_a)
self.assertEqual(node.kwargs["other"], ph_b)
mm = node
else:
self.assertEqual(node.target, acc_ops.add)
self.assertEqual(node.kwargs["input"], mm)
self.assertEqual(node.kwargs["other"], ph_in)
add = node
elif node.op == "output":
self.assertEqual(add, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
self.assertTrue(torch.equal(m(input, a, b), traced(input, a, b)))
def test_gelu(self):
return self._make_acc_op_function_test(acc_ops.gelu, torch.nn.functional.gelu)
@parameterized.expand(
[
(1, True),
(1, False),
(None, False),
]
)
def test_argmin(self, dim, keepdim):
class TestModule(torch.nn.Module):
def __init__(self, dim, keepdim):
super().__init__()
self.dim = dim
self.keepdim = keepdim
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.argmin(input, dim=self.dim, keepdim=self.keepdim)
m = TestModule(dim, keepdim)
input = torch.randn(2, 2)
traced = acc_tracer.trace(m, [input])
ph_in = flatten = topk = getitem = squeeze = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertTrue(str(node.target) == "input")
ph_in = node
elif node.op == "call_function":
if node.target == acc_ops.flatten:
self.assertEqual(node.kwargs["input"], ph_in)
flatten = node
elif node.target == acc_ops.topk:
self.assertEqual(
node.kwargs["input"], flatten if flatten else ph_in
)
topk = node
elif node.target == acc_ops.getitem:
self.assertEqual(node.kwargs["input"], topk)
getitem = node
elif node.target == acc_ops.squeeze:
self.assertEqual(node.kwargs["input"], getitem)
squeeze = node
elif node.op == "output":
self.assertEqual(squeeze if squeeze else getitem, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
if dim is None:
self.assertTrue(flatten is not None)
if not keepdim:
self.assertTrue(squeeze is not None)
self.assertTrue(torch.equal(m(input), traced(input)))
def test_t(self):
"""
Test Tensor.t() is traced correctly.
"""
self._make_acc_op_function_test(acc_ops.permute, lambda x: x.t())
self._make_acc_op_function_test(
acc_ops.permute, lambda x: x.t(), input_shape=(3,)
)
def test_split_size(self):
self._make_acc_op_function_test(
acc_ops.split,
torch.split,
validate_same_kwargs=False,
split_size_or_sections=2,
dim=1,
)
def test_split_sections(self):
class TestModule(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.split(input, [2, 5, 3], 1)
m = TestModule()
input = torch.randn(1, 10)
traced = acc_tracer.trace(m, [input])
ph_in = slice_node_0 = slice_node_1 = slice_node_2 = None
tuple_construct_node = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertTrue(str(node.target) == "input")
ph_in = node
elif node.op == "call_function":
if node.target == acc_ops.slice_tensor:
self.assertEqual(node.kwargs["input"], ph_in)
if slice_node_0:
if slice_node_1:
slice_node_2 = node
else:
slice_node_1 = node
else:
slice_node_0 = node
else:
self.assertEqual(node.target, acc_ops.tuple_construct)
self.assertEqual(
node.kwargs["tensors"],
(slice_node_0, slice_node_1, slice_node_2),
)
tuple_construct_node = node
elif node.op == "output":
self.assertEqual(tuple_construct_node, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
ref_output = m(input)
output = traced(input)
for i, j in zip(ref_output, output):
self.assertTrue(torch.equal(i, j))
def test_list_input(self):
"""
Test that list inputs are traced correctly.
"""
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a: List[torch.Tensor]) -> torch.Tensor:
return a[0] + a[1]
m = TestModule()
input = [torch.randn(2, 3), torch.randn(2, 3)]
traced = acc_tracer.trace(m, [input])
ph = getitem_0 = getitem_1 = add = None
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "call_function" and node.target == acc_ops.getitem:
self.assertTrue(node.kwargs["idx"] == 0 or node.kwargs["idx"] == 1)
if node.kwargs["idx"] == 0:
getitem_0 = node
else:
getitem_1 = node
elif node.op == "call_function":
self.assertEqual(node.target, acc_ops.add)
self.assertEqual(node.kwargs["input"], getitem_0)
self.assertEqual(node.kwargs["other"], getitem_1)
add = node
elif node.op == "output":
self.assertEqual(add, node.args[0])
else:
self.fail(f"Unexpected node: {node.format_node()}")
# Check the tensor metadatas are correct given the input is a list.
self.assertTrue(isinstance(ph.meta["tensor_meta"], list))
self.assertEqual(len(ph.meta["tensor_meta"]), 2)
self.assertEqual(getitem_0.meta["tensor_meta"], ph.meta["tensor_meta"][0])
self.assertEqual(getitem_1.meta["tensor_meta"], ph.meta["tensor_meta"][1])
self.assertTrue(torch.equal(m(input), traced(input)))
def test_mobilenet_v3(self):
"""
Test that we can trace mobilenet v3 small and run/compare against the untraced version.
"""
m = torchvision.models.mobilenet_v3_small(pretrained=True)
self._make_model_unit_test(m, enable_allclose=True)
def test_mobilenet_v2(self):
"""
Test that we can trace mobilenet v2 small and run/compare against the untraced version.
"""
m = torchvision.models.mobilenet_v2(pretrained=True)
self._make_model_unit_test(m)
def test_vgg16(self):
"""
Test that we can trace vgg16 and run/compare against the untraced version.
"""
m = torchvision.models.vgg16(pretrained=True)
self._make_model_unit_test(m)
def test_resnet18(self):
"""
Test that we can trace resnet18 and run/compare against the untraced version.
"""
m = torchvision.models.resnet18(pretrained=True)
self._make_model_unit_test(m)
def test_resnext50_32x4d(self):
"""
Test that we can trace resnext and run/compare against the untraced version.
"""
m = torchvision.models.resnext50_32x4d(pretrained=True)
self._make_model_unit_test(m)
def test_cumsum(self):
self._make_acc_op_function_test(acc_ops.cumsum, torch.cumsum, dim=1)
self._make_acc_op_function_test(
acc_ops.cumsum, torch.cumsum, dim=1, dtype=torch.float
)
def test_chunk(self):
self._make_acc_op_function_test(acc_ops.chunk, torch.chunk, chunks=2, dim=0)
def test_all_acc_ops_registered(self):
self.assertEqual(
acc_normalizer._acc_ops,
{
acc_ops.linear,
acc_ops.max_pool2d,
acc_ops.flatten,
acc_ops.adaptive_avg_pool2d,
acc_ops.avg_pool2d,
acc_ops.add,
acc_ops.min_full_reduce,
acc_ops.min_dim_reduce,
acc_ops.minimum,
acc_ops.cat,
acc_ops.softmax,
acc_ops.sign,
acc_ops.permute,
acc_ops.matmul,
acc_ops.quantize_per_tensor,
acc_ops.quantize_per_channel,
acc_ops.quantized_add,
acc_ops.quantized_mul,
acc_ops.dequantize,
acc_ops.sub,
acc_ops.mul,
acc_ops.div,
acc_ops.pow,
acc_ops.relu,
acc_ops.tuple_construct,
acc_ops.unsqueeze,
acc_ops.sigmoid,
acc_ops.sum,
acc_ops.max_full_reduce,
acc_ops.max_dim_reduce,
acc_ops.maximum,
acc_ops.sinh,
acc_ops.cosh,
acc_ops.tanh,
acc_ops.asin,
acc_ops.acos,
acc_ops.atan,
acc_ops.exp,
acc_ops.log,
acc_ops.sqrt,
acc_ops.reciprocal,
acc_ops.abs,
acc_ops.neg,
acc_ops.floor,
acc_ops.ceil,
acc_ops.size,
acc_ops.split,
acc_ops.conv2d,
acc_ops.batch_norm,
acc_ops.embedding_bag,
acc_ops.embedding_bag_byte_rowwise_offsets,
acc_ops.embedding_bag_4bit_rowwise_offsets,
acc_ops.contiguous,
acc_ops.sin,
acc_ops.cos,
acc_ops.tan,
acc_ops.topk,
acc_ops.getitem,
acc_ops.squeeze,
acc_ops.tile,
acc_ops.reshape,
acc_ops.quantized_linear,
acc_ops.quantized_conv2d,
acc_ops.quantized_batch_norm2d,
acc_ops.to_dtype,
acc_ops.clamp,
acc_ops.layer_norm,
acc_ops.linalg_norm,
acc_ops.slice_tensor,
acc_ops.hardsigmoid,
acc_ops.hardtanh,
acc_ops.gelu,
acc_ops.cumsum,
acc_ops.chunk,
acc_ops.rescale_quantize_per_tensor,
acc_ops.rescale_quantize_per_channel,
},
)