| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "a474c143-05c4-43b6-b12c-17b592d07a6a", |
| "metadata": {}, |
| "source": [ |
| "# Per-sample-gradients\n", |
| "\n", |
| "## What is it?\n", |
| "\n", |
| "Per-sample-gradient computation is computing the gradient for each and every\n", |
| "sample in a batch of data. It is a useful quantity in differential privacy\n", |
| "and optimization research." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "id": "4dba22dd-4b45-4816-8955-e87b50e672e6", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import torch\n", |
| "import torch.nn as nn\n", |
| "import torch.nn.functional as F\n", |
| "from functools import partial\n", |
| "torch.manual_seed(0)\n", |
| "\n", |
| "# Here's a simple CNN\n", |
| "class SimpleCNN(nn.Module):\n", |
| " def __init__(self):\n", |
| " super(SimpleCNN, self).__init__()\n", |
| " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", |
| " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", |
| " self.fc1 = nn.Linear(9216, 128)\n", |
| " self.fc2 = nn.Linear(128, 10)\n", |
| "\n", |
| " def forward(self, x):\n", |
| " x = self.conv1(x)\n", |
| " x = F.relu(x)\n", |
| " x = self.conv2(x)\n", |
| " x = F.relu(x)\n", |
| " x = F.max_pool2d(x, 2)\n", |
| " x = torch.flatten(x, 1)\n", |
| " x = self.fc1(x)\n", |
| " x = F.relu(x)\n", |
| " x = self.fc2(x)\n", |
| " output = F.log_softmax(x, dim=1)\n", |
| " output = x\n", |
| " return output\n", |
| "\n", |
| "def loss_fn(predictions, targets):\n", |
| " return F.nll_loss(predictions, targets)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "395ed4f2-0b25-4617-9b25-d9781fc16a15", |
| "metadata": {}, |
| "source": [ |
| "Let's generate a batch of dummy data. Pretend that we're working with an\n", |
| "MNIST dataset where the images are 28 by 28 and we have a minibatch of size 64." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "id": "872481df-377f-46d2-b32e-86670454a6f1", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "device = 'cuda'\n", |
| "num_models = 10\n", |
| "batch_size = 64\n", |
| "data = torch.randn(batch_size, 1, 28, 28, device=device)\n", |
| "targets = torch.randint(10, (64,), device=device)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "e4121248-aaba-478e-a9bc-6c121bcb46fd", |
| "metadata": {}, |
| "source": [ |
| "In regular model training, one would forward the batch of examples and then\n", |
| "call `.backward()` to compute gradients:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "19a76ba8-d037-4eb8-9bc2-43444ddcc0eb", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "model = SimpleCNN().to(device=device)\n", |
| "predictions = model(data)\n", |
| "loss = loss_fn(predictions, targets)\n", |
| "loss.backward()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "4a3a94e4-9c7e-4ff0-8e15-b50e4d13f18b", |
| "metadata": {}, |
| "source": [ |
| "Conceptually, per-sample-gradient computation is equivalent to: for each sample\n", |
| "of the data, perform a forward and a backward pass to get a gradient." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "id": "5f704d47-60ee-440c-8b62-6400d57113b4", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def compute_grad(sample, target):\n", |
| " sample = sample.unsqueeze(0)\n", |
| " target = target.unsqueeze(0)\n", |
| " prediction = model(sample)\n", |
| " loss = loss_fn(prediction, target)\n", |
| " return torch.autograd.grad(loss, list(model.parameters()))\n", |
| "\n", |
| "def compute_sample_grads(data, targets):\n", |
| " sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]\n", |
| " sample_grads = zip(*sample_grads)\n", |
| " sample_grads = [torch.stack(shards) for shards in sample_grads]\n", |
| " return sample_grads\n", |
| "\n", |
| "per_sample_grads = compute_sample_grads(data, targets)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "355c31c6-c312-48e7-9c9a-43d57da2f70f", |
| "metadata": {}, |
| "source": [ |
| "`sample_grads[0]` is the per-sample-grad for `model.conv1.weight`.\n", |
| "`model.conv1.weight.shape` is `[32, 1, 3, 3]`; notice how there is one gradient\n", |
| "per sample in the batch for a total of 64." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "fa0ba3c9-f37d-4bc7-bbd9-aaadf45fa702", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "torch.Size([64, 32, 1, 3, 3])\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(per_sample_grads[0].shape)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "569eac3c-e286-41c9-8601-e2f7a9be31b1", |
| "metadata": {}, |
| "source": [ |
| "## Per-sample-grads using functorch\n", |
| "\n", |
| "We can compute per-sample-gradients efficiently by using function transforms.\n", |
| "First, let's create a stateless functional version of `model` by using\n", |
| "`functorch.make_functional_with_buffers`." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "id": "c6a97623-a4f7-4492-8f0e-0f7ea47b149e", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from functorch import make_functional_with_buffers, vmap, grad\n", |
| "fmodel, params, buffers = make_functional_with_buffers(model)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "668da123-623e-4dc9-be1f-bfbba2abd6db", |
| "metadata": {}, |
| "source": [ |
| "Next, let's define a function to compute the loss of the model given a single\n", |
| "input rather than a batch of inputs. It is important that this function accepts the\n", |
| "parameters, the input, and the target, because we will be transforming over them.\n", |
| "Because the model was originally written to handle batches, we'll use\n", |
| "`torch.unsqueeze` to add a batch dimension." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "a4d324bf-e94a-46ac-875d-5cba8dc4ea20", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def compute_loss(params, buffers, sample, target):\n", |
| " batch = sample.unsqueeze(0)\n", |
| " targets = target.unsqueeze(0)\n", |
| " predictions = fmodel(params, buffers, batch)\n", |
| " loss = loss_fn(predictions, targets)\n", |
| " return loss" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "b9014be3-7516-428e-b6b9-ee71a4e9f6c8", |
| "metadata": {}, |
| "source": [ |
| "Now, let's use `grad` to create a new function that computes the gradient\n", |
| "with respect to the first argument of compute_loss (i.e. the params)." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "id": "ee5c01ef-ba90-4e96-b619-6102c07fa39d", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "ft_compute_grad = grad(compute_loss)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "6d764bb2-94c1-4a7c-9623-a7ade929359d", |
| "metadata": {}, |
| "source": [ |
| "`ft_compute_grad` computes the gradient for a single (sample, target) pair.\n", |
| "We can use `vmap` to get it to compute the gradient over an entire batch\n", |
| "of samples and targets. Note that `in_dims=(None, None, 0, 0)` because we wish\n", |
| "to map `ft_compute_grad` over the 0th dimension of the data and targets\n", |
| "and use the same params and buffers for each." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "id": "4e9c5029-b492-485c-b953-22e22c8f3468", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "bf45f8fd-a2c3-4916-83b4-fd3087bfeace", |
| "metadata": {}, |
| "source": [ |
| "Finally, let's used our transformed function to compute per-sample-gradients:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "id": "1c75070c-b844-45e3-8d62-7953dd396403", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)\n", |
| "for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):\n", |
| " assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-3, rtol=1e-5)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ff42597b-f1f6-433f-a4b9-184420030117", |
| "metadata": {}, |
| "source": [ |
| "A quick note: there are limitations around what types of functions can be\n", |
| "transformed by vmap. The best functions to transform are ones that are\n", |
| "pure functions: a function where the outputs are only determined by the inputs\n", |
| "that have no side effects (e.g. mutation). vmap is unable to handle mutation of\n", |
| "arbitrary Python data structures, but it is able to handle many in-place\n", |
| "PyTorch operations." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "0fa16650-9951-43a9-8db4-9543a0d597a2", |
| "metadata": {}, |
| "source": [ |
| "## Performance\n", |
| "Curious about performance numbers? Here's how the numbers look on my machine with an A100 GPU:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "id": "00449ce9-9dee-4371-8d1e-8d301214bd86", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fecb028ef10>\n", |
| "compute_sample_grads(data, targets)\n", |
| " 79.88 ms\n", |
| " 1 measurement, 500 runs , 1 thread\n", |
| "Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fecb028ef70>\n", |
| "ft_compute_sample_grad(params, buffers, data, targets)\n", |
| " 3.05 ms\n", |
| " 1 measurement, 500 runs , 1 thread\n" |
| ] |
| } |
| ], |
| "source": [ |
| "from torch.utils.benchmark import Timer\n", |
| "without_vmap = Timer(\n", |
| " stmt=\"compute_sample_grads(data, targets)\",\n", |
| " globals=globals())\n", |
| "with_vmap = Timer(\n", |
| " stmt=\"ft_compute_sample_grad(params, buffers, data, targets)\",\n", |
| " globals=globals())\n", |
| "print(f'Per-sample-grads without vmap {without_vmap.timeit(100)}')\n", |
| "print(f'Per-sample-grads with vmap {with_vmap.timeit(100)}')" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "1bb71dc6-fc12-4b27-839a-79bbcb8e76e6", |
| "metadata": {}, |
| "source": [ |
| "This may not be the fairest comparison because there are other optimized solutions to computing per-sample-gradients in PyTorch that perform much better than the naive method like in https://github.com/pytorch/opacus. But it's cool that we get the speedup on this example.\n", |
| "\n", |
| "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven't implemented the vmap rule for a particular operation or if the underlying kernels weren't optimized for older hardware. If you see any of these cases, please let us know by opening an issue at our GitHub!" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.8.5" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |