| Patching Batch Norm | 
 | =================== | 
 |  | 
 | What's happening? | 
 | ----------------- | 
 | Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. | 
 | Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. | 
 | ``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module, | 
 | we end up with this error | 
 |  | 
 | How to fix | 
 | ---------- | 
 | One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this | 
 |  | 
 | All of these options assume that you don't need running stats. If you're using a module this means | 
 | that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves | 
 | running batch norm with vmap in evaluation mode, please file an issue | 
 |  | 
 | Option 1: Change the BatchNorm | 
 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | 
 | If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with: | 
 |  | 
 | .. code-block:: python | 
 |  | 
 |     BatchNorm2d(C, G, track_running_stats=False) | 
 |  | 
 | Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to | 
 | break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning | 
 | each channel will be treated separately. | 
 |  | 
 | If you must use BatchNorm and you've built the module yourself, you can change the module to | 
 | not use running stats. In other words, anywhere that there's a BatchNorm module, set the | 
 | ``track_running_stats`` flag to be False | 
 |  | 
 | .. code-block:: python | 
 |  | 
 |     BatchNorm2d(64, track_running_stats=False) | 
 |  | 
 |  | 
 | Option 2: torchvision parameter | 
 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | 
 | Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are | 
 | often defaulted to be BatchNorm2d if they've been defaulted. | 
 |  | 
 | Instead you can set it to be GroupNorm. | 
 |  | 
 | .. code-block:: python | 
 |  | 
 |     import torchvision | 
 |     from functools import partial | 
 |     torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c)) | 
 |  | 
 | Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``. | 
 |  | 
 | If you are attached to BatchNorm, be sure to use a version that doesn't use running stats | 
 |  | 
 | .. code-block:: python | 
 |  | 
 |     import torchvision | 
 |     from functools import partial | 
 |     torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False)) | 
 |  | 
 | Option 3: functorch's patching | 
 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | 
 | functorch has added some functionality to allow for quick, in-place patching of the module to not | 
 | use running stats. Changing the norm layer is more fragile, so we have not offered that. If you | 
 | have a net where you want the BatchNorm to not use running stats, you can run | 
 | ``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats | 
 |  | 
 | .. code-block:: python | 
 |  | 
 |     from torch.func import replace_all_batch_norm_modules_ | 
 |     replace_all_batch_norm_modules_(net) | 
 |  | 
 | Option 4: eval mode | 
 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | 
 | When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode | 
 |  | 
 | .. code-block:: python | 
 |  | 
 |     model.eval() | 
 |     vmap(model)(x) | 
 |     model.train() |