[functorch] hessian API
diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py
index f563581..57b589c 100644
--- a/functorch/functorch/_src/eager_transforms.py
+++ b/functorch/functorch/_src/eager_transforms.py
@@ -633,6 +633,9 @@
return tree_unflatten(jac_outs_ins, spec)
return wrapper_fn
+def hessian(f, argnums=0):
+ return jacfwd(jacrev(f, argnums), argnums)
+
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""
Returns a function to compute a tuple of the gradient and primal, or
diff --git a/functorch/functorch/experimental/__init__.py b/functorch/functorch/experimental/__init__.py
index 7f0bf36..289b2df 100644
--- a/functorch/functorch/experimental/__init__.py
+++ b/functorch/functorch/experimental/__init__.py
@@ -1,2 +1,2 @@
# PyTorch forward-mode is not mature yet
-from .._src.eager_transforms import jvp, jacfwd
+from .._src.eager_transforms import jvp, jacfwd, hessian
diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py
index 6a7175f..eaf19fe 100644
--- a/functorch/test/test_eager_transforms.py
+++ b/functorch/test/test_eager_transforms.py
@@ -28,7 +28,7 @@
functional_init, functional_init_with_buffers,
)
from functorch.experimental import (
- jvp, jacfwd,
+ jvp, jacfwd, hessian,
)
# NB: numpy is a testing dependency!
@@ -811,7 +811,7 @@
assert torch.allclose(y, expected)
@FIXME_jacrev_only
- def test_hessian_simple(self, device, jacapi):
+ def test_nested_jac_simple(self, device, jacapi):
def foo(x):
return x.sin().sum()
@@ -1006,6 +1006,14 @@
with self.assertRaisesRegex(RuntimeError, "must be int"):
z = jacapi(torch.multiply, argnums=(1, 0.0))(x, x)
+ @unittest.expectedFailure
+ def test_hessian_simple(self, device):
+ def f(x):
+ return x.sin()
+
+ x = torch.randn(3, device=device)
+ result = hessian(f)(x)
+
class TestJvp(TestCase):
def test_inplace_on_captures(self, device):
x = torch.tensor([1., 2., 3.], device=device)