[functorch] hessian API
diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py
index f563581..57b589c 100644
--- a/functorch/functorch/_src/eager_transforms.py
+++ b/functorch/functorch/_src/eager_transforms.py
@@ -633,6 +633,9 @@
         return tree_unflatten(jac_outs_ins, spec)
     return wrapper_fn
 
+def hessian(f, argnums=0):
+    return jacfwd(jacrev(f, argnums), argnums)
+
 def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
     """
     Returns a function to compute a tuple of the gradient and primal, or
diff --git a/functorch/functorch/experimental/__init__.py b/functorch/functorch/experimental/__init__.py
index 7f0bf36..289b2df 100644
--- a/functorch/functorch/experimental/__init__.py
+++ b/functorch/functorch/experimental/__init__.py
@@ -1,2 +1,2 @@
 # PyTorch forward-mode is not mature yet
-from .._src.eager_transforms import jvp, jacfwd
+from .._src.eager_transforms import jvp, jacfwd, hessian
diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py
index 6a7175f..eaf19fe 100644
--- a/functorch/test/test_eager_transforms.py
+++ b/functorch/test/test_eager_transforms.py
@@ -28,7 +28,7 @@
     functional_init, functional_init_with_buffers,
 )
 from functorch.experimental import (
-    jvp, jacfwd,
+    jvp, jacfwd, hessian,
 )
 
 # NB: numpy is a testing dependency!
@@ -811,7 +811,7 @@
         assert torch.allclose(y, expected)
 
     @FIXME_jacrev_only
-    def test_hessian_simple(self, device, jacapi):
+    def test_nested_jac_simple(self, device, jacapi):
         def foo(x):
             return x.sin().sum()
 
@@ -1006,6 +1006,14 @@
         with self.assertRaisesRegex(RuntimeError, "must be int"):
             z = jacapi(torch.multiply, argnums=(1, 0.0))(x, x)
 
+    @unittest.expectedFailure
+    def test_hessian_simple(self, device):
+        def f(x):
+            return x.sin()
+
+        x = torch.randn(3, device=device)
+        result = hessian(f)(x)
+
 class TestJvp(TestCase):
     def test_inplace_on_captures(self, device):
         x = torch.tensor([1., 2., 3.], device=device)