blob: 462595c571ad8cc6ee921f435771f526450411f1 [file] [log] [blame]
import os
import sys
import unittest
from common_utils import GRAPH_EXECUTOR, ProfilingMode, enable_profiling_mode
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from jit_utils import JitTestCase, disable_autodiff_subgraph_inlining
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# NB: torch.jit.script, when used as a function, uses the current scope
# to resolve variable names. This function cannot be made local to
# TestAutodiffSubgraphSlicing because those tests call torch.jit.script on functions
# in a different scope than they are defined in.
@torch.jit.ignore
def pyfn(a, b):
return a * b
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients")
class TestAutodiffSubgraphSlicing(JitTestCase):
# TODO: It is better if we can test directly on graphs instead of the current
# end-to-end fashion.
def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
with disable_autodiff_subgraph_inlining():
with enable_profiling_mode():
ge = torch.jit.script(fn)
inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
ge(*inputs, profile_and_replay=True)
return ge.graph_for(*inputs)
def assertGraphSize(self, graph, size):
nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", graph.nodes()))
self.assertEqual(len(list(nodes)), size)
def test_chunk_constant_script_ad(self):
@torch.jit.script
def func(x):
x1, x2 = torch.chunk(x, 2)
return (x1, x2)
input = torch.rand(6, 10).requires_grad_()
with disable_autodiff_subgraph_inlining():
with enable_profiling_mode():
output = func(input, profile_and_replay=True)
self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
def test_simple_merge(self):
# o --> o
def fn(x, y, z):
a = x * y
b = a * z
return b
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphSize(graph, 1)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_simple_no_merge(self):
# o: autodiff supported. x: not autodiff supported.
# o --> x
def fn(x, y, z):
a = x * y
b = pyfn(a, z)
return b
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphSize(graph, 2)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_does_not_merge_unrelated(self):
# o o
def fn(w, x, y, z):
a = x * y
b = w * z
return a, b
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
self.assertGraphSize(graph, 3)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
def test_merges_without_cycles(self):
# o --> o --> o
# | ^
# \_________/
def fn(w, x, y):
a = w * x
b = a * y
c = a * b
return c
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphSize(graph, 1)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_merges_dense(self):
# o o
# |\ /|
# | \ / |
# | /\ |
# vv vv
# o o
def fn(x, y):
a, b = x.chunk(2)
c, d = y.chunk(2)
return a + c, b + d
graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
self.assertGraphSize(graph, 2)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_does_not_create_cycles(self):
# o --> x --> o
# | ^
# \_________/
def fn(w, x, y):
a = w * x
b = pyfn(a, y)
c = a * b
return c
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphSize(graph, 3)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
def test_merges_up(self):
# o --> x o
# | ^
# \_________/
def fn(w, x, y, z):
a = w * x
b = pyfn(a, y)
c = a * z
return b, c
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
self.assertGraphSize(graph, 3)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_merges_down(self):
# o x --> o
# | ^
# \_________/
def fn(v, w, x, y):
a = v * w
b = pyfn(x, y)
c = b * a
return a, c
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
# GuardElimination can't get rid of a prim::BailOut on ^pyfn
# which makes us create two `prim::DifferentiableGraph`s
# instead of just one
num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3
self.assertGraphSize(graph, num_nodes)
num_diff_nodes = 2 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 1
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', num_diff_nodes)
def test_respects_lexical_scoping(self):
def fn(x, k):
y = x * 1.1
if bool(k):
k = k + y
z = y * k
return z, k
graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
# We should not have combined the two multiplications into
# the same group; they should each be a separate DiffGraph
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)