[functorch] Synchronize jacobians_hessians tutorial (pytorch/functorch#555)

diff --git a/functorch/notebooks/colab/jacobians_hessians_colab.ipynb b/functorch/notebooks/colab/jacobians_hessians_colab.ipynb
index 20a4fe7..17fd856 100644
--- a/functorch/notebooks/colab/jacobians_hessians_colab.ipynb
+++ b/functorch/notebooks/colab/jacobians_hessians_colab.ipynb
@@ -3,7 +3,7 @@
   "nbformat_minor": 0,
   "metadata": {
     "colab": {
-      "name": "functorch_hessians_colab.ipynb",
+      "name": "jacobians_hessians_colab.ipynb",
       "provenance": [],
       "collapsed_sections": [
         "0I5Mm2q2f5aw"
@@ -42,8 +42,8 @@
     {
       "cell_type": "markdown",
       "source": [
-        "**Getting setup** - running functorch currently requires Pytorch Nightly.  \n",
-        "Thus we'll go through a pytorch nightly install and build functorch. \n",
+        "**Getting setup** - running functorch currently requires at least PyTorch 1.11.  \n",
+        "Thus we'll go through a pytorch 1.11 install and build functorch. \n",
         "\n",
         "After that and a restart, you'll be ready to run the tutorial here on colab."
       ],
@@ -70,7 +70,7 @@
       "metadata": {
         "id": "MklsA-KRhZKC"
       },
-      "execution_count": 2,
+      "execution_count": 7,
       "outputs": []
     },
     {
@@ -94,9 +94,9 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "HxidO4dpiPGi",
-        "outputId": "8a285bcd-a791-4d19-9a71-8e53d8325eba"
+        "outputId": "d6d31c17-02cf-427b-cae8-6994c57c2320"
       },
-      "execution_count": 3,
+      "execution_count": 8,
       "outputs": [
         {
           "output_type": "stream",
@@ -130,9 +130,9 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "NIoTNykP9xI5",
-        "outputId": "f8678909-b7f0-4c37-8186-a6ce38c0f483"
+        "outputId": "5cc5a77d-9696-4cde-a7e5-3d835058afee"
       },
-      "execution_count": 4,
+      "execution_count": 9,
       "outputs": [
         {
           "output_type": "stream",
@@ -148,7 +148,7 @@
     {
       "cell_type": "markdown",
       "source": [
-        "And install the relevant nightly version.  (this defaults to 11.1 Cuda which works on most colabs). "
+        "And install the relevant version. (this defaults to 11.1 Cuda which works on most colabs). "
       ],
       "metadata": {
         "id": "n-DFUwBVkHaX"
@@ -162,45 +162,59 @@
       "metadata": {
         "id": "BH5ffJBkkRR8"
       },
-      "execution_count": 5,
+      "execution_count": 10,
       "outputs": []
     },
     {
       "cell_type": "code",
       "source": [
-        "!pip install --pre torch -f https://download.pytorch.org/whl/nightly/{cuda_version}/torch_nightly.html --upgrade"
+        "!pip install --pre torch -f https://download.pytorch.org/whl/test/{cuda_version}/torch_test.html --upgrade"
       ],
       "metadata": {
         "colab": {
-          "base_uri": "https://localhost:8080/"
+          "base_uri": "https://localhost:8080/",
+          "height": 0
         },
         "id": "Bi2oymijkav5",
-        "outputId": "d78924ae-a04e-44ce-c28f-b7d20b9b2cfc"
+        "outputId": "25de4fb2-4424-452e-9b2a-1a32d72a9ea2"
       },
-      "execution_count": 6,
+      "execution_count": 11,
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            "Looking in links: https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html\n",
+            "Looking in links: https://download.pytorch.org/whl/test/cu111/torch_test.html\n",
             "Collecting torch\n",
-            "  Downloading https://download.pytorch.org/whl/nightly/cu111/torch-1.12.0.dev20220216%2Bcu111-cp37-cp37m-linux_x86_64.whl (1922.9 MB)\n",
-            "\u001b[K     |█████████████▉                  | 834.1 MB 60.9 MB/s eta 0:00:18tcmalloc: large alloc 1147494400 bytes == 0x559246f1a000 @  0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f\n",
-            "\u001b[K     |█████████████████▋              | 1055.7 MB 1.3 MB/s eta 0:11:14tcmalloc: large alloc 1434370048 bytes == 0x55928b570000 @  0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f\n",
-            "\u001b[K     |██████████████████████▎         | 1336.2 MB 1.3 MB/s eta 0:07:18tcmalloc: large alloc 1792966656 bytes == 0x5592103a2000 @  0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f\n",
-            "\u001b[K     |████████████████████████████▏   | 1691.1 MB 1.3 MB/s eta 0:03:02tcmalloc: large alloc 2241208320 bytes == 0x55927b18a000 @  0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f\n",
-            "\u001b[K     |████████████████████████████████| 1922.9 MB 1.2 MB/s eta 0:00:01tcmalloc: large alloc 1922924544 bytes == 0x559300aec000 @  0x7fa0182cd1e7 0x55920e84e5d7 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e81c9da 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f\n",
-            "tcmalloc: large alloc 2403655680 bytes == 0x5593734c4000 @  0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e81c9da 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81d151\n",
-            "\u001b[K     |████████████████████████████████| 1922.9 MB 4.7 kB/s \n",
+            "  Downloading https://download.pytorch.org/whl/test/cu111/torch-1.11.0%2Bcu111-cp37-cp37m-linux_x86_64.whl (1922.7 MB)\n",
+            "\u001b[K     |█████████████▉                  | 834.1 MB 1.8 MB/s eta 0:10:05tcmalloc: large alloc 1147494400 bytes == 0x56452bfee000 @  0x7f6180125615 0x564529eee3bc 0x564529fcf18a 0x564529ef11cd 0x564529fe3b3d 0x564529f65458 0x564529f6002f 0x564529ef2aba 0x564529f652c0 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529fe4986 0x564529f61350 0x564529fe4986 0x564529f61350 0x564529fe4986 0x564529f61350 0x564529ef2f19 0x564529f36a79 0x564529ef1b32 0x564529f651dd 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529f6002f 0x564529ef2aba 0x564529f60eae 0x564529ef29da 0x564529f61108 0x564529f6002f\n",
+            "\u001b[K     |█████████████████▋              | 1055.7 MB 1.6 MB/s eta 0:09:07tcmalloc: large alloc 1434370048 bytes == 0x564570644000 @  0x7f6180125615 0x564529eee3bc 0x564529fcf18a 0x564529ef11cd 0x564529fe3b3d 0x564529f65458 0x564529f6002f 0x564529ef2aba 0x564529f652c0 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529fe4986 0x564529f61350 0x564529fe4986 0x564529f61350 0x564529fe4986 0x564529f61350 0x564529ef2f19 0x564529f36a79 0x564529ef1b32 0x564529f651dd 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529f6002f 0x564529ef2aba 0x564529f60eae 0x564529ef29da 0x564529f61108 0x564529f6002f\n",
+            "\u001b[K     |██████████████████████▎         | 1336.2 MB 42.7 MB/s eta 0:00:14tcmalloc: large alloc 1792966656 bytes == 0x5645c5e30000 @  0x7f6180125615 0x564529eee3bc 0x564529fcf18a 0x564529ef11cd 0x564529fe3b3d 0x564529f65458 0x564529f6002f 0x564529ef2aba 0x564529f652c0 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529fe4986 0x564529f61350 0x564529fe4986 0x564529f61350 0x564529fe4986 0x564529f61350 0x564529ef2f19 0x564529f36a79 0x564529ef1b32 0x564529f651dd 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529f6002f 0x564529ef2aba 0x564529f60eae 0x564529ef29da 0x564529f61108 0x564529f6002f\n",
+            "\u001b[K     |████████████████████████████▏   | 1691.1 MB 1.3 MB/s eta 0:03:02tcmalloc: large alloc 2241208320 bytes == 0x56452bfee000 @  0x7f6180125615 0x564529eee3bc 0x564529fcf18a 0x564529ef11cd 0x564529fe3b3d 0x564529f65458 0x564529f6002f 0x564529ef2aba 0x564529f652c0 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529fe4986 0x564529f61350 0x564529fe4986 0x564529f61350 0x564529fe4986 0x564529f61350 0x564529ef2f19 0x564529f36a79 0x564529ef1b32 0x564529f651dd 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529f6002f 0x564529ef2aba 0x564529f60eae 0x564529ef29da 0x564529f61108 0x564529f6002f\n",
+            "\u001b[K     |████████████████████████████████| 1922.7 MB 1.1 MB/s eta 0:00:01tcmalloc: large alloc 1922670592 bytes == 0x5645b1950000 @  0x7f61801241e7 0x564529f245d7 0x564529eee3bc 0x564529fcf18a 0x564529ef11cd 0x564529fe3b3d 0x564529f65458 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529ef29da 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529f6002f\n",
+            "tcmalloc: large alloc 2403344384 bytes == 0x5646242ea000 @  0x7f6180125615 0x564529eee3bc 0x564529fcf18a 0x564529ef11cd 0x564529fe3b3d 0x564529f65458 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61108 0x564529ef29da 0x564529f61108 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529f6002f 0x564529ef2aba 0x564529f61cd4 0x564529f6002f 0x564529ef3151\n",
+            "\u001b[K     |████████████████████████████████| 1922.7 MB 3.0 kB/s \n",
             "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (3.10.0.2)\n",
             "Installing collected packages: torch\n",
             "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
-            "torchvision 0.11.1+cu111 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\n",
-            "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\n",
-            "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\u001b[0m\n",
-            "Successfully installed torch-1.12.0.dev20220216+cu111\n"
+            "torchvision 0.11.1+cu111 requires torch==1.10.0, but you have torch 1.11.0+cu111 which is incompatible.\n",
+            "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.11.0+cu111 which is incompatible.\n",
+            "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.11.0+cu111 which is incompatible.\u001b[0m\n",
+            "Successfully installed torch-1.11.0+cu111\n"
           ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "application/vnd.colab-display-data+json": {
+              "pip_warning": {
+                "packages": [
+                  "torch"
+                ]
+              }
+            }
+          },
+          "metadata": {}
         }
       ]
     },
@@ -223,9 +237,9 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "VEJs4UEOkay0",
-        "outputId": "ca624c7e-dfe6-4233-a51d-b8ad5188a0c4"
+        "outputId": "0ae3eb7d-4227-4464-de75-642571dafd79"
       },
-      "execution_count": 7,
+      "execution_count": 12,
       "outputs": [
         {
           "output_type": "stream",
@@ -233,7 +247,7 @@
           "text": [
             "Collecting ninja\n",
             "  Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)\n",
-            "\u001b[?25l\r\u001b[K     |███                             | 10 kB 19.8 MB/s eta 0:00:01\r\u001b[K     |██████                          | 20 kB 8.7 MB/s eta 0:00:01\r\u001b[K     |█████████                       | 30 kB 7.4 MB/s eta 0:00:01\r\u001b[K     |████████████▏                   | 40 kB 6.8 MB/s eta 0:00:01\r\u001b[K     |███████████████▏                | 51 kB 4.0 MB/s eta 0:00:01\r\u001b[K     |██████████████████▏             | 61 kB 4.2 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▏          | 71 kB 4.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▎       | 81 kB 4.8 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▎    | 92 kB 3.7 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▎ | 102 kB 4.0 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 108 kB 4.0 MB/s \n",
+            "\u001b[?25l\r\u001b[K     |███                             | 10 kB 19.7 MB/s eta 0:00:01\r\u001b[K     |██████                          | 20 kB 8.3 MB/s eta 0:00:01\r\u001b[K     |█████████                       | 30 kB 5.4 MB/s eta 0:00:01\r\u001b[K     |████████████▏                   | 40 kB 5.1 MB/s eta 0:00:01\r\u001b[K     |███████████████▏                | 51 kB 3.8 MB/s eta 0:00:01\r\u001b[K     |██████████████████▏             | 61 kB 4.5 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▏          | 71 kB 4.6 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▎       | 81 kB 4.6 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▎    | 92 kB 5.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▎ | 102 kB 4.9 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 108 kB 4.9 MB/s \n",
             "\u001b[?25hInstalling collected packages: ninja\n",
             "Successfully installed ninja-1.10.2.3\n"
           ]
@@ -251,36 +265,39 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 8,
+      "execution_count": 13,
       "metadata": {
         "id": "UtBgzUPDfIQg",
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "2c103dc1-7123-4320-b012-652b2d27d0a8"
+        "outputId": "26e76a73-a2d8-46c9-bce3-c272b1fed450"
       },
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            "Collecting git+https://github.com/pytorch/functorch.git\n",
-            "  Cloning https://github.com/pytorch/functorch.git to /tmp/pip-req-build-htz8t0jk\n",
-            "  Running command git clone -q https://github.com/pytorch/functorch.git /tmp/pip-req-build-htz8t0jk\n",
-            "Requirement already satisfied: torch>=1.10.0.dev in /usr/local/lib/python3.7/dist-packages (from functorch==0.2.0a0+2cf76f3) (1.12.0.dev20220216+cu111)\n",
-            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.10.0.dev->functorch==0.2.0a0+2cf76f3) (3.10.0.2)\n",
+            "Collecting git+https://github.com/pytorch/functorch.git@release/0.1\n",
+            "  Cloning https://github.com/pytorch/functorch.git (to revision release/0.1) to /tmp/pip-req-build-kr52k4jf\n",
+            "  Running command git clone -q https://github.com/pytorch/functorch.git /tmp/pip-req-build-kr52k4jf\n",
+            "  Running command git checkout -b release/0.1 --track origin/release/0.1\n",
+            "  Switched to a new branch 'release/0.1'\n",
+            "  Branch 'release/0.1' set up to track remote branch 'release/0.1' from 'origin'.\n",
+            "Requirement already satisfied: torch<1.12,>=1.11 in /usr/local/lib/python3.7/dist-packages (from functorch==0.1.0) (1.11.0+cu111)\n",
+            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch<1.12,>=1.11->functorch==0.1.0) (3.10.0.2)\n",
             "Building wheels for collected packages: functorch\n",
             "  Building wheel for functorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
-            "  Created wheel for functorch: filename=functorch-0.2.0a0+2cf76f3-cp37-cp37m-linux_x86_64.whl size=21457003 sha256=be6cfe683ff09d15bac0a66e14d6d2d476a15a18273ceb0fc64a1d13fa0e37d7\n",
-            "  Stored in directory: /tmp/pip-ephem-wheel-cache-zrhpj6mp/wheels/b0/a9/4a/ffec50dda854c8d9f2ba21e4ffc0f2489ea97946cb1102c5ab\n",
+            "  Created wheel for functorch: filename=functorch-0.1.0-cp37-cp37m-linux_x86_64.whl size=20865554 sha256=abc550fabeca713b81ad728c64b7300a755b721f7018b07359a0b0c032591cbc\n",
+            "  Stored in directory: /tmp/pip-ephem-wheel-cache-py1n1nsx/wheels/36/1e/b5/1f1fa47f6155cd0302354303feaf209e777785883d94956873\n",
             "Successfully built functorch\n",
             "Installing collected packages: functorch\n",
-            "Successfully installed functorch-0.2.0a0+2cf76f3\n"
+            "Successfully installed functorch-0.1.0\n"
           ]
         }
       ],
       "source": [
-        "!pip install --user \"git+https://github.com/pytorch/functorch.git\""
+        "!pip install --user \"git+https://github.com/pytorch/functorch.git@release/0.1\""
       ]
     },
     {
@@ -302,9 +319,9 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "cf9301e6-13ce-4cb6-aefe-c43b14561ec8"
+        "outputId": "a59e08c1-7206-4439-e08e-c4b8ff004f49"
       },
