blob: 0774c7f52ff51672bb579fcb56b2f57f8f45603a [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import math
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
class CustomFunc1(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
return foo + foo
@staticmethod
def backward(ctx, grad_output):
return grad_output
class CustomFunc3(torch.autograd.Function):
# Test there is graph break in forward function
@staticmethod
def forward(ctx, foo):
result = foo + foo
torch._dynamo.graph_break()
result = result + foo
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * math.sqrt(result.numel())
class Module1(torch.nn.Module):
def forward(self, foo):
return CustomFunc1().apply(foo)
class Module2(torch.nn.Module):
def __init__(self):
super().__init__()
self.fn = CustomFunc1.apply
def forward(self, foo):
return self.fn(foo)
class Module3(torch.nn.Module):
def forward(self, foo):
return CustomFunc1().apply(foo)
class Module4(torch.nn.Module):
def __init__(self):
super().__init__()
self.fn = CustomFunc1.apply
def forward(self, foo):
return self.fn(foo)
class Module5(torch.nn.Module):
def forward(self, foo):
return CustomFunc3().apply(foo)
class Module6(torch.nn.Module):
def __init__(self):
super().__init__()
self.fn = CustomFunc3.apply
def forward(self, foo):
return self.fn(foo)
class LinearFunction(torch.autograd.Function):
# Note that forward, setup_context, and backward are @staticmethods
@staticmethod
def forward(input, weight, bias):
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
# inputs is a Tuple of all of the inputs passed to forward.
# output is the output of the forward().
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class ModuleLinear(torch.nn.Module):
def forward(self, input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)
class MaterializingGradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.set_materialize_grads(False)
return x.clone(), x.clone()
@staticmethod
def backward(ctx, grad_out1, grad_out2):
return grad_out1, grad_out2
class MaterializingGradModule(torch.nn.Module):
def forward(self, x):
return MaterializingGradFunction.apply(x)
class CustomFuncBwdPrintGraphBreak(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
return torch.add(foo, foo)
@staticmethod
def backward(ctx, grad_output):
print("graph break!")
return grad_output
class CustomFuncBwdPrintModule(torch.nn.Module):
def forward(self, x):
return CustomFuncBwdPrintGraphBreak.apply(x)
class CustomFuncStrideBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
return torch.add(foo, foo)
@staticmethod
def backward(ctx, grad_output):
return grad_output.stride()
class CustomFuncStrideModule(torch.nn.Module):
def forward(self, x):
return CustomFuncStrideBwd.apply(x)
class CustomFuncSaveForBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
result = foo + foo
result = result + foo
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * math.sqrt(result.numel())
class SaveForBwdModule(torch.nn.Module):
def forward(self, foo):
return CustomFuncSaveForBwd().apply(foo)
class ContextSaveAndMark(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
with torch.no_grad():
ctx.save_for_backward(x)
ctx.mark_non_differentiable(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output
class ContextMarkAndSave(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
with torch.no_grad():
ctx.mark_non_differentiable(x)
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output
class ModuleWithGradFunc(torch.nn.Module):
def __init__(self, func):
super(ModuleWithGradFunc, self).__init__()
self.f = func.apply
def forward(self, x):
return self.f(x)
class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
# Sound behaviors, tested for working capture
def test_autograd_function_equivalence(self):
for grad in [True, False]:
for i in range(1, 5):
torch._dynamo.reset()
model = globals()[f"Module{i}"]()
opt_model = torch._dynamo.optimize("eager")(model)
self.assertTrue(
torch.allclose(
opt_model(torch.ones(2, 3, requires_grad=grad)),
torch.tensor([2.0], requires_grad=grad),
)
)
def test_autograd_function_has_graph_break(self):
for grad in [True, False]:
x = torch.randn(10, requires_grad=grad)
for model in [Module5(), Module6()]:
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
opt_model = torch._dynamo.optimize(cnts)(model)
for _ in range(3):
ref = model(x)
res = opt_model(x)
self.assertTrue(torch.allclose(ref, res))
self.assertEqual(cnts.frame_count, 2)
def test_linear_setup_context(self):
model = ModuleLinear()
opt_model = torch._dynamo.optimize("eager")(model)
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
optim_result = opt_model(input, weight)
eager_result = model(input, weight)
self.assertEqual(optim_result, eager_result)
def test_materialize_grad(self):
model = MaterializingGradModule()
opt_model = torch._dynamo.optimize("eager")(model)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
optim_result = opt_model(x)
eager_result = model(x)
self.assertEqual(optim_result, eager_result)
def test_print_in_bwd(self):
model = CustomFuncBwdPrintModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, ".*BuiltinVariable\\(print\\).*"
):
opt_model(x)
def test_stride_in_bwd(self):
model = CustomFuncStrideModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Illegal getattr invocation stride in strict mod",
):
opt_model(x)
def test_save_for_bwd(self):
model = SaveForBwdModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
opt_model(x)
def test_classmethod(self):
class Shake(torch.autograd.Function):
@classmethod
def forward(cls, ctx, foo):
return foo + foo
@classmethod
def backward(cls, ctx, grad_output):
return grad_output
def f(x):
return Shake.apply(x)
x = torch.randn(4, 4, 4, 4, requires_grad=True)
opt_m = torch.compile(backend="eager")(f)
opt_m(x)
def test_function_context_save_and_mark(self):
mod = ModuleWithGradFunc(ContextSaveAndMark)
args, kwargs = ([torch.rand([1])], {})
before = mod(*args, **kwargs)
torch._dynamo.reset()
compiled_model = torch._dynamo.optimize("eager")(mod)
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
def test_function_context_mark_and_save(self):
mod = ModuleWithGradFunc(ContextMarkAndSave)
args, kwargs = ([torch.rand([1])], {})
before = mod(*args, **kwargs)
torch._dynamo.reset()
compiled_model = torch._dynamo.optimize("eager")(mod)
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()