blob: 817a4c198099e37930a27788f5c17e8cc5d70d15 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "de1548fb-a313-4e9c-ae5d-8ec4c12ddd94",
"metadata": {},
"source": [
"# Model ensembling\n",
"\n",
"This example illustrates how to vectorize model ensembling using vmap.\n",
"\n",
"## What is model ensembling?\n",
"Model ensembling combines the predictions from multiple models together.\n",
"Traditionally this is done by running each model on some inputs separately\n",
"and then combining the predictions. However, if you're running models with\n",
"the same architecture, then it may be possible to combine them together\n",
"using `vmap`. `vmap` is a function transform that maps functions across\n",
"dimensions of the input tensors. One of its use cases is eliminating\n",
"for-loops and speeding them up through vectorization.\n",
"\n",
"Let's demonstrate how to do this using an ensemble of simple CNNs."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "38a958e7-47f1-463b-8ade-e55fc9832bf1",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from functools import partial\n",
"torch.manual_seed(0)\n",
"\n",
"# Here's a simple CNN\n",
"class SimpleCNN(nn.Module):\n",
" def __init__(self):\n",
" super(SimpleCNN, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
" self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
" self.fc1 = nn.Linear(9216, 128)\n",
" self.fc2 = nn.Linear(128, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = F.relu(x)\n",
" x = self.conv2(x)\n",
" x = F.relu(x)\n",
" x = F.max_pool2d(x, 2)\n",
" x = torch.flatten(x, 1)\n",
" x = self.fc1(x)\n",
" x = F.relu(x)\n",
" x = self.fc2(x)\n",
" output = F.log_softmax(x, dim=1)\n",
" output = x\n",
" return output"
]
},
{
"cell_type": "markdown",
"id": "b143b79b-9762-4198-b1f6-7c71d5a7df9f",
"metadata": {},
"source": [
"Let's generate some dummy data. Pretend that we're working with an MNIST dataset\n",
"where the images are 28 by 28.\n",
"Furthermore, let's say we wish to combine the predictions from 10 different\n",
"models."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cf025d68-00b0-4de4-88ca-d010c965ee99",
"metadata": {},
"outputs": [],
"source": [
"device = 'cuda'\n",
"num_models = 10\n",
"data = torch.randn(100, 64, 1, 28, 28, device=device)\n",
"targets = torch.randint(10, (6400,), device=device)\n",
"models = [SimpleCNN().to(device) for _ in range(num_models)]"
]
},
{
"cell_type": "markdown",
"id": "82923d20-b20e-482e-99a5-0af6ce2aa8d8",
"metadata": {},
"source": [
"We have a couple of options for generating predictions. Maybe we want\n",
"to give each model a different randomized minibatch of data, or maybe we\n",
"want to run the same minibatch of data through each model (e.g. if we were\n",
"testing the effect of different model initializations)."
]
},
{
"cell_type": "markdown",
"id": "2a16258d-5ab1-4c99-a741-88e4c77bb8c1",
"metadata": {},
"source": [
"Option 1: different minibatch for each model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f25c6f7d-4992-4b7e-a64a-52c101222a29",
"metadata": {},
"outputs": [],
"source": [
"minibatches = data[:num_models]\n",
"predictions1 = [model(minibatch) for model, minibatch in zip(models, minibatches)]"
]
},
{
"cell_type": "markdown",
"id": "35739f27-7d3b-4fe7-92ee-040d510d3656",
"metadata": {},
"source": [
"Option 2: Same minibatch"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f1ac17df-9304-4398-8045-c68afea8eb71",
"metadata": {},
"outputs": [],
"source": [
"minibatch = data[0]\n",
"predictions2 = [model(minibatch) for model in models]"
]
},
{
"cell_type": "markdown",
"id": "d2b978b6-86b9-4c5f-a8c4-a9cdb490ff49",
"metadata": {},
"source": [
"## Using vmap to vectorize the ensemble\n",
"\n",
"Let's use `vmap` to speed up the for-loop. We must first prepare the models\n",
"for use with `vmap`.\n",
"\n",
"First, let's combine the states of the model together by stacking each parameter.\n",
"For example, `model[i].fc1.weight` has shape `[9216, 128]`; we are going to stack the\n",
"`.fc1.weight` of each of the 10 models to produce a big weight of shape `[10, 9216, 128]`.\n",
"\n",
"functorch offers the following convenience function to do that. It returns a\n",
"stateless version of the model (fmodel) and stacked parameters and buffers."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e0f7476f-bb13-4394-a087-9f33fbe9b41f",
"metadata": {},
"outputs": [],
"source": [
"from functorch import combine_state_for_ensemble\n",
"fmodel, params, buffers = combine_state_for_ensemble(models)\n",
"[p.requires_grad_() for p in params]\n",
"pass"
]
},
{
"cell_type": "markdown",
"id": "9cdd7f97-f583-4e6f-a756-7378251d7bed",
"metadata": {},
"source": [
"Option 1: get predictions using a different minibatch for each model.\n",
"By default, vmap maps a function across the first dimension of all inputs to the\n",
"passed-in function. After `combine_state_for_ensemble`, each of of `params`,\n",
"`buffers` have an additional dimension of size `num_models` at the front;\n",
"and `minibatches` has a dimension of size `num_models`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bbece999-78bc-43d3-89e3-50cb961fbe70",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[10, 10, 10, 10, 10, 10, 10, 10]\n"
]
}
],
"source": [
"print([p.size(0) for p in params])\n",
"assert minibatches.shape == (num_models, 64, 1, 28, 28)\n",
"from functorch import vmap\n",
"predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)\n",
"assert torch.allclose(predictions1_vmap, torch.stack(predictions1), atol=1e-3, rtol=1e-5)"
]
},
{
"cell_type": "markdown",
"id": "095eb8a9-6666-44df-a85f-66bb481d3460",
"metadata": {},
"source": [
"Option 2: get predictions using the same minibatch of data\n",
"vmap has an in_dims arg that specify which dimensions to map over.\n",
"Using `None`, we tell vmap we want the same minibatch to apply for all of\n",
"the 10 models."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "26befbfc-021a-4d77-bfff-a980257dbd08",
"metadata": {},
"outputs": [],
"source": [
"predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)\n",
"assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)"
]
},
{
"cell_type": "markdown",
"id": "5572d906-d0b5-4d1d-b601-5ee4922422e6",
"metadata": {},
"source": [
"A quick note: there are limitations around what types of functions can be\n",
"transformed by vmap. The best functions to transform are ones that are\n",
"pure functions: a function where the outputs are only determined by the inputs\n",
"that have no side effects (e.g. mutation). vmap is unable to handle mutation of\n",
"arbitrary Python data structures, but it is able to handle many in-place\n",
"PyTorch operations."
]
},
{
"cell_type": "markdown",
"id": "c250b753-427b-49ad-a1ef-9764abffb36f",
"metadata": {},
"source": [
"## Performance\n",
"\n",
"Curious about performance numbers? Here's how the numbers look on my machine with an A100 GPU:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fef2b3bc-230a-436c-bb0f-29084bce80a1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f3c8c061e80>\n",
"[model(minibatch) for model, minibatch in zip(models, minibatches)]\n",
" 5.91 ms\n",
" 1 measurement, 500 runs , 1 thread\n",
"Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f3c860b5520>\n",
"vmap(fmodel)(params, buffers, minibatches)\n",
" 2.07 ms\n",
" 1 measurement, 500 runs , 1 thread\n"
]
}
],
"source": [
"from torch.utils.benchmark import Timer\n",
"without_vmap = Timer(\n",
" stmt=\"[model(minibatch) for model, minibatch in zip(models, minibatches)]\",\n",
" globals=globals())\n",
"with_vmap = Timer(\n",
" stmt=\"vmap(fmodel)(params, buffers, minibatches)\",\n",
" globals=globals())\n",
"print(f'Predictions without vmap {without_vmap.timeit(500)}')\n",
"print(f'Predictions with vmap {with_vmap.timeit(500)}')"
]
},
{
"cell_type": "markdown",
"id": "3361a73b-a0e9-41b3-ba60-a54d9b8efa44",
"metadata": {},
"source": [
"In general, vectorization with vmap should be faster than running a function in a for-loop. There are some exceptions though, like if we haven't implemented the vmap rule for a particular operation or if the underlying kernels weren't optimized for older hardware. If you see any of these cases, please let us know by opening an issue at our GitHub!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}