-      "execution_count": 9,
+      "execution_count": 14,
       "outputs": [
         {
           "output_type": "stream",
@@ -336,20 +353,29 @@
       "metadata": {
         "id": "SvUfIxRyeAaL"
       },
-      "execution_count": 2,
+      "execution_count": 1,
       "outputs": []
     },
     {
       "cell_type": "markdown",
       "source": [
-        "# Jacobians, hessians, and more: composing functorch transforms\n",
+        "# Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms\n",
         "\n",
-        "Computing jacobians or hessians are useful in a number of non-traditional deep learning models. \n",
+        "Computing quantities related to Jacobians or Hessians is useful in a number of non-traditional deep learning models. \n",
         "\n",
-        "It is difficult (or annoying) to compute these quantities efficiently using a standard autodiff system like PyTorch Autograd; functorch provides ways of computing various higher-order autodiff quantities efficiently.\n"
+        "It is difficult (or annoying) to compute these quantities efficiently using a standard autodiff system like PyTorch Autograd; functorch provides ways of computing various higher-order autodiff quantities efficiently."
       ],
       "metadata": {
-        "id": "nLdOLDH6m9oy"
+        "id": "OeTtrGkGfsE9"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Computing the Jacobian"
+      ],
+      "metadata": {
+        "id": "viWZDMQtflUG"
       }
     },
     {
@@ -362,21 +388,19 @@
         "_ = torch.manual_seed(0)\n"
       ],
       "metadata": {
-        "id": "vUsb3VfexJrY"
+        "id": "w_IinyjzflUH"
       },
-      "execution_count": 3,
+      "execution_count": 1,
       "outputs": []
     },
     {
       "cell_type": "markdown",
       "source": [
-        "**Comparing functorch vs the naive approach:**\n",
-        "\n",
         "Let’s start with a function that we’d like to compute the jacobian of.  This is a simple linear function with non-linear activation.\n",
         "\n"
       ],
       "metadata": {
-        "id": "aNkX6lFIxzcm"
+        "id": "cibF_PEYflUH"
       }
     },
     {
@@ -386,9 +410,9 @@
         "    return F.linear(x, weight, bias).tanh()"
       ],
       "metadata": {
-        "id": "C3a9_clvyPho"
+        "id": "qhcD9hWYflUH"
       },
-      "execution_count": 4,
+      "execution_count": 2,
       "outputs": []
     },
     {
@@ -398,7 +422,7 @@
         "\n"
       ],
       "metadata": {
-        "id": "mFJDWMM9yaYZ"
+        "id": "G8tqQrO_flUH"
       }
     },
     {
@@ -410,9 +434,9 @@
         "x = torch.randn(D) # feature vector"
       ],
       "metadata": {
-        "id": "WiSMupvCyecd"
+        "id": "FZ4uJfZGflUH"
       },
-      "execution_count": 55,
+      "execution_count": 3,
       "outputs": []
     },
     {
@@ -424,7 +448,7 @@
         "by using a different unit vector each time."
       ],
       "metadata": {
-        "id": "cTgIIZ9Wyih8"
+        "id": "uMAW-ArQflUH"
       }
     },
     {
@@ -436,9 +460,9 @@
         "    return torch.stack(jacobian_rows)"
       ],
       "metadata": {
-        "id": "ItURFU3M-p98"
+        "id": "z-BJPtbpflUI"
       },
-      "execution_count": 56,
+      "execution_count": 4,
       "outputs": []
     },
     {
@@ -456,20 +480,18 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "id": "1gehVA1c-BHd",
-        "outputId": "81454f59-59e6-470f-e6a6-92c671137ad8"
+        "outputId": "f1f1ec12-56ef-40f7-8c3c-cbad7bf86644",
+        "id": "zuWGSXspflUI"
       },
-      "execution_count": 57,
+      "execution_count": 5,
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
             "torch.Size([16, 16])\n",
-            "tensor([-5.9625e-06,  1.9876e-05,  7.0103e-06,  1.1086e-05, -1.1939e-05,\n",
-            "         1.0975e-05,  8.3484e-06, -1.4599e-06, -1.9937e-05,  1.4976e-05,\n",
-            "        -7.4515e-06, -2.2042e-06,  5.0195e-07,  1.5267e-05, -7.8227e-06,\n",
-            "         6.9435e-06])\n"
+            "tensor([-0.5956, -0.6096, -0.1326, -0.2295,  0.4490,  0.3661, -0.1672, -1.1190,\n",
+            "         0.1705, -0.6683,  0.1851,  0.1630,  0.0634,  0.6547,  0.5908, -0.1308])\n"
           ]
         }
       ]
@@ -482,7 +504,7 @@
         "\n"
       ],
       "metadata": {
-        "id": "BEZaNt1d_bc1"
+        "id": "mxlEOUieflUI"
       }
     },
     {
@@ -498,15 +520,15 @@
         "assert torch.allclose(ft_jacobian, jacobian)"
       ],
       "metadata": {
-        "id": "Zfnn2C2g-6Fb"
+        "id": "DeF6uy4WflUI"
       },
-      "execution_count": 58,
+      "execution_count": 6,
       "outputs": []
     },
     {
       "cell_type": "markdown",
       "source": [
-        "In another tutorial a composition of reverse-mode AD and vmap gave us per-sample-gradients. \n",
+        "In future tutorial a composition of reverse-mode AD and vmap will give us per-sample-gradients. \n",
         "In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! \n",
         "Various compositions of vmap and autodiff transforms can give us different interesting quantities.\n",
         "\n",
@@ -514,7 +536,7 @@
         "\n"
       ],
       "metadata": {
-        "id": "4gDqecJbyVgt"
+        "id": "Hy4REmwDflUI"
       }
     },
     {
@@ -528,9 +550,9 @@
         "assert torch.allclose(ft_jacobian, jacobian)"
       ],
       "metadata": {
-        "id": "t0EfptYTAO47"
+        "id": "Rt7i6_YlflUI"
       },
-      "execution_count": 59,
+      "execution_count": 7,
       "outputs": []
     },
     {
@@ -545,7 +567,7 @@
         "\n"
       ],
       "metadata": {
-        "id": "PDEIxPZoxUb7"
+        "id": "JYe2H1UcflUJ"
       }
     },
     {
@@ -554,7 +576,7 @@
         "Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:"
       ],
       "metadata": {
-        "id": "gHxrra_jA3ur"
+        "id": "i_143LZwflUJ"
       }
     },
     {
@@ -571,9 +593,9 @@
         "  "
       ],
       "metadata": {
-        "id": "rENMCuodBIef"
+        "id": "II7r6jBtflUJ"
       },
-      "execution_count": 60,
+      "execution_count": 8,
       "outputs": []
     },
     {
@@ -582,7 +604,7 @@
         "And then run the performance comparison:"
       ],
       "metadata": {
-        "id": "IaPfXXHngmUG"
+        "id": "r4clPnPKflUJ"
       }
     },
     {
@@ -603,22 +625,22 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "id": "quEKCA2-Afdq",
-        "outputId": "72cf8a9f-759f-479e-9525-190da282b802"
+        "outputId": "cbf77a19-aac9-428d-eba1-74d337c53e49",
+        "id": "ZPtoxF6eflUJ"
       },
-      "execution_count": 61,
+      "execution_count": 9,
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            "<torch.utils.benchmark.utils.common.Measurement object at 0x7f682eb5a450>\n",
+            "<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a911b350>\n",
             "compute_jac(xp)\n",
-            "  2.04 ms\n",
+            "  2.25 ms\n",
             "  1 measurement, 500 runs , 1 thread\n",
-            "<torch.utils.benchmark.utils.common.Measurement object at 0x7f6733f08810>\n",
+            "<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a6a99d50>\n",
             "jacrev(predict, argnums=2)(weight, bias, x)\n",
-            "  810.29 us\n",
+            "  884.34 us\n",
             "  1 measurement, 500 runs , 1 thread\n"
           ]
         }
@@ -630,7 +652,7 @@
         "Lets do a relative performance comparison of the above with our get_perf function:"
       ],
       "metadata": {
-        "id": "5tY4c45fxVMi"
+        "id": "nGBBi4dZflUJ"
       }
     },
     {
@@ -642,16 +664,16 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "id": "gQ_iv23m97A7",
-        "outputId": "e9a44c9e-9ed7-41f5-dd3c-77f0b7599f06"
+        "outputId": "85d0bc5f-34aa-4826-f953-6c637404490c",
+        "id": "zqV2RzEXflUJ"
       },
-      "execution_count": 62,
+      "execution_count": 10,
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            " Performance delta: 60.3299 percent improvement with vmap \n"
+            " Performance delta: 60.7170 percent improvement with vmap \n"
           ]
         }
       ]
@@ -662,18 +684,19 @@
         "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input."
       ],
       "metadata": {
-        "id": "wtUxdj8gD1w7"
+        "id": "EQAB99EQflUJ"
       }
     },
     {
       "cell_type": "code",
       "source": [
-        "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) # note the change in input via argnums params of 0,1 to map to weight and bias"
+        "# note the change in input via argnums params of 0,1 to map to weight and bias\n",
+        "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)"
       ],
       "metadata": {
-        "id": "iKtvWR0n-b3E"
+        "id": "8UZpC8DnflUK"
       },
-      "execution_count": 63,
+      "execution_count": 11,
       "outputs": []
     },
     {
@@ -682,31 +705,25 @@
         "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n"
       ],
       "metadata": {
-        "id": "zKm1sgT0EPx8"
+        "id": "F3USYENIflUK"
       }
     },
     {
       "cell_type": "markdown",
       "source": [
         "We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: \n",
-        "\n",
         "- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. \n",
-        "\n",
         "- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. \n",
         "\n",
         "jacfwd and jacrev can be substituted for each other but they have different performance characteristics.\n",
         "\n",
-        "As a general rule of thumb, if you’re computing the jacobian of an 𝑅𝑁−>𝑅𝑀 function, and there are many more outputs than inputs (i.e. M > N) then jacfwd is preferred, otherwise use jacrev. \n",
+        "As a general rule of thumb, if you’re computing the jacobian of an $𝑅^N \\to R^M$ function, and there are many more outputs than inputs (i.e. $M > N$) then jacfwd is preferred, otherwise use jacrev. There are exceptions to this rule, but a non-rigorous argument for this follows:\n",
         "\n",
-        "There are exceptions to this rule, but a non-rigorous argument for this follows:\n",
-        "\n",
-        "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. \n",
-        "\n",
-        "The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n",
+        "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n",
         "\n"
       ],
       "metadata": {
-        "id": "LDqTlkfXEP0q"
+        "id": "V7B3vE8dflUK"
       }
     },
     {
@@ -715,9 +732,9 @@
         "from functorch import jacrev, jacfwd"
       ],
       "metadata": {
-        "id": "GrQG0lRoFML7"
+        "id": "k7Tok7m3flUK"
       },
-      "execution_count": 64,
+      "execution_count": 12,
       "outputs": []
     },
     {
@@ -727,7 +744,7 @@
         "\n"
       ],
       "metadata": {
-        "id": "7QIZkss7FQhK"
+        "id": "YrV-gZAaflUL"
       }
     },
     {
@@ -753,26 +770,26 @@
         "print(f'jacrev time: {jacrev_timing}')\n"
       ],
       "metadata": {
-        "id": "N0M0i6xf-nBt",
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "add77ba5-a947-4bb7-8c36-a2e742086cab"
+        "outputId": "dd882726-9723-47c0-a72f-3c7835a85aa1",
+        "id": "m5j-4hSxflUL"
       },
-      "execution_count": 65,
+      "execution_count": 13,
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
             "torch.Size([2048, 32])\n",
-            "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f6734014c90>\n",
+            "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d792d0>\n",
             "jacfwd(predict, argnums=2)(weight, bias, x)\n",
-            "  1.18 ms\n",
+            "  1.32 ms\n",
             "  1 measurement, 500 runs , 1 thread\n",
-            "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67344b8650>\n",
+            "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a4dee450>\n",
             "jacrev(predict, argnums=2)(weight, bias, x)\n",
-            "  14.98 ms\n",
+            "  12.46 ms\n",
             "  1 measurement, 500 runs , 1 thread\n"
           ]
         }
@@ -784,7 +801,7 @@
         "and then do a relative benchmark:"
       ],
       "metadata": {
-        "id": "UEh5jIK2FpBJ"
+        "id": "k_Sg-4tVflUL"
       }
     },
     {
@@ -796,16 +813,16 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "id": "YmEqbvDeFrtt",
-        "outputId": "a1eae78f-9a08-4b14-d96c-18483070b9d1"
+        "outputId": "3a6586a1-269d-46d8-d119-e24f6d46277f",
+        "id": "_4T96zGjflUL"
       },
-      "execution_count": 67,
+      "execution_count": 14,
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            " Performance delta: 1170.0622 percent improvement with jacrev \n"
+            " Performance delta: 842.8274 percent improvement with jacrev \n"
           ]
         }
       ]
@@ -816,7 +833,7 @@
         "and now the reverse - more outputs (M) than inputs (N):"
       ],
       "metadata": {
-        "id": "aZAXlFUNFxAY"
+        "id": "RCDPot1yflUL"
       }
     },
     {
@@ -841,22 +858,22 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "id": "jHVkLcr9_SIe",
-        "outputId": "86f7fb41-d14b-4b5f-9086-709278c99f67"
+        "outputId": "913e9ccd-3d4f-472a-a749-19cee36d0a16",
+        "id": "_DRFqzqZflUM"
       },
-      "execution_count": 71,
+      "execution_count": 15,
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67340145d0>\n",
+            "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d64790>\n",
             "jacfwd(predict, argnums=2)(weight, bias, x)\n",
-            "  8.99 ms\n",
+            "  7.99 ms\n",
             "  1 measurement, 500 runs , 1 thread\n",
-            "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67344b8110>\n",
+            "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d67b50>\n",
             "jacrev(predict, argnums=2)(weight, bias, x)\n",
-            "  1.03 ms\n",
+            "  1.09 ms\n",
             "  1 measurement, 500 runs , 1 thread\n"
           ]
         }
@@ -868,7 +885,7 @@
         "and a relative perf comparison:"
       ],
       "metadata": {
-        "id": "I47HDJBwGAM4"
+        "id": "5SRbMCNsflUM"
       }
     },
     {
@@ -877,19 +894,19 @@
         "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")"
       ],
       "metadata": {
-        "id": "jPdAcIgu1es-",
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "ceebc303-9531-4903-84a9-0cbefe4bc318"
+        "outputId": "c282ce25-4f6e-44cd-aed7-60f6f5010e5b",
+        "id": "uF_9GaoiflUM"
       },
-      "execution_count": 72,
+      "execution_count": 16,
       "outputs": [
         {
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            " Performance delta: 775.3424 percent improvement with jacfwd \n"
+            " Performance delta: 635.2095 percent improvement with jacfwd \n"
           ]
         }
       ]
@@ -900,31 +917,31 @@
         "## Hessian computation with functorch.hessian\n"
       ],
       "metadata": {
-        "id": "NRr6l4u0obus"
+        "id": "J29FQaBQflUM"
       }
     },
     {
       "cell_type": "markdown",
       "source": [
-        "We offer a convenience API to compute hessians: functorch.hessian. \n",
+        "We offer a convenience API to compute hessians: `functorch.hessian`. \n",
         "Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).\n",
         "\n",
         "This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. \n",
-        "Indeed, under the hood, hessian(f) is simply jacfwd(jacrev(f)).\n",
+        "Indeed, under the hood, `hessian(f)` is simply `jacfwd(jacrev(f))`.\n",
         "\n"
       ],
       "metadata": {
-        "id": "k0vSE1C1GeUJ"
+        "id": "My4DPH97flUM"
       }
     },
     {
       "cell_type": "markdown",
       "source": [
-        "Note: to boost performance: depending on your model, you may also want to use jacfwd(jacfwd(f)) or jacrev(jacrev(f)) instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n",
+        "Note: to boost performance: depending on your model, you may also want to use `jacfwd(jacfwd(f))` or `jacrev(jacrev(f))` instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n",
         "\n"
       ],
       "metadata": {
-        "id": "NEu1Zfo2G9fa"
+        "id": "FJt038l5flUM"
       }
     },
     {
@@ -944,9 +961,9 @@
         "#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)\n"
       ],
       "metadata": {
-        "id": "tYhxPLb-Gdh-"
+        "id": "jEqr2ywZflUM"
       },
-      "execution_count": 87,
+      "execution_count": 17,
       "outputs": []
     },
     {
@@ -955,7 +972,7 @@
         "Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())"
       ],
       "metadata": {
-        "id": "Qm_TPCCiso9u"
+        "id": "n9BHcICQflUN"
       }
     },
     {
@@ -967,10 +984,10 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "id": "uL23RF5UrroT",
-        "outputId": "48327575-6a69-4e3c-898d-06b4792f44ca"
+        "outputId": "e457e3bc-f085-4f90-966d-f98893b98ea8",
+        "id": "eHiWRkjJflUN"
       },
-      "execution_count": 89,
+      "execution_count": 18,
       "outputs": [
         {
           "output_type": "execute_result",
@@ -980,7 +997,7 @@
             ]
           },
           "metadata": {},
-          "execution_count": 89
+          "execution_count": 18
         }
       ]
     },
@@ -990,23 +1007,18 @@
         "## Batch Jacobian and Batch Hessian\n"
       ],
       "metadata": {
-        "id": "9xBE48HXIOOj"
+        "id": "Gjt1RO8HflUN"
       }
     },
     {
       "cell_type": "markdown",
       "source": [
-        "In the above examples we’ve been operating with a single feature vector. \n",
+        "In the above examples we’ve been operating with a single feature vector. In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. That is, given a batch of inputs of shape `(B, N)` and a function that goes from $R^N \\to R^M$, we would like a Jacobian of shape `(B, M, N)`. \n",
         "\n",
-        "In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. \n",
-        "\n",
-        "That is, given a batch of inputs of shape (B, N) and a function that goes from R^N -> R^M, we would like a Jacobian of shape (B, M, N). \n",
-        "\n",
-        "The easiest way to do this is to use vmap:\n",
-        "\n"
+        "The easiest way to do this is to use vmap:"
       ],
       "metadata": {
-        "id": "sJVzGqnEIhJA"
+        "id": "RjIzdoQNflUN"
       }
     },
     {
@@ -1024,13 +1036,13 @@
         "x = torch.randn(batch_size, Din)"
       ],
       "metadata": {
-        "id": "gEEWzX2QndqN",
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "c55e07d8-e3b6-40f6-f8f4-39014ff7d9b9"
+        "outputId": "561eb618-e00f-40d5-bd99-fa51ab82051f",
+        "id": "B1eoEO4UflUN"
       },
-      "execution_count": 91,
+      "execution_count": 19,
       "outputs": [
         {
           "output_type": "stream",
@@ -1048,9 +1060,9 @@
         "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)"
       ],
       "metadata": {
-        "id": "khYtmCqJn1h-"
+        "id": "nZ_V02NhflUN"
       },
-      "execution_count": 92,
+      "execution_count": 20,
       "outputs": []
     },
     {
@@ -1060,7 +1072,7 @@
         "\n"
       ],
       "metadata": {
-        "id": "rXE9tY05JHaJ"
+        "id": "_OLDiY3MflUN"
       }
     },
     {
@@ -1073,21 +1085,21 @@
         "assert torch.allclose(batch_jacobian0, batch_jacobian1)"
       ],
       "metadata": {
-        "id": "eohigCobop4R"
+        "id": "_QH4hD8PflUO"
       },
-      "execution_count": 93,
+      "execution_count": 21,
       "outputs": []
     },
     {
       "cell_type": "markdown",
       "source": [
-        "If you instead have a function that goes from 𝑅𝑁−>𝑅𝑀 but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n",
+        "If you instead have a function that goes from $𝑅^𝑁 \\to 𝑅^𝑀$ but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n",
         "\n",
         "Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.\n",
         "\n"
       ],
       "metadata": {
-        "id": "O3AGffymp_39"
+        "id": "eUjw65cCflUO"
       }
     },
     {
@@ -1102,10 +1114,10 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "id": "HYddoLSTfi-g",
-        "outputId": "5ddbd9b0-57f8-4ac4-a399-a6bdadfaa167"
+        "outputId": "f3135cfa-e9e5-4f18-8cb7-0655e8a37cb5",
+        "id": "3vAyQjMsflUO"
       },
