[functorch] Colab ready tutorials (3 total) with updated colab badge links (pytorch/functorch#519)
* Create readme.md
* colab version, jacobians hessians tutorial
* per sample gradients tutorial in colab
* colab ready tutorial, ensembling
* add colab badge and link to colab ready tutorial
* Delete ensembling_colab.ipynb
* add colab badge and link to colab ready version
* add colab badge and link to colab ready tutorial
* add colab badge and link to colab ready tutorial, add meta-learning mention as addtl use case
diff --git a/functorch/notebooks/colab/ensembling_colab.ipynb b/functorch/notebooks/colab/ensembling_colab.ipynb
new file mode 100644
index 0000000..93fc5f9
--- /dev/null
+++ b/functorch/notebooks/colab/ensembling_colab.ipynb
@@ -0,0 +1,676 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "ensembling_colab.ipynb",
+ "provenance": [],
+ "collapsed_sections": [
+ "0I5Mm2q2f5aw"
+ ],
+ "machine_shape": "hm",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Welcome to the functorch tutorial on ensembling models, in colab."
+ ],
+ "metadata": {
+ "id": "W6b4RUiYnhSt"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Configuring your colab to run functorch \n"
+ ],
+ "metadata": {
+ "id": "0I5Mm2q2f5aw"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "**Getting setup** - running functorch currently requires Pytorch Nightly. \n",
+ "Thus we'll go through a pytorch nightly install and build functorch. \n",
+ "\n",
+ "After that and a restart, you'll be ready to run the tutorial here on colab."
+ ],
+ "metadata": {
+ "id": "jnHxd2KFgPJg"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's setup a restart function:"
+ ],
+ "metadata": {
+ "id": "PvwZSOklhpB2"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def colab_restart():\n",
+ " print(\"--> Restarting colab instance\") \n",
+ " get_ipython().kernel.do_shutdown(True)"
+ ],
+ "metadata": {
+ "id": "MklsA-KRhZKC"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next, let's confirm that we have a gpu. \n",
+ "(If not, select Runtime -> Change Runtime type above,\n",
+ " and select GPU under Hardward Accelerator )"
+ ],
+ "metadata": {
+ "id": "Njk9qPgTiiGS"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!nvcc --version"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HxidO4dpiPGi",
+ "outputId": "f97e7000-a327-45c4-a993-f6469c121a7f"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "nvcc: NVIDIA (R) Cuda compiler driver\n",
+ "Copyright (c) 2005-2020 NVIDIA Corporation\n",
+ "Built on Mon_Oct_12_20:09:46_PDT_2020\n",
+ "Cuda compilation tools, release 11.1, V11.1.105\n",
+ "Build cuda_11.1.TC455_06.29190527_0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's remove the default PyTorch install:"
+ ],
+ "metadata": {
+ "id": "HanoUO62jtKx"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip uninstall -y torch"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NIoTNykP9xI5",
+ "outputId": "d36069bb-cc94-45e9-a1e4-635b9bc616f4"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Found existing installation: torch 1.10.0+cu111\n",
+ "Uninstalling torch-1.10.0+cu111:\n",
+ " Successfully uninstalled torch-1.10.0+cu111\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "And install the relevant nightly version. (this defaults to 11.1 Cuda which works on most colabs). "
+ ],
+ "metadata": {
+ "id": "n-DFUwBVkHaX"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "cuda_version = \"cu111\" # optionally - cu113 (for 11.3) is an option as well if you have 11.3 listed above in the nvcc output. "
+ ],
+ "metadata": {
+ "id": "BH5ffJBkkRR8"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --pre torch -f https://download.pytorch.org/whl/nightly/{cuda_version}/torch_nightly.html --upgrade"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Bi2oymijkav5",
+ "outputId": "bfd678c2-5d52-4ba1-cea2-3125c8f0fcbb"
+ },
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in links: https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html\n",
+ "Collecting torch\n",
+ " Downloading https://download.pytorch.org/whl/nightly/cu111/torch-1.12.0.dev20220217%2Bcu111-cp37-cp37m-linux_x86_64.whl (1923.7 MB)\n",
+ "\u001b[K |█████████████▉ | 834.1 MB 133.5 MB/s eta 0:00:09tcmalloc: large alloc 1147494400 bytes == 0x555621e3c000 @ 0x7f9e1d70c615 0x5555e92683bc 0x5555e934918a 0x5555e926b1cd 0x5555e935db3d 0x5555e92df458 0x5555e92da02f 0x5555e926caba 0x5555e92df2c0 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e935e986 0x5555e92db350 0x5555e935e986 0x5555e92db350 0x5555e935e986 0x5555e92db350 0x5555e926cf19 0x5555e92b0a79 0x5555e926bb32 0x5555e92df1dd 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e92da02f 0x5555e926caba 0x5555e92daeae 0x5555e926c9da 0x5555e92db108 0x5555e92da02f\n",
+ "\u001b[K |█████████████████▋ | 1055.7 MB 1.8 MB/s eta 0:08:01tcmalloc: large alloc 1434370048 bytes == 0x555666492000 @ 0x7f9e1d70c615 0x5555e92683bc 0x5555e934918a 0x5555e926b1cd 0x5555e935db3d 0x5555e92df458 0x5555e92da02f 0x5555e926caba 0x5555e92df2c0 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e935e986 0x5555e92db350 0x5555e935e986 0x5555e92db350 0x5555e935e986 0x5555e92db350 0x5555e926cf19 0x5555e92b0a79 0x5555e926bb32 0x5555e92df1dd 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e92da02f 0x5555e926caba 0x5555e92daeae 0x5555e926c9da 0x5555e92db108 0x5555e92da02f\n",
+ "\u001b[K |██████████████████████▎ | 1336.2 MB 1.2 MB/s eta 0:07:54tcmalloc: large alloc 1792966656 bytes == 0x5555eb2c4000 @ 0x7f9e1d70c615 0x5555e92683bc 0x5555e934918a 0x5555e926b1cd 0x5555e935db3d 0x5555e92df458 0x5555e92da02f 0x5555e926caba 0x5555e92df2c0 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e935e986 0x5555e92db350 0x5555e935e986 0x5555e92db350 0x5555e935e986 0x5555e92db350 0x5555e926cf19 0x5555e92b0a79 0x5555e926bb32 0x5555e92df1dd 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e92da02f 0x5555e926caba 0x5555e92daeae 0x5555e926c9da 0x5555e92db108 0x5555e92da02f\n",
+ "\u001b[K |████████████████████████████▏ | 1691.1 MB 1.2 MB/s eta 0:03:17tcmalloc: large alloc 2241208320 bytes == 0x5556560ac000 @ 0x7f9e1d70c615 0x5555e92683bc 0x5555e934918a 0x5555e926b1cd 0x5555e935db3d 0x5555e92df458 0x5555e92da02f 0x5555e926caba 0x5555e92df2c0 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e935e986 0x5555e92db350 0x5555e935e986 0x5555e92db350 0x5555e935e986 0x5555e92db350 0x5555e926cf19 0x5555e92b0a79 0x5555e926bb32 0x5555e92df1dd 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e92da02f 0x5555e926caba 0x5555e92daeae 0x5555e926c9da 0x5555e92db108 0x5555e92da02f\n",
+ "\u001b[K |████████████████████████████████| 1923.7 MB 2.5 MB/s eta 0:00:01tcmalloc: large alloc 1923702784 bytes == 0x5556dba0e000 @ 0x7f9e1d70b1e7 0x5555e929e5d7 0x5555e92683bc 0x5555e934918a 0x5555e926b1cd 0x5555e935db3d 0x5555e92df458 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e926c9da 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e92da02f\n",
+ "tcmalloc: large alloc 2404630528 bytes == 0x55574e4a4000 @ 0x7f9e1d70c615 0x5555e92683bc 0x5555e934918a 0x5555e926b1cd 0x5555e935db3d 0x5555e92df458 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92db108 0x5555e926c9da 0x5555e92db108 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e92da02f 0x5555e926caba 0x5555e92dbcd4 0x5555e92da02f 0x5555e926d151\n",
+ "\u001b[K |████████████████████████████████| 1923.7 MB 8.0 kB/s \n",
+ "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (3.10.0.2)\n",
+ "Installing collected packages: torch\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "torchvision 0.11.1+cu111 requires torch==1.10.0, but you have torch 1.12.0.dev20220217+cu111 which is incompatible.\n",
+ "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.12.0.dev20220217+cu111 which is incompatible.\n",
+ "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.12.0.dev20220217+cu111 which is incompatible.\u001b[0m\n",
+ "Successfully installed torch-1.12.0.dev20220217+cu111\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's install Ninja to accelerate the functorch building process:"
+ ],
+ "metadata": {
+ "id": "OkL1Q6KXlzid"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install ninja"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VEJs4UEOkay0",
+ "outputId": "6aef4cb4-c158-457c-e093-1eb919d2e419"
+ },
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting ninja\n",
+ " Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)\n",
+ "\u001b[?25l\r\u001b[K |███ | 10 kB 15.0 MB/s eta 0:00:01\r\u001b[K |██████ | 20 kB 9.2 MB/s eta 0:00:01\r\u001b[K |█████████ | 30 kB 6.1 MB/s eta 0:00:01\r\u001b[K |████████████▏ | 40 kB 3.6 MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 51 kB 3.5 MB/s eta 0:00:01\r\u001b[K |██████████████████▏ | 61 kB 4.2 MB/s eta 0:00:01\r\u001b[K |█████████████████████▏ | 71 kB 4.4 MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 81 kB 5.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▎ | 92 kB 5.2 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▎ | 102 kB 4.2 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 108 kB 4.2 MB/s \n",
+ "\u001b[?25hInstalling collected packages: ninja\n",
+ "Successfully installed ninja-1.10.2.3\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next we'll install and build functorch (eta is ~6 minutes):"
+ ],
+ "metadata": {
+ "id": "s3rrVgGkmNpi"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "UtBgzUPDfIQg",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "affd08d0-7a6f-41d7-d32b-00418120231a"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting git+https://github.com/pytorch/functorch.git\n",
+ " Cloning https://github.com/pytorch/functorch.git to /tmp/pip-req-build-fu9dydpo\n",
+ " Running command git clone -q https://github.com/pytorch/functorch.git /tmp/pip-req-build-fu9dydpo\n",
+ "Requirement already satisfied: torch>=1.10.0.dev in /usr/local/lib/python3.7/dist-packages (from functorch==0.2.0a0+8915608) (1.12.0.dev20220217+cu111)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.10.0.dev->functorch==0.2.0a0+8915608) (3.10.0.2)\n",
+ "Building wheels for collected packages: functorch\n",
+ " Building wheel for functorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for functorch: filename=functorch-0.2.0a0+8915608-cp37-cp37m-linux_x86_64.whl size=21388647 sha256=7df53c10ec474b040d0ab8d774b29c400af196b80303a8f0341d21732c46e28d\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-r01u5lf5/wheels/b0/a9/4a/ffec50dda854c8d9f2ba21e4ffc0f2489ea97946cb1102c5ab\n",
+ "Successfully built functorch\n",
+ "Installing collected packages: functorch\n",
+ "Successfully installed functorch-0.2.0a0+8915608\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install --user \"git+https://github.com/pytorch/functorch.git\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Finally - restart colab and after that - just skip directly down to the '-- Tutorial Start --' section to get underway."
+ ],
+ "metadata": {
+ "id": "T8dhR1XEmcJ6"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "colab_restart() "
+ ],
+ "metadata": {
+ "id": "xo2UY9b8ma8t",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "5e0f9316-7a33-4069-aff2-2fb8b8631d04"
+ },
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--> Restarting colab instance\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## -- Tutorial Start -- \n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "nj6_fW76wM0d"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Confirm we are ready to start. \n",
+ "# If this errs, please make sure you have completed the 'configuring your colab' steps above first and then return here.\n",
+ "\n",
+ "import functorch "
+ ],
+ "metadata": {
+ "id": "SvUfIxRyeAaL"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Model Ensembling\n",
+ "\n",
+ "This example illustrates how to vectorize model ensembling, using vmap.\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "nLdOLDH6m9oy"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "**What is model ensembling?**\n",
+ "\n",
+ "Model ensembling combines the predictions from multiple models together. \n",
+ "\n",
+ "Traditionally this is done by running each model on some inputs separately and then combining the predictions. \n",
+ "\n",
+ "However, if you’re running models with the same architecture, then it may be possible to combine them together using vmap. vmap is a function transform that maps functions across dimensions of the input tensors. \n",
+ "\n",
+ "One of its use cases is eliminating for-loops and speeding them up through vectorization.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "CJJBTOl-tawq"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let’s demonstrate how to do this using an ensemble of simple CNNs."
+ ],
+ "metadata": {
+ "id": "z21OfixOvBaM"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "from functools import partial\n",
+ "torch.manual_seed(0);"
+ ],
+ "metadata": {
+ "id": "Gb-yt4VKUUuc"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Here's a simple CNN\n",
+ "\n",
+ "class SimpleCNN(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(SimpleCNN, self).__init__()\n",
+ " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
+ " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
+ " self.fc1 = nn.Linear(9216, 128)\n",
+ " self.fc2 = nn.Linear(128, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.conv1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.conv2(x)\n",
+ " x = F.relu(x)\n",
+ " x = F.max_pool2d(x, 2)\n",
+ " x = torch.flatten(x, 1)\n",
+ " x = self.fc1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.fc2(x)\n",
+ " output = F.log_softmax(x, dim=1)\n",
+ " output = x\n",
+ " return output"
+ ],
+ "metadata": {
+ "id": "tf-HKHjUUbyY"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. \n",
+ "\n",
+ "Thus, the dummy images are 28 by 28, and we have a minibatch of size 64.\n",
+ "\n",
+ "Furthermore, lets say we want to combine the predictions from 10 different models. \n"
+ ],
+ "metadata": {
+ "id": "VEDPe-EoU5Fa"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "device = 'cuda'\n",
+ "\n",
+ "num_models = 10\n",
+ "\n",
+ "data = torch.randn(100, 64, 1, 28, 28, device=device)\n",
+ "targets = torch.randint(10, (6400,), device=device)\n",
+ "\n",
+ "models = [SimpleCNN().to(device) for _ in range(num_models)]\n"
+ ],
+ "metadata": {
+ "id": "WB2Qe3AHUvPN"
+ },
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "We have a couple of options for generating predictions. \n",
+ "\n",
+ "Maybe we want to give each model a different randomized minibatch of data. \n",
+ "\n",
+ "Alternatively, maybe we want to run the same minibatch of data through each model (e.g. if we were testing the effect of different model initializations).\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "GOGJ-OUxVcT5"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Option 1: different minibatch for each model"
+ ],
+ "metadata": {
+ "id": "CwJBb09MxCN3"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "minibatches = data[:num_models]\n",
+ "predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]"
+ ],
+ "metadata": {
+ "id": "WYjMx8QTUvRu"
+ },
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Option 2: Same minibatch"
+ ],
+ "metadata": {
+ "id": "HNw4_IVzU5Pz"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "minibatch = data[0]\n",
+ "predictions2 = [model(minibatch) for model in models]"
+ ],
+ "metadata": {
+ "id": "vUsb3VfexJrY"
+ },
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Using vmap to vectorize the ensemble\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "aNkX6lFIxzcm"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let’s use vmap to speed up the for-loop. We must first prepare the models for use with vmap.\n",
+ "\n",
+ "First, let’s combine the states of the model together by stacking each parameter. For example, model[i].fc1.weight has shape [9216, 128]; we are going to stack the .fc1.weight of each of the 10 models to produce a big weight of shape [10, 9216, 128].\n",
+ "\n",
+ "functorch offers the 'combine_state_for_ensemble' convenience function to do that. It returns a stateless version of the model (fmodel) and stacked parameters and buffers.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "-sFMojhryviM"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from functorch import combine_state_for_ensemble\n",
+ "\n",
+ "fmodel, params, buffers = combine_state_for_ensemble(models)\n",
+ "[p.requires_grad_() for p in params];\n"
+ ],
+ "metadata": {
+ "id": "C3a9_clvyPho"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Option 1: get predictions using a different minibatch for each model. \n",
+ "\n",
+ "By default, vmap maps a function across the first dimension of all inputs to the passed-in function. \n",
+ "\n",
+ "After using the combine_state_for_ensemble, each of the params and buffers have an additional dimension of size 'num_models' at the front, and minibatches has a dimension of size 'num_models'.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "mFJDWMM9yaYZ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print([p.size(0) for p in params]) # show the leading 'num_models' dimension\n",
+ "\n",
+ "assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ezuFQx1G1zLG",
+ "outputId": "15dcae9a-1a0e-4c24-fb7c-f8059a630400"
+ },
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "[10, 10, 10, 10, 10, 10, 10, 10]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from functorch import vmap\n",
+ "\n",
+ "predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)\n",
+ "\n",
+ "# verify the vmap predictions match the \n",
+ "assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)"
+ ],
+ "metadata": {
+ "id": "VroLnfD82DDf"
+ },
+ "execution_count": 11,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Option 2: get predictions using the same minibatch of data.\n",
+ "\n",
+ "vmap has an in_dims arg that specifies which dimensions to map over. \n",
+ "\n",
+ "By using None, we tell vmap we want the same minibatch to apply for all of the 10 models.\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "tlkmyQyfY6XU"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)\n",
+ "\n",
+ "assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)"
+ ],
+ "metadata": {
+ "id": "WiSMupvCyecd"
+ },
+ "execution_count": 13,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "A quick note: there are limitations around what types of functions can be transformed by vmap. \n",
+ "\n",
+ "The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs that have no side effects (e.g. mutation). \n",
+ "\n",
+ "vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "wMsbppPNZklo"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. \n",
+ "\n",
+ "There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). \n",
+ "\n",
+ "If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "UI74G9JarQU8"
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/functorch/notebooks/colab/jacobians_hessians_colab.ipynb b/functorch/notebooks/colab/jacobians_hessians_colab.ipynb
new file mode 100644
index 0000000..20a4fe7
--- /dev/null
+++ b/functorch/notebooks/colab/jacobians_hessians_colab.ipynb
@@ -0,0 +1,1134 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "functorch_hessians_colab.ipynb",
+ "provenance": [],
+ "collapsed_sections": [
+ "0I5Mm2q2f5aw"
+ ],
+ "machine_shape": "hm",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Welcome to the functorch tutorial on Jacobians, Hessians and more - on colab! "
+ ],
+ "metadata": {
+ "id": "W6b4RUiYnhSt"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Configuring your colab to run functorch \n"
+ ],
+ "metadata": {
+ "id": "0I5Mm2q2f5aw"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "**Getting setup** - running functorch currently requires Pytorch Nightly. \n",
+ "Thus we'll go through a pytorch nightly install and build functorch. \n",
+ "\n",
+ "After that and a restart, you'll be ready to run the tutorial here on colab."
+ ],
+ "metadata": {
+ "id": "jnHxd2KFgPJg"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's setup a restart function:"
+ ],
+ "metadata": {
+ "id": "PvwZSOklhpB2"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def colab_restart():\n",
+ " print(\"--> Restarting colab instance\") \n",
+ " get_ipython().kernel.do_shutdown(True)"
+ ],
+ "metadata": {
+ "id": "MklsA-KRhZKC"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next, let's confirm that we have a gpu. \n",
+ "(If not, select Runtime -> Change Runtime type above,\n",
+ " and select GPU under Hardward Accelerator )"
+ ],
+ "metadata": {
+ "id": "Njk9qPgTiiGS"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!nvcc --version"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HxidO4dpiPGi",
+ "outputId": "8a285bcd-a791-4d19-9a71-8e53d8325eba"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "nvcc: NVIDIA (R) Cuda compiler driver\n",
+ "Copyright (c) 2005-2020 NVIDIA Corporation\n",
+ "Built on Mon_Oct_12_20:09:46_PDT_2020\n",
+ "Cuda compilation tools, release 11.1, V11.1.105\n",
+ "Build cuda_11.1.TC455_06.29190527_0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's remove the default PyTorch install:"
+ ],
+ "metadata": {
+ "id": "HanoUO62jtKx"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip uninstall -y torch"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NIoTNykP9xI5",
+ "outputId": "f8678909-b7f0-4c37-8186-a6ce38c0f483"
+ },
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Found existing installation: torch 1.10.0+cu111\n",
+ "Uninstalling torch-1.10.0+cu111:\n",
+ " Successfully uninstalled torch-1.10.0+cu111\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "And install the relevant nightly version. (this defaults to 11.1 Cuda which works on most colabs). "
+ ],
+ "metadata": {
+ "id": "n-DFUwBVkHaX"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "cuda_version = \"cu111\" # optionally - cu113 (for 11.3) is an option as well if you have 11.3 listed above in the nvcc output. "
+ ],
+ "metadata": {
+ "id": "BH5ffJBkkRR8"
+ },
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --pre torch -f https://download.pytorch.org/whl/nightly/{cuda_version}/torch_nightly.html --upgrade"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Bi2oymijkav5",
+ "outputId": "d78924ae-a04e-44ce-c28f-b7d20b9b2cfc"
+ },
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in links: https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html\n",
+ "Collecting torch\n",
+ " Downloading https://download.pytorch.org/whl/nightly/cu111/torch-1.12.0.dev20220216%2Bcu111-cp37-cp37m-linux_x86_64.whl (1922.9 MB)\n",
+ "\u001b[K |█████████████▉ | 834.1 MB 60.9 MB/s eta 0:00:18tcmalloc: large alloc 1147494400 bytes == 0x559246f1a000 @ 0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f\n",
+ "\u001b[K |█████████████████▋ | 1055.7 MB 1.3 MB/s eta 0:11:14tcmalloc: large alloc 1434370048 bytes == 0x55928b570000 @ 0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f\n",
+ "\u001b[K |██████████████████████▎ | 1336.2 MB 1.3 MB/s eta 0:07:18tcmalloc: large alloc 1792966656 bytes == 0x5592103a2000 @ 0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f\n",
+ "\u001b[K |████████████████████████████▏ | 1691.1 MB 1.3 MB/s eta 0:03:02tcmalloc: large alloc 2241208320 bytes == 0x55927b18a000 @ 0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88f2c0 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e90e986 0x55920e88b350 0x55920e81cf19 0x55920e860a79 0x55920e81bb32 0x55920e88f1dd 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88aeae 0x55920e81c9da 0x55920e88b108 0x55920e88a02f\n",
+ "\u001b[K |████████████████████████████████| 1922.9 MB 1.2 MB/s eta 0:00:01tcmalloc: large alloc 1922924544 bytes == 0x559300aec000 @ 0x7fa0182cd1e7 0x55920e84e5d7 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e81c9da 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f\n",
+ "tcmalloc: large alloc 2403655680 bytes == 0x5593734c4000 @ 0x7fa0182ce615 0x55920e8183bc 0x55920e8f918a 0x55920e81b1cd 0x55920e90db3d 0x55920e88f458 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88b108 0x55920e81c9da 0x55920e88b108 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81caba 0x55920e88bcd4 0x55920e88a02f 0x55920e81d151\n",
+ "\u001b[K |████████████████████████████████| 1922.9 MB 4.7 kB/s \n",
+ "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (3.10.0.2)\n",
+ "Installing collected packages: torch\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "torchvision 0.11.1+cu111 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\n",
+ "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\n",
+ "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\u001b[0m\n",
+ "Successfully installed torch-1.12.0.dev20220216+cu111\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's install Ninja to accelerate the functorch building process:"
+ ],
+ "metadata": {
+ "id": "OkL1Q6KXlzid"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install ninja"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VEJs4UEOkay0",
+ "outputId": "ca624c7e-dfe6-4233-a51d-b8ad5188a0c4"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting ninja\n",
+ " Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)\n",
+ "\u001b[?25l\r\u001b[K |███ | 10 kB 19.8 MB/s eta 0:00:01\r\u001b[K |██████ | 20 kB 8.7 MB/s eta 0:00:01\r\u001b[K |█████████ | 30 kB 7.4 MB/s eta 0:00:01\r\u001b[K |████████████▏ | 40 kB 6.8 MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 51 kB 4.0 MB/s eta 0:00:01\r\u001b[K |██████████████████▏ | 61 kB 4.2 MB/s eta 0:00:01\r\u001b[K |█████████████████████▏ | 71 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 81 kB 4.8 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▎ | 92 kB 3.7 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▎ | 102 kB 4.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 108 kB 4.0 MB/s \n",
+ "\u001b[?25hInstalling collected packages: ninja\n",
+ "Successfully installed ninja-1.10.2.3\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next we'll install and build functorch (eta is ~6 minutes):"
+ ],
+ "metadata": {
+ "id": "s3rrVgGkmNpi"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "id": "UtBgzUPDfIQg",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "2c103dc1-7123-4320-b012-652b2d27d0a8"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting git+https://github.com/pytorch/functorch.git\n",
+ " Cloning https://github.com/pytorch/functorch.git to /tmp/pip-req-build-htz8t0jk\n",
+ " Running command git clone -q https://github.com/pytorch/functorch.git /tmp/pip-req-build-htz8t0jk\n",
+ "Requirement already satisfied: torch>=1.10.0.dev in /usr/local/lib/python3.7/dist-packages (from functorch==0.2.0a0+2cf76f3) (1.12.0.dev20220216+cu111)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.10.0.dev->functorch==0.2.0a0+2cf76f3) (3.10.0.2)\n",
+ "Building wheels for collected packages: functorch\n",
+ " Building wheel for functorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for functorch: filename=functorch-0.2.0a0+2cf76f3-cp37-cp37m-linux_x86_64.whl size=21457003 sha256=be6cfe683ff09d15bac0a66e14d6d2d476a15a18273ceb0fc64a1d13fa0e37d7\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-zrhpj6mp/wheels/b0/a9/4a/ffec50dda854c8d9f2ba21e4ffc0f2489ea97946cb1102c5ab\n",
+ "Successfully built functorch\n",
+ "Installing collected packages: functorch\n",
+ "Successfully installed functorch-0.2.0a0+2cf76f3\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install --user \"git+https://github.com/pytorch/functorch.git\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Finally - restart colab and after that - just skip directly down to the '-- Tutorial Start --' section to get underway."
+ ],
+ "metadata": {
+ "id": "T8dhR1XEmcJ6"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "colab_restart() "
+ ],
+ "metadata": {
+ "id": "xo2UY9b8ma8t",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "cf9301e6-13ce-4cb6-aefe-c43b14561ec8"
+ },
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--> Restarting colab instance\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## -- Tutorial Start -- \n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "nj6_fW76wM0d"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Confirm we are ready to start. \n",
+ "# If this errs, please make sure you have completed the install steps above first and then return here.\n",
+ "\n",
+ "import functorch "
+ ],
+ "metadata": {
+ "id": "SvUfIxRyeAaL"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Jacobians, hessians, and more: composing functorch transforms\n",
+ "\n",
+ "Computing jacobians or hessians are useful in a number of non-traditional deep learning models. \n",
+ "\n",
+ "It is difficult (or annoying) to compute these quantities efficiently using a standard autodiff system like PyTorch Autograd; functorch provides ways of computing various higher-order autodiff quantities efficiently.\n"
+ ],
+ "metadata": {
+ "id": "nLdOLDH6m9oy"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "from functools import partial\n",
+ "_ = torch.manual_seed(0)\n"
+ ],
+ "metadata": {
+ "id": "vUsb3VfexJrY"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "**Comparing functorch vs the naive approach:**\n",
+ "\n",
+ "Let’s start with a function that we’d like to compute the jacobian of. This is a simple linear function with non-linear activation.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "aNkX6lFIxzcm"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def predict(weight, bias, x):\n",
+ " return F.linear(x, weight, bias).tanh()"
+ ],
+ "metadata": {
+ "id": "C3a9_clvyPho"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's add some dummy data: a weight, a bias, and a feature vector x.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "mFJDWMM9yaYZ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "D = 16\n",
+ "weight = torch.randn(D, D)\n",
+ "bias = torch.randn(D)\n",
+ "x = torch.randn(D) # feature vector"
+ ],
+ "metadata": {
+ "id": "WiSMupvCyecd"
+ },
+ "execution_count": 55,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.\n",
+ "PyTorch Autograd computes vector-Jacobian products. In order to compute the full\n",
+ "Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row\n",
+ "by using a different unit vector each time."
+ ],
+ "metadata": {
+ "id": "cTgIIZ9Wyih8"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def compute_jac(xp):\n",
+ " jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n",
+ " for vec in unit_vectors]\n",
+ " return torch.stack(jacobian_rows)"
+ ],
+ "metadata": {
+ "id": "ItURFU3M-p98"
+ },
+ "execution_count": 56,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "xp = x.clone().requires_grad_()\n",
+ "unit_vectors = torch.eye(D)\n",
+ "\n",
+ "jacobian = compute_jac(xp)\n",
+ "\n",
+ "print(jacobian.shape)\n",
+ "print(jacobian[0]) # show first row"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1gehVA1c-BHd",
+ "outputId": "81454f59-59e6-470f-e6a6-92c671137ad8"
+ },
+ "execution_count": 57,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "torch.Size([16, 16])\n",
+ "tensor([-5.9625e-06, 1.9876e-05, 7.0103e-06, 1.1086e-05, -1.1939e-05,\n",
+ " 1.0975e-05, 8.3484e-06, -1.4599e-06, -1.9937e-05, 1.4976e-05,\n",
+ " -7.4515e-06, -2.2042e-06, 5.0195e-07, 1.5267e-05, -7.8227e-06,\n",
+ " 6.9435e-06])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Instead of computing the jacobian row-by-row, we can use vmap to get rid of the for-loop and vectorize the computation. \n",
+ "We can’t directly apply vmap to PyTorch Autograd; instead, functorch provides a vjp transform:\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "BEZaNt1d_bc1"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from functorch import vmap, vjp\n",
+ "\n",
+ "_, vjp_fn = vjp(partial(predict, weight, bias), x)\n",
+ "\n",
+ "ft_jacobian, = vmap(vjp_fn)(unit_vectors)\n",
+ "\n",
+ "# lets confirm both methods compute the same result\n",
+ "assert torch.allclose(ft_jacobian, jacobian)"
+ ],
+ "metadata": {
+ "id": "Zfnn2C2g-6Fb"
+ },
+ "execution_count": 58,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In another tutorial a composition of reverse-mode AD and vmap gave us per-sample-gradients. \n",
+ "In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! \n",
+ "Various compositions of vmap and autodiff transforms can give us different interesting quantities.\n",
+ "\n",
+ "functorch provides **jacrev** as a convenience function that performs the vmap-vjp composition to compute jacobians. **jacrev** accepts an argnums argument that says which argument we would like to compute Jacobians with respect to.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "4gDqecJbyVgt"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from functorch import jacrev\n",
+ "\n",
+ "ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n",
+ "\n",
+ "# confirm \n",
+ "assert torch.allclose(ft_jacobian, jacobian)"
+ ],
+ "metadata": {
+ "id": "t0EfptYTAO47"
+ },
+ "execution_count": 59,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let’s compare the performance of the two ways to compute the jacobian. The functorch version is much faster (and becomes even faster the more outputs there are). \n",
+ "\n",
+ "In general, we expect that vectorization via vmap can help eliminate overhead and give better utilization of your hardware.\n",
+ "\n",
+ "Vmap does this magic by pushing the outer loop down into the functions primitive operations in order to obtain better performance.\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "PDEIxPZoxUb7"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:"
+ ],
+ "metadata": {
+ "id": "gHxrra_jA3ur"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def get_perf(first, first_descriptor, second, second_descriptor):\n",
+ " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n",
+ " faster = second.times[0]\n",
+ " slower = first.times[0]\n",
+ " gain = (slower-faster)/slower\n",
+ " if gain < 0: gain *=-1 \n",
+ " final_gain = gain*100\n",
+ " print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")\n",
+ " "
+ ],
+ "metadata": {
+ "id": "rENMCuodBIef"
+ },
+ "execution_count": 60,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "And then run the performance comparison:"
+ ],
+ "metadata": {
+ "id": "IaPfXXHngmUG"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.utils.benchmark import Timer\n",
+ "\n",
+ "without_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\n",
+ "with_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+ "\n",
+ "no_vmap_timer = without_vmap.timeit(500)\n",
+ "with_vmap_timer = with_vmap.timeit(500)\n",
+ "\n",
+ "print(no_vmap_timer)\n",
+ "print(with_vmap_timer)\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "quEKCA2-Afdq",
+ "outputId": "72cf8a9f-759f-479e-9525-190da282b802"
+ },
+ "execution_count": 61,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "<torch.utils.benchmark.utils.common.Measurement object at 0x7f682eb5a450>\n",
+ "compute_jac(xp)\n",
+ " 2.04 ms\n",
+ " 1 measurement, 500 runs , 1 thread\n",
+ "<torch.utils.benchmark.utils.common.Measurement object at 0x7f6733f08810>\n",
+ "jacrev(predict, argnums=2)(weight, bias, x)\n",
+ " 810.29 us\n",
+ " 1 measurement, 500 runs , 1 thread\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Lets do a relative performance comparison of the above with our get_perf function:"
+ ],
+ "metadata": {
+ "id": "5tY4c45fxVMi"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "get_perf(no_vmap_timer, \"without vmap\", with_vmap_timer, \"vmap\");"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "gQ_iv23m97A7",
+ "outputId": "e9a44c9e-9ed7-41f5-dd3c-77f0b7599f06"
+ },
+ "execution_count": 62,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " Performance delta: 60.3299 percent improvement with vmap \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input."
+ ],
+ "metadata": {
+ "id": "wtUxdj8gD1w7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) # note the change in input via argnums params of 0,1 to map to weight and bias"
+ ],
+ "metadata": {
+ "id": "iKtvWR0n-b3E"
+ },
+ "execution_count": 63,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n"
+ ],
+ "metadata": {
+ "id": "zKm1sgT0EPx8"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: \n",
+ "\n",
+ "- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. \n",
+ "\n",
+ "- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. \n",
+ "\n",
+ "jacfwd and jacrev can be substituted for each other but they have different performance characteristics.\n",
+ "\n",
+ "As a general rule of thumb, if you’re computing the jacobian of an 𝑅𝑁−>𝑅𝑀 function, and there are many more outputs than inputs (i.e. M > N) then jacfwd is preferred, otherwise use jacrev. \n",
+ "\n",
+ "There are exceptions to this rule, but a non-rigorous argument for this follows:\n",
+ "\n",
+ "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. \n",
+ "\n",
+ "The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "LDqTlkfXEP0q"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from functorch import jacrev, jacfwd"
+ ],
+ "metadata": {
+ "id": "GrQG0lRoFML7"
+ },
+ "execution_count": 64,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "First, let's benchmark with more inputs than outputs:\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "7QIZkss7FQhK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "Din = 32\n",
+ "Dout = 2048\n",
+ "weight = torch.randn(Dout, Din)\n",
+ "\n",
+ "bias = torch.randn(Dout)\n",
+ "x = torch.randn(Din)\n",
+ "\n",
+ "# remember the general rule about taller vs wider...here we have a taller matrix:\n",
+ "print(weight.shape)\n",
+ "\n",
+ "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+ "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+ "\n",
+ "jacfwd_timing = using_fwd.timeit(500)\n",
+ "jacrev_timing = using_bwd.timeit(500)\n",
+ "\n",
+ "print(f'jacfwd time: {jacfwd_timing}')\n",
+ "print(f'jacrev time: {jacrev_timing}')\n"
+ ],
+ "metadata": {
+ "id": "N0M0i6xf-nBt",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "add77ba5-a947-4bb7-8c36-a2e742086cab"
+ },
+ "execution_count": 65,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "torch.Size([2048, 32])\n",
+ "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f6734014c90>\n",
+ "jacfwd(predict, argnums=2)(weight, bias, x)\n",
+ " 1.18 ms\n",
+ " 1 measurement, 500 runs , 1 thread\n",
+ "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67344b8650>\n",
+ "jacrev(predict, argnums=2)(weight, bias, x)\n",
+ " 14.98 ms\n",
+ " 1 measurement, 500 runs , 1 thread\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "and then do a relative benchmark:"
+ ],
+ "metadata": {
+ "id": "UEh5jIK2FpBJ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YmEqbvDeFrtt",
+ "outputId": "a1eae78f-9a08-4b14-d96c-18483070b9d1"
+ },
+ "execution_count": 67,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " Performance delta: 1170.0622 percent improvement with jacrev \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "and now the reverse - more outputs (M) than inputs (N):"
+ ],
+ "metadata": {
+ "id": "aZAXlFUNFxAY"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "Din = 2048\n",
+ "Dout = 32\n",
+ "weight = torch.randn(Dout, Din)\n",
+ "bias = torch.randn(Dout)\n",
+ "x = torch.randn(Din)\n",
+ "\n",
+ "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+ "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
+ "\n",
+ "jacfwd_timing = using_fwd.timeit(500)\n",
+ "jacrev_timing = using_bwd.timeit(500)\n",
+ "\n",
+ "print(f'jacfwd time: {jacfwd_timing}')\n",
+ "print(f'jacrev time: {jacrev_timing}')"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "jHVkLcr9_SIe",
+ "outputId": "86f7fb41-d14b-4b5f-9086-709278c99f67"
+ },
+ "execution_count": 71,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67340145d0>\n",
+ "jacfwd(predict, argnums=2)(weight, bias, x)\n",
+ " 8.99 ms\n",
+ " 1 measurement, 500 runs , 1 thread\n",
+ "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f67344b8110>\n",
+ "jacrev(predict, argnums=2)(weight, bias, x)\n",
+ " 1.03 ms\n",
+ " 1 measurement, 500 runs , 1 thread\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "and a relative perf comparison:"
+ ],
+ "metadata": {
+ "id": "I47HDJBwGAM4"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")"
+ ],
+ "metadata": {
+ "id": "jPdAcIgu1es-",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "ceebc303-9531-4903-84a9-0cbefe4bc318"
+ },
+ "execution_count": 72,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " Performance delta: 775.3424 percent improvement with jacfwd \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Hessian computation with functorch.hessian\n"
+ ],
+ "metadata": {
+ "id": "NRr6l4u0obus"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "We offer a convenience API to compute hessians: functorch.hessian. \n",
+ "Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).\n",
+ "\n",
+ "This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. \n",
+ "Indeed, under the hood, hessian(f) is simply jacfwd(jacrev(f)).\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "k0vSE1C1GeUJ"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Note: to boost performance: depending on your model, you may also want to use jacfwd(jacfwd(f)) or jacrev(jacrev(f)) instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "NEu1Zfo2G9fa"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from functorch import hessian\n",
+ "\n",
+ "# lets reduce the size in order not to blow out colab. Hessians require significant memory:\n",
+ "Din = 512\n",
+ "Dout = 32\n",
+ "weight = torch.randn(Dout, Din)\n",
+ "bias = torch.randn(Dout)\n",
+ "x = torch.randn(Din)\n",
+ "\n",
+ "hess_api = hessian(predict, argnums=2)(weight, bias, x)\n",
+ "hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\n",
+ "#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)\n"
+ ],
+ "metadata": {
+ "id": "tYhxPLb-Gdh-"
+ },
+ "execution_count": 87,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())"
+ ],
+ "metadata": {
+ "id": "Qm_TPCCiso9u"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "torch.allclose(hess_api, hess_fwdfwd)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "uL23RF5UrroT",
+ "outputId": "48327575-6a69-4e3c-898d-06b4792f44ca"
+ },
+ "execution_count": 89,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 89
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Batch Jacobian and Batch Hessian\n"
+ ],
+ "metadata": {
+ "id": "9xBE48HXIOOj"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In the above examples we’ve been operating with a single feature vector. \n",
+ "\n",
+ "In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. \n",
+ "\n",
+ "That is, given a batch of inputs of shape (B, N) and a function that goes from R^N -> R^M, we would like a Jacobian of shape (B, M, N). \n",
+ "\n",
+ "The easiest way to do this is to use vmap:\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "sJVzGqnEIhJA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "batch_size = 64\n",
+ "Din = 31\n",
+ "Dout = 33\n",
+ "\n",
+ "weight = torch.randn(Dout, Din)\n",
+ "print(f\"weight shape = {weight.shape}\")\n",
+ "\n",
+ "bias = torch.randn(Dout)\n",
+ "\n",
+ "x = torch.randn(batch_size, Din)"
+ ],
+ "metadata": {
+ "id": "gEEWzX2QndqN",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "c55e07d8-e3b6-40f6-f8f4-39014ff7d9b9"
+ },
+ "execution_count": 91,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "weight shape = torch.Size([33, 31])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\n",
+ "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)"
+ ],
+ "metadata": {
+ "id": "khYtmCqJn1h-"
+ },
+ "execution_count": 92,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it’s also sometimes possible to do this without using vmap by summing the outputs and then computing the Jacobian of that function:\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "rXE9tY05JHaJ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def predict_with_output_summed(weight, bias, x):\n",
+ " return predict(weight, bias, x).sum(0)\n",
+ "\n",
+ "batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\n",
+ "assert torch.allclose(batch_jacobian0, batch_jacobian1)"
+ ],
+ "metadata": {
+ "id": "eohigCobop4R"
+ },
+ "execution_count": 93,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "If you instead have a function that goes from 𝑅𝑁−>𝑅𝑀 but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n",
+ "\n",
+ "Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "O3AGffymp_39"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n",
+ "\n",
+ "batch_hess = compute_batch_hessian(weight, bias, x)\n",
+ "batch_hess.shape"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HYddoLSTfi-g",
+ "outputId": "5ddbd9b0-57f8-4ac4-a399-a6bdadfaa167"
+ },
+ "execution_count": 95,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "torch.Size([64, 33, 31, 31])"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 95
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "metadata": {
+ "id": "22e5fo2jqANi"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/functorch/notebooks/colab/per_sample_grads_colab.ipynb b/functorch/notebooks/colab/per_sample_grads_colab.ipynb
new file mode 100644
index 0000000..2400d17
--- /dev/null
+++ b/functorch/notebooks/colab/per_sample_grads_colab.ipynb
@@ -0,0 +1,929 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "per-sample-gradients_colab.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "machine_shape": "hm",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Welcome to the functorch tutorial on Per-Sample-Gradients, in colab."
+ ],
+ "metadata": {
+ "id": "W6b4RUiYnhSt"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Configuring your colab to run functorch \n"
+ ],
+ "metadata": {
+ "id": "0I5Mm2q2f5aw"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "**Getting setup** - running functorch currently requires Pytorch Nightly. \n",
+ "Thus we'll go through a pytorch nightly install and build functorch. \n",
+ "\n",
+ "After that and a restart, you'll be ready to run the tutorial here on colab."
+ ],
+ "metadata": {
+ "id": "jnHxd2KFgPJg"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's setup a restart function:"
+ ],
+ "metadata": {
+ "id": "PvwZSOklhpB2"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def colab_restart():\n",
+ " print(\"--> Restarting colab instance\") \n",
+ " get_ipython().kernel.do_shutdown(True)"
+ ],
+ "metadata": {
+ "id": "MklsA-KRhZKC"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next, let's confirm that we have a gpu. \n",
+ "(If not, select Runtime -> Change Runtime type above,\n",
+ " and select GPU under Hardward Accelerator )"
+ ],
+ "metadata": {
+ "id": "Njk9qPgTiiGS"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!nvcc --version"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HxidO4dpiPGi",
+ "outputId": "675468f4-6d81-4590-b56f-c61d041300c3"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "nvcc: NVIDIA (R) Cuda compiler driver\n",
+ "Copyright (c) 2005-2020 NVIDIA Corporation\n",
+ "Built on Mon_Oct_12_20:09:46_PDT_2020\n",
+ "Cuda compilation tools, release 11.1, V11.1.105\n",
+ "Build cuda_11.1.TC455_06.29190527_0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's remove the default PyTorch install:"
+ ],
+ "metadata": {
+ "id": "HanoUO62jtKx"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip uninstall -y torch"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NIoTNykP9xI5",
+ "outputId": "d6fa784d-b837-4f35-e7ac-3ee834fe188d"
+ },
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Found existing installation: torch 1.10.0+cu111\n",
+ "Uninstalling torch-1.10.0+cu111:\n",
+ " Successfully uninstalled torch-1.10.0+cu111\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "And install the relevant nightly version. (this defaults to 11.1 Cuda which works on most colabs). "
+ ],
+ "metadata": {
+ "id": "n-DFUwBVkHaX"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "cuda_version = \"cu111\" # optionally - cu113 (for 11.3) is an option as well if you have 11.3 listed above in the nvcc output. "
+ ],
+ "metadata": {
+ "id": "BH5ffJBkkRR8"
+ },
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --pre torch -f https://download.pytorch.org/whl/nightly/{cuda_version}/torch_nightly.html --upgrade"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Bi2oymijkav5",
+ "outputId": "3e99707b-cb55-45cf-ed89-1312e5bf6a7a"
+ },
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in links: https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html\n",
+ "Collecting torch\n",
+ " Downloading https://download.pytorch.org/whl/nightly/cu111/torch-1.12.0.dev20220216%2Bcu111-cp37-cp37m-linux_x86_64.whl (1922.9 MB)\n",
+ "\u001b[K |█████████████▉ | 834.1 MB 95.9 MB/s eta 0:00:12tcmalloc: large alloc 1147494400 bytes == 0x55baa2524000 @ 0x7fd8e627d615 0x55ba696303bc 0x55ba6971118a 0x55ba696331cd 0x55ba69725b3d 0x55ba696a7458 0x55ba696a202f 0x55ba69634aba 0x55ba696a72c0 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba69726986 0x55ba696a3350 0x55ba69726986 0x55ba696a3350 0x55ba69726986 0x55ba696a3350 0x55ba69634f19 0x55ba69678a79 0x55ba69633b32 0x55ba696a71dd 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba696a202f 0x55ba69634aba 0x55ba696a2eae 0x55ba696349da 0x55ba696a3108 0x55ba696a202f\n",
+ "\u001b[K |█████████████████▋ | 1055.7 MB 1.2 MB/s eta 0:11:42tcmalloc: large alloc 1434370048 bytes == 0x55bae6b7a000 @ 0x7fd8e627d615 0x55ba696303bc 0x55ba6971118a 0x55ba696331cd 0x55ba69725b3d 0x55ba696a7458 0x55ba696a202f 0x55ba69634aba 0x55ba696a72c0 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba69726986 0x55ba696a3350 0x55ba69726986 0x55ba696a3350 0x55ba69726986 0x55ba696a3350 0x55ba69634f19 0x55ba69678a79 0x55ba69633b32 0x55ba696a71dd 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba696a202f 0x55ba69634aba 0x55ba696a2eae 0x55ba696349da 0x55ba696a3108 0x55ba696a202f\n",
+ "\u001b[K |██████████████████████▎ | 1336.2 MB 91.0 MB/s eta 0:00:07tcmalloc: large alloc 1792966656 bytes == 0x55ba6b9ac000 @ 0x7fd8e627d615 0x55ba696303bc 0x55ba6971118a 0x55ba696331cd 0x55ba69725b3d 0x55ba696a7458 0x55ba696a202f 0x55ba69634aba 0x55ba696a72c0 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba69726986 0x55ba696a3350 0x55ba69726986 0x55ba696a3350 0x55ba69726986 0x55ba696a3350 0x55ba69634f19 0x55ba69678a79 0x55ba69633b32 0x55ba696a71dd 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba696a202f 0x55ba69634aba 0x55ba696a2eae 0x55ba696349da 0x55ba696a3108 0x55ba696a202f\n",
+ "\u001b[K |████████████████████████████▏ | 1691.1 MB 1.2 MB/s eta 0:03:18tcmalloc: large alloc 2241208320 bytes == 0x55bad6794000 @ 0x7fd8e627d615 0x55ba696303bc 0x55ba6971118a 0x55ba696331cd 0x55ba69725b3d 0x55ba696a7458 0x55ba696a202f 0x55ba69634aba 0x55ba696a72c0 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba69726986 0x55ba696a3350 0x55ba69726986 0x55ba696a3350 0x55ba69726986 0x55ba696a3350 0x55ba69634f19 0x55ba69678a79 0x55ba69633b32 0x55ba696a71dd 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba696a202f 0x55ba69634aba 0x55ba696a2eae 0x55ba696349da 0x55ba696a3108 0x55ba696a202f\n",
+ "\u001b[K |████████████████████████████████| 1922.9 MB 95.2 MB/s eta 0:00:01tcmalloc: large alloc 1922924544 bytes == 0x55bb5c0f6000 @ 0x7fd8e627c1e7 0x55ba696665d7 0x55ba696303bc 0x55ba6971118a 0x55ba696331cd 0x55ba69725b3d 0x55ba696a7458 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696349da 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba696a202f\n",
+ "tcmalloc: large alloc 2403655680 bytes == 0x55bbceace000 @ 0x7fd8e627d615 0x55ba696303bc 0x55ba6971118a 0x55ba696331cd 0x55ba69725b3d 0x55ba696a7458 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3108 0x55ba696349da 0x55ba696a3108 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba696a202f 0x55ba69634aba 0x55ba696a3cd4 0x55ba696a202f 0x55ba69635151\n",
+ "\u001b[K |████████████████████████████████| 1922.9 MB 4.6 kB/s \n",
+ "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (3.10.0.2)\n",
+ "Installing collected packages: torch\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "torchvision 0.11.1+cu111 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\n",
+ "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\n",
+ "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.12.0.dev20220216+cu111 which is incompatible.\u001b[0m\n",
+ "Successfully installed torch-1.12.0.dev20220216+cu111\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's install Ninja to accelerate the functorch building process:"
+ ],
+ "metadata": {
+ "id": "OkL1Q6KXlzid"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install ninja"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VEJs4UEOkay0",
+ "outputId": "978fa3a2-5db4-4494-8f7f-e4ef3955f54d"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting ninja\n",
+ " Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)\n",
+ "\u001b[?25l\r\u001b[K |███ | 10 kB 38.2 MB/s eta 0:00:01\r\u001b[K |██████ | 20 kB 25.5 MB/s eta 0:00:01\r\u001b[K |█████████ | 30 kB 17.1 MB/s eta 0:00:01\r\u001b[K |████████████▏ | 40 kB 14.8 MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 51 kB 8.6 MB/s eta 0:00:01\r\u001b[K |██████████████████▏ | 61 kB 10.1 MB/s eta 0:00:01\r\u001b[K |█████████████████████▏ | 71 kB 8.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 81 kB 9.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▎ | 92 kB 10.0 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▎ | 102 kB 8.9 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 108 kB 8.9 MB/s \n",
+ "\u001b[?25hInstalling collected packages: ninja\n",
+ "Successfully installed ninja-1.10.2.3\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next we'll install and build functorch (eta is ~6 minutes):"
+ ],
+ "metadata": {
+ "id": "s3rrVgGkmNpi"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "id": "UtBgzUPDfIQg",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "fe4ba10d-d615-41fc-8ce5-8a5fba68d1ec"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting git+https://github.com/pytorch/functorch.git\n",
+ " Cloning https://github.com/pytorch/functorch.git to /tmp/pip-req-build-uzx6hua9\n",
+ " Running command git clone -q https://github.com/pytorch/functorch.git /tmp/pip-req-build-uzx6hua9\n",
+ "Requirement already satisfied: torch>=1.10.0.dev in /usr/local/lib/python3.7/dist-packages (from functorch==0.2.0a0+588410b) (1.12.0.dev20220216+cu111)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.10.0.dev->functorch==0.2.0a0+588410b) (3.10.0.2)\n",
+ "Building wheels for collected packages: functorch\n",
+ " Building wheel for functorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for functorch: filename=functorch-0.2.0a0+588410b-cp37-cp37m-linux_x86_64.whl size=21303322 sha256=535625b8a293f957fc3e0deddb44e9490cd7b3cb5a0918968573abf195505b66\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-4l5feo3i/wheels/b0/a9/4a/ffec50dda854c8d9f2ba21e4ffc0f2489ea97946cb1102c5ab\n",
+ "Successfully built functorch\n",
+ "Installing collected packages: functorch\n",
+ "Successfully installed functorch-0.2.0a0+588410b\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install --user \"git+https://github.com/pytorch/functorch.git\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Finally - restart colab and after that - just skip directly down to the '-- Tutorial Start --' section to get underway."
+ ],
+ "metadata": {
+ "id": "T8dhR1XEmcJ6"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "colab_restart() "
+ ],
+ "metadata": {
+ "id": "xo2UY9b8ma8t",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "2a3cf5f3-5f22-4e4a-ecc2-8ad04071ff63"
+ },
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--> Restarting colab instance\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## -- Tutorial Start -- \n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "nj6_fW76wM0d"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Confirm we are ready to start. \n",
+ "# If this errs, please make sure you have completed the 'configuring your colab' steps above first and then return here.\n",
+ "\n",
+ "import functorch "
+ ],
+ "metadata": {
+ "id": "SvUfIxRyeAaL"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# What is per-sample-gradients?\n",
+ "Per-sample-gradient computation is computing the gradient for each and every sample in a batch of data. \n",
+ "It is a useful quantity for differential privacy, meta-learning, and optimization research.\n",
+ "\n",
+ "Let's walk through a simple example of per-sample-gradients in action below with a simple CNN model. \n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "nLdOLDH6m9oy"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "from functools import partial\n",
+ "\n",
+ "torch.manual_seed(0);"
+ ],
+ "metadata": {
+ "id": "Gb-yt4VKUUuc"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Here's a simple CNN and loss function:\n",
+ "\n",
+ "class SimpleCNN(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(SimpleCNN, self).__init__()\n",
+ " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
+ " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
+ " self.fc1 = nn.Linear(9216, 128)\n",
+ " self.fc2 = nn.Linear(128, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.conv1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.conv2(x)\n",
+ " x = F.relu(x)\n",
+ " x = F.max_pool2d(x, 2)\n",
+ " x = torch.flatten(x, 1)\n",
+ " x = self.fc1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.fc2(x)\n",
+ " output = F.log_softmax(x, dim=1)\n",
+ " output = x\n",
+ " return output\n",
+ "\n",
+ "def loss_fn(predictions, targets):\n",
+ " return F.nll_loss(predictions, targets)"
+ ],
+ "metadata": {
+ "id": "tf-HKHjUUbyY"
+ },
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. \n",
+ "\n",
+ "Thus, the dummy images are 28 by 28, and we have a minibatch of size 64.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "VEDPe-EoU5Fa"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "device = 'cuda'\n",
+ "\n",
+ "num_models = 10\n",
+ "batch_size = 64\n",
+ "data = torch.randn(batch_size, 1, 28, 28, device=device)\n",
+ "\n",
+ "targets = torch.randint(10, (64,), device=device)\n"
+ ],
+ "metadata": {
+ "id": "WB2Qe3AHUvPN"
+ },
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In regular model training, one would forward the minibatch through the model, and then call .backward() to compute gradients. This would generate an 'average' gradient of the entire mini-batch:\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "GOGJ-OUxVcT5"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model = SimpleCNN().to(device=device)\n",
+ "predictions = model(data) # move the entire mini-batch through the model\n",
+ "\n",
+ "loss = loss_fn(predictions, targets)\n",
+ "loss.backward() # back propogate the 'average' gradient of this mini-batch\n"
+ ],
+ "metadata": {
+ "id": "WYjMx8QTUvRu"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In contrast to the above approach, per-sample-gradient computation is equivalent to: \n",
+ "\n",
+ "for each individual sample of the data, perform a forward and a backward pass to get an individual (per-sample) gradient.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "HNw4_IVzU5Pz"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def compute_grad(sample, target):\n",
+ " \n",
+ " sample = sample.unsqueeze(0) # prepend batch dimension for processing\n",
+ " target = target.unsqueeze(0)\n",
+ "\n",
+ " prediction = model(sample)\n",
+ " loss = loss_fn(prediction, target)\n",
+ "\n",
+ " return torch.autograd.grad(loss, list(model.parameters()))\n",
+ "\n",
+ "\n",
+ "def compute_sample_grads(data, targets):\n",
+ " \"\"\" manually process each sample with per sample gradient \"\"\"\n",
+ " sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]\n",
+ " sample_grads = zip(*sample_grads)\n",
+ " sample_grads = [torch.stack(shards) for shards in sample_grads]\n",
+ " return sample_grads\n",
+ "\n",
+ "per_sample_grads = compute_sample_grads(data, targets)\n"
+ ],
+ "metadata": {
+ "id": "vUsb3VfexJrY"
+ },
+ "execution_count": 40,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "sample_grads[0] is the per-sample-grad for model.conv1.weight.\n",
+ "\n",
+ "model.conv1.weight.shape is [32, 1, 3, 3]; \n",
+ "\n",
+ "notice how there is one gradient, per sample, in the batch for a total of 64.\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "aNkX6lFIxzcm"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(per_sample_grads[0].shape)\n"
+ ],
+ "metadata": {
+ "id": "C3a9_clvyPho",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "407abc1a-846f-4e50-83bc-c90719a26073"
+ },
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "torch.Size([64, 32, 1, 3, 3])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Per-sample-grads, *the efficient way*, using functorch\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "mFJDWMM9yaYZ"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "We can compute per-sample-gradients efficiently by using function transforms. \n",
+ "\n",
+ "First, let’s create a stateless functional version of model by using functorch.make_functional_with_buffers. \n",
+ "\n",
+ "This will seperate state (the parameters) from the model and turn the model into a pure function:\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "tlkmyQyfY6XU"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from functorch import make_functional_with_buffers, vmap, grad\n",
+ "\n",
+ "fmodel, params, buffers = make_functional_with_buffers(model)"
+ ],
+ "metadata": {
+ "id": "WiSMupvCyecd"
+ },
+ "execution_count": 13,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's review the changes - first, the model has become stateless FunctionalModuleWithBuffers:"
+ ],
+ "metadata": {
+ "id": "wMsbppPNZklo"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "fmodel"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Xj0cZOJMZbbB",
+ "outputId": "2e87dfde-3af2-4e1f-cd91-5c232446fb53"
+ },
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "FunctionalModuleWithBuffers(\n",
+ " (stateless_model): SimpleCNN(\n",
+ " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
+ " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
+ " (fc1): Linear(in_features=9216, out_features=128, bias=True)\n",
+ " (fc2): Linear(in_features=128, out_features=10, bias=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 15
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "And the model parameters now exist independently of the model, stored as a tuple:"
+ ],
+ "metadata": {
+ "id": "zv4_YYPxZvvg"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "for x in params:\n",
+ " print(f\"{x.shape}\")\n",
+ "\n",
+ "print(f\"\\n{type(params)}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tH0TAZhBZ3bS",
+ "outputId": "97c4401f-cccb-43f6-b071-c85a18fc439b"
+ },
+ "execution_count": 33,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "torch.Size([32, 1, 3, 3])\n",
+ "torch.Size([32])\n",
+ "torch.Size([64, 32, 3, 3])\n",
+ "torch.Size([64])\n",
+ "torch.Size([128, 9216])\n",
+ "torch.Size([128])\n",
+ "torch.Size([10, 128])\n",
+ "torch.Size([10])\n",
+ "\n",
+ "<class 'tuple'>\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next, let’s define a function to compute the loss of the model given a single input rather than a batch of inputs. \n",
+ "\n",
+ "It is important that this function accepts the parameters, the input, and the target, because we will be transforming over them. \n",
+ "\n",
+ "Note - because the model was originally written to handle batches, we’ll use torch.unsqueeze to add a batch dimension.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "cTgIIZ9Wyih8"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def compute_loss_stateless_model (params, buffers, sample, target):\n",
+ "\n",
+ " batch = sample.unsqueeze(0)\n",
+ " targets = target.unsqueeze(0)\n",
+ "\n",
+ " predictions = fmodel(params, buffers, batch) \n",
+ "\n",
+ " loss = loss_fn(predictions, targets)\n",
+ "\n",
+ " return loss"
+ ],
+ "metadata": {
+ "id": "ItURFU3M-p98"
+ },
+ "execution_count": 34,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Now, let’s use functorch's **grad** to create a new function that computes the gradient \n",
+ "with respect to the first argument of compute_loss (i.e. the params)."
+ ],
+ "metadata": {
+ "id": "Qo3sbDK2i_bH"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ft_compute_grad = grad(compute_loss_stateless_model)\n"
+ ],
+ "metadata": {
+ "id": "sqRp_Sxni-Xm"
+ },
+ "execution_count": 35,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The **ft_compute_grad** function computes the gradient for a single (sample, target) pair. \n",
+ "\n",
+ "We can use vmap to get it to compute the gradient over an entire batch of samples and targets. \n",
+ "\n",
+ "Note that in_dims=(None, None, 0, 0) because we wish to map ft_compute_grad over the 0th dimension of the data and targets, and use the same params and buffers for each.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "2pG3Ofqjjc8O"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))"
+ ],
+ "metadata": {
+ "id": "62ecNMO6inqX"
+ },
+ "execution_count": 37,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Finally, let’s used our transformed function to compute per-sample-gradients:\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "_alXdQ3QkETu"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)\n",
+ "\n",
+ "# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:\n",
+ "for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):\n",
+ " assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)"
+ ],
+ "metadata": {
+ "id": "1gehVA1c-BHd"
+ },
+ "execution_count": 48,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "A quick note: there are limitations around what types of functions can be transformed by vmap. \n",
+ "\n",
+ "The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs, and that have no side effects (e.g. mutation). \n",
+ "\n",
+ "vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations.\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "BEZaNt1d_bc1"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Performance comparison"
+ ],
+ "metadata": {
+ "id": "BASP151Iml7B"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Curious about how the performance of vmap compares? \n",
+ "Currently the best results are obtained on newer GPU's such as the A100 (Ampere), but we can run some live results right here in colab, with the caveat that results will vary based on the age of the gpu assigned:"
+ ],
+ "metadata": {
+ "id": "jr1xNpV4nJ7u"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def get_perf(first, first_descriptor, second, second_descriptor):\n",
+ " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n",
+ " second_res = second.times[0]\n",
+ " first_res = first.times[0]\n",
+ "\n",
+ " gain = (first_res-second_res)/first_res\n",
+ " if gain < 0: gain *=-1 \n",
+ " final_gain = gain*100\n",
+ "\n",
+ " print(f\" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} \")\n"
+ ],
+ "metadata": {
+ "id": "GnAnMkYmoc-j"
+ },
+ "execution_count": 61,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.utils.benchmark import Timer\n",
+ "\n",
+ "without_vmap = Timer( stmt=\"compute_sample_grads(data, targets)\", globals=globals())\n",
+ "\n",
+ "with_vmap = Timer(stmt=\"ft_compute_sample_grad(params, buffers, data, targets)\",globals=globals())\n",
+ "\n",
+ "no_vmap_timing = without_vmap.timeit(100)\n",
+ "with_vmap_timing = with_vmap.timeit(100)\n",
+ "\n",
+ "\n",
+ "print(f'Per-sample-grads without vmap {no_vmap_timing}')\n",
+ "print(f'Per-sample-grads with vmap {with_vmap_timing}')"
+ ],
+ "metadata": {
+ "id": "Zfnn2C2g-6Fb",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "922f3901-773f-446b-b562-88e78f49036c"
+ },
+ "execution_count": 49,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f71ac3f1850>\n",
+ "compute_sample_grads(data, targets)\n",
+ " 79.86 ms\n",
+ " 1 measurement, 100 runs , 1 thread\n",
+ "Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f7143e26f10>\n",
+ "ft_compute_sample_grad(params, buffers, data, targets)\n",
+ " 12.93 ms\n",
+ " 1 measurement, 100 runs , 1 thread\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "get_perf(with_vmap_timing, \"vmap\", no_vmap_timing,\"no vmap\" )"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NV9R3LZQoavl",
+ "outputId": "e11e8be9-287d-4e60-e517-e08f8d6909bd"
+ },
+ "execution_count": 62,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " Performance delta: 517.5791 percent improvement with vmap \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. \n",
+ "\n",
+ "There are other optimized solutions to computing per-sample-gradients in PyTorch that also perform much better than the naive method like in https://github.com/pytorch/opacus. But it’s cool that we get the speedup on this example.\n",
+ "\n",
+ "\n",
+ "There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). \n",
+ "\n",
+ "If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "UI74G9JarQU8"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "metadata": {
+ "id": "2s6DpZUWobC9"
+ },
+ "execution_count": 62,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "metadata": {
+ "id": "22e5fo2jqANi"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/functorch/notebooks/colab/readme.md b/functorch/notebooks/colab/readme.md
new file mode 100644
index 0000000..b3feae5
--- /dev/null
+++ b/functorch/notebooks/colab/readme.md
@@ -0,0 +1,5 @@
+### Holds the colab ready versions of the notebook tutorials.
+
+These are similar to the jupyter notebooks, but have additional colab specific changes including the building of functorch in colab to prep for running.
+
+The colabs and notebooks are not auto-synced atm, thus currently updates to one need to be synched to the other.
diff --git a/functorch/notebooks/ensembling.ipynb b/functorch/notebooks/ensembling.ipynb
index 817a4c1..5dac091 100644
--- a/functorch/notebooks/ensembling.ipynb
+++ b/functorch/notebooks/ensembling.ipynb
@@ -9,6 +9,10 @@
"\n",
"This example illustrates how to vectorize model ensembling using vmap.\n",
"\n",
+ "<a href=\"https://colab.research.google.com/github/pytorch/functorch/blob/main/notebooks/colab/jacobians_hessians_colab.ipynb\">\n",
+ " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
+ "</a>\n",
+ "\n",
"## What is model ensembling?\n",
"Model ensembling combines the predictions from multiple models together.\n",
"Traditionally this is done by running each model on some inputs separately\n",
diff --git a/functorch/notebooks/jacobians_hessians.ipynb b/functorch/notebooks/jacobians_hessians.ipynb
index c76da97..7ceef2a 100644
--- a/functorch/notebooks/jacobians_hessians.ipynb
+++ b/functorch/notebooks/jacobians_hessians.ipynb
@@ -7,6 +7,10 @@
"source": [
"# Jacobians, hessians, and more: composing functorch transforms\n",
"\n",
+ "<a href=\"https://colab.research.google.com/github/pytorch/functorch/blob/main/notebooks/colab/jacobians_hessians_colab.ipynb\">\n",
+ " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
+ "</a>\n",
+ "\n",
"Computing jacobians or hessians are useful in a number of non-traditional\n",
"deep learning models. It is difficult (or annoying) to compute these quantities\n",
"efficiently using a standard autodiff system like PyTorch Autograd; functorch\n",
@@ -14,6 +18,12 @@
]
},
{
+ "cell_type": "markdown",
+ "id": "93379d25",
+ "metadata": {},
+ "source": []
+ },
+ {
"cell_type": "code",
"execution_count": 1,
"id": "a1aabaa9-6e86-4717-b645-b979a6b980a6",
@@ -471,7 +481,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.5"
+ "version": "3.8.3"
}
},
"nbformat": 4,
diff --git a/functorch/notebooks/per_sample_grads.ipynb b/functorch/notebooks/per_sample_grads.ipynb
index 685becd..748faca 100644
--- a/functorch/notebooks/per_sample_grads.ipynb
+++ b/functorch/notebooks/per_sample_grads.ipynb
@@ -7,11 +7,17 @@
"source": [
"# Per-sample-gradients\n",
"\n",
+ "\n",
+ "\n",
"## What is it?\n",
"\n",
"Per-sample-gradient computation is computing the gradient for each and every\n",
- "sample in a batch of data. It is a useful quantity in differential privacy\n",
- "and optimization research."
+ "sample in a batch of data. It is a useful quantity in differential privacy, meta-learning,\n",
+ "and optimization research.\n",
+ "\n",
+ "<a href=\"https://colab.research.google.com/github/pytorch/functorch/blob/main/notebooks/colab/per_sample_gradients_colab.ipynb\">\n",
+ " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
+ "</a>"
]
},
{