add unsupported section (#31329)
Summary:
Add a section for unsupported ops, and modules. Automatically generate the properties and attributes that aren't bound, and for ops that have semantic mismatches set up tests so the docs stay up to date.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31329
Differential Revision: D19164472
Pulled By: eellison
fbshipit-source-id: 46290bb8a64d9de928cfb1eda5ff4558c3799c88
diff --git a/docs/source/jit.rst b/docs/source/jit.rst
index ba77c1a..74d4414 100644
--- a/docs/source/jit.rst
+++ b/docs/source/jit.rst
@@ -1418,15 +1418,10 @@
TorchScript supports a subset of the tensor and neural network
functions that PyTorch provides. Most methods on Tensor as well as functions in
-the ``torch`` namespace, all functions in ``torch.nn.functional`` and all
-modules from ``torch.nn`` are supported in TorchScript, excluding those in the
-table below. For unsupported modules, we suggest using :meth:`torch.jit.trace`.
+the ``torch`` namespace, all functions in ``torch.nn.functional`` and
+most modules from ``torch.nn`` are supported in TorchScript.
-Unsupported ``torch.nn`` Modules::
-
- torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss
- torch.nn.modules.normalization.CrossMapLRN2d
- torch.nn.modules.rnn.RNN
+See :ref:`jit_unsupported` for a list of unsupported PyTorch functions and modules.
Python Functions and Modules
----------------------------
diff --git a/docs/source/jit_unsupported.rst b/docs/source/jit_unsupported.rst
new file mode 100644
index 0000000..94bd548
--- /dev/null
+++ b/docs/source/jit_unsupported.rst
@@ -0,0 +1,103 @@
+.. _jit_unsupported:
+
+TorchScript Unsupported Pytorch Constructs
+============================================
+
+Torch and Tensor Unsupported Attributes
+------------------------------------------
+
+
+TorchScript supports most methods defined on ``torch`` and ``torch.Tensor``, but we do not have full coverage.
+Here are specific known ops and categories of ops which have diverging behavior between
+Python and TorchScript. If you encounter something else that is not supported please
+file a GitHub issue. Deprecated ops are not listed below.
+
+
+
+.. automodule:: torch.jit.unsupported_tensor_ops
+
+
+Functions Not Correctly Bound on Torch
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The following functions will fail if used in TorchScript, either because they
+are not bound on `torch` or because Python expects a different schema than
+TorchScript.
+
+ * :func:`torch.cdist`
+ * :func:`torch.lu`
+ * :func:`torch.lu_unpack`
+ * :func:`torch.norm`
+ * :func:`torch.tensordot`
+ * :func:`torch.unique`
+ * :func:`torch.unique_consecutive`
+ * :func:`torch.nn.init.calculate_gain`
+ * :func:`torch.nn.init.eye_`
+ * :func:`torch.nn.init.dirac_`
+ * :func:`torch.nn.init.kaiming_normal_`
+ * :func:`torch.nn.init.orthogonal_`
+ * :func:`torch.nn.init.sparse`
+
+
+Ops With Divergent Schemas Between Torch & Python
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The following categories of ops have divergent schemas:
+
+Functions which construct tensors from non-tensor inputs do not support the `requires_grad`
+argument, except for `torch.tensor`. This covers the following ops:
+
+ * :func:`torch.norm`
+ * :func:`torch.bartlett_window`
+ * :func:`torch.blackman_window`
+ * :func:`torch.empty`
+ * :func:`torch.empty_like`
+ * :func:`torch.empty_strided`
+ * :func:`torch.eye`
+ * :func:`torch.full`
+ * :func:`torch.full_like`
+ * :func:`torch.hamming_window`
+ * :func:`torch.hann_window`
+ * :func:`torch.linspace`
+ * :func:`torch.logspace`
+ * :func:`torch.normal`
+ * :func:`torch.ones`
+ * :func:`torch.rand`
+ * :func:`torch.rand_like`
+ * :func:`torch.randint_like`
+ * :func:`torch.randn`
+ * :func:`torch.randn_like`
+ * :func:`torch.randperm`
+ * :func:`torch.tril_indices`
+ * :func:`torch.triu_indices`
+ * :func:`torch.zeros`
+ * :func:`torch.zeros_like`
+
+The following functions require `dtype`, `layout`, `device` as parameters in TorchScript,
+but these parameters are optional in Python.
+
+ * :func:`torch.empty_like`
+ * :func:`torch.full_like`
+ * :func:`torch.ones_like`
+ * :func:`torch.rand_like`
+ * :func:`torch.randint`
+ * :func:`torch.randn_like`
+ * :func:`torch.zeros_like`
+ * :func:`torch.sparse_coo_tensor`
+ * :meth:`~torch.Tensor.to`
+
+
+PyTorch Unsupported Modules and Classes
+------------------------------------------
+
+TorchScript cannot currently compile a number of other commonly used PyTorch
+constructs. Below are listed the modules that TorchScript does not support, and
+an incomplete list of PyTorch classes that are not supported. For unsupported modules
+we suggest using :meth:`torch.jit.trace`.
+
+ * :class:`torch.nn.RNN`
+ * :class:`torch.nn.AdaptiveLogSoftmaxWithLoss`
+ * :class:`torch.autograd.Function`
+ * :class:`torch.autograd.no_grad`
+ * :class:`torch.autograd.enable_grad`
+ * :class:`torch._C.Generator`
diff --git a/test/jit/unsupported_ops.py b/test/jit/unsupported_ops.py
new file mode 100644
index 0000000..76d7170
--- /dev/null
+++ b/test/jit/unsupported_ops.py
@@ -0,0 +1,128 @@
+import os
+import sys
+from textwrap import dedent
+
+import torch
+
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+from jit_utils import JitTestCase, execWrapper
+
+if __name__ == '__main__':
+ raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
+ "\tpython test/test_jit.py TESTNAME\n\n"
+ "instead.")
+
+# NOTE: FIXING FAILING TESTS
+# If you are seeing a test failure from this file, congrats, you improved
+# parity between JIT and Python API. Before you fix the test, you must also update
+# the corresponding section in documentation that states the unsupported behavior.
+# see: `jit_unsupported.rst`
+
+class TestUnsupportedOps(JitTestCase):
+ def test_factory_ops_requires_grad_fail(self):
+ # Keyword argument {name} unknown is a JIT-only error message,
+ # so these functions are succeeding in eager and failing in JIT
+
+ # Complete issue and set of ops is https://github.com/pytorch/pytorch/issues/30761
+ # only testing some because they should be fixed all at once
+ def ones():
+ return torch.ones([2], requires_grad=True)
+
+ def randn():
+ return torch.randn([2], requires_grad=True)
+
+ def zeros():
+ return torch.zeros([2], requires_grad=True)
+
+ for func in [ones, randn, zeros]:
+ func()
+ with self.assertRaisesRegex(Exception, "Keyword argument requires_grad unknown"):
+ torch.jit.script(func)
+
+ def test_tensor_options_behavior_mismatch(self):
+ # Any schema declaration which contains a non-optional (ScalarType dtype, Layout layout, Device device)
+ # tuple is implicitly made to be optional for pytorch eager code. This makes the schema incorrect for JIT / C++ api.
+
+ # Complete issue and set of ops is https://github.com/pytorch/pytorch/issues/30763
+ # only testing one here because they should be fixed all at once
+
+ with self.assertRaisesRegex(Exception, "Argument layout not provided."):
+ def foo(x):
+ return torch.ones_like(x, dtype=torch.double)
+ foo(torch.tensor([2.]))
+ print(torch.jit.script(foo).graph)
+
+ def test_ops_bound_in_functional(self):
+ ops_bound_in_functional = "lu_unpack", "unique", "lu"
+
+ tensor = torch.tensor([2])
+ funcs_template = dedent('''
+ def func():
+ return torch.{op}()
+ ''')
+ for op in ops_bound_in_functional:
+ funcs_str = funcs_template.format(op=op)
+ scope = {}
+ execWrapper(funcs_str, globals(), scope)
+ f = scope['func']
+ with self.assertRaisesRegex(Exception, "Unknown builtin op"):
+ cu = torch.jit.CompilationUnit(funcs_str)
+
+ def fn():
+ a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
+ b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
+ return torch.cdist(a, b, compute_mode="use_mm_for_euclid_dist")
+ fn()
+ with self.assertRaisesRegex(Exception, "Expected a value of type"):
+ torch.jit.script(fn)
+
+ def norm():
+ c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float)
+ return torch.norm(c, p="fro")
+
+ norm()
+ with self.assertRaisesRegex(Exception, "Expected a value of type"):
+ torch.jit.script(norm)
+
+ def unique_consec():
+ x = torch.tensor([1])
+ return torch.unique_consecutive(x, return_inverse=False, return_counts=True, dim=0)
+
+ self.assertNotEqual(unique_consec(), torch.jit.script(unique_consec)())
+
+ def tensordot():
+ a = torch.arange(60.).reshape(3, 4, 5)
+ b = torch.arange(24.).reshape(4, 3, 2)
+ torch.tensordot(a, b, dims=([1, 0], [0, 1]))
+
+ tensordot()
+ with self.assertRaisesRegex(Exception, "Argument dims_self"):
+ torch.jit.script(tensordot)
+
+
+ def test_init_ops(self):
+ def calculate_gain():
+ return torch.nn.init.calculate_gain('leaky_relu', 0.2)
+
+ def eye_():
+ return torch.nn.init.eye_(torch.zeros([2, 2]))
+
+ def dirac_():
+ return torch.nn.init.dirac_(torch.empty(3, 16, 5, 5))
+
+ def kaiming_uniform_():
+ return torch.nn.init.kaiming_normal_(torch.empty(3, 5))
+
+ def orthogonal_():
+ return torch.nn.init.orthogonal_(torch.empty(3, 5))
+
+ def sparse():
+ return torch.nn.init.sparse(torch.empty(3, 5), sparsity=.1)
+
+ for func in [calculate_gain, eye_, dirac_, kaiming_uniform_, orthogonal_, sparse]:
+ # doesn't error in eager
+ func()
+ with self.assertRaisesRegex(Exception, ""):
+ torch.jit.script(func)
diff --git a/test/test_jit.py b/test/test_jit.py
index a3243e5..c853bc5 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -21,6 +21,7 @@
from jit.test_export_modes import TestExportModes # noqa: F401
from jit.test_class_type import TestClassType # noqa: F401
from jit.test_builtins import TestBuiltins # noqa: F401
+from jit.unsupported_ops import TestUnsupportedOps # noqa: F401
# Torch
from torch import Tensor
diff --git a/torch/jit/unsupported_tensor_ops.py b/torch/jit/unsupported_tensor_ops.py
new file mode 100644
index 0000000..7a9c4dd
--- /dev/null
+++ b/torch/jit/unsupported_tensor_ops.py
@@ -0,0 +1,56 @@
+import torch.jit
+from textwrap import dedent
+from torch._six import PY2
+
+def execWrapper(code, glob, loc):
+ if PY2:
+ exec(code) in glob, loc
+ else:
+ exec(code, glob, loc)
+
+def _gen_unsupported_methods_properties():
+ tensor_attrs = set(filter(lambda x: x[0] != "_", dir(torch.Tensor)))
+ tensor = torch.tensor([2])
+ funcs_template = dedent('''
+ def func(x):
+ return x.{op}()
+ ''')
+
+ deprecated_apis = set(["volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"])
+ tensor_attrs = tensor_attrs - deprecated_apis
+
+ properties = []
+ methods = []
+ sorted_tensor_attrs = sorted(list(tensor_attrs), key=lambda x: x.lower())
+ for attr in sorted_tensor_attrs:
+ funcs_str = funcs_template.format(op=attr)
+ scope = {}
+ execWrapper(funcs_str, globals(), scope)
+ try:
+ cu = torch.jit.CompilationUnit(funcs_str)
+ except Exception as e:
+ if "nonexistent attribute" not in repr(e):
+ continue
+ attr_repr = repr(getattr(tensor, attr))
+ if "bound method" in attr_repr or "built-in method" in attr_repr:
+ methods.append(attr)
+ else:
+ properties.append(attr)
+
+ methods = map(lambda x: "\t* :meth:`~torch.Tensor." + x + r"`", methods)
+ properties = map(lambda x: "\t* :attr:`~torch.Tensor." + x + r"`", properties)
+ return "\n".join(methods), "\n".join(properties)
+
+
+def _list_unsupported_tensor_ops():
+ header = """\n\n
+Unsupported Tensor Methods
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ """
+ methods, properties = _gen_unsupported_methods_properties()
+ return header + "\n" + methods + """
+Unsupported Tensor Properties
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ """ + "\n" + properties
+
+__doc__ = _list_unsupported_tensor_ops()