-      "execution_count": 95,
+      "execution_count": 22,
       "outputs": [
         {
           "output_type": "execute_result",
@@ -1115,19 +1127,97 @@
             ]
           },
           "metadata": {},
-          "execution_count": 95
+          "execution_count": 22
         }
       ]
     },
     {
-      "cell_type": "code",
+      "cell_type": "markdown",
       "source": [
-        ""
+        "## Computing Hessian-vector products\n",
+        "\n",
+        "The naive way to compute a Hessian-vector product (hvp) is to materialize the full Hessian and perform a dot-product with a vector. We can do better: it turns out we don't need to materialize the full Hessian to do this. We'll go through two (of many) different strategies to compute Hessian-vector products:\n",
+        "- composing reverse-mode AD with reverse-mode AD\n",
+        "- composing reverse-mode AD with forward-mode AD\n",
+        "\n",
+        "Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn't need to construct an Autograd graph and save intermediates for backward:"
       ],
       "metadata": {
-        "id": "22e5fo2jqANi"
+        "id": "Wa8E48sQgpkb"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from functorch import jvp, grad, vjp\n",
+        "\n",
+        "def hvp(f, primals, tangents):\n",
+        "  return jvp(grad(f), primals, tangents)[1]"
+      ],
+      "metadata": {
+        "id": "trw6WbAth6BM"
       },
-      "execution_count": null,
+      "execution_count": 23,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Here's some sample usage."
+      ],
+      "metadata": {
+        "id": "DQMpRo6nitfr"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def f(x):\n",
+        "  return x.sin().sum()\n",
+        "\n",
+        "x = torch.randn(2048)\n",
+        "tangent = torch.randn(2048)\n",
+        "\n",
+        "result = hvp(f, (x,), (tangent,))"
+      ],
+      "metadata": {
+        "id": "sPwg8SOdiVAK"
+      },
+      "execution_count": 24,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "If PyTorch forward-AD does not have coverage for your operations, then we can instead compose reverse-mode AD with reverse-mode AD:"
+      ],
+      "metadata": {
+        "id": "zGvUIcB0j1Ez"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def hvp_revrev(f, primals, tangents):\n",
+        "  _, vjp_fn = vjp(grad(f), *primals)\n",
+        "  return vjp_fn(*tangents)"
+      ],
+      "metadata": {
+        "id": "mdDFZdlekAOK"
+      },
+      "execution_count": 25,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\n",
+        "assert torch.allclose(result, result_hvp_revrev[0])"
+      ],
+      "metadata": {
+        "id": "_CuCk9X0lW7C"
+      },
+      "execution_count": 26,
       "outputs": []
     }
   ]
diff --git a/functorch/notebooks/jacobians_hessians.ipynb b/functorch/notebooks/jacobians_hessians.ipynb
index 0a60ff9..172479a 100644
--- a/functorch/notebooks/jacobians_hessians.ipynb
+++ b/functorch/notebooks/jacobians_hessians.ipynb
@@ -1,489 +1,952 @@
 {
- "cells": [
-  {
-   "cell_type": "markdown",
-   "id": "98c5346d-c11a-4be1-8a20-447d7390fdd9",
-   "metadata": {},
-   "source": [
-    "# Jacobians, hessians, and more: composing functorch transforms\n",
-    "\n",
-    "<a href=\"https://colab.research.google.com/github/pytorch/functorch/blob/main/notebooks/colab/jacobians_hessians_colab.ipynb\">\n",
-    "  <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
-    "</a>\n",
-    "\n",
-    "Computing jacobians or hessians are useful in a number of non-traditional\n",
-    "deep learning models. It is difficult (or annoying) to compute these quantities\n",
-    "efficiently using a standard autodiff system like PyTorch Autograd; functorch\n",
-    "provides ways of computing various higher-order autodiff quantities efficiently."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "93379d25",
-   "metadata": {},
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "id": "a1aabaa9-6e86-4717-b645-b979a6b980a6",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch\n",
-    "import torch.nn as nn\n",
-    "import torch.nn.functional as F\n",
-    "from functools import partial\n",
-    "_ = torch.manual_seed(0)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "f033af1c-7d5d-4acf-a3b0-aa53c38f07c8",
-   "metadata": {},
-   "source": [
-    "## Setup: Comparing functorch vs the naive approach\n",
-    "\n",
-    "Let's start with a function that we'd like to compute the jacobian of.\n",
-    "This is a simple linear function with non-linear activation."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "id": "6e056fcb-8cc3-4ea4-a0dd-0cb93beb07cc",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def predict(weight, bias, x):\n",
-    "    return F.linear(x, weight, bias).tanh()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "987040aa-65a9-40eb-ab17-fa8a5078a9f8",
-   "metadata": {},
-   "source": [
-    "Here's some dummy data: a weight, a bias, and a feature vector."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "id": "41c34504-5873-4861-b513-25bfc2e431b6",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "D = 16\n",
-    "weight = torch.randn(D, D)\n",
-    "bias = torch.randn(D)\n",
-    "x = torch.randn(D)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "1ceff130-0765-47e9-aa59-de66b23651e0",
-   "metadata": {},
-   "source": [
-    "Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.\n",
-    "PyTorch Autograd computes vector-Jacobian products. In order to compute the full\n",
-    "Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row\n",
-    "by using a different unit vector each time."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "id": "0b13ce38-1f4c-4b4e-b28c-94de44bf6cf1",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "xp = x.clone().requires_grad_()\n",
-    "unit_vectors = torch.eye(D)\n",
-    "\n",
-    "def compute_jac(xp):\n",
-    "    jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n",
-    "                     for vec in unit_vectors]\n",
-    "    return torch.stack(jacobian_rows)\n",
-    "\n",
-    "jacobian = compute_jac(xp)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "04fac9ba-a185-4478-bf29-1011c4e75ac0",
-   "metadata": {},
-   "source": [
-    "Instead of computing the jacobian row-by-row, we can use `vmap` to get rid\n",
-    "of the for-loop and vectorize the computation. We can't directly apply vmap\n",
-    "to PyTorch Autograd; instead, functorch provides a `vjp` transform:"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "id": "23e7861d-62a6-4b24-aac5-cca018002867",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from functorch import vmap, vjp\n",
-    "_, vjp_fn = vjp(partial(predict, weight, bias), x)\n",
-    "ft_jacobian, = vmap(vjp_fn)(unit_vectors)\n",
-    "assert torch.allclose(ft_jacobian, jacobian)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "98c5396e-83cd-4c10-9cdc-8989605ff178",
-   "metadata": {},
-   "source": [
-    "In another tutorial a composition of reverse-mode AD and vmap gave us\n",
-    "per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap\n",
-    "gives us Jacobian computation! Various compositions of vmap and autodiff\n",
-    "transforms can give us different interesting quantities.\n",
-    "\n",
-    "functorch provides `jacrev` as a convenience function that performs\n",
-    "the vmap-vjp composition to compute jacobians. `jacrev` accepts an argnums\n",
-    "argument that says which argument we would like to compute Jacobians with\n",
-    "respect to."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "id": "21d24a8c-69af-4c0e-b222-901c0f72182f",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from functorch import jacrev\n",
-    "ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n",
-    "assert torch.allclose(ft_jacobian, jacobian)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "8d55f7df-907e-4236-9104-c3d1b86ce935",
-   "metadata": {},
-   "source": [
-    "Let's compare the performance of the two ways to compute jacobian.\n",
-    "The functorch version is much faster (and becomes even faster the more outputs\n",
-    "there are). In general, we expect that vectorization via `vmap` can help\n",
-    "eliminate overhead and give better utilization of your hardware."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "id": "2f502dc0-2faf-42b3-a9f3-d77737e94abd",
-   "metadata": {},
-   "outputs": [
+  "cells": [
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "<torch.utils.benchmark.utils.common.Measurement object at 0x7f4462a8b3a0>\n",
-      "compute_jac(xp)\n",
-      "  1.08 ms\n",
-      "  1 measurement, 500 runs , 1 thread\n",
-      "<torch.utils.benchmark.utils.common.Measurement object at 0x7f4461e3ee20>\n",
-      "jacrev(predict, argnums=2)(weight, bias, x)\n",
-      "  361.07 us\n",
-      "  1 measurement, 500 runs , 1 thread\n"
-     ]
-    }
-   ],
-   "source": [
-    "from torch.utils.benchmark import Timer\n",
-    "without_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\n",
-    "with_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
-    "print(without_vmap.timeit(500))\n",
-    "print(with_vmap.timeit(500))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "50fb7863-20f2-416b-9273-2c8ab2f7980c",
-   "metadata": {},
-   "source": [
-    "Furthemore, it's pretty easy to flip the problem around and say we want to compute\n",
-    "Jacobians of the parameters to our model (weight, bias) instead of the input."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "id": "14d12ec7-c40f-42b9-b549-d79fac60d541",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "0a27d8b9-448b-4b73-9e11-f443fde24f6f",
-   "metadata": {},
-   "source": [
-    "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n",
-    "\n",
-    "We offer two APIs to compute jacobians: jacrev and jacfwd:\n",
-    "- jacrev uses reverse-mode AD. As you saw above it is a composition of our\n",
-    "vjp and vmap transforms.\n",
-    "- jacfwd uses forward-mode AD. It is implemented as a composition of our\n",
-    "jvp and vmap transforms.\n",
-    "jacfwd and jacrev can be subsituted for each other and have different\n",
-    "performance characteristics.\n",
-    "\n",
-    "As a general rule of thumb, if you're computing the jacobian of an $R^N -> R^M$\n",
-    "function, if there are many more outputs than inputs (i.e. M > N) then jacfwd is\n",
-    "preferred, otherwise use jacrev. There are exceptions to this rule, but a\n",
-    "non-rigorous argument for this follows:\n",
-    "\n",
-    "In reverse-mode AD, we are computing the jacobian row-by-row, while in\n",
-    "forward-mode AD (which computes Jacobian-vector products), we are computing\n",
-    "it column-by-column. The Jacobian matrix has M rows and N columns, so if it is\n",
-    "taller or wider one way we may prefer the method that deals with fewer rows or\n",
-    "columns."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 9,
-   "id": "201c6470-0e35-4e5b-a3c0-cb602da8ca5e",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from functorch import jacrev, jacfwd"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "01bd6c60-5cb7-4628-a63c-2ebf9d9e1cff",
-   "metadata": {},
-   "source": [
-    "Benchmark with more inputs than outputs"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "id": "08a83c44-9e3f-4734-9d25-9844a5691791",
-   "metadata": {},
-   "outputs": [
+      "cell_type": "markdown",
+      "source": [
+        "# Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms\n",
+        "\n",
+        "<a href=\"https://colab.research.google.com/github/pytorch/functorch/blob/main/notebooks/colab/jacobians_hessians_colab.ipynb\">\n",
+        "  <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
+        "</a>\n",
+        "\n",
+        "Computing jacobians or hessians are useful in a number of non-traditional\n",
+        "deep learning models. It is difficult (or annoying) to compute these quantities\n",
+        "efficiently using a standard autodiff system like PyTorch Autograd; functorch\n",
+        "provides ways of computing various higher-order autodiff quantities efficiently."
+      ],
+      "metadata": {
+        "id": "zPbR6-eP51fe"
+      },
+      "id": "zPbR6-eP51fe"
+    },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f44629bc760>\n",
-      "jacfwd(predict, argnums=2)(weight, bias, x)\n",
-      "  603.91 us\n",
-      "  1 measurement, 500 runs , 1 thread\n",
-      "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f4461e1b8b0>\n",
-      "jacrev(predict, argnums=2)(weight, bias, x)\n",
-      "  5.25 ms\n",
-      "  1 measurement, 500 runs , 1 thread\n"
-     ]
-    }
-   ],
-   "source": [
-    "Din = 32\n",
-    "Dout = 2048\n",
-    "weight = torch.randn(Dout, Din)\n",
-    "bias = torch.randn(Dout)\n",
-    "x = torch.randn(Din)\n",
-    "\n",
-    "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
-    "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
-    "print(f'jacfwd time: {using_fwd.timeit(500)}')\n",
-    "print(f'jacrev time: {using_bwd.timeit(500)}')"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "66984d8c-135c-49aa-b177-655836f87e3c",
-   "metadata": {},
-   "source": [
-    "Benchmark with more outputs than inputs"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 11,
-   "id": "18029e78-c722-4163-af56-79b820074bfc",
-   "metadata": {},
-   "outputs": [
+      "cell_type": "markdown",
+      "source": [
+        "## Computing the Jacobian"
+      ],
+      "metadata": {
+        "id": "3kDj8fhn52j3"
+      },
+      "id": "3kDj8fhn52j3"
+    },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f4461e19a60>\n",
-      "jacfwd(predict, argnums=2)(weight, bias, x)\n",
-      "  5.33 ms\n",
-      "  1 measurement, 500 runs , 1 thread\n",
-      "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f4461e30ee0>\n",
-      "jacrev(predict, argnums=2)(weight, bias, x)\n",
-      "  424.29 us\n",
-      "  1 measurement, 500 runs , 1 thread\n"
-     ]
+      "cell_type": "code",
+      "source": [
+        "import torch\n",
+        "import torch.nn as nn\n",
+        "import torch.nn.functional as F\n",
+        "from functools import partial\n",
+        "_ = torch.manual_seed(0)"
+      ],
+      "metadata": {
+        "id": "w_IinyjzflUH"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "w_IinyjzflUH"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Let’s start with a function that we’d like to compute the jacobian of.  This is a simple linear function with non-linear activation.\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "cibF_PEYflUH"
+      },
+      "id": "cibF_PEYflUH"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def predict(weight, bias, x):\n",
+        "    return F.linear(x, weight, bias).tanh()"
+      ],
+      "metadata": {
+        "id": "qhcD9hWYflUH"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "qhcD9hWYflUH"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Let's add some dummy data:   a weight, a bias, and a feature vector x.\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "G8tqQrO_flUH"
+      },
+      "id": "G8tqQrO_flUH"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "D = 16\n",
+        "weight = torch.randn(D, D)\n",
+        "bias = torch.randn(D)\n",
+        "x = torch.randn(D) # feature vector"
+      ],
+      "metadata": {
+        "id": "FZ4uJfZGflUH"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "FZ4uJfZGflUH"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.\n",
+        "PyTorch Autograd computes vector-Jacobian products. In order to compute the full\n",
+        "Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row\n",
+        "by using a different unit vector each time."
+      ],
+      "metadata": {
+        "id": "uMAW-ArQflUH"
+      },
+      "id": "uMAW-ArQflUH"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def compute_jac(xp):\n",
+        "    jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n",
+        "                     for vec in unit_vectors]\n",
+        "    return torch.stack(jacobian_rows)"
+      ],
+      "metadata": {
+        "id": "z-BJPtbpflUI"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "z-BJPtbpflUI"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "xp = x.clone().requires_grad_()\n",
+        "unit_vectors = torch.eye(D)\n",
+        "\n",
+        "jacobian = compute_jac(xp)\n",
+        "\n",
+        "print(jacobian.shape)\n",
+        "print(jacobian[0])  # show first row"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "f1f1ec12-56ef-40f7-8c3c-cbad7bf86644",
+        "id": "zuWGSXspflUI"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "torch.Size([16, 16])\n",
+            "tensor([-0.5956, -0.6096, -0.1326, -0.2295,  0.4490,  0.3661, -0.1672, -1.1190,\n",
+            "         0.1705, -0.6683,  0.1851,  0.1630,  0.0634,  0.6547,  0.5908, -0.1308])\n"
+          ]
+        }
+      ],
+      "id": "zuWGSXspflUI"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Instead of computing the jacobian row-by-row, we can use vmap to get rid of the for-loop and vectorize the computation. \n",
+        "We can’t directly apply vmap to PyTorch Autograd; instead, functorch provides a vjp transform:\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "mxlEOUieflUI"
+      },
+      "id": "mxlEOUieflUI"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from functorch import vmap, vjp\n",
+        "\n",
+        "_, vjp_fn = vjp(partial(predict, weight, bias), x)\n",
+        "\n",
+        "ft_jacobian, = vmap(vjp_fn)(unit_vectors)\n",
+        "\n",
+        "# lets confirm both methods compute the same result\n",
+        "assert torch.allclose(ft_jacobian, jacobian)"
+      ],
+      "metadata": {
+        "id": "DeF6uy4WflUI"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "DeF6uy4WflUI"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "In future tutorial a composition of reverse-mode AD and vmap will give us per-sample-gradients. \n",
+        "In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! \n",
+        "Various compositions of vmap and autodiff transforms can give us different interesting quantities.\n",
+        "\n",
+        "functorch provides **jacrev** as a convenience function that performs the vmap-vjp composition to compute jacobians. **jacrev** accepts an argnums argument that says which argument we would like to compute Jacobians with respect to.\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "Hy4REmwDflUI"
+      },
+      "id": "Hy4REmwDflUI"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from functorch import jacrev\n",
+        "\n",
+        "ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n",
+        "\n",
+        "# confirm \n",
+        "assert torch.allclose(ft_jacobian, jacobian)"
+      ],
+      "metadata": {
+        "id": "Rt7i6_YlflUI"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "Rt7i6_YlflUI"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Let’s compare the performance of the two ways to compute the jacobian. The functorch version is much faster (and becomes even faster the more outputs there are). \n",
+        "\n",
+        "In general, we expect that vectorization via vmap can help eliminate overhead and give better utilization of your hardware.\n",
+        "\n",
+        "Vmap does this magic by pushing the outer loop down into the functions primitive operations in order to obtain better performance.\n",
+        "\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "JYe2H1UcflUJ"
+      },
+      "id": "JYe2H1UcflUJ"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:"
+      ],
+      "metadata": {
+        "id": "i_143LZwflUJ"
+      },
+      "id": "i_143LZwflUJ"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def get_perf(first, first_descriptor, second, second_descriptor):\n",
+        "  \"\"\"  takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n",
+        "  faster = second.times[0]\n",
+        "  slower = first.times[0]\n",
+        "  gain = (slower-faster)/slower\n",
+        "  if gain < 0: gain *=-1 \n",
+        "  final_gain = gain*100\n",
+        "  print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")"
+      ],
+      "metadata": {
+        "id": "II7r6jBtflUJ"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "II7r6jBtflUJ"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "And then run the performance comparison:"
+      ],
+      "metadata": {
+        "id": "r4clPnPKflUJ"
+      },
+      "id": "r4clPnPKflUJ"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from torch.utils.benchmark import Timer\n",
+        "\n",
+        "without_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\n",
+        "with_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+        "\n",
+        "no_vmap_timer = without_vmap.timeit(500)\n",
+        "with_vmap_timer = with_vmap.timeit(500)\n",
+        "\n",
+        "print(no_vmap_timer)\n",
+        "print(with_vmap_timer)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "cbf77a19-aac9-428d-eba1-74d337c53e49",
+        "id": "ZPtoxF6eflUJ"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a911b350>\n",
+            "compute_jac(xp)\n",
+            "  2.25 ms\n",
+            "  1 measurement, 500 runs , 1 thread\n",
+            "<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a6a99d50>\n",
+            "jacrev(predict, argnums=2)(weight, bias, x)\n",
+            "  884.34 us\n",
+            "  1 measurement, 500 runs , 1 thread\n"
+          ]
+        }
+      ],
+      "id": "ZPtoxF6eflUJ"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Lets do a relative performance comparison of the above with our get_perf function:"
+      ],
+      "metadata": {
+        "id": "nGBBi4dZflUJ"
+      },
+      "id": "nGBBi4dZflUJ"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "get_perf(no_vmap_timer, \"without vmap\",  with_vmap_timer, \"vmap\");"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "85d0bc5f-34aa-4826-f953-6c637404490c",
+        "id": "zqV2RzEXflUJ"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            " Performance delta: 60.7170 percent improvement with vmap \n"
+          ]
+        }
+      ],
+      "id": "zqV2RzEXflUJ"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input."
+      ],
+      "metadata": {
+        "id": "EQAB99EQflUJ"
+      },
+      "id": "EQAB99EQflUJ"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# note the change in input via argnums params of 0,1 to map to weight and bias\n",
+        "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)"
+      ],
+      "metadata": {
+        "id": "8UZpC8DnflUK"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "8UZpC8DnflUK"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n"
+      ],
+      "metadata": {
+        "id": "F3USYENIflUK"
+      },
+      "id": "F3USYENIflUK"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: \n",
+        "- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. \n",
+        "- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. \n",
+        "\n",
+        "jacfwd and jacrev can be substituted for each other but they have different performance characteristics.\n",
+        "\n",
+        "As a general rule of thumb, if you’re computing the jacobian of an $𝑅^N \\to R^M$ function, and there are many more outputs than inputs (i.e. $M > N$) then jacfwd is preferred, otherwise use jacrev. There are exceptions to this rule, but a non-rigorous argument for this follows:\n",
+        "\n",
+        "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "V7B3vE8dflUK"
+      },
+      "id": "V7B3vE8dflUK"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from functorch import jacrev, jacfwd"
+      ],
+      "metadata": {
+        "id": "k7Tok7m3flUK"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "k7Tok7m3flUK"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "First, let's benchmark with more inputs than outputs:\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "YrV-gZAaflUL"
+      },
+      "id": "YrV-gZAaflUL"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "Din = 32\n",
+        "Dout = 2048\n",
+        "weight = torch.randn(Dout, Din)\n",
+        "\n",
+        "bias = torch.randn(Dout)\n",
+        "x = torch.randn(Din)\n",
+        "\n",
+        "# remember the general rule about taller vs wider...here we have a taller matrix:\n",
+        "print(weight.shape)\n",
+        "\n",
+        "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+        "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+        "\n",
+        "jacfwd_timing = using_fwd.timeit(500)\n",
+        "jacrev_timing = using_bwd.timeit(500)\n",
+        "\n",
+        "print(f'jacfwd time: {jacfwd_timing}')\n",
+        "print(f'jacrev time: {jacrev_timing}')\n"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "dd882726-9723-47c0-a72f-3c7835a85aa1",
+        "id": "m5j-4hSxflUL"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "torch.Size([2048, 32])\n",
+            "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d792d0>\n",
+            "jacfwd(predict, argnums=2)(weight, bias, x)\n",
+            "  1.32 ms\n",
+            "  1 measurement, 500 runs , 1 thread\n",
+            "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a4dee450>\n",
+            "jacrev(predict, argnums=2)(weight, bias, x)\n",
+            "  12.46 ms\n",
+            "  1 measurement, 500 runs , 1 thread\n"
+          ]
+        }
+      ],
+      "id": "m5j-4hSxflUL"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "and then do a relative benchmark:"
+      ],
+      "metadata": {
+        "id": "k_Sg-4tVflUL"
+      },
+      "id": "k_Sg-4tVflUL"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "3a6586a1-269d-46d8-d119-e24f6d46277f",
+        "id": "_4T96zGjflUL"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            " Performance delta: 842.8274 percent improvement with jacrev \n"
+          ]
+        }
+      ],
+      "id": "_4T96zGjflUL"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "and now the reverse - more outputs (M) than inputs (N):"
+      ],
+      "metadata": {
+        "id": "RCDPot1yflUL"
+      },
+      "id": "RCDPot1yflUL"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "Din = 2048\n",
+        "Dout = 32\n",
+        "weight = torch.randn(Dout, Din)\n",
+        "bias = torch.randn(Dout)\n",
+        "x = torch.randn(Din)\n",
+        "\n",
+        "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+        "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+        "\n",
+        "jacfwd_timing = using_fwd.timeit(500)\n",
+        "jacrev_timing = using_bwd.timeit(500)\n",
+        "\n",
+        "print(f'jacfwd time: {jacfwd_timing}')\n",
+        "print(f'jacrev time: {jacrev_timing}')"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "913e9ccd-3d4f-472a-a749-19cee36d0a16",
+        "id": "_DRFqzqZflUM"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d64790>\n",
+            "jacfwd(predict, argnums=2)(weight, bias, x)\n",
+            "  7.99 ms\n",
+            "  1 measurement, 500 runs , 1 thread\n",
+            "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d67b50>\n",
+            "jacrev(predict, argnums=2)(weight, bias, x)\n",
+            "  1.09 ms\n",
+            "  1 measurement, 500 runs , 1 thread\n"
+          ]
+        }
+      ],
+      "id": "_DRFqzqZflUM"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "and a relative perf comparison:"
+      ],
+      "metadata": {
+        "id": "5SRbMCNsflUM"
+      },
+      "id": "5SRbMCNsflUM"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "c282ce25-4f6e-44cd-aed7-60f6f5010e5b",
+        "id": "uF_9GaoiflUM"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            " Performance delta: 635.2095 percent improvement with jacfwd \n"
+          ]
+        }
+      ],
+      "id": "uF_9GaoiflUM"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Hessian computation with functorch.hessian\n"
+      ],
+      "metadata": {
+        "id": "J29FQaBQflUM"
+      },
+      "id": "J29FQaBQflUM"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "We offer a convenience API to compute hessians: `functorch.hessian`. \n",
+        "Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).\n",
+        "\n",
+        "This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. \n",
+        "Indeed, under the hood, `hessian(f)` is simply `jacfwd(jacrev(f))`.\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "My4DPH97flUM"
+      },
+      "id": "My4DPH97flUM"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Note: to boost performance: depending on your model, you may also want to use `jacfwd(jacfwd(f))` or `jacrev(jacrev(f))` instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "FJt038l5flUM"
+      },
+      "id": "FJt038l5flUM"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from functorch import hessian\n",
+        "\n",
+        "# lets reduce the size in order not to blow out colab. Hessians require significant memory:\n",
+        "Din = 512\n",
+        "Dout = 32\n",
+        "weight = torch.randn(Dout, Din)\n",
+        "bias = torch.randn(Dout)\n",
+        "x = torch.randn(Din)\n",
+        "\n",
+        "hess_api = hessian(predict, argnums=2)(weight, bias, x)\n",
+        "hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\n",
+        "#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)\n"
+      ],
+      "metadata": {
+        "id": "jEqr2ywZflUM"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "jEqr2ywZflUM"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())"
+      ],
+      "metadata": {
+        "id": "n9BHcICQflUN"
+      },
+      "id": "n9BHcICQflUN"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "torch.allclose(hess_api, hess_fwdfwd)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "e457e3bc-f085-4f90-966d-f98893b98ea8",
+        "id": "eHiWRkjJflUN"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "True"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 18
+        }
+      ],
+      "id": "eHiWRkjJflUN"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Batch Jacobian and Batch Hessian\n"
+      ],
+      "metadata": {
+        "id": "Gjt1RO8HflUN"
+      },
+      "id": "Gjt1RO8HflUN"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "In the above examples we’ve been operating with a single feature vector. In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. That is, given a batch of inputs of shape `(B, N)` and a function that goes from $R^N \\to R^M$, we would like a Jacobian of shape `(B, M, N)`. \n",
+        "\n",
+        "The easiest way to do this is to use vmap:"
+      ],
+      "metadata": {
+        "id": "RjIzdoQNflUN"
+      },
+      "id": "RjIzdoQNflUN"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "batch_size = 64\n",
+        "Din = 31\n",
+        "Dout = 33\n",
+        "\n",
+        "weight = torch.randn(Dout, Din)\n",
+        "print(f\"weight shape = {weight.shape}\")\n",
+        "\n",
+        "bias = torch.randn(Dout)\n",
+        "\n",
+        "x = torch.randn(batch_size, Din)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "561eb618-e00f-40d5-bd99-fa51ab82051f",
+        "id": "B1eoEO4UflUN"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "weight shape = torch.Size([33, 31])\n"
+          ]
+        }
+      ],
+      "id": "B1eoEO4UflUN"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\n",
+        "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)"
+      ],
+      "metadata": {
+        "id": "nZ_V02NhflUN"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "nZ_V02NhflUN"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it’s also sometimes possible to do this without using vmap by summing the outputs and then computing the Jacobian of that function:\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "_OLDiY3MflUN"
+      },
+      "id": "_OLDiY3MflUN"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def predict_with_output_summed(weight, bias, x):\n",
+        "    return predict(weight, bias, x).sum(0)\n",
+        "\n",
+        "batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\n",
+        "assert torch.allclose(batch_jacobian0, batch_jacobian1)"
+      ],
+      "metadata": {
+        "id": "_QH4hD8PflUO"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "_QH4hD8PflUO"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "If you instead have a function that goes from $𝑅^𝑁 \\to 𝑅^𝑀$ but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n",
+        "\n",
+        "Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "eUjw65cCflUO"
+      },
+      "id": "eUjw65cCflUO"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n",
+        "\n",
+        "batch_hess = compute_batch_hessian(weight, bias, x)\n",
+        "batch_hess.shape"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "f3135cfa-e9e5-4f18-8cb7-0655e8a37cb5",
+        "id": "3vAyQjMsflUO"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "torch.Size([64, 33, 31, 31])"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 22
+        }
+      ],
+      "id": "3vAyQjMsflUO"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Computing Hessian-vector products\n",
+        "\n",
+        "The naive way to compute a Hessian-vector product (hvp) is to materialize the full Hessian and perform a dot-product with a vector. We can do better: it turns out we don't need to materialize the full Hessian to do this. We'll go through two (of many) different strategies to compute Hessian-vector products:\n",
+        "- composing reverse-mode AD with reverse-mode AD\n",
+        "- composing reverse-mode AD with forward-mode AD\n",
+        "\n",
+        "Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn't need to construct an Autograd graph and save intermediates for backward:"
+      ],
+      "metadata": {
+        "id": "Wa8E48sQgpkb"
+      },
+      "id": "Wa8E48sQgpkb"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from functorch import jvp, grad, vjp\n",
+        "\n",
+        "def hvp(f, primals, tangents):\n",
+        "  return jvp(grad(f), primals, tangents)[1]"
+      ],
+      "metadata": {
+        "id": "trw6WbAth6BM"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "trw6WbAth6BM"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Here's some sample usage."
+      ],
+      "metadata": {
+        "id": "DQMpRo6nitfr"
+      },
+      "id": "DQMpRo6nitfr"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def f(x):\n",
+        "  return x.sin().sum()\n",
+        "\n",
+        "x = torch.randn(2048)\n",
+        "tangent = torch.randn(2048)\n",
+        "\n",
+        "result = hvp(f, (x,), (tangent,))"
+      ],
+      "metadata": {
+        "id": "sPwg8SOdiVAK"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "sPwg8SOdiVAK"
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "If PyTorch forward-AD does not have coverage for your operations, then we can instead compose reverse-mode AD with reverse-mode AD:"
+      ],
+      "metadata": {
+        "id": "zGvUIcB0j1Ez"
+      },
+      "id": "zGvUIcB0j1Ez"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def hvp_revrev(f, primals, tangents):\n",
+        "  _, vjp_fn = vjp(grad(f), *primals)\n",
+        "  return vjp_fn(*tangents)"
+      ],
+      "metadata": {
+        "id": "mdDFZdlekAOK"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "mdDFZdlekAOK"
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\n",
+        "assert torch.allclose(result, result_hvp_revrev[0])"
+      ],
+      "metadata": {
+        "id": "_CuCk9X0lW7C"
+      },
+      "execution_count": null,
+      "outputs": [],
+      "id": "_CuCk9X0lW7C"
     }
-   ],
-   "source": [
-    "Din = 2048\n",
-    "Dout = 32\n",
-    "weight = torch.randn(Dout, Din)\n",
-    "bias = torch.randn(Dout)\n",
-    "x = torch.randn(Din)\n",
-    "\n",
-    "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
-    "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
-    "print(f'jacfwd time: {using_fwd.timeit(500)}')\n",
-    "print(f'jacrev time: {using_bwd.timeit(500)}')"
-   ]
+  ],
+  "metadata": {
+    "kernelspec": {
+      "display_name": "Python 3",
+      "language": "python",
+      "name": "python3"
+    },
+    "language_info": {
+      "codemirror_mode": {
+        "name": "ipython",
+        "version": 3
+      },
+      "file_extension": ".py",
+      "mimetype": "text/x-python",
+      "name": "python",
+      "nbconvert_exporter": "python",
+      "pygments_lexer": "ipython3",
+      "version": "3.8.3"
+    },
+    "colab": {
+      "name": "jacobians_hessians.ipynb",
+      "provenance": []
+    }
   },
-  {
-   "cell_type": "markdown",
-   "id": "7612bcdd-f93d-4c7a-89f6-0a734a7aca5b",
-   "metadata": {},
-   "source": [
-    "## Hessian computation with functorch.hessian\n",
-    "\n",
-    "We offer a convenience API to compute hessians: functorch.hessian.\n",
-    "Hessians are the jacobian of the jacobian, which suggests that one can just\n",
-    "compose functorch's jacobian transforms to compute one.\n",
-    "Indeed, under the hood, ``hessian(f)`` is simply ``jacfwd(jacrev(f))``\n",
-    "\n",
-    "Depending on your model, you may also want to use `jacfwd(jacfwd(f))` or\n",
-    "`jacrev(jacrev(f))` instead to compute hessians."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 12,
-   "id": "845118a5-b923-48f2-adbc-9509efd9143f",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from functorch import hessian\n",
-    "# # TODO(rzou): make sure PyTorch has tanh_backward implemented for jvp!!\n",
-    "# hess0 = hessian(predict, argnums=2)(weight, bias, x)\n",
-    "# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\n",
-    "hess2 = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "b4dd11da-19c0-4f82-8710-fc92c81967ea",
-   "metadata": {},
-   "source": [
-    "## Batch Jacobian (and Batch Hessian)\n",
-    "\n",
-    "In the above examples we've been operating with a single feature vector.\n",
-    "In some cases you might want to take the Jacobian of a batch of outputs\n",
-    "with respect to a batch of inputs. That is, given a batch of inputs of shape `(B, N)` and a function\n",
-    "that goes from `R^N -> R^M`, we would like a Jacobian of shape `(B, M, N)`.\n",
-    "The easiest way to do this is to use vmap:"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 13,
-   "id": "af55d512-b8cf-4304-8d2c-2df88a0352ed",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "batch_size = 64\n",
-    "Din = 31\n",
-    "Dout = 33\n",
-    "weight = torch.randn(Dout, Din)\n",
-    "bias = torch.randn(Dout)\n",
-    "x = torch.randn(batch_size, Din)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 14,
-   "id": "e2a159d6-6724-4086-bc5c-490d1e7515f8",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\n",
-    "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "b6fae848-d3d7-4148-93d4-ed609372e14f",
-   "metadata": {},
-   "source": [
-    "If you have a function that goes from `(B, N) -> (B, M)` instead and are certain that each input produces an independent\n",
-    "output, then it's also sometimes possible to do this without using `vmap` by summing the outputs and then computing the Jacobian of that function:"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 15,
-   "id": "13de15ed-d2f7-4733-816e-eb265aa83429",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def predict_with_output_summed(weight, bias, x):\n",
-    "    return predict(weight, bias, x).sum(0)\n",
-    "\n",
-    "batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\n",
-    "assert torch.allclose(batch_jacobian0, batch_jacobian1)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "c15adba4-e8ab-482f-bfa6-bdd6055eb7fb",
-   "metadata": {},
-   "source": [
-    "If you instead have a function that goes from $R^N -> R^M$ but inputs that are\n",
-    "batched, you compose vmap with jacrev to compute batched jacobians:"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "0d257bbb-d208-40be-9982-6b88a0eb9f3a",
-   "metadata": {},
-   "source": [
-    "Finally, batch hessians can be computed similarly. It's easiest to think about\n",
-    "them by using vmap to batch over hessian computation, but in some cases the sum\n",
-    "trick also works."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "id": "10353024-d0db-4c63-8865-c21973dcfc03",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n",
-    "# TODO(rzou): PyTorch forward-mode AD does not support tanh_backward\n",
-    "# batch_hess = compute_batch_hessian(weight, bias, x)"
-   ]
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python 3",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.3"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 5
+  "nbformat": 4,
+  "nbformat_minor": 5
 }