| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "de1548fb-a313-4e9c-ae5d-8ec4c12ddd94", |
| "metadata": {}, |
| "source": [ |
| "# Model ensembling\n", |
| "\n", |
| "This example illustrates how to vectorize model ensembling using vmap.\n", |
| "\n", |
| "## What is model ensembling?\n", |
| "Model ensembling combines the predictions from multiple models together.\n", |
| "Traditionally this is done by running each model on some inputs separately\n", |
| "and then combining the predictions. However, if you're running models with\n", |
| "the same architecture, then it may be possible to combine them together\n", |
| "using `vmap`. `vmap` is a function transform that maps functions across\n", |
| "dimensions of the input tensors. One of its use cases is eliminating\n", |
| "for-loops and speeding them up through vectorization.\n", |
| "\n", |
| "Let's demonstrate how to do this using an ensemble of simple CNNs." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "id": "38a958e7-47f1-463b-8ade-e55fc9832bf1", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import torch\n", |
| "import torch.nn as nn\n", |
| "import torch.nn.functional as F\n", |
| "from functools import partial\n", |
| "torch.manual_seed(0)\n", |
| "\n", |
| "# Here's a simple CNN\n", |
| "class SimpleCNN(nn.Module):\n", |
| " def __init__(self):\n", |
| " super(SimpleCNN, self).__init__()\n", |
| " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", |
| " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", |
| " self.fc1 = nn.Linear(9216, 128)\n", |
| " self.fc2 = nn.Linear(128, 10)\n", |
| "\n", |
| " def forward(self, x):\n", |
| " x = self.conv1(x)\n", |
| " x = F.relu(x)\n", |
| " x = self.conv2(x)\n", |
| " x = F.relu(x)\n", |
| " x = F.max_pool2d(x, 2)\n", |
| " x = torch.flatten(x, 1)\n", |
| " x = self.fc1(x)\n", |
| " x = F.relu(x)\n", |
| " x = self.fc2(x)\n", |
| " output = F.log_softmax(x, dim=1)\n", |
| " output = x\n", |
| " return output" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "b143b79b-9762-4198-b1f6-7c71d5a7df9f", |
| "metadata": {}, |
| "source": [ |
| "Let's generate some dummy data. Pretend that we're working with an MNIST dataset\n", |
| "where the images are 28 by 28.\n", |
| "Furthermore, let's say we wish to combine the predictions from 10 different\n", |
| "models." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "id": "cf025d68-00b0-4de4-88ca-d010c965ee99", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "device = 'cuda'\n", |
| "num_models = 10\n", |
| "data = torch.randn(100, 64, 1, 28, 28, device=device)\n", |
| "targets = torch.randint(10, (6400,), device=device)\n", |
| "models = [SimpleCNN().to(device) for _ in range(num_models)]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "82923d20-b20e-482e-99a5-0af6ce2aa8d8", |
| "metadata": {}, |
| "source": [ |
| "We have a couple of options for generating predictions. Maybe we want\n", |
| "to give each model a different randomized minibatch of data, or maybe we\n", |
| "want to run the same minibatch of data through each model (e.g. if we were\n", |
| "testing the effect of different model initializations)." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "2a16258d-5ab1-4c99-a741-88e4c77bb8c1", |
| "metadata": {}, |
| "source": [ |
| "Option 1: different minibatch for each model" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "f25c6f7d-4992-4b7e-a64a-52c101222a29", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "minibatches = data[:num_models]\n", |
| "predictions1 = [model(minibatch) for model, minibatch in zip(models, minibatches)]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "35739f27-7d3b-4fe7-92ee-040d510d3656", |
| "metadata": {}, |
| "source": [ |
| "Option 2: Same minibatch" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "id": "f1ac17df-9304-4398-8045-c68afea8eb71", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "minibatch = data[0]\n", |
| "predictions2 = [model(minibatch) for model in models]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d2b978b6-86b9-4c5f-a8c4-a9cdb490ff49", |
| "metadata": {}, |
| "source": [ |
| "## Using vmap to vectorize the ensemble\n", |
| "\n", |
| "Let's use `vmap` to speed up the for-loop. We must first prepare the models\n", |
| "for use with `vmap`.\n", |
| "\n", |
| "First, let's combine the states of the model together by stacking each parameter.\n", |
| "For example, `model[i].fc1.weight` has shape `[9216, 128]`; we are going to stack the\n", |
| "`.fc1.weight` of each of the 10 models to produce a big weight of shape `[10, 9216, 128]`.\n", |
| "\n", |
| "functorch offers the following convenience function to do that. It returns a\n", |
| "stateless version of the model (fmodel) and stacked parameters and buffers." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "e0f7476f-bb13-4394-a087-9f33fbe9b41f", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from functorch import combine_state_for_ensemble\n", |
| "fmodel, params, buffers = combine_state_for_ensemble(models)\n", |
| "[p.requires_grad_() for p in params]\n", |
| "pass" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "9cdd7f97-f583-4e6f-a756-7378251d7bed", |
| "metadata": {}, |
| "source": [ |
| "Option 1: get predictions using a different minibatch for each model.\n", |
| "By default, vmap maps a function across the first dimension of all inputs to the\n", |
| "passed-in function. After `combine_state_for_ensemble`, each of of `params`,\n", |
| "`buffers` have an additional dimension of size `num_models` at the front;\n", |
| "and `minibatches` has a dimension of size `num_models`." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "id": "bbece999-78bc-43d3-89e3-50cb961fbe70", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "[10, 10, 10, 10, 10, 10, 10, 10]\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print([p.size(0) for p in params])\n", |
| "assert minibatches.shape == (num_models, 64, 1, 28, 28)\n", |
| "from functorch import vmap\n", |
| "predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)\n", |
| "assert torch.allclose(predictions1_vmap, torch.stack(predictions1), atol=1e-3, rtol=1e-5)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "095eb8a9-6666-44df-a85f-66bb481d3460", |
| "metadata": {}, |
| "source": [ |
| "Option 2: get predictions using the same minibatch of data\n", |
| "vmap has an in_dims arg that specify which dimensions to map over.\n", |
| "Using `None`, we tell vmap we want the same minibatch to apply for all of\n", |
| "the 10 models." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "26befbfc-021a-4d77-bfff-a980257dbd08", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)\n", |
| "assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "5572d906-d0b5-4d1d-b601-5ee4922422e6", |
| "metadata": {}, |
| "source": [ |
| "A quick note: there are limitations around what types of functions can be\n", |
| "transformed by vmap. The best functions to transform are ones that are\n", |
| "pure functions: a function where the outputs are only determined by the inputs\n", |
| "that have no side effects (e.g. mutation). vmap is unable to handle mutation of\n", |
| "arbitrary Python data structures, but it is able to handle many in-place\n", |
| "PyTorch operations." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "c250b753-427b-49ad-a1ef-9764abffb36f", |
| "metadata": {}, |
| "source": [ |
| "## Performance\n", |
| "\n", |
| "Curious about performance numbers? Here's how the numbers look on my machine with an A100 GPU:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "id": "fef2b3bc-230a-436c-bb0f-29084bce80a1", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f3c8c061e80>\n", |
| "[model(minibatch) for model, minibatch in zip(models, minibatches)]\n", |
| " 5.91 ms\n", |
| " 1 measurement, 500 runs , 1 thread\n", |
| "Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f3c860b5520>\n", |
| "vmap(fmodel)(params, buffers, minibatches)\n", |
| " 2.07 ms\n", |
| " 1 measurement, 500 runs , 1 thread\n" |
| ] |
| } |
| ], |
| "source": [ |
| "from torch.utils.benchmark import Timer\n", |
| "without_vmap = Timer(\n", |
| " stmt=\"[model(minibatch) for model, minibatch in zip(models, minibatches)]\",\n", |
| " globals=globals())\n", |
| "with_vmap = Timer(\n", |
| " stmt=\"vmap(fmodel)(params, buffers, minibatches)\",\n", |
| " globals=globals())\n", |
| "print(f'Predictions without vmap {without_vmap.timeit(500)}')\n", |
| "print(f'Predictions with vmap {with_vmap.timeit(500)}')" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "3361a73b-a0e9-41b3-ba60-a54d9b8efa44", |
| "metadata": {}, |
| "source": [ |
| "In general, vectorization with vmap should be faster than running a function in a for-loop. There are some exceptions though, like if we haven't implemented the vmap rule for a particular operation or if the underlying kernels weren't optimized for older hardware. If you see any of these cases, please let us know by opening an issue at our GitHub!" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.8.5" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |