[functorch] a lot of files
diff --git a/functorch/.gitignore b/functorch/.gitignore
new file mode 100644
index 0000000..e38e12d
--- /dev/null
+++ b/functorch/.gitignore
@@ -0,0 +1,4 @@
+build/
+dist/
+functorch.egg-info/
+*__pycache__*
diff --git a/functorch/examples/.gitignore b/functorch/examples/.gitignore
new file mode 100644
index 0000000..ca86b70
--- /dev/null
+++ b/functorch/examples/.gitignore
@@ -0,0 +1 @@
+cifar10/
diff --git a/functorch/examples/dp_cifar10/.gdbinit b/functorch/examples/dp_cifar10/.gdbinit
new file mode 100644
index 0000000..d11fc06
--- /dev/null
+++ b/functorch/examples/dp_cifar10/.gdbinit
@@ -0,0 +1,2 @@
+catch throw
+r cifar10_transforms.py
diff --git a/functorch/examples/dp_cifar10/cifar10_expandweights.py b/functorch/examples/dp_cifar10/cifar10_expandweights.py
new file mode 100644
index 0000000..2a33e49
--- /dev/null
+++ b/functorch/examples/dp_cifar10/cifar10_expandweights.py
@@ -0,0 +1,436 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+"""
+Runs CIFAR10 training with differential privacy.
+"""
+
+import argparse
+import os
+import shutil
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torch.utils.tensorboard as tensorboard
+import torchvision.models as models
+import torchvision.transforms as transforms
+from opacus import PrivacyEngine
+from opacus.utils import stats
+from opacus.utils.module_modification import convert_batchnorm_modules
+from torchvision.datasets import CIFAR10
+from tqdm import tqdm
+
+# This is based off of zou3519/pytorch:expand_weights.
+
+def save_checkpoint(state, is_best, filename="checkpoint.tar"):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, "model_best.pth.tar")
+
+
+def accuracy(preds, labels):
+    return (preds == labels).mean()
+
+
+def compute_norms(sample_grads):
+    batch_size = sample_grads[0].shape[0]
+    norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
+    norms = torch.stack(norms, dim=0).norm(2, dim=0)
+    return norms
+
+def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
+    sample_grads = tuple(param.grad_sample for param in model.parameters())
+
+    # step 0: compute the norms
+    sample_norms = compute_norms(sample_grads)
+
+    # step 1: compute clipping factors
+    clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
+    clip_factor = clip_factor.clamp(max=1.0)
+
+    # step 2: clip
+    grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
+                  for sample_grad in sample_grads)
+
+    # step 3: add gaussian noise
+    stddev = max_per_sample_grad_norm * noise_multiplier
+    noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
+                   for grad_param in grads)
+    grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
+
+    # step 4: assign the new grads, delete the sample grads
+    for param, param_grad in zip(model.parameters(), grads):
+        param.grad = param_grad
+        del param.grad_sample
+
+def train(args, model, train_loader, optimizer, epoch, device):
+    model.train()
+    criterion = nn.CrossEntropyLoss()
+
+    losses = []
+    top1_acc = []
+
+    for i, (images, target) in enumerate(tqdm(train_loader)):
+
+        images = images.to(device)
+        target = target.to(device)
+
+        # Step 1: compute per-sample-grads (provided by pytorch)
+        # loss.backward() populates the grad_sample attribute of each param
+        with model.compute_per_sample_grads(batch_size=images.shape[0]):
+            output = model(images)
+            loss = criterion(output, target)
+            loss.backward()
+
+        # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
+        # Opacus implements this but I wrote a custom one to show how this would work.
+        # This deletes the grad_sample attributes and populates the grad attributes
+        clip_and_accumulate_and_add_noise(
+            model, args.max_per_sample_grad_norm, args.sigma)
+
+        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+        labels = target.detach().cpu().numpy()
+        losses.append(loss.item())
+
+        # measure accuracy and record loss
+        acc1 = accuracy(preds, labels)
+
+        top1_acc.append(acc1)
+        stats.update(stats.StatType.TRAIN, acc1=acc1)
+
+        # make sure we take a step after processing the last mini-batch in the
+        # epoch to ensure we start the next epoch with a clean state
+        if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
+            optimizer.step()
+            optimizer.zero_grad()
+        else:
+            optimizer.virtual_step()
+
+        if i % args.print_freq == 0:
+            print(
+                f"\tTrain Epoch: {epoch} \t"
+                f"Loss: {np.mean(losses):.6f} "
+                f"Acc@1: {np.mean(top1_acc):.6f} "
+            )
+
+def test(args, model, test_loader, device):
+    model.eval()
+    criterion = nn.CrossEntropyLoss()
+    losses = []
+    top1_acc = []
+
+    with torch.no_grad():
+        for images, target in tqdm(test_loader):
+            images = images.to(device)
+            target = target.to(device)
+
+            output = model(images)
+            loss = criterion(output, target)
+            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+            labels = target.detach().cpu().numpy()
+            acc1 = accuracy(preds, labels)
+
+            losses.append(loss.item())
+            top1_acc.append(acc1)
+
+    top1_avg = np.mean(top1_acc)
+    stats.update(stats.StatType.TEST, acc1=top1_avg)
+
+    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
+    return np.mean(top1_acc)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
+    parser.add_argument(
+        "-j",
+        "--workers",
+        default=2,
+        type=int,
+        metavar="N",
+        help="number of data loading workers (default: 2)",
+    )
+    parser.add_argument(
+        "--epochs",
+        default=90,
+        type=int,
+        metavar="N",
+        help="number of total epochs to run",
+    )
+    parser.add_argument(
+        "--start-epoch",
+        default=1,
+        type=int,
+        metavar="N",
+        help="manual epoch number (useful on restarts)",
+    )
+    parser.add_argument(
+        "-b",
+        "--batch-size",
+        # This should be 256, but that OOMs using the prototype.
+        default=64,
+        type=int,
+        metavar="N",
+        help="mini-batch size (default: 64), this is the total "
+        "batch size of all GPUs on the current node when "
+        "using Data Parallel or Distributed Data Parallel",
+    )
+    parser.add_argument(
+        "-na",
+        "--n_accumulation_steps",
+        default=1,
+        type=int,
+        metavar="N",
+        help="number of mini-batches to accumulate into an effective batch",
+    )
+    parser.add_argument(
+        "--lr",
+        "--learning-rate",
+        default=0.001,
+        type=float,
+        metavar="LR",
+        help="initial learning rate",
+        dest="lr",
+    )
+    parser.add_argument(
+        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
+    )
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=5e-4,
+        type=float,
+        metavar="W",
+        help="SGD weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "-p",
+        "--print-freq",
+        default=10,
+        type=int,
+        metavar="N",
+        help="print frequency (default: 10)",
+    )
+    parser.add_argument(
+        "--resume",
+        default="",
+        type=str,
+        metavar="PATH",
+        help="path to latest checkpoint (default: none)",
+    )
+    parser.add_argument(
+        "-e",
+        "--evaluate",
+        dest="evaluate",
+        action="store_true",
+        help="evaluate model on validation set",
+    )
+    parser.add_argument(
+        "--seed", default=None, type=int, help="seed for initializing training. "
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help="GPU ID for this process (default: 'cuda')",
+    )
+    parser.add_argument(
+        "--sigma",
+        type=float,
+        default=1.0,
+        metavar="S",
+        help="Noise multiplier (default 1.0)",
+    )
+    parser.add_argument(
+        "-c",
+        "--max-per-sample-grad_norm",
+        type=float,
+        default=1.0,
+        metavar="C",
+        help="Clip per-sample gradients to this norm (default 1.0)",
+    )
+    parser.add_argument(
+        "--disable-dp",
+        action="store_true",
+        default=False,
+        help="Disable privacy training and just train with vanilla SGD",
+    )
+    parser.add_argument(
+        "--secure-rng",
+        action="store_true",
+        default=False,
+        help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
+    )
+    parser.add_argument(
+        "--delta",
+        type=float,
+        default=1e-5,
+        metavar="D",
+        help="Target delta (default: 1e-5)",
+    )
+
+    parser.add_argument(
+        "--checkpoint-file",
+        type=str,
+        default="checkpoint",
+        help="path to save check points",
+    )
+    parser.add_argument(
+        "--data-root",
+        type=str,
+        default="../cifar10",
+        help="Where CIFAR10 is/will be stored",
+    )
+    parser.add_argument(
+        "--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
+    )
+    parser.add_argument(
+        "--optim",
+        type=str,
+        default="Adam",
+        help="Optimizer to use (Adam, RMSprop, SGD)",
+    )
+
+    args = parser.parse_args()
+    args.disable_dp = True
+
+    if args.disable_dp and args.n_accumulation_steps > 1:
+        raise ValueError("Virtual steps only works with enabled DP")
+
+    # The following few lines, enable stats gathering about the run
+    # 1. where the stats should be logged
+    stats.set_global_summary_writer(
+        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
+    )
+    # 2. enable stats
+    stats.add(
+        # stats about gradient norms aggregated for all layers
+        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
+        # stats about gradient norms per layer
+        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
+        # stats about clipping
+        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
+        # stats on training accuracy
+        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
+        # stats on validation accuracy
+        stats.Stat(stats.StatType.TEST, "accuracy"),
+    )
+
+    # The following lines enable stat gathering for the clipping process
+    # and set a default of per layer clipping for the Privacy Engine
+    clipping = {"clip_per_layer": False, "enable_stat": True}
+
+    if args.secure_rng:
+        assert False
+        try:
+            import torchcsprng as prng
+        except ImportError as e:
+            msg = (
+                "To use secure RNG, you must install the torchcsprng package! "
+                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
+            )
+            raise ImportError(msg) from e
+
+        generator = prng.create_random_device_generator("/dev/urandom")
+
+    else:
+        generator = None
+
+    augmentations = [
+        transforms.RandomCrop(32, padding=4),
+        transforms.RandomHorizontalFlip(),
+    ]
+    normalize = [
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ]
+    train_transform = transforms.Compose(
+        augmentations + normalize if args.disable_dp else normalize
+    )
+
+    test_transform = transforms.Compose(normalize)
+
+    train_dataset = CIFAR10(
+        root=args.data_root, train=True, download=True, transform=train_transform
+    )
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        drop_last=True,
+        generator=generator,
+    )
+
+    test_dataset = CIFAR10(
+        root=args.data_root, train=False, download=True, transform=test_transform
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+    )
+
+    best_acc1 = 0
+    device = torch.device(args.device)
+    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
+    # model = CIFAR10Model()
+    model = model.to(device)
+
+    if args.optim == "SGD":
+        optimizer = optim.SGD(
+            model.parameters(),
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+        )
+    elif args.optim == "RMSprop":
+        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
+    elif args.optim == "Adam":
+        optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    else:
+        raise NotImplementedError("Optimizer not recognized. Please check spelling")
+
+    if not args.disable_dp:
+        privacy_engine = PrivacyEngine(
+            model,
+            batch_size=args.batch_size * args.n_accumulation_steps,
+            sample_size=len(train_dataset),
+            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
+            noise_multiplier=args.sigma,
+            max_grad_norm=args.max_per_sample_grad_norm,
+            secure_rng=args.secure_rng,
+            **clipping,
+        )
+        privacy_engine.attach(optimizer)
+
+    for epoch in range(args.start_epoch, args.epochs + 1):
+        train(args, model, train_loader, optimizer, epoch, device)
+        top1_acc = test(args, model, test_loader, device)
+
+        # remember best acc@1 and save checkpoint
+        is_best = top1_acc > best_acc1
+        best_acc1 = max(top1_acc, best_acc1)
+
+        save_checkpoint(
+            {
+                "epoch": epoch + 1,
+                "arch": "ResNet18",
+                "state_dict": model.state_dict(),
+                "best_acc1": best_acc1,
+                "optimizer": optimizer.state_dict(),
+            },
+            is_best,
+            filename=args.checkpoint_file + ".tar",
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/functorch/examples/dp_cifar10/cifar10_opacus.py b/functorch/examples/dp_cifar10/cifar10_opacus.py
new file mode 100644
index 0000000..43e917c
--- /dev/null
+++ b/functorch/examples/dp_cifar10/cifar10_opacus.py
@@ -0,0 +1,405 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+"""
+Runs CIFAR10 training with differential privacy.
+"""
+
+import argparse
+import os
+import shutil
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torch.utils.tensorboard as tensorboard
+import torchvision.models as models
+import torchvision.transforms as transforms
+from opacus import PrivacyEngine
+from opacus.utils import stats
+from opacus.utils.module_modification import convert_batchnorm_modules
+from torchvision.datasets import CIFAR10
+from tqdm import tqdm
+
+
+def save_checkpoint(state, is_best, filename="checkpoint.tar"):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, "model_best.pth.tar")
+
+
+def accuracy(preds, labels):
+    return (preds == labels).mean()
+
+
+def train(args, model, train_loader, optimizer, epoch, device):
+    model.train()
+    criterion = nn.CrossEntropyLoss()
+
+    losses = []
+    top1_acc = []
+
+    for i, (images, target) in enumerate(tqdm(train_loader)):
+
+        images = images.to(device)
+        target = target.to(device)
+
+        # compute output
+        output = model(images)
+        loss = criterion(output, target)
+        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+        labels = target.detach().cpu().numpy()
+
+        # measure accuracy and record loss
+        acc1 = accuracy(preds, labels)
+
+        losses.append(loss.item())
+        top1_acc.append(acc1)
+        stats.update(stats.StatType.TRAIN, acc1=acc1)
+
+        # compute gradient and do SGD step
+        loss.backward()
+
+        # make sure we take a step after processing the last mini-batch in the
+        # epoch to ensure we start the next epoch with a clean state
+        if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
+            optimizer.step()
+            optimizer.zero_grad()
+        else:
+            optimizer.virtual_step()
+
+        if i % args.print_freq == 0:
+            if not args.disable_dp:
+                epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
+                    args.delta
+                )
+                print(
+                    f"\tTrain Epoch: {epoch} \t"
+                    f"Loss: {np.mean(losses):.6f} "
+                    f"Acc@1: {np.mean(top1_acc):.6f} "
+                    f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
+                )
+            else:
+                print(
+                    f"\tTrain Epoch: {epoch} \t"
+                    f"Loss: {np.mean(losses):.6f} "
+                    f"Acc@1: {np.mean(top1_acc):.6f} "
+                )
+
+
+def test(args, model, test_loader, device):
+    model.eval()
+    criterion = nn.CrossEntropyLoss()
+    losses = []
+    top1_acc = []
+
+    with torch.no_grad():
+        for images, target in tqdm(test_loader):
+            images = images.to(device)
+            target = target.to(device)
+
+            output = model(images)
+            loss = criterion(output, target)
+            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+            labels = target.detach().cpu().numpy()
+            acc1 = accuracy(preds, labels)
+
+            losses.append(loss.item())
+            top1_acc.append(acc1)
+
+    top1_avg = np.mean(top1_acc)
+    stats.update(stats.StatType.TEST, acc1=top1_avg)
+
+    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
+    return np.mean(top1_acc)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
+    parser.add_argument(
+        "-j",
+        "--workers",
+        default=2,
+        type=int,
+        metavar="N",
+        help="number of data loading workers (default: 2)",
+    )
+    parser.add_argument(
+        "--epochs",
+        default=90,
+        type=int,
+        metavar="N",
+        help="number of total epochs to run",
+    )
+    parser.add_argument(
+        "--start-epoch",
+        default=1,
+        type=int,
+        metavar="N",
+        help="manual epoch number (useful on restarts)",
+    )
+    parser.add_argument(
+        "-b",
+        "--batch-size",
+        default=256,
+        type=int,
+        metavar="N",
+        help="mini-batch size (default: 256), this is the total "
+        "batch size of all GPUs on the current node when "
+        "using Data Parallel or Distributed Data Parallel",
+    )
+    parser.add_argument(
+        "-na",
+        "--n_accumulation_steps",
+        default=1,
+        type=int,
+        metavar="N",
+        help="number of mini-batches to accumulate into an effective batch",
+    )
+    parser.add_argument(
+        "--lr",
+        "--learning-rate",
+        default=0.001,
+        type=float,
+        metavar="LR",
+        help="initial learning rate",
+        dest="lr",
+    )
+    parser.add_argument(
+        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
+    )
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=5e-4,
+        type=float,
+        metavar="W",
+        help="SGD weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "-p",
+        "--print-freq",
+        default=10,
+        type=int,
+        metavar="N",
+        help="print frequency (default: 10)",
+    )
+    parser.add_argument(
+        "--resume",
+        default="",
+        type=str,
+        metavar="PATH",
+        help="path to latest checkpoint (default: none)",
+    )
+    parser.add_argument(
+        "-e",
+        "--evaluate",
+        dest="evaluate",
+        action="store_true",
+        help="evaluate model on validation set",
+    )
+    parser.add_argument(
+        "--seed", default=None, type=int, help="seed for initializing training. "
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help="GPU ID for this process (default: 'cuda')",
+    )
+    parser.add_argument(
+        "--sigma",
+        type=float,
+        default=1.0,
+        metavar="S",
+        help="Noise multiplier (default 1.0)",
+    )
+    parser.add_argument(
+        "-c",
+        "--max-per-sample-grad_norm",
+        type=float,
+        default=1.0,
+        metavar="C",
+        help="Clip per-sample gradients to this norm (default 1.0)",
+    )
+    parser.add_argument(
+        "--disable-dp",
+        action="store_true",
+        default=False,
+        help="Disable privacy training and just train with vanilla SGD",
+    )
+    parser.add_argument(
+        "--secure-rng",
+        action="store_true",
+        default=False,
+        help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
+    )
+    parser.add_argument(
+        "--delta",
+        type=float,
+        default=1e-5,
+        metavar="D",
+        help="Target delta (default: 1e-5)",
+    )
+
+    parser.add_argument(
+        "--checkpoint-file",
+        type=str,
+        default="checkpoint",
+        help="path to save check points",
+    )
+    parser.add_argument(
+        "--data-root",
+        type=str,
+        default="../cifar10",
+        help="Where CIFAR10 is/will be stored",
+    )
+    parser.add_argument(
+        "--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
+    )
+    parser.add_argument(
+        "--optim",
+        type=str,
+        default="Adam",
+        help="Optimizer to use (Adam, RMSprop, SGD)",
+    )
+
+    args = parser.parse_args()
+
+    if args.disable_dp and args.n_accumulation_steps > 1:
+        raise ValueError("Virtual steps only works with enabled DP")
+
+    # The following few lines, enable stats gathering about the run
+    # 1. where the stats should be logged
+    stats.set_global_summary_writer(
+        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
+    )
+    # 2. enable stats
+    stats.add(
+        # stats about gradient norms aggregated for all layers
+        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
+        # stats about gradient norms per layer
+        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
+        # stats about clipping
+        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
+        # stats on training accuracy
+        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
+        # stats on validation accuracy
+        stats.Stat(stats.StatType.TEST, "accuracy"),
+    )
+
+    # The following lines enable stat gathering for the clipping process
+    # and set a default of per layer clipping for the Privacy Engine
+    clipping = {"clip_per_layer": False, "enable_stat": True}
+
+    if args.secure_rng:
+        try:
+            import torchcsprng as prng
+        except ImportError as e:
+            msg = (
+                "To use secure RNG, you must install the torchcsprng package! "
+                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
+            )
+            raise ImportError(msg) from e
+
+        generator = prng.create_random_device_generator("/dev/urandom")
+
+    else:
+        generator = None
+
+    augmentations = [
+        transforms.RandomCrop(32, padding=4),
+        transforms.RandomHorizontalFlip(),
+    ]
+    normalize = [
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ]
+    train_transform = transforms.Compose(
+        augmentations + normalize if args.disable_dp else normalize
+    )
+
+    test_transform = transforms.Compose(normalize)
+
+    train_dataset = CIFAR10(
+        root=args.data_root, train=True, download=True, transform=train_transform
+    )
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        drop_last=True,
+        generator=generator,
+    )
+
+    test_dataset = CIFAR10(
+        root=args.data_root, train=False, download=True, transform=test_transform
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+    )
+
+    best_acc1 = 0
+    device = torch.device(args.device)
+    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
+    model = model.to(device)
+
+    if args.optim == "SGD":
+        optimizer = optim.SGD(
+            model.parameters(),
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+        )
+    elif args.optim == "RMSprop":
+        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
+    elif args.optim == "Adam":
+        optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    else:
+        raise NotImplementedError("Optimizer not recognized. Please check spelling")
+
+    if not args.disable_dp:
+        privacy_engine = PrivacyEngine(
+            model,
+            batch_size=args.batch_size * args.n_accumulation_steps,
+            sample_size=len(train_dataset),
+            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
+            noise_multiplier=args.sigma,
+            max_grad_norm=args.max_per_sample_grad_norm,
+            secure_rng=args.secure_rng,
+            **clipping,
+        )
+        privacy_engine.attach(optimizer)
+
+    for epoch in range(args.start_epoch, args.epochs + 1):
+        train(args, model, train_loader, optimizer, epoch, device)
+        top1_acc = test(args, model, test_loader, device)
+
+        # remember best acc@1 and save checkpoint
+        is_best = top1_acc > best_acc1
+        best_acc1 = max(top1_acc, best_acc1)
+
+        save_checkpoint(
+            {
+                "epoch": epoch + 1,
+                "arch": "ResNet18",
+                "state_dict": model.state_dict(),
+                "best_acc1": best_acc1,
+                "optimizer": optimizer.state_dict(),
+            },
+            is_best,
+            filename=args.checkpoint_file + ".tar",
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/functorch/examples/dp_cifar10/cifar10_transforms.py b/functorch/examples/dp_cifar10/cifar10_transforms.py
new file mode 100644
index 0000000..8ba40b7
--- /dev/null
+++ b/functorch/examples/dp_cifar10/cifar10_transforms.py
@@ -0,0 +1,475 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+"""
+Runs CIFAR10 training with differential privacy.
+"""
+
+import argparse
+import os
+import shutil
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torch.utils.tensorboard as tensorboard
+import torchvision.models as models
+import torchvision.transforms as transforms
+from opacus import PrivacyEngine
+from opacus.utils import stats
+from opacus.utils.module_modification import convert_batchnorm_modules
+from torchvision.datasets import CIFAR10
+from tqdm import tqdm
+
+from make_functional import make_functional, load_weights
+from torch.eager_transforms import vmap, grad_with_value
+from functools import partial
+# from resnet import resnet18
+
+def save_checkpoint(state, is_best, filename="checkpoint.tar"):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, "model_best.pth.tar")
+
+
+def accuracy(preds, labels):
+    return (preds == labels).mean()
+
+
+def compute_norms(sample_grads):
+    batch_size = sample_grads[0].shape[0]
+    norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
+    norms = torch.stack(norms, dim=0).norm(2, dim=0)
+    return norms
+
+def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
+    sample_grads = tuple(param.grad_sample for param in model.parameters())
+
+    # step 0: compute the norms
+    sample_norms = compute_norms(sample_grads)
+
+    # step 1: compute clipping factors
+    clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
+    clip_factor = clip_factor.clamp(max=1.0)
+
+    # step 2: clip
+    grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
+                  for sample_grad in sample_grads)
+
+    # step 3: add gaussian noise
+    stddev = max_per_sample_grad_norm * noise_multiplier
+    noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
+                   for grad_param in grads)
+    grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
+
+    # step 4: assign the new grads, delete the sample grads
+    for param, param_grad in zip(model.parameters(), grads):
+        param.grad = param_grad
+        del param.grad_sample
+
+def train(args, model, train_loader, optimizer, epoch, device):
+    model.train()
+    criterion = nn.CrossEntropyLoss()
+
+    losses = []
+    top1_acc = []
+
+    for i, (images, target) in enumerate(tqdm(train_loader)):
+
+        images = images.to(device)
+        target = target.to(device)
+
+        # Step 1: compute per-sample-grads
+
+        # In order to use functional vmap+grad, we need to be able to
+        # pass the weights to a model.
+        weights, func_model, descriptors = make_functional(model)
+
+        # To use vmap+grad to compute per-sample-grads, the forward pass
+        # must be re-formulated on a single example.
+        # We use the `grad` operator to compute forward+backward on a single example,
+        # and finally `vmap` to do forward+backward on multiple examples.
+        def compute_loss_and_output(weights, image, target):
+            images = image.unsqueeze(0)
+            targets = target.unsqueeze(0)
+            output = func_model(weights, (images,))
+            loss = criterion(output, targets)
+            return loss, output.squeeze(0)
+
+        # `grad(f)` is a functional API that returns a function `f'` that
+        # computes gradients by running both the forward and backward pass.
+        # We want to extract some intermediate
+        # values from the computation (i.e. the loss and output).
+        # 
+        # To extract the loss, we use the `grad_with_value` API, that returns the
+        # gradient of the weights w.r.t. the loss and the loss.
+        # 
+        # To extract the output, we use the `has_aux=True` flag.
+        # `has_aux=True` assumes that `f` returns a tuple of two values,
+        # where the first is to be differentiated and the second "auxiliary value"
+        # is not to be differentiated. `f'` returns the gradient w.r.t. the loss,
+        # the loss, and the auxiliary value.
+        grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
+        def packed(weights, images, target):
+            grads, loss, output = grads_loss_output(weights, images, target)
+            result = tuple([*grads, loss, output])
+            return result
+        result = vmap(partial(packed, weights))(images, target)
+        sample_grads, sample_loss, output, = result[:-2], result[-2], result[-1]
+        loss = sample_loss.mean()
+
+        # `load_weights` is the inverse operation of make_functional. We put
+        # things back into a model so that they're easier to manipulate
+        load_weights(model, descriptors, weights)
+        for grad_sample, weight in zip(sample_grads, model.parameters()):
+            weight.grad_sample = grad_sample.detach()
+
+        # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
+        grads = clip_and_accumulate_and_add_noise(
+            model, args.max_per_sample_grad_norm, args.sigma)
+
+        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+        labels = target.detach().cpu().numpy()
+        losses.append(loss.item())
+
+        # measure accuracy and record loss
+        acc1 = accuracy(preds, labels)
+
+        top1_acc.append(acc1)
+        stats.update(stats.StatType.TRAIN, acc1=acc1)
+
+        # make sure we take a step after processing the last mini-batch in the
+        # epoch to ensure we start the next epoch with a clean state
+        if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
+            optimizer.step()
+            optimizer.zero_grad()
+        else:
+            optimizer.virtual_step()
+
+        if i % args.print_freq == 0:
+            print(
+                f"\tTrain Epoch: {epoch} \t"
+                f"Loss: {np.mean(losses):.6f} "
+                f"Acc@1: {np.mean(top1_acc):.6f} "
+            )
+
+def test(args, model, test_loader, device):
+    model.eval()
+    criterion = nn.CrossEntropyLoss()
+    losses = []
+    top1_acc = []
+
+    with torch.no_grad():
+        for images, target in tqdm(test_loader):
+            images = images.to(device)
+            target = target.to(device)
+
+            output = model(images)
+            loss = criterion(output, target)
+            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+            labels = target.detach().cpu().numpy()
+            acc1 = accuracy(preds, labels)
+
+            losses.append(loss.item())
+            top1_acc.append(acc1)
+
+    top1_avg = np.mean(top1_acc)
+    stats.update(stats.StatType.TEST, acc1=top1_avg)
+
+    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
+    return np.mean(top1_acc)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
+    parser.add_argument(
+        "-j",
+        "--workers",
+        default=2,
+        type=int,
+        metavar="N",
+        help="number of data loading workers (default: 2)",
+    )
+    parser.add_argument(
+        "--epochs",
+        default=90,
+        type=int,
+        metavar="N",
+        help="number of total epochs to run",
+    )
+    parser.add_argument(
+        "--start-epoch",
+        default=1,
+        type=int,
+        metavar="N",
+        help="manual epoch number (useful on restarts)",
+    )
+    parser.add_argument(
+        "-b",
+        "--batch-size",
+        # This should be 256, but that OOMs using the prototype.
+        default=64,
+        type=int,
+        metavar="N",
+        help="mini-batch size (default: 64), this is the total "
+        "batch size of all GPUs on the current node when "
+        "using Data Parallel or Distributed Data Parallel",
+    )
+    parser.add_argument(
+        "-na",
+        "--n_accumulation_steps",
+        default=1,
+        type=int,
+        metavar="N",
+        help="number of mini-batches to accumulate into an effective batch",
+    )
+    parser.add_argument(
+        "--lr",
+        "--learning-rate",
+        default=0.001,
+        type=float,
+        metavar="LR",
+        help="initial learning rate",
+        dest="lr",
+    )
+    parser.add_argument(
+        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
+    )
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=5e-4,
+        type=float,
+        metavar="W",
+        help="SGD weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "-p",
+        "--print-freq",
+        default=10,
+        type=int,
+        metavar="N",
+        help="print frequency (default: 10)",
+    )
+    parser.add_argument(
+        "--resume",
+        default="",
+        type=str,
+        metavar="PATH",
+        help="path to latest checkpoint (default: none)",
+    )
+    parser.add_argument(
+        "-e",
+        "--evaluate",
+        dest="evaluate",
+        action="store_true",
+        help="evaluate model on validation set",
+    )
+    parser.add_argument(
+        "--seed", default=None, type=int, help="seed for initializing training. "
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help="GPU ID for this process (default: 'cuda')",
+    )
+    parser.add_argument(
+        "--sigma",
+        type=float,
+        default=1.0,
+        metavar="S",
+        help="Noise multiplier (default 1.0)",
+    )
+    parser.add_argument(
+        "-c",
+        "--max-per-sample-grad_norm",
+        type=float,
+        default=1.0,
+        metavar="C",
+        help="Clip per-sample gradients to this norm (default 1.0)",
+    )
+    parser.add_argument(
+        "--disable-dp",
+        action="store_true",
+        default=False,
+        help="Disable privacy training and just train with vanilla SGD",
+    )
+    parser.add_argument(
+        "--secure-rng",
+        action="store_true",
+        default=False,
+        help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
+    )
+    parser.add_argument(
+        "--delta",
+        type=float,
+        default=1e-5,
+        metavar="D",
+        help="Target delta (default: 1e-5)",
+    )
+
+    parser.add_argument(
+        "--checkpoint-file",
+        type=str,
+        default="checkpoint",
+        help="path to save check points",
+    )
+    parser.add_argument(
+        "--data-root",
+        type=str,
+        default="../cifar10",
+        help="Where CIFAR10 is/will be stored",
+    )
+    parser.add_argument(
+        "--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
+    )
+    parser.add_argument(
+        "--optim",
+        type=str,
+        default="Adam",
+        help="Optimizer to use (Adam, RMSprop, SGD)",
+    )
+
+    args = parser.parse_args()
+    args.disable_dp = True
+
+    if args.disable_dp and args.n_accumulation_steps > 1:
+        raise ValueError("Virtual steps only works with enabled DP")
+
+    # The following few lines, enable stats gathering about the run
+    # 1. where the stats should be logged
+    stats.set_global_summary_writer(
+        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
+    )
+    # 2. enable stats
+    stats.add(
+        # stats about gradient norms aggregated for all layers
+        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
+        # stats about gradient norms per layer
+        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
+        # stats about clipping
+        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
+        # stats on training accuracy
+        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
+        # stats on validation accuracy
+        stats.Stat(stats.StatType.TEST, "accuracy"),
+    )
+
+    # The following lines enable stat gathering for the clipping process
+    # and set a default of per layer clipping for the Privacy Engine
+    clipping = {"clip_per_layer": False, "enable_stat": True}
+
+    if args.secure_rng:
+        assert False
+        try:
+            import torchcsprng as prng
+        except ImportError as e:
+            msg = (
+                "To use secure RNG, you must install the torchcsprng package! "
+                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
+            )
+            raise ImportError(msg) from e
+
+        generator = prng.create_random_device_generator("/dev/urandom")
+
+    else:
+        generator = None
+
+    augmentations = [
+        transforms.RandomCrop(32, padding=4),
+        transforms.RandomHorizontalFlip(),
+    ]
+    normalize = [
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ]
+    train_transform = transforms.Compose(
+        augmentations + normalize if args.disable_dp else normalize
+    )
+
+    test_transform = transforms.Compose(normalize)
+
+    train_dataset = CIFAR10(
+        root=args.data_root, train=True, download=True, transform=train_transform
+    )
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        drop_last=True,
+        generator=generator,
+    )
+
+    test_dataset = CIFAR10(
+        root=args.data_root, train=False, download=True, transform=test_transform
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+    )
+
+    best_acc1 = 0
+    device = torch.device(args.device)
+    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
+    # model = CIFAR10Model()
+    model = model.to(device)
+
+    if args.optim == "SGD":
+        optimizer = optim.SGD(
+            model.parameters(),
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+        )
+    elif args.optim == "RMSprop":
+        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
+    elif args.optim == "Adam":
+        optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    else:
+        raise NotImplementedError("Optimizer not recognized. Please check spelling")
+
+    if not args.disable_dp:
+        privacy_engine = PrivacyEngine(
+            model,
+            batch_size=args.batch_size * args.n_accumulation_steps,
+            sample_size=len(train_dataset),
+            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
+            noise_multiplier=args.sigma,
+            max_grad_norm=args.max_per_sample_grad_norm,
+            secure_rng=args.secure_rng,
+            **clipping,
+        )
+        privacy_engine.attach(optimizer)
+
+    for epoch in range(args.start_epoch, args.epochs + 1):
+        train(args, model, train_loader, optimizer, epoch, device)
+        top1_acc = test(args, model, test_loader, device)
+
+        # remember best acc@1 and save checkpoint
+        is_best = top1_acc > best_acc1
+        best_acc1 = max(top1_acc, best_acc1)
+
+        save_checkpoint(
+            {
+                "epoch": epoch + 1,
+                "arch": "ResNet18",
+                "state_dict": model.state_dict(),
+                "best_acc1": best_acc1,
+                "optimizer": optimizer.state_dict(),
+            },
+            is_best,
+            filename=args.checkpoint_file + ".tar",
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/functorch/examples/dp_cifar10/cifar10_transforms.py.~1~ b/functorch/examples/dp_cifar10/cifar10_transforms.py.~1~
new file mode 100644
index 0000000..1b28af4
--- /dev/null
+++ b/functorch/examples/dp_cifar10/cifar10_transforms.py.~1~
@@ -0,0 +1,438 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+"""
+Runs CIFAR10 training with differential privacy.
+"""
+
+import argparse
+import os
+import shutil
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torch.utils.tensorboard as tensorboard
+import torchvision.models as models
+import torchvision.transforms as transforms
+from opacus import PrivacyEngine
+from opacus.utils import stats
+from opacus.utils.module_modification import convert_batchnorm_modules
+from torchvision.datasets import CIFAR10
+from tqdm import tqdm
+
+from functools import partial
+from make_functional import make_functional, load_weights
+
+# NB: The following might not exist depending on what you're using
+from torch import vmap
+from functional_utils import grad, grad_with_value
+
+
+def save_checkpoint(state, is_best, filename="checkpoint.tar"):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, "model_best.pth.tar")
+
+
+def accuracy(preds, labels):
+    return (preds == labels).mean()
+
+
+def train(args, model, train_loader, optimizer, epoch, device):
+    model.train()
+    criterion = nn.CrossEntropyLoss()
+
+    losses = []
+    top1_acc = []
+
+    for i, (images, target) in enumerate(tqdm(train_loader)):
+        images = images.to(device)
+        target = target.to(device)
+
+        # Step 1: compute per-sample-grads
+
+        # TODO(rzou): does the group norm work correctly?
+        weights, func_model, descriptors = make_functional(model)
+        [weight.requires_grad_(False) for weight in weights]
+
+        def compute_loss_and_output(weights, image, target):
+            images = image.unsqueeze(0)
+            target = target.unsqueeze(0)
+            output = func_model(weights, (images,))
+            loss = criterion(output, target)
+            return loss, output.squeeze(0)
+
+        grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
+        grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
+        loss = sample_loss.mean(0)
+
+        # Step 2: Clip the per-sample-grads and sum them to form grads
+
+        # TODO(rzou): Right now we just sum the grads. Instead we need to clip them.
+        for sample_grad, weight in zip(grads, weights):
+            weight_grad = sample_grad.sum(0)
+            weight.grad = weight_grad
+
+        # `load_weights` is the inverse operation of make_functional. We put
+        # things back into a model so that we can directly apply optimizers.
+        # TODO(rzou): this might not be necessary, optimizers just take
+        # the params straight up.
+        [weight.requires_grad_(True) for weight in weights]
+        load_weights(model, descriptors, weights, as_params=True)
+
+        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+        labels = target.detach().cpu().numpy()
+
+        # measure accuracy and record loss
+        acc1 = accuracy(preds, labels)
+
+        losses.append(loss.item())
+        top1_acc.append(acc1)
+        stats.update(stats.StatType.TRAIN, acc1=acc1)
+
+        # make sure we take a step after processing the last mini-batch in the
+        # epoch to ensure we start the next epoch with a clean state
+        if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
+            optimizer.step()
+            optimizer.zero_grad()
+        else:
+            optimizer.virtual_step()
+
+        if i % args.print_freq == 0:
+            if not args.disable_dp:
+                epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
+                    args.delta
+                )
+                print(
+                    f"\tTrain Epoch: {epoch} \t"
+                    f"Loss: {np.mean(losses):.6f} "
+                    f"Acc@1: {np.mean(top1_acc):.6f} "
+                    f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
+                )
+            else:
+                print(
+                    f"\tTrain Epoch: {epoch} \t"
+                    f"Loss: {np.mean(losses):.6f} "
+                    f"Acc@1: {np.mean(top1_acc):.6f} "
+                )
+
+
+def test(args, model, test_loader, device):
+    model.eval()
+    criterion = nn.CrossEntropyLoss()
+    losses = []
+    top1_acc = []
+
+    with torch.no_grad():
+        for images, target in tqdm(test_loader):
+            images = images.to(device)
+            target = target.to(device)
+
+            output = model(images)
+            loss = criterion(output, target)
+            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+            labels = target.detach().cpu().numpy()
+            acc1 = accuracy(preds, labels)
+
+            losses.append(loss.item())
+            top1_acc.append(acc1)
+
+    top1_avg = np.mean(top1_acc)
+    stats.update(stats.StatType.TEST, acc1=top1_avg)
+
+    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
+    return np.mean(top1_acc)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
+    parser.add_argument(
+        "-j",
+        "--workers",
+        default=2,
+        type=int,
+        metavar="N",
+        help="number of data loading workers (default: 2)",
+    )
+    parser.add_argument(
+        "--epochs",
+        default=90,
+        type=int,
+        metavar="N",
+        help="number of total epochs to run",
+    )
+    parser.add_argument(
+        "--start-epoch",
+        default=1,
+        type=int,
+        metavar="N",
+        help="manual epoch number (useful on restarts)",
+    )
+    parser.add_argument(
+        "-b",
+        "--batch-size",
+        default=256,
+        type=int,
+        metavar="N",
+        help="mini-batch size (default: 256), this is the total "
+        "batch size of all GPUs on the current node when "
+        "using Data Parallel or Distributed Data Parallel",
+    )
+    parser.add_argument(
+        "-na",
+        "--n_accumulation_steps",
+        default=1,
+        type=int,
+        metavar="N",
+        help="number of mini-batches to accumulate into an effective batch",
+    )
+    parser.add_argument(
+        "--lr",
+        "--learning-rate",
+        default=0.001,
+        type=float,
+        metavar="LR",
+        help="initial learning rate",
+        dest="lr",
+    )
+    parser.add_argument(
+        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
+    )
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=5e-4,
+        type=float,
+        metavar="W",
+        help="SGD weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "-p",
+        "--print-freq",
+        default=10,
+        type=int,
+        metavar="N",
+        help="print frequency (default: 10)",
+    )
+    parser.add_argument(
+        "--resume",
+        default="",
+        type=str,
+        metavar="PATH",
+        help="path to latest checkpoint (default: none)",
+    )
+    parser.add_argument(
+        "-e",
+        "--evaluate",
+        dest="evaluate",
+        action="store_true",
+        help="evaluate model on validation set",
+    )
+    parser.add_argument(
+        "--seed", default=None, type=int, help="seed for initializing training. "
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help="GPU ID for this process (default: 'cuda')",
+    )
+    parser.add_argument(
+        "--sigma",
+        type=float,
+        default=1.0,
+        metavar="S",
+        help="Noise multiplier (default 1.0)",
+    )
+    parser.add_argument(
+        "-c",
+        "--max-per-sample-grad_norm",
+        type=float,
+        default=1.0,
+        metavar="C",
+        help="Clip per-sample gradients to this norm (default 1.0)",
+    )
+    parser.add_argument(
+        "--disable-dp",
+        action="store_true",
+        default=False,
+        help="Disable privacy training and just train with vanilla SGD",
+    )
+    parser.add_argument(
+        "--secure-rng",
+        action="store_true",
+        default=False,
+        help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
+    )
+    parser.add_argument(
+        "--delta",
+        type=float,
+        default=1e-5,
+        metavar="D",
+        help="Target delta (default: 1e-5)",
+    )
+
+    parser.add_argument(
+        "--checkpoint-file",
+        type=str,
+        default="checkpoint",
+        help="path to save check points",
+    )
+    parser.add_argument(
+        "--data-root",
+        type=str,
+        default="../cifar10",
+        help="Where CIFAR10 is/will be stored",
+    )
+    parser.add_argument(
+        "--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
+    )
+    parser.add_argument(
+        "--optim",
+        type=str,
+        default="Adam",
+        help="Optimizer to use (Adam, RMSprop, SGD)",
+    )
+
+    args = parser.parse_args()
+    args.disable_dp = True
+
+    if args.disable_dp and args.n_accumulation_steps > 1:
+        raise ValueError("Virtual steps only works with enabled DP")
+
+    # The following few lines, enable stats gathering about the run
+    # 1. where the stats should be logged
+    stats.set_global_summary_writer(
+        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
+    )
+    # 2. enable stats
+    stats.add(
+        # stats about gradient norms aggregated for all layers
+        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
+        # stats about gradient norms per layer
+        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
+        # stats about clipping
+        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
+        # stats on training accuracy
+        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
+        # stats on validation accuracy
+        stats.Stat(stats.StatType.TEST, "accuracy"),
+    )
+
+    # The following lines enable stat gathering for the clipping process
+    # and set a default of per layer clipping for the Privacy Engine
+    clipping = {"clip_per_layer": False, "enable_stat": True}
+
+    if args.secure_rng:
+        assert False
+        try:
+            import torchcsprng as prng
+        except ImportError as e:
+            msg = (
+                "To use secure RNG, you must install the torchcsprng package! "
+                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
+            )
+            raise ImportError(msg) from e
+
+        generator = prng.create_random_device_generator("/dev/urandom")
+
+    else:
+        generator = None
+
+    augmentations = [
+        transforms.RandomCrop(32, padding=4),
+        transforms.RandomHorizontalFlip(),
+    ]
+    normalize = [
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ]
+    train_transform = transforms.Compose(
+        augmentations + normalize if args.disable_dp else normalize
+    )
+
+    test_transform = transforms.Compose(normalize)
+
+    train_dataset = CIFAR10(
+        root=args.data_root, train=True, download=True, transform=train_transform
+    )
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        drop_last=True,
+        generator=generator,
+    )
+
+    test_dataset = CIFAR10(
+        root=args.data_root, train=False, download=True, transform=test_transform
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+    )
+
+    best_acc1 = 0
+    device = torch.device(args.device)
+    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
+    model = model.to(device)
+
+    if args.optim == "SGD":
+        optimizer = optim.SGD(
+            model.parameters(),
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+        )
+    elif args.optim == "RMSprop":
+        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
+    elif args.optim == "Adam":
+        optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    else:
+        raise NotImplementedError("Optimizer not recognized. Please check spelling")
+
+    if not args.disable_dp:
+        privacy_engine = PrivacyEngine(
+            model,
+            batch_size=args.batch_size * args.n_accumulation_steps,
+            sample_size=len(train_dataset),
+            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
+            noise_multiplier=args.sigma,
+            max_grad_norm=args.max_per_sample_grad_norm,
+            secure_rng=args.secure_rng,
+            **clipping,
+        )
+        privacy_engine.attach(optimizer)
+
+    for epoch in range(args.start_epoch, args.epochs + 1):
+        train(args, model, train_loader, optimizer, epoch, device)
+        top1_acc = test(args, model, test_loader, device)
+
+        # remember best acc@1 and save checkpoint
+        is_best = top1_acc > best_acc1
+        best_acc1 = max(top1_acc, best_acc1)
+
+        save_checkpoint(
+            {
+                "epoch": epoch + 1,
+                "arch": "ResNet18",
+                "state_dict": model.state_dict(),
+                "best_acc1": best_acc1,
+                "optimizer": optimizer.state_dict(),
+            },
+            is_best,
+            filename=args.checkpoint_file + ".tar",
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/functorch/examples/dp_cifar10/cifar10_transforms.py.~2~ b/functorch/examples/dp_cifar10/cifar10_transforms.py.~2~
new file mode 100644
index 0000000..53cffb5
--- /dev/null
+++ b/functorch/examples/dp_cifar10/cifar10_transforms.py.~2~
@@ -0,0 +1,491 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+"""
+Runs CIFAR10 training with differential privacy.
+"""
+
+import argparse
+import os
+import shutil
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torch.utils.tensorboard as tensorboard
+import torchvision.models as models
+import torchvision.transforms as transforms
+from opacus import PrivacyEngine
+from opacus.utils import stats
+from opacus.utils.module_modification import convert_batchnorm_modules
+from torchvision.datasets import CIFAR10
+from tqdm import tqdm
+
+from torch import vmap
+from make_functional import make_functional, load_weights
+from functional_utils import grad, grad_with_value
+from functools import partial
+# from resnet import resnet18
+
+class CudaMemoryLeakCheck():
+    def __init__(self, name):
+        self.name = name
+
+        # initialize context & RNG to prevenikkt false positive detections
+        # when the test is the first to initialize those
+        from torch.testing._internal.common_cuda import initialize_cuda_context_rng
+        initialize_cuda_context_rng()
+
+    @staticmethod
+    def get_cuda_memory_usage():
+        # we don't need CUDA synchronize because the statistics are not tracked at
+        # actual freeing, but at when marking the block as free.
+        num_devices = torch.cuda.device_count()
+        import gc
+        gc.collect()
+        return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
+
+    def __enter__(self):
+        self.befores = self.get_cuda_memory_usage()
+
+    def __exit__(self, exec_type, exec_value, traceback):
+        # Don't check for leaks if an exception was thrown
+        if exec_type is not None:
+            return
+
+        afters = self.get_cuda_memory_usage()
+
+        for i, (before, after) in enumerate(zip(self.befores, afters)):
+            if after - before == 0:
+                continue
+            raise RuntimeError(f'{self.name} leaked {after-before} bytes')
+
+
+def save_checkpoint(state, is_best, filename="checkpoint.tar"):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, "model_best.pth.tar")
+
+
+def accuracy(preds, labels):
+    return (preds == labels).mean()
+
+
+def compute_norms(sample_grads):
+    batch_size = sample_grads[0].shape[0]
+    norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
+    norms = torch.stack(norms, dim=0).norm(2, dim=0)
+    return norms
+
+def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
+    # step 0: compute the norms
+    sample_norms = compute_norms(sample_grads)
+
+    # step 1: compute clipping factors
+    clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
+    clip_factor = clip_factor.clamp(max=1.0)
+
+    # step 2: clip
+    grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
+                  for sample_grad in sample_grads)
+
+    # step 3: add gaussian noise
+    stddev = max_per_sample_grad_norm * noise_multiplier
+    noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
+                   for grad_param in grads)
+    grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
+
+    return grads
+
+def train(args, model, train_loader, optimizer, epoch, device):
+    use_prototype = False
+    model.train()
+    criterion = nn.CrossEntropyLoss()
+
+    losses = []
+    top1_acc = []
+
+    for i, (images, target) in enumerate(tqdm(train_loader)):
+
+        images = images.to(device)
+        target = target.to(device)
+
+        # Step 1: compute per-sample-grads
+        weights, func_model, descriptors = make_functional(model)
+
+        def compute_loss_and_output(weights, image, target):
+            images = image.unsqueeze(0)
+            targets = target.unsqueeze(0)
+            output = func_model(weights, (images,))
+            loss = criterion(output, targets)
+            return loss, output.squeeze(0)
+
+        # grad_with_value(f) returns a function that returns (1) the grad and
+        # (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
+        # where the first is to be differentiated and the second is not to be
+        # differentiated and further adds a 3rd output.
+        #
+        # We need to use `grad_with_value(..., has_aux=True)` because we do
+        # some analyses on the returned loss and output.
+        grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
+        sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
+        loss = sample_loss.mean()
+
+        # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
+        grads = clip_and_accumulate_and_add_noise(
+            sample_grads, args.max_per_sample_grad_norm, args.sigma)
+
+        # `load_weights` is the inverse operation of make_functional. We put
+        # things back into a model so that we can directly apply optimizers.
+        # TODO(rzou): this might not be necessary, optimizers just take
+        # the params straight up.
+        load_weights(model, descriptors, weights)
+
+        for weight_grad, weight in zip(grads, model.parameters()):
+            weight.grad = weight_grad.detach()
+
+        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+        labels = target.detach().cpu().numpy()
+        losses.append(loss.item())
+
+        # measure accuracy and record loss
+        acc1 = accuracy(preds, labels)
+
+        top1_acc.append(acc1)
+        stats.update(stats.StatType.TRAIN, acc1=acc1)
+
+        # make sure we take a step after processing the last mini-batch in the
+        # epoch to ensure we start the next epoch with a clean state
+        if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
+            optimizer.step()
+            optimizer.zero_grad()
+        else:
+            optimizer.virtual_step()
+
+        if i % args.print_freq == 0:
+            print(
+                f"\tTrain Epoch: {epoch} \t"
+                f"Loss: {np.mean(losses):.6f} "
+                f"Acc@1: {np.mean(top1_acc):.6f} "
+            )
+
+def test(args, model, test_loader, device):
+    model.eval()
+    criterion = nn.CrossEntropyLoss()
+    losses = []
+    top1_acc = []
+
+    with torch.no_grad():
+        for images, target in tqdm(test_loader):
+            images = images.to(device)
+            target = target.to(device)
+
+            output = model(images)
+            loss = criterion(output, target)
+            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+            labels = target.detach().cpu().numpy()
+            acc1 = accuracy(preds, labels)
+
+            losses.append(loss.item())
+            top1_acc.append(acc1)
+
+    top1_avg = np.mean(top1_acc)
+    stats.update(stats.StatType.TEST, acc1=top1_avg)
+
+    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
+    return np.mean(top1_acc)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
+    parser.add_argument(
+        "-j",
+        "--workers",
+        default=2,
+        type=int,
+        metavar="N",
+        help="number of data loading workers (default: 2)",
+    )
+    parser.add_argument(
+        "--epochs",
+        default=90,
+        type=int,
+        metavar="N",
+        help="number of total epochs to run",
+    )
+    parser.add_argument(
+        "--start-epoch",
+        default=1,
+        type=int,
+        metavar="N",
+        help="manual epoch number (useful on restarts)",
+    )
+    parser.add_argument(
+        "-b",
+        "--batch-size",
+        # This should be 256, but that OOMs using the prototype.
+        default=64,
+        type=int,
+        metavar="N",
+        help="mini-batch size (default: 64), this is the total "
+        "batch size of all GPUs on the current node when "
+        "using Data Parallel or Distributed Data Parallel",
+    )
+    parser.add_argument(
+        "-na",
+        "--n_accumulation_steps",
+        default=1,
+        type=int,
+        metavar="N",
+        help="number of mini-batches to accumulate into an effective batch",
+    )
+    parser.add_argument(
+        "--lr",
+        "--learning-rate",
+        default=0.001,
+        type=float,
+        metavar="LR",
+        help="initial learning rate",
+        dest="lr",
+    )
+    parser.add_argument(
+        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
+    )
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=5e-4,
+        type=float,
+        metavar="W",
+        help="SGD weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "-p",
+        "--print-freq",
+        default=10,
+        type=int,
+        metavar="N",
+        help="print frequency (default: 10)",
+    )
+    parser.add_argument(
+        "--resume",
+        default="",
+        type=str,
+        metavar="PATH",
+        help="path to latest checkpoint (default: none)",
+    )
+    parser.add_argument(
+        "-e",
+        "--evaluate",
+        dest="evaluate",
+        action="store_true",
+        help="evaluate model on validation set",
+    )
+    parser.add_argument(
+        "--seed", default=None, type=int, help="seed for initializing training. "
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help="GPU ID for this process (default: 'cuda')",
+    )
+    parser.add_argument(
+        "--sigma",
+        type=float,
+        default=1.0,
+        metavar="S",
+        help="Noise multiplier (default 1.0)",
+    )
+    parser.add_argument(
+        "-c",
+        "--max-per-sample-grad_norm",
+        type=float,
+        default=1.0,
+        metavar="C",
+        help="Clip per-sample gradients to this norm (default 1.0)",
+    )
+    parser.add_argument(
+        "--disable-dp",
+        action="store_true",
+        default=False,
+        help="Disable privacy training and just train with vanilla SGD",
+    )
+    parser.add_argument(
+        "--secure-rng",
+        action="store_true",
+        default=False,
+        help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
+    )
+    parser.add_argument(
+        "--delta",
+        type=float,
+        default=1e-5,
+        metavar="D",
+        help="Target delta (default: 1e-5)",
+    )
+
+    parser.add_argument(
+        "--checkpoint-file",
+        type=str,
+        default="checkpoint",
+        help="path to save check points",
+    )
+    parser.add_argument(
+        "--data-root",
+        type=str,
+        default="../cifar10",
+        help="Where CIFAR10 is/will be stored",
+    )
+    parser.add_argument(
+        "--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
+    )
+    parser.add_argument(
+        "--optim",
+        type=str,
+        default="Adam",
+        help="Optimizer to use (Adam, RMSprop, SGD)",
+    )
+
+    args = parser.parse_args()
+    args.disable_dp = True
+
+    if args.disable_dp and args.n_accumulation_steps > 1:
+        raise ValueError("Virtual steps only works with enabled DP")
+
+    # The following few lines, enable stats gathering about the run
+    # 1. where the stats should be logged
+    stats.set_global_summary_writer(
+        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
+    )
+    # 2. enable stats
+    stats.add(
+        # stats about gradient norms aggregated for all layers
+        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
+        # stats about gradient norms per layer
+        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
+        # stats about clipping
+        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
+        # stats on training accuracy
+        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
+        # stats on validation accuracy
+        stats.Stat(stats.StatType.TEST, "accuracy"),
+    )
+
+    # The following lines enable stat gathering for the clipping process
+    # and set a default of per layer clipping for the Privacy Engine
+    clipping = {"clip_per_layer": False, "enable_stat": True}
+
+    if args.secure_rng:
+        assert False
+        try:
+            import torchcsprng as prng
+        except ImportError as e:
+            msg = (
+                "To use secure RNG, you must install the torchcsprng package! "
+                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
+            )
+            raise ImportError(msg) from e
+
+        generator = prng.create_random_device_generator("/dev/urandom")
+
+    else:
+        generator = None
+
+    augmentations = [
+        transforms.RandomCrop(32, padding=4),
+        transforms.RandomHorizontalFlip(),
+    ]
+    normalize = [
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ]
+    train_transform = transforms.Compose(
+        augmentations + normalize if args.disable_dp else normalize
+    )
+
+    test_transform = transforms.Compose(normalize)
+
+    train_dataset = CIFAR10(
+        root=args.data_root, train=True, download=True, transform=train_transform
+    )
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        drop_last=True,
+        generator=generator,
+    )
+
+    test_dataset = CIFAR10(
+        root=args.data_root, train=False, download=True, transform=test_transform
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+    )
+
+    best_acc1 = 0
+    device = torch.device(args.device)
+    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
+    # model = CIFAR10Model()
+    model = model.to(device)
+
+    if args.optim == "SGD":
+        optimizer = optim.SGD(
+            model.parameters(),
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+        )
+    elif args.optim == "RMSprop":
+        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
+    elif args.optim == "Adam":
+        optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    else:
+        raise NotImplementedError("Optimizer not recognized. Please check spelling")
+
+    if not args.disable_dp:
+        privacy_engine = PrivacyEngine(
+            model,
+            batch_size=args.batch_size * args.n_accumulation_steps,
+            sample_size=len(train_dataset),
+            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
+            noise_multiplier=args.sigma,
+            max_grad_norm=args.max_per_sample_grad_norm,
+            secure_rng=args.secure_rng,
+            **clipping,
+        )
+        privacy_engine.attach(optimizer)
+
+    for epoch in range(args.start_epoch, args.epochs + 1):
+        train(args, model, train_loader, optimizer, epoch, device)
+        top1_acc = test(args, model, test_loader, device)
+
+        # remember best acc@1 and save checkpoint
+        is_best = top1_acc > best_acc1
+        best_acc1 = max(top1_acc, best_acc1)
+
+        save_checkpoint(
+            {
+                "epoch": epoch + 1,
+                "arch": "ResNet18",
+                "state_dict": model.state_dict(),
+                "best_acc1": best_acc1,
+                "optimizer": optimizer.state_dict(),
+            },
+            is_best,
+            filename=args.checkpoint_file + ".tar",
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/functorch/examples/dp_cifar10/cifar10_transforms.py.~3~ b/functorch/examples/dp_cifar10/cifar10_transforms.py.~3~
new file mode 100644
index 0000000..7d72ec3
--- /dev/null
+++ b/functorch/examples/dp_cifar10/cifar10_transforms.py.~3~
@@ -0,0 +1,457 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+"""
+Runs CIFAR10 training with differential privacy.
+"""
+
+import argparse
+import os
+import shutil
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torch.utils.tensorboard as tensorboard
+import torchvision.models as models
+import torchvision.transforms as transforms
+from opacus import PrivacyEngine
+from opacus.utils import stats
+from opacus.utils.module_modification import convert_batchnorm_modules
+from torchvision.datasets import CIFAR10
+from tqdm import tqdm
+
+from torch import vmap
+from make_functional import make_functional, load_weights
+from functional_utils import grad, grad_with_value
+from functools import partial
+# from resnet import resnet18
+
+def save_checkpoint(state, is_best, filename="checkpoint.tar"):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, "model_best.pth.tar")
+
+
+def accuracy(preds, labels):
+    return (preds == labels).mean()
+
+
+def compute_norms(sample_grads):
+    batch_size = sample_grads[0].shape[0]
+    norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
+    norms = torch.stack(norms, dim=0).norm(2, dim=0)
+    return norms
+
+def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
+    # step 0: compute the norms
+    sample_norms = compute_norms(sample_grads)
+
+    # step 1: compute clipping factors
+    clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
+    clip_factor = clip_factor.clamp(max=1.0)
+
+    # step 2: clip
+    grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
+                  for sample_grad in sample_grads)
+
+    # step 3: add gaussian noise
+    stddev = max_per_sample_grad_norm * noise_multiplier
+    noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
+                   for grad_param in grads)
+    grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
+
+    return grads
+
+def train(args, model, train_loader, optimizer, epoch, device):
+    use_prototype = False
+    model.train()
+    criterion = nn.CrossEntropyLoss()
+
+    losses = []
+    top1_acc = []
+
+    for i, (images, target) in enumerate(tqdm(train_loader)):
+
+        images = images.to(device)
+        target = target.to(device)
+
+        # Step 1: compute per-sample-grads
+        weights, func_model, descriptors = make_functional(model)
+
+        def compute_loss_and_output(weights, image, target):
+            images = image.unsqueeze(0)
+            targets = target.unsqueeze(0)
+            output = func_model(weights, (images,))
+            loss = criterion(output, targets)
+            return loss, output.squeeze(0)
+
+        # grad_with_value(f) returns a function that returns (1) the grad and
+        # (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
+        # where the first is to be differentiated and the second is not to be
+        # differentiated and further adds a 3rd output.
+        #
+        # We need to use `grad_with_value(..., has_aux=True)` because we do
+        # some analyses on the returned loss and output.
+        grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
+        sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
+        loss = sample_loss.mean()
+
+        # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
+        grads = clip_and_accumulate_and_add_noise(
+            sample_grads, args.max_per_sample_grad_norm, args.sigma)
+
+        # `load_weights` is the inverse operation of make_functional. We put
+        # things back into a model so that we can directly apply optimizers.
+        # TODO(rzou): this might not be necessary, optimizers just take
+        # the params straight up.
+        load_weights(model, descriptors, weights)
+
+        for weight_grad, weight in zip(grads, model.parameters()):
+            weight.grad = weight_grad.detach()
+
+        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+        labels = target.detach().cpu().numpy()
+        losses.append(loss.item())
+
+        # measure accuracy and record loss
+        acc1 = accuracy(preds, labels)
+
+        top1_acc.append(acc1)
+        stats.update(stats.StatType.TRAIN, acc1=acc1)
+
+        # make sure we take a step after processing the last mini-batch in the
+        # epoch to ensure we start the next epoch with a clean state
+        if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
+            optimizer.step()
+            optimizer.zero_grad()
+        else:
+            optimizer.virtual_step()
+
+        if i % args.print_freq == 0:
+            print(
+                f"\tTrain Epoch: {epoch} \t"
+                f"Loss: {np.mean(losses):.6f} "
+                f"Acc@1: {np.mean(top1_acc):.6f} "
+            )
+
+def test(args, model, test_loader, device):
+    model.eval()
+    criterion = nn.CrossEntropyLoss()
+    losses = []
+    top1_acc = []
+
+    with torch.no_grad():
+        for images, target in tqdm(test_loader):
+            images = images.to(device)
+            target = target.to(device)
+
+            output = model(images)
+            loss = criterion(output, target)
+            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+            labels = target.detach().cpu().numpy()
+            acc1 = accuracy(preds, labels)
+
+            losses.append(loss.item())
+            top1_acc.append(acc1)
+
+    top1_avg = np.mean(top1_acc)
+    stats.update(stats.StatType.TEST, acc1=top1_avg)
+
+    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
+    return np.mean(top1_acc)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
+    parser.add_argument(
+        "-j",
+        "--workers",
+        default=2,
+        type=int,
+        metavar="N",
+        help="number of data loading workers (default: 2)",
+    )
+    parser.add_argument(
+        "--epochs",
+        default=90,
+        type=int,
+        metavar="N",
+        help="number of total epochs to run",
+    )
+    parser.add_argument(
+        "--start-epoch",
+        default=1,
+        type=int,
+        metavar="N",
+        help="manual epoch number (useful on restarts)",
+    )
+    parser.add_argument(
+        "-b",
+        "--batch-size",
+        # This should be 256, but that OOMs using the prototype.
+        default=64,
+        type=int,
+        metavar="N",
+        help="mini-batch size (default: 64), this is the total "
+        "batch size of all GPUs on the current node when "
+        "using Data Parallel or Distributed Data Parallel",
+    )
+    parser.add_argument(
+        "-na",
+        "--n_accumulation_steps",
+        default=1,
+        type=int,
+        metavar="N",
+        help="number of mini-batches to accumulate into an effective batch",
+    )
+    parser.add_argument(
+        "--lr",
+        "--learning-rate",
+        default=0.001,
+        type=float,
+        metavar="LR",
+        help="initial learning rate",
+        dest="lr",
+    )
+    parser.add_argument(
+        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
+    )
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=5e-4,
+        type=float,
+        metavar="W",
+        help="SGD weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "-p",
+        "--print-freq",
+        default=10,
+        type=int,
+        metavar="N",
+        help="print frequency (default: 10)",
+    )
+    parser.add_argument(
+        "--resume",
+        default="",
+        type=str,
+        metavar="PATH",
+        help="path to latest checkpoint (default: none)",
+    )
+    parser.add_argument(
+        "-e",
+        "--evaluate",
+        dest="evaluate",
+        action="store_true",
+        help="evaluate model on validation set",
+    )
+    parser.add_argument(
+        "--seed", default=None, type=int, help="seed for initializing training. "
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help="GPU ID for this process (default: 'cuda')",
+    )
+    parser.add_argument(
+        "--sigma",
+        type=float,
+        default=1.0,
+        metavar="S",
+        help="Noise multiplier (default 1.0)",
+    )
+    parser.add_argument(
+        "-c",
+        "--max-per-sample-grad_norm",
+        type=float,
+        default=1.0,
+        metavar="C",
+        help="Clip per-sample gradients to this norm (default 1.0)",
+    )
+    parser.add_argument(
+        "--disable-dp",
+        action="store_true",
+        default=False,
+        help="Disable privacy training and just train with vanilla SGD",
+    )
+    parser.add_argument(
+        "--secure-rng",
+        action="store_true",
+        default=False,
+        help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
+    )
+    parser.add_argument(
+        "--delta",
+        type=float,
+        default=1e-5,
+        metavar="D",
+        help="Target delta (default: 1e-5)",
+    )
+
+    parser.add_argument(
+        "--checkpoint-file",
+        type=str,
+        default="checkpoint",
+        help="path to save check points",
+    )
+    parser.add_argument(
+        "--data-root",
+        type=str,
+        default="../cifar10",
+        help="Where CIFAR10 is/will be stored",
+    )
+    parser.add_argument(
+        "--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
+    )
+    parser.add_argument(
+        "--optim",
+        type=str,
+        default="Adam",
+        help="Optimizer to use (Adam, RMSprop, SGD)",
+    )
+
+    args = parser.parse_args()
+    args.disable_dp = True
+
+    if args.disable_dp and args.n_accumulation_steps > 1:
+        raise ValueError("Virtual steps only works with enabled DP")
+
+    # The following few lines, enable stats gathering about the run
+    # 1. where the stats should be logged
+    stats.set_global_summary_writer(
+        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
+    )
+    # 2. enable stats
+    stats.add(
+        # stats about gradient norms aggregated for all layers
+        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
+        # stats about gradient norms per layer
+        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
+        # stats about clipping
+        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
+        # stats on training accuracy
+        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
+        # stats on validation accuracy
+        stats.Stat(stats.StatType.TEST, "accuracy"),
+    )
+
+    # The following lines enable stat gathering for the clipping process
+    # and set a default of per layer clipping for the Privacy Engine
+    clipping = {"clip_per_layer": False, "enable_stat": True}
+
+    if args.secure_rng:
+        assert False
+        try:
+            import torchcsprng as prng
+        except ImportError as e:
+            msg = (
+                "To use secure RNG, you must install the torchcsprng package! "
+                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
+            )
+            raise ImportError(msg) from e
+
+        generator = prng.create_random_device_generator("/dev/urandom")
+
+    else:
+        generator = None
+
+    augmentations = [
+        transforms.RandomCrop(32, padding=4),
+        transforms.RandomHorizontalFlip(),
+    ]
+    normalize = [
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ]
+    train_transform = transforms.Compose(
+        augmentations + normalize if args.disable_dp else normalize
+    )
+
+    test_transform = transforms.Compose(normalize)
+
+    train_dataset = CIFAR10(
+        root=args.data_root, train=True, download=True, transform=train_transform
+    )
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        drop_last=True,
+        generator=generator,
+    )
+
+    test_dataset = CIFAR10(
+        root=args.data_root, train=False, download=True, transform=test_transform
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+    )
+
+    best_acc1 = 0
+    device = torch.device(args.device)
+    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
+    # model = CIFAR10Model()
+    model = model.to(device)
+
+    if args.optim == "SGD":
+        optimizer = optim.SGD(
+            model.parameters(),
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+        )
+    elif args.optim == "RMSprop":
+        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
+    elif args.optim == "Adam":
+        optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    else:
+        raise NotImplementedError("Optimizer not recognized. Please check spelling")
+
+    if not args.disable_dp:
+        privacy_engine = PrivacyEngine(
+            model,
+            batch_size=args.batch_size * args.n_accumulation_steps,
+            sample_size=len(train_dataset),
+            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
+            noise_multiplier=args.sigma,
+            max_grad_norm=args.max_per_sample_grad_norm,
+            secure_rng=args.secure_rng,
+            **clipping,
+        )
+        privacy_engine.attach(optimizer)
+
+    for epoch in range(args.start_epoch, args.epochs + 1):
+        train(args, model, train_loader, optimizer, epoch, device)
+        top1_acc = test(args, model, test_loader, device)
+
+        # remember best acc@1 and save checkpoint
+        is_best = top1_acc > best_acc1
+        best_acc1 = max(top1_acc, best_acc1)
+
+        save_checkpoint(
+            {
+                "epoch": epoch + 1,
+                "arch": "ResNet18",
+                "state_dict": model.state_dict(),
+                "best_acc1": best_acc1,
+                "optimizer": optimizer.state_dict(),
+            },
+            is_best,
+            filename=args.checkpoint_file + ".tar",
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/functorch/examples/dp_cifar10/cifar10_transforms.py.~4~ b/functorch/examples/dp_cifar10/cifar10_transforms.py.~4~
new file mode 100644
index 0000000..21c5c9a
--- /dev/null
+++ b/functorch/examples/dp_cifar10/cifar10_transforms.py.~4~
@@ -0,0 +1,456 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+"""
+Runs CIFAR10 training with differential privacy.
+"""
+
+import argparse
+import os
+import shutil
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torch.utils.tensorboard as tensorboard
+import torchvision.models as models
+import torchvision.transforms as transforms
+from opacus import PrivacyEngine
+from opacus.utils import stats
+from opacus.utils.module_modification import convert_batchnorm_modules
+from torchvision.datasets import CIFAR10
+from tqdm import tqdm
+
+from torch import vmap
+from make_functional import make_functional, load_weights
+from functional_utils import grad, grad_with_value
+from functools import partial
+# from resnet import resnet18
+
+def save_checkpoint(state, is_best, filename="checkpoint.tar"):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, "model_best.pth.tar")
+
+
+def accuracy(preds, labels):
+    return (preds == labels).mean()
+
+
+def compute_norms(sample_grads):
+    batch_size = sample_grads[0].shape[0]
+    norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
+    norms = torch.stack(norms, dim=0).norm(2, dim=0)
+    return norms
+
+def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
+    # step 0: compute the norms
+    sample_norms = compute_norms(sample_grads)
+
+    # step 1: compute clipping factors
+    clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
+    clip_factor = clip_factor.clamp(max=1.0)
+
+    # step 2: clip
+    grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
+                  for sample_grad in sample_grads)
+
+    # step 3: add gaussian noise
+    stddev = max_per_sample_grad_norm * noise_multiplier
+    noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
+                   for grad_param in grads)
+    grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
+
+    return grads
+
+def train(args, model, train_loader, optimizer, epoch, device):
+    model.train()
+    criterion = nn.CrossEntropyLoss()
+
+    losses = []
+    top1_acc = []
+
+    for i, (images, target) in enumerate(tqdm(train_loader)):
+
+        images = images.to(device)
+        target = target.to(device)
+
+        # Step 1: compute per-sample-grads
+        weights, func_model, descriptors = make_functional(model)
+
+        def compute_loss_and_output(weights, image, target):
+            images = image.unsqueeze(0)
+            targets = target.unsqueeze(0)
+            output = func_model(weights, (images,))
+            loss = criterion(output, targets)
+            return loss, output.squeeze(0)
+
+        # grad_with_value(f) returns a function that returns (1) the grad and
+        # (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
+        # where the first is to be differentiated and the second is not to be
+        # differentiated and further adds a 3rd output.
+        #
+        # We need to use `grad_with_value(..., has_aux=True)` because we do
+        # some analyses on the returned loss and output.
+        grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
+        sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
+        loss = sample_loss.mean()
+
+        # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
+        grads = clip_and_accumulate_and_add_noise(
+            sample_grads, args.max_per_sample_grad_norm, args.sigma)
+
+        # `load_weights` is the inverse operation of make_functional. We put
+        # things back into a model so that we can directly apply optimizers.
+        # TODO(rzou): this might not be necessary, optimizers just take
+        # the params straight up.
+        load_weights(model, descriptors, weights)
+
+        for weight_grad, weight in zip(grads, model.parameters()):
+            weight.grad = weight_grad.detach()
+
+        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+        labels = target.detach().cpu().numpy()
+        losses.append(loss.item())
+
+        # measure accuracy and record loss
+        acc1 = accuracy(preds, labels)
+
+        top1_acc.append(acc1)
+        stats.update(stats.StatType.TRAIN, acc1=acc1)
+
+        # make sure we take a step after processing the last mini-batch in the
+        # epoch to ensure we start the next epoch with a clean state
+        if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
+            optimizer.step()
+            optimizer.zero_grad()
+        else:
+            optimizer.virtual_step()
+
+        if i % args.print_freq == 0:
+            print(
+                f"\tTrain Epoch: {epoch} \t"
+                f"Loss: {np.mean(losses):.6f} "
+                f"Acc@1: {np.mean(top1_acc):.6f} "
+            )
+
+def test(args, model, test_loader, device):
+    model.eval()
+    criterion = nn.CrossEntropyLoss()
+    losses = []
+    top1_acc = []
+
+    with torch.no_grad():
+        for images, target in tqdm(test_loader):
+            images = images.to(device)
+            target = target.to(device)
+
+            output = model(images)
+            loss = criterion(output, target)
+            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
+            labels = target.detach().cpu().numpy()
+            acc1 = accuracy(preds, labels)
+
+            losses.append(loss.item())
+            top1_acc.append(acc1)
+
+    top1_avg = np.mean(top1_acc)
+    stats.update(stats.StatType.TEST, acc1=top1_avg)
+
+    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
+    return np.mean(top1_acc)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
+    parser.add_argument(
+        "-j",
+        "--workers",
+        default=2,
+        type=int,
+        metavar="N",
+        help="number of data loading workers (default: 2)",
+    )
+    parser.add_argument(
+        "--epochs",
+        default=90,
+        type=int,
+        metavar="N",
+        help="number of total epochs to run",
+    )
+    parser.add_argument(
+        "--start-epoch",
+        default=1,
+        type=int,
+        metavar="N",
+        help="manual epoch number (useful on restarts)",
+    )
+    parser.add_argument(
+        "-b",
+        "--batch-size",
+        # This should be 256, but that OOMs using the prototype.
+        default=64,
+        type=int,
+        metavar="N",
+        help="mini-batch size (default: 64), this is the total "
+        "batch size of all GPUs on the current node when "
+        "using Data Parallel or Distributed Data Parallel",
+    )
+    parser.add_argument(
+        "-na",
+        "--n_accumulation_steps",
+        default=1,
+        type=int,
+        metavar="N",
+        help="number of mini-batches to accumulate into an effective batch",
+    )
+    parser.add_argument(
+        "--lr",
+        "--learning-rate",
+        default=0.001,
+        type=float,
+        metavar="LR",
+        help="initial learning rate",
+        dest="lr",
+    )
+    parser.add_argument(
+        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
+    )
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=5e-4,
+        type=float,
+        metavar="W",
+        help="SGD weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "-p",
+        "--print-freq",
+        default=10,
+        type=int,
+        metavar="N",
+        help="print frequency (default: 10)",
+    )
+    parser.add_argument(
+        "--resume",
+        default="",
+        type=str,
+        metavar="PATH",
+        help="path to latest checkpoint (default: none)",
+    )
+    parser.add_argument(
+        "-e",
+        "--evaluate",
+        dest="evaluate",
+        action="store_true",
+        help="evaluate model on validation set",
+    )
+    parser.add_argument(
+        "--seed", default=None, type=int, help="seed for initializing training. "
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help="GPU ID for this process (default: 'cuda')",
+    )
+    parser.add_argument(
+        "--sigma",
+        type=float,
+        default=1.0,
+        metavar="S",
+        help="Noise multiplier (default 1.0)",
+    )
+    parser.add_argument(
+        "-c",
+        "--max-per-sample-grad_norm",
+        type=float,
+        default=1.0,
+        metavar="C",
+        help="Clip per-sample gradients to this norm (default 1.0)",
+    )
+    parser.add_argument(
+        "--disable-dp",
+        action="store_true",
+        default=False,
+        help="Disable privacy training and just train with vanilla SGD",
+    )
+    parser.add_argument(
+        "--secure-rng",
+        action="store_true",
+        default=False,
+        help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
+    )
+    parser.add_argument(
+        "--delta",
+        type=float,
+        default=1e-5,
+        metavar="D",
+        help="Target delta (default: 1e-5)",
+    )
+
+    parser.add_argument(
+        "--checkpoint-file",
+        type=str,
+        default="checkpoint",
+        help="path to save check points",
+    )
+    parser.add_argument(
+        "--data-root",
+        type=str,
+        default="../cifar10",
+        help="Where CIFAR10 is/will be stored",
+    )
+    parser.add_argument(
+        "--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
+    )
+    parser.add_argument(
+        "--optim",
+        type=str,
+        default="Adam",
+        help="Optimizer to use (Adam, RMSprop, SGD)",
+    )
+
+    args = parser.parse_args()
+    args.disable_dp = True
+
+    if args.disable_dp and args.n_accumulation_steps > 1:
+        raise ValueError("Virtual steps only works with enabled DP")
+
+    # The following few lines, enable stats gathering about the run
+    # 1. where the stats should be logged
+    stats.set_global_summary_writer(
+        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
+    )
+    # 2. enable stats
+    stats.add(
+        # stats about gradient norms aggregated for all layers
+        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
+        # stats about gradient norms per layer
+        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
+        # stats about clipping
+        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
+        # stats on training accuracy
+        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
+        # stats on validation accuracy
+        stats.Stat(stats.StatType.TEST, "accuracy"),
+    )
+
+    # The following lines enable stat gathering for the clipping process
+    # and set a default of per layer clipping for the Privacy Engine
+    clipping = {"clip_per_layer": False, "enable_stat": True}
+
+    if args.secure_rng:
+        assert False
+        try:
+            import torchcsprng as prng
+        except ImportError as e:
+            msg = (
+                "To use secure RNG, you must install the torchcsprng package! "
+                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
+            )
+            raise ImportError(msg) from e
+
+        generator = prng.create_random_device_generator("/dev/urandom")
+
+    else:
+        generator = None
+
+    augmentations = [
+        transforms.RandomCrop(32, padding=4),
+        transforms.RandomHorizontalFlip(),
+    ]
+    normalize = [
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ]
+    train_transform = transforms.Compose(
+        augmentations + normalize if args.disable_dp else normalize
+    )
+
+    test_transform = transforms.Compose(normalize)
+
+    train_dataset = CIFAR10(
+        root=args.data_root, train=True, download=True, transform=train_transform
+    )
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        drop_last=True,
+        generator=generator,
+    )
+
+    test_dataset = CIFAR10(
+        root=args.data_root, train=False, download=True, transform=test_transform
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+    )
+
+    best_acc1 = 0
+    device = torch.device(args.device)
+    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
+    # model = CIFAR10Model()
+    model = model.to(device)
+
+    if args.optim == "SGD":
+        optimizer = optim.SGD(
+            model.parameters(),
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+        )
+    elif args.optim == "RMSprop":
+        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
+    elif args.optim == "Adam":
+        optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    else:
+        raise NotImplementedError("Optimizer not recognized. Please check spelling")
+
+    if not args.disable_dp:
+        privacy_engine = PrivacyEngine(
+            model,
+            batch_size=args.batch_size * args.n_accumulation_steps,
+            sample_size=len(train_dataset),
+            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
+            noise_multiplier=args.sigma,
+            max_grad_norm=args.max_per_sample_grad_norm,
+            secure_rng=args.secure_rng,
+            **clipping,
+        )
+        privacy_engine.attach(optimizer)
+
+    for epoch in range(args.start_epoch, args.epochs + 1):
+        train(args, model, train_loader, optimizer, epoch, device)
+        top1_acc = test(args, model, test_loader, device)
+
+        # remember best acc@1 and save checkpoint
+        is_best = top1_acc > best_acc1
+        best_acc1 = max(top1_acc, best_acc1)
+
+        save_checkpoint(
+            {
+                "epoch": epoch + 1,
+                "arch": "ResNet18",
+                "state_dict": model.state_dict(),
+                "best_acc1": best_acc1,
+                "optimizer": optimizer.state_dict(),
+            },
+            is_best,
+            filename=args.checkpoint_file + ".tar",
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/functorch/examples/dp_cifar10/make_functional.py b/functorch/examples/dp_cifar10/make_functional.py
new file mode 100644
index 0000000..0304c7f
--- /dev/null
+++ b/functorch/examples/dp_cifar10/make_functional.py
@@ -0,0 +1,71 @@
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import List, Tuple
+import copy
+
+# Utilities to make nn.Module "functional"
+# In particular the goal is to be able to provide a function that takes as input
+# the parameters and evaluate the nn.Module using fixed inputs.
+def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
+    """
+    Deletes the attribute specified by the given list of names.
+    For example, to delete the attribute obj.conv.weight,
+    use _del_nested_attr(obj, ['conv', 'weight'])
+    """
+    if len(names) == 1:
+        delattr(obj, names[0])
+    else:
+        _del_nested_attr(getattr(obj, names[0]), names[1:])
+
+def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
+    """
+    Set the attribute specified by the given list of names to value.
+    For example, to set the attribute obj.conv.weight,
+    use _del_nested_attr(obj, ['conv', 'weight'], value)
+    """
+    if len(names) == 1:
+        setattr(obj, names[0], value)
+    else:
+        _set_nested_attr(getattr(obj, names[0]), names[1:], value)
+
+def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
+    """
+    This function removes all the Parameters from the model and
+    return them as a tuple as well as their original attribute names.
+    The weights must be re-loaded with `load_weights` before the model
+    can be used again.
+    Note that this function modifies the model in place and after this
+    call, mod.parameters() will be empty.
+    """
+    orig_params = tuple(mod.parameters())
+    # Remove all the parameters in the model
+    names = []
+    for name, p in list(mod.named_parameters()):
+        _del_nested_attr(mod, name.split("."))
+        names.append(name)
+
+    # Make params regular Tensors instead of nn.Parameter
+    params = tuple(p for p in orig_params)
+    return params, names
+
+def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
+    """
+    Reload a set of weights so that `mod` can be used again to perform a forward pass.
+    Note that the `params` are regular Tensors (that can have history) and so are left
+    as Tensors. This means that mod.parameters() will still be empty after this call.
+    """
+    for name, p in zip(names, params):
+        if as_params:
+            p = nn.Parameter(p)
+        _set_nested_attr(mod, name.split("."), p)
+
+def make_functional(model: nn.Module):
+    weights, descriptors = extract_weights(model)
+
+    def fun(weights, data):
+        mutable_model = copy.deepcopy(model)
+        load_weights(mutable_model, descriptors, weights)
+        return mutable_model(*data)
+
+    return weights, fun, descriptors
diff --git a/functorch/examples/dp_cifar10/make_functional.py.~1~ b/functorch/examples/dp_cifar10/make_functional.py.~1~
new file mode 100644
index 0000000..c231a7b
--- /dev/null
+++ b/functorch/examples/dp_cifar10/make_functional.py.~1~
@@ -0,0 +1,71 @@
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import List, Tuple
+import copy
+
+# Utilities to make nn.Module "functional"
+# In particular the goal is to be able to provide a function that takes as input
+# the parameters and evaluate the nn.Module using fixed inputs.
+def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
+    """
+    Deletes the attribute specified by the given list of names.
+    For example, to delete the attribute obj.conv.weight,
+    use _del_nested_attr(obj, ['conv', 'weight'])
+    """
+    if len(names) == 1:
+        delattr(obj, names[0])
+    else:
+        _del_nested_attr(getattr(obj, names[0]), names[1:])
+
+def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
+    """
+    Set the attribute specified by the given list of names to value.
+    For example, to set the attribute obj.conv.weight,
+    use _del_nested_attr(obj, ['conv', 'weight'], value)
+    """
+    if len(names) == 1:
+        setattr(obj, names[0], value)
+    else:
+        _set_nested_attr(getattr(obj, names[0]), names[1:], value)
+
+def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
+    """
+    This function removes all the Parameters from the model and
+    return them as a tuple as well as their original attribute names.
+    The weights must be re-loaded with `load_weights` before the model
+    can be used again.
+    Note that this function modifies the model in place and after this
+    call, mod.parameters() will be empty.
+    """
+    orig_params = tuple(mod.parameters())
+    # Remove all the parameters in the model
+    names = []
+    for name, p in list(mod.named_parameters()):
+        _del_nested_attr(mod, name.split("."))
+        names.append(name)
+
+    # Make params regular Tensors instead of nn.Parameter
+    params = tuple(p.detach().requires_grad_() for p in orig_params)
+    return params, names
+
+def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
+    """
+    Reload a set of weights so that `mod` can be used again to perform a forward pass.
+    Note that the `params` are regular Tensors (that can have history) and so are left
+    as Tensors. This means that mod.parameters() will still be empty after this call.
+    """
+    for name, p in zip(names, params):
+        if as_params:
+            p = nn.Parameter(p)
+        _set_nested_attr(mod, name.split("."), p)
+
+def make_functional(model: nn.Module):
+    weights, descriptors = extract_weights(model)
+
+    def fun(weights, data):
+        mutable_model = copy.deepcopy(model)
+        load_weights(mutable_model, descriptors, weights)
+        return mutable_model(*data)
+
+    return weights, fun, descriptors
diff --git a/functorch/examples/ensembling/parallel_train.py b/functorch/examples/ensembling/parallel_train.py
new file mode 100644
index 0000000..fd95afa
--- /dev/null
+++ b/functorch/examples/ensembling/parallel_train.py
@@ -0,0 +1,122 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functorch import make_functional, grad_with_value, vmap
+
+# Adapted from http://willwhitney.com/parallel-training-jax.html
+# GOAL: Demonstrate that it is possible to use eager-mode vmap
+# to parallelize training over models.
+
+# NB: this code runs off of a branch on zou3519/pytorch:dynlayer
+
+DEVICE = 'cpu'
+
+# Step 1: Make some spirals
+def make_spirals(n_samples, noise_std=0., rotations=1.):
+    ts = torch.linspace(0, 1, n_samples, device=DEVICE)
+    rs = ts ** 0.5
+    thetas = rs * rotations * 2 * math.pi
+    signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
+    labels = (signs > 0).to(torch.long)
+
+    xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
+    ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
+    points = torch.stack([xs, ys], dim=1)
+    return points, labels
+
+points, labels = make_spirals(100, noise_std=0.05)
+
+
+# Step 2: Define two-layer MLP and loss function
+class MLPClassifier(nn.Module):
+    def __init__(self, hidden_dim=32, n_classes=2):
+        super().__init__()
+        self.hidden_dim = hidden_dim
+        self.n_classes = n_classes
+
+        self.fc1 = nn.Linear(2, self.hidden_dim)
+        self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.fc2(x)
+        x = F.log_softmax(x, -1)
+        return x
+
+loss_fn = nn.NLLLoss()
+
+# Step 3: Make the model functional(!!) and define a training function.
+# NB: this mechanism doesn't exist in PyTorch today, but we want it to:
+# https://github.com/pytorch/pytorch/issues/49171
+weights, func_model, _ = make_functional(MLPClassifier().to(DEVICE))
+
+def train_step_fn(weights, batch, targets, lr=0.2):
+    def compute_loss(weights, batch, targets):
+        output = func_model(weights, (batch,))
+        loss = loss_fn(output, targets)
+        return loss
+
+    grad_weights, loss = grad_with_value(compute_loss)(weights, batch, targets)
+
+    # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
+    # so we are going to re-implement SGD here.
+    new_weights = []
+    with torch.no_grad():
+        for grad_weight, weight in zip(grad_weights, weights):
+            new_weights.append(weight - grad_weight * lr)
+    # NB: return looks weird because torch.vmap must return Tensors
+    return (loss, *new_weights)
+
+
+def unpack(train_result):
+    return train_result[0], train_result[1:]
+
+# Step 4: Let's verify this actually trains.
+# We should see the loss decrease.
+def step4():
+    global weights
+    for i in range(2000):
+        loss, weights = unpack(train_step_fn(weights, points, labels))
+        if i % 100 == 0:
+            print(loss)
+
+step4()
+
+# Step 5: We're ready for multiple models. Let's define an init_fn
+# that, given a number of models, returns to us all of the weights.
+def init_fn(num_models):
+    models = tuple(MLPClassifier() for _ in range(num_models))
+    weights = tuple(make_functional(model)[0] for model in models)
+    weights = tuple(zip(*weights))
+    weights = tuple(torch.stack(shards).detach() for shards in weights)
+    return weights
+
+# Step 6: Now, can we try multiple models at the same time?
+# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
+# on decreasing
+def step6():
+    parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
+    batched_weights = init_fn(num_models=2)
+    for i in range(2000):
+        loss, batched_weights = unpack(parallel_train_step_fn(batched_weights, points, labels))
+        if i % 200 == 0:
+            print(loss)
+
+step6()
+
+# Step 7: Now, the flaw with step 6 is that we were training on the same exact
+# data. This can lead to all of the models in the ensemble overfitting in the 
+# same way. The solution that http://willwhitney.com/parallel-training-jax.html
+# applies is to randomly subset the data in a way that the models do not recieve
+# exactly the same data in each training step!
+# Because the goal of this doc is to show that we can use eager-mode vmap to
+# achieve similar things as JAX, the rest of this is left as an exercise to the reader.
+
+# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html
+# does, we used the following additional items that PyTorch does not have:
+# 1. NN module functional API that turns a module into a (state, state_less_fn) pair
+# 2. Functional optimizers
+# 3. A "functional" grad API (that effectively wraps autograd.grad)
+# 4. Composability between the functional grad API and torch.vmap.
diff --git a/functorch/examples/maml_omniglot/.gdbinit b/functorch/examples/maml_omniglot/.gdbinit
new file mode 100644
index 0000000..cf49729
--- /dev/null
+++ b/functorch/examples/maml_omniglot/.gdbinit
@@ -0,0 +1,2 @@
+catch throw
+r maml-omniglot-transforms.py
diff --git a/functorch/examples/maml_omniglot/.gitignore b/functorch/examples/maml_omniglot/.gitignore
new file mode 100644
index 0000000..1a2aff2
--- /dev/null
+++ b/functorch/examples/maml_omniglot/.gitignore
@@ -0,0 +1,3 @@
+omniglot/
+maml-accs.png
+
diff --git a/functorch/examples/maml_omniglot/README.md b/functorch/examples/maml_omniglot/README.md
new file mode 100644
index 0000000..06657f6
--- /dev/null
+++ b/functorch/examples/maml_omniglot/README.md
@@ -0,0 +1,17 @@
+# Omniglot MAML examples
+
+In this directory we've provided some examples of traning omniglot that reproduce the experiments from [the original MAML paper](https://arxiv.org/abs/1703.03400).
+
+They can be run via `python {filename}`.
+
+`maml-omniglot-higher.py` uses the [facebookresearch/higher](https://github.com/facebookresearch/higher) metalearning package and is the reference implementation. It runs all of its tasks sequentially.
+
+`maml-omniglot-transforms.py` uses an experimental vmap (and functional grad) prototype. It runs all of its tasks in parallel. In theory this should lead to some speedups, but we haven't finished implementing all the rules for vmap that would actually make training faster.
+
+`maml-omniglot-ptonly.py` is an implementation of `maml-omniglot-transforms.py` that runs all of its tasks sequentially (and also doesn't use the higher package).
+
+The prototype vmap used for these experiments currently run off of a branch.
+We'd love some feedback on the prototype and encourage folks to try it out.
+It's a bit difficult to install, but here are some options:
+1. If you're on the FAIR cluster, we can share a path to a conda environment
+2. We are looking into building binaries using our branch and shipping them.
diff --git a/functorch/examples/maml_omniglot/maml-omniglot-higher.py b/functorch/examples/maml_omniglot/maml-omniglot-higher.py
new file mode 100755
index 0000000..c5ad2b6
--- /dev/null
+++ b/functorch/examples/maml_omniglot/maml-omniglot-higher.py
@@ -0,0 +1,277 @@
+#!/usr/bin/env python3
+#
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
+for few-shot Omniglot classification.
+For more details see the original MAML paper:
+https://arxiv.org/abs/1703.03400
+
+This code has been modified from Jackie Loong's PyTorch MAML implementation:
+https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
+
+Our MAML++ fork and experiments are available at:
+https://github.com/bamos/HowToTrainYourMAMLPytorch
+"""
+
+import argparse
+import time
+import typing
+
+import pandas as pd
+import numpy as np
+import matplotlib as mpl
+mpl.use('Agg')
+import matplotlib.pyplot as plt
+plt.style.use('bmh')
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.optim as optim
+
+import higher
+
+from support.omniglot_loaders import OmniglotNShot
+
+
+def main():
+    argparser = argparse.ArgumentParser()
+    argparser.add_argument('--n_way', type=int, help='n way', default=5)
+    argparser.add_argument(
+        '--k_spt', type=int, help='k shot for support set', default=5)
+    argparser.add_argument(
+        '--k_qry', type=int, help='k shot for query set', default=15)
+    argparser.add_argument(
+        '--device', type=str, help='device', default='cuda')
+    argparser.add_argument(
+        '--task_num',
+        type=int,
+        help='meta batch size, namely task num',
+        default=32)
+    argparser.add_argument('--seed', type=int, help='random seed', default=1)
+    args = argparser.parse_args()
+
+    torch.manual_seed(args.seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(args.seed)
+    np.random.seed(args.seed)
+
+    # Set up the Omniglot loader.
+    device = args.device
+    db = OmniglotNShot(
+        '/tmp/omniglot-data',
+        batchsz=args.task_num,
+        n_way=args.n_way,
+        k_shot=args.k_spt,
+        k_query=args.k_qry,
+        imgsz=28,
+        device=device,
+    )
+
+    # Create a vanilla PyTorch neural network that will be
+    # automatically monkey-patched by higher later.
+    # Before higher, models could *not* be created like this
+    # and the parameters needed to be manually updated and copied
+    # for the updates.
+    net = nn.Sequential(
+        nn.Conv2d(1, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=True),
+        nn.MaxPool2d(2, 2),
+        nn.Conv2d(64, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=True),
+        nn.MaxPool2d(2, 2),
+        nn.Conv2d(64, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=True),
+        nn.MaxPool2d(2, 2),
+        Flatten(),
+        nn.Linear(64, args.n_way)).to(device)
+
+    # We will use Adam to (meta-)optimize the initial parameters
+    # to be adapted.
+    meta_opt = optim.Adam(net.parameters(), lr=1e-3)
+
+    log = []
+    for epoch in range(100):
+        train(db, net, device, meta_opt, epoch, log)
+        test(db, net, device, epoch, log)
+        plot(log)
+
+
+def train(db, net, device, meta_opt, epoch, log):
+    net.train()
+    n_train_iter = db.x_train.shape[0] // db.batchsz
+
+    for batch_idx in range(n_train_iter):
+        start_time = time.time()
+        # Sample a batch of support and query images and labels.
+        x_spt, y_spt, x_qry, y_qry = db.next()
+
+        task_num, setsz, c_, h, w = x_spt.size()
+        querysz = x_qry.size(1)
+
+        # TODO: Maybe pull this out into a separate module so it
+        # doesn't have to be duplicated between `train` and `test`?
+
+        # Initialize the inner optimizer to adapt the parameters to
+        # the support set.
+        n_inner_iter = 5
+        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
+
+        qry_losses = []
+        qry_accs = []
+        meta_opt.zero_grad()
+        for i in range(task_num):
+            with higher.innerloop_ctx(
+                net, inner_opt, copy_initial_weights=False
+            ) as (fnet, diffopt):
+                # Optimize the likelihood of the support set by taking
+                # gradient steps w.r.t. the model's parameters.
+                # This adapts the model's meta-parameters to the task.
+                # higher is able to automatically keep copies of
+                # your network's parameters as they are being updated.
+                for _ in range(n_inner_iter):
+                    spt_logits = fnet(x_spt[i])
+                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
+                    diffopt.step(spt_loss)
+
+                # The final set of adapted parameters will induce some
+                # final loss and accuracy on the query dataset.
+                # These will be used to update the model's meta-parameters.
+                qry_logits = fnet(x_qry[i])
+                qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+                qry_losses.append(qry_loss.detach())
+                qry_acc = (qry_logits.argmax(
+                    dim=1) == y_qry[i]).sum().item() / querysz
+                qry_accs.append(qry_acc)
+
+                # print([b.shape for b in fnet[1].buffers()])
+
+                # Update the model's meta-parameters to optimize the query
+                # losses across all of the tasks sampled in this batch.
+                # This unrolls through the gradient steps.
+                qry_loss.backward()
+
+        meta_opt.step()
+        qry_losses = sum(qry_losses) / task_num
+        qry_accs = 100. * sum(qry_accs) / task_num
+        i = epoch + float(batch_idx) / n_train_iter
+        iter_time = time.time() - start_time
+        if batch_idx % 4 == 0:
+            print(
+                f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+            )
+
+        log.append({
+            'epoch': i,
+            'loss': qry_losses,
+            'acc': qry_accs,
+            'mode': 'train',
+            'time': time.time(),
+        })
+
+
+def test(db, net, device, epoch, log):
+    # Crucially in our testing procedure here, we do *not* fine-tune
+    # the model during testing for simplicity.
+    # Most research papers using MAML for this task do an extra
+    # stage of fine-tuning here that should be added if you are
+    # adapting this code for research.
+    net.train()
+    n_test_iter = db.x_test.shape[0] // db.batchsz
+
+    qry_losses = []
+    qry_accs = []
+
+    for batch_idx in range(n_test_iter):
+        x_spt, y_spt, x_qry, y_qry = db.next('test')
+
+
+        task_num, setsz, c_, h, w = x_spt.size()
+        querysz = x_qry.size(1)
+
+        # TODO: Maybe pull this out into a separate module so it
+        # doesn't have to be duplicated between `train` and `test`?
+        n_inner_iter = 5
+        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
+
+        for i in range(task_num):
+            with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
+                # Optimize the likelihood of the support set by taking
+                # gradient steps w.r.t. the model's parameters.
+                # This adapts the model's meta-parameters to the task.
+                for _ in range(n_inner_iter):
+                    spt_logits = fnet(x_spt[i])
+                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
+                    diffopt.step(spt_loss)
+
+                # The query loss and acc induced by these parameters.
+                qry_logits = fnet(x_qry[i]).detach()
+                qry_loss = F.cross_entropy(
+                    qry_logits, y_qry[i], reduction='none')
+                qry_losses.append(qry_loss.detach())
+                qry_accs.append(
+                    (qry_logits.argmax(dim=1) == y_qry[i]).detach())
+
+    qry_losses = torch.cat(qry_losses).mean().item()
+    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
+    print(
+        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
+    )
+    log.append({
+        'epoch': epoch + 1,
+        'loss': qry_losses,
+        'acc': qry_accs,
+        'mode': 'test',
+        'time': time.time(),
+    })
+
+
+
+
+def plot(log):
+    # Generally you should pull your plotting code out of your training
+    # script but we are doing it here for brevity.
+    df = pd.DataFrame(log)
+
+    fig, ax = plt.subplots(figsize=(6, 4))
+    train_df = df[df['mode'] == 'train']
+    test_df = df[df['mode'] == 'test']
+    ax.plot(train_df['epoch'], train_df['acc'], label='Train')
+    ax.plot(test_df['epoch'], test_df['acc'], label='Test')
+    ax.set_xlabel('Epoch')
+    ax.set_ylabel('Accuracy')
+    ax.set_ylim(70, 100)
+    fig.legend(ncol=2, loc='lower right')
+    fig.tight_layout()
+    fname = 'maml-accs.png'
+    print(f'--- Plotting accuracy to {fname}')
+    fig.savefig(fname)
+    plt.close(fig)
+
+
+# Won't need this after this PR is merged in:
+# https://github.com/pytorch/pytorch/pull/22245
+class Flatten(nn.Module):
+    def forward(self, input):
+        return input.view(input.size(0), -1)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py
new file mode 100755
index 0000000..c41235d
--- /dev/null
+++ b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py
@@ -0,0 +1,270 @@
+#!/usr/bin/env python3
+#
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
+for few-shot Omniglot classification.
+For more details see the original MAML paper:
+https://arxiv.org/abs/1703.03400
+
+This code has been modified from Jackie Loong's PyTorch MAML implementation:
+https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
+
+Our MAML++ fork and experiments are available at:
+https://github.com/bamos/HowToTrainYourMAMLPytorch
+"""
+
+import argparse
+import time
+import typing
+
+import pandas as pd
+import numpy as np
+import matplotlib as mpl
+mpl.use('Agg')
+import matplotlib.pyplot as plt
+plt.style.use('bmh')
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.eager_transforms import make_functional_with_buffers
+
+import higher
+
+from support.omniglot_loaders import OmniglotNShot
+
+
+def main():
+    argparser = argparse.ArgumentParser()
+    argparser.add_argument('--n_way', type=int, help='n way', default=5)
+    argparser.add_argument(
+        '--k_spt', type=int, help='k shot for support set', default=5)
+    argparser.add_argument(
+        '--k_qry', type=int, help='k shot for query set', default=15)
+    argparser.add_argument(
+        '--device', type=str, help='device', default='cuda')
+    argparser.add_argument(
+        '--task_num',
+        type=int,
+        help='meta batch size, namely task num',
+        default=32)
+    argparser.add_argument('--seed', type=int, help='random seed', default=1)
+    args = argparser.parse_args()
+
+    torch.manual_seed(args.seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(args.seed)
+    np.random.seed(args.seed)
+
+    # Set up the Omniglot loader.
+    device = args.device
+    db = OmniglotNShot(
+        '/tmp/omniglot-data',
+        batchsz=args.task_num,
+        n_way=args.n_way,
+        k_shot=args.k_spt,
+        k_query=args.k_qry,
+        imgsz=28,
+        device=device,
+    )
+
+    # Create a vanilla PyTorch neural network that will be
+    # automatically monkey-patched by higher later.
+    # Before higher, models could *not* be created like this
+    # and the parameters needed to be manually updated and copied
+    # for the updates.
+    net = nn.Sequential(
+        nn.Conv2d(1, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=True),
+        nn.MaxPool2d(2, 2),
+        nn.Conv2d(64, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=True),
+        nn.MaxPool2d(2, 2),
+        nn.Conv2d(64, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=True),
+        nn.MaxPool2d(2, 2),
+        Flatten(),
+        nn.Linear(64, args.n_way)).to(device)
+
+    net.train()
+    params, buffers, fnet, _, _, = make_functional_with_buffers(net)
+
+    # We will use Adam to (meta-)optimize the initial parameters
+    # to be adapted.
+    meta_opt = optim.Adam(params, lr=1e-3)
+
+    log = []
+    for epoch in range(100):
+        train(db, [params, buffers, fnet], device, meta_opt, epoch, log)
+        test(db, [params, buffers, fnet], device, epoch, log)
+        plot(log)
+
+
+def train(db, net, device, meta_opt, epoch, log):
+    params, buffers, fnet = net
+    n_train_iter = db.x_train.shape[0] // db.batchsz
+
+    for batch_idx in range(n_train_iter):
+        start_time = time.time()
+        # Sample a batch of support and query images and labels.
+        x_spt, y_spt, x_qry, y_qry = db.next()
+
+        task_num, setsz, c_, h, w = x_spt.size()
+        querysz = x_qry.size(1)
+
+        # TODO: Maybe pull this out into a separate module so it
+        # doesn't have to be duplicated between `train` and `test`?
+
+        # Initialize the inner optimizer to adapt the parameters to
+        # the support set.
+        n_inner_iter = 5
+        # inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
+
+        qry_losses = []
+        qry_accs = []
+        meta_opt.zero_grad()
+        for i in range(task_num):
+            # Optimize the likelihood of the support set by taking
+            # gradient steps w.r.t. the model's parameters.
+            # This adapts the model's meta-parameters to the task.
+            new_params = params
+            for _ in range(n_inner_iter):
+                spt_logits = fnet(new_params, buffers, (x_spt[i],))
+                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
+                grads = torch.autograd.grad(spt_loss, new_params, create_graph=True)
+                new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
+
+            # The final set of adapted parameters will induce some
+            # final loss and accuracy on the query dataset.
+            # These will be used to update the model's meta-parameters.
+            qry_logits = fnet(new_params, buffers, (x_qry[i],))
+            qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+            qry_losses.append(qry_loss.detach())
+            qry_acc = (qry_logits.argmax(
+                dim=1) == y_qry[i]).sum().item() / querysz
+            qry_accs.append(qry_acc)
+
+            # Update the model's meta-parameters to optimize the query
+            # losses across all of the tasks sampled in this batch.
+            # This unrolls through the gradient steps.
+            qry_loss.backward()
+
+        meta_opt.step()
+        qry_losses = sum(qry_losses) / task_num
+        qry_accs = 100. * sum(qry_accs) / task_num
+        i = epoch + float(batch_idx) / n_train_iter
+        iter_time = time.time() - start_time
+        if batch_idx % 4 == 0:
+            print(
+                f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+            )
+
+        log.append({
+            'epoch': i,
+            'loss': qry_losses,
+            'acc': qry_accs,
+            'mode': 'train',
+            'time': time.time(),
+        })
+
+
+def test(db, net, device, epoch, log):
+    # Crucially in our testing procedure here, we do *not* fine-tune
+    # the model during testing for simplicity.
+    # Most research papers using MAML for this task do an extra
+    # stage of fine-tuning here that should be added if you are
+    # adapting this code for research.
+    [params, buffers, fnet] = net
+    n_test_iter = db.x_test.shape[0] // db.batchsz
+
+    qry_losses = []
+    qry_accs = []
+
+    for batch_idx in range(n_test_iter):
+        x_spt, y_spt, x_qry, y_qry = db.next('test')
+        task_num, setsz, c_, h, w = x_spt.size()
+
+        # TODO: Maybe pull this out into a separate module so it
+        # doesn't have to be duplicated between `train` and `test`?
+        n_inner_iter = 5
+
+        for i in range(task_num):
+            new_params = params
+            for _ in range(n_inner_iter):
+                spt_logits = fnet(new_params, buffers, (x_spt[i],))
+                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
+                grads = torch.autograd.grad(spt_loss, new_params)
+                new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
+
+            # The query loss and acc induced by these parameters.
+            qry_logits = fnet(new_params, buffers, (x_qry[i],)).detach()
+            qry_loss = F.cross_entropy(
+                qry_logits, y_qry[i], reduction='none')
+            qry_losses.append(qry_loss.detach())
+            qry_accs.append(
+                (qry_logits.argmax(dim=1) == y_qry[i]).detach())
+
+    qry_losses = torch.cat(qry_losses).mean().item()
+    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
+    print(
+        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
+    )
+    log.append({
+        'epoch': epoch + 1,
+        'loss': qry_losses,
+        'acc': qry_accs,
+        'mode': 'test',
+        'time': time.time(),
+    })
+
+
+
+
+def plot(log):
+    # Generally you should pull your plotting code out of your training
+    # script but we are doing it here for brevity.
+    df = pd.DataFrame(log)
+
+    fig, ax = plt.subplots(figsize=(6, 4))
+    train_df = df[df['mode'] == 'train']
+    test_df = df[df['mode'] == 'test']
+    ax.plot(train_df['epoch'], train_df['acc'], label='Train')
+    ax.plot(test_df['epoch'], test_df['acc'], label='Test')
+    ax.set_xlabel('Epoch')
+    ax.set_ylabel('Accuracy')
+    ax.set_ylim(70, 100)
+    fig.legend(ncol=2, loc='lower right')
+    fig.tight_layout()
+    fname = 'maml-accs.png'
+    print(f'--- Plotting accuracy to {fname}')
+    fig.savefig(fname)
+    plt.close(fig)
+
+
+# Won't need this after this PR is merged in:
+# https://github.com/pytorch/pytorch/pull/22245
+class Flatten(nn.Module):
+    def forward(self, input):
+        return input.view(input.size(0), -1)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py
new file mode 100755
index 0000000..c23cebc
--- /dev/null
+++ b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+#
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
+for few-shot Omniglot classification.
+For more details see the original MAML paper:
+https://arxiv.org/abs/1703.03400
+
+This code has been modified from Jackie Loong's PyTorch MAML implementation:
+https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
+
+Our MAML++ fork and experiments are available at:
+https://github.com/bamos/HowToTrainYourMAMLPytorch
+"""
+
+import argparse
+import time
+import typing
+import functools
+
+import pandas as pd
+import numpy as np
+import matplotlib as mpl
+mpl.use('Agg')
+import matplotlib.pyplot as plt
+plt.style.use('bmh')
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.optim as optim
+
+from functorch import make_functional_with_buffers, vmap, grad
+
+import higher
+
+from support.omniglot_loaders import OmniglotNShot
+# torch._C._debug_only_display_vmap_fallback_warnings(True)
+
+
+def main():
+    argparser = argparse.ArgumentParser()
+    argparser.add_argument('--n_way', type=int, help='n way', default=5)
+    argparser.add_argument(
+        '--k_spt', type=int, help='k shot for support set', default=5)
+    argparser.add_argument(
+        '--k_qry', type=int, help='k shot for query set', default=15)
+    argparser.add_argument(
+        '--device', type=str, help='device', default='cuda')
+    argparser.add_argument(
+        '--task_num',
+        type=int,
+        help='meta batch size, namely task num',
+        default=32)
+    argparser.add_argument('--seed', type=int, help='random seed', default=1)
+    args = argparser.parse_args()
+
+    torch.manual_seed(args.seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(args.seed)
+    np.random.seed(args.seed)
+
+    # Set up the Omniglot loader.
+    device = args.device
+    db = OmniglotNShot(
+        '/tmp/omniglot-data',
+        batchsz=args.task_num,
+        n_way=args.n_way,
+        k_shot=args.k_spt,
+        k_query=args.k_qry,
+        imgsz=28,
+        device=device,
+    )
+
+    # Create a vanilla PyTorch neural network.
+    # TODO: The prototype doesn't support in-place relu (and some other
+    # in-place operations. That can be fixed.)
+    inplace_relu = False
+    net = nn.Sequential(
+        nn.Conv2d(1, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=inplace_relu),
+        nn.MaxPool2d(2, 2),
+        nn.Conv2d(64, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=inplace_relu),
+        nn.MaxPool2d(2, 2),
+        nn.Conv2d(64, 64, 3),
+        nn.BatchNorm2d(64, momentum=1, affine=True),
+        nn.ReLU(inplace=inplace_relu),
+        nn.MaxPool2d(2, 2),
+        Flatten(),
+        nn.Linear(64, args.n_way)).to(device)
+
+    net.train()
+
+    # Given this module we've created, rip out the parameters and buffers
+    # and return a functional version of the module. `fnet` is stateless
+    # and can be called with `fnet(params, buffers, args, kwargs)`
+    params, buffers, fnet, _, _, = make_functional_with_buffers(net)
+
+    # We will use Adam to (meta-)optimize the initial parameters
+    # to be adapted.
+    meta_opt = optim.Adam(params, lr=1e-3)
+
+    log = []
+    for epoch in range(100):
+        train(db, [params, buffers, fnet], device, meta_opt, epoch, log)
+        test(db, [params, buffers, fnet], device, epoch, log)
+        plot(log)
+
+
+# Trains a model for n_inner_iter using the support and returns a loss
+# using the query.
+def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
+    params, buffers, fnet = net
+    querysz = x_qry.size(0)
+
+    def compute_loss(new_params, buffers, x, y):
+        logits = fnet(new_params, buffers, (x,))
+        loss = F.cross_entropy(logits, y)
+        return loss
+
+    new_params = params
+    for _ in range(n_inner_iter):
+        grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
+        new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
+
+    # The final set of adapted parameters will induce some
+    # final loss and accuracy on the query dataset.
+    # These will be used to update the model's meta-parameters.
+    qry_logits = fnet(new_params, buffers, (x_qry,))
+    qry_loss = F.cross_entropy(qry_logits, y_qry)
+    qry_acc = (qry_logits.argmax(
+        dim=1) == y_qry).sum() / querysz
+
+    return qry_loss, qry_acc
+
+
+def train(db, net, device, meta_opt, epoch, log):
+    params, buffers, fnet = net
+    n_train_iter = db.x_train.shape[0] // db.batchsz
+
+    for batch_idx in range(n_train_iter):
+        start_time = time.time()
+        # Sample a batch of support and query images and labels.
+        x_spt, y_spt, x_qry, y_qry = db.next()
+
+        task_num, setsz, c_, h, w = x_spt.size()
+
+        n_inner_iter = 5
+        meta_opt.zero_grad()
+
+        # In parallel, trains one model per task. There is a support (x, y)
+        # for each task and a query (x, y) for each task.
+        compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter)
+        qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)
+
+        # Compute the maml loss by summing together the returned losses.
+        qry_losses.sum().backward()
+
+        meta_opt.step()
+        qry_losses = qry_losses.detach().sum() / task_num
+        qry_accs = 100. * qry_accs.sum() / task_num
+        i = epoch + float(batch_idx) / n_train_iter
+        iter_time = time.time() - start_time
+        if batch_idx % 4 == 0:
+            print(
+                f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+            )
+
+        log.append({
+            'epoch': i,
+            'loss': qry_losses,
+            'acc': qry_accs,
+            'mode': 'train',
+            'time': time.time(),
+        })
+
+
+def test(db, net, device, epoch, log):
+    # Crucially in our testing procedure here, we do *not* fine-tune
+    # the model during testing for simplicity.
+    # Most research papers using MAML for this task do an extra
+    # stage of fine-tuning here that should be added if you are
+    # adapting this code for research.
+    [params, buffers, fnet] = net
+    n_test_iter = db.x_test.shape[0] // db.batchsz
+
+    qry_losses = []
+    qry_accs = []
+
+    for batch_idx in range(n_test_iter):
+        x_spt, y_spt, x_qry, y_qry = db.next('test')
+        task_num, setsz, c_, h, w = x_spt.size()
+
+        # TODO: Maybe pull this out into a separate module so it
+        # doesn't have to be duplicated between `train` and `test`?
+        n_inner_iter = 5
+
+        for i in range(task_num):
+            new_params = params
+            for _ in range(n_inner_iter):
+                spt_logits = fnet(new_params, buffers, (x_spt[i],))
+                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
+                grads = torch.autograd.grad(spt_loss, new_params)
+                new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
+
+            # The query loss and acc induced by these parameters.
+            qry_logits = fnet(new_params, buffers, (x_qry[i],)).detach()
+            qry_loss = F.cross_entropy(
+                qry_logits, y_qry[i], reduction='none')
+            qry_losses.append(qry_loss.detach())
+            qry_accs.append(
+                (qry_logits.argmax(dim=1) == y_qry[i]).detach())
+
+    qry_losses = torch.cat(qry_losses).mean().item()
+    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
+    print(
+        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
+    )
+    log.append({
+        'epoch': epoch + 1,
+        'loss': qry_losses,
+        'acc': qry_accs,
+        'mode': 'test',
+        'time': time.time(),
+    })
+
+
+
+
+def plot(log):
+    # Generally you should pull your plotting code out of your training
+    # script but we are doing it here for brevity.
+    df = pd.DataFrame(log)
+
+    fig, ax = plt.subplots(figsize=(6, 4))
+    train_df = df[df['mode'] == 'train']
+    test_df = df[df['mode'] == 'test']
+    ax.plot(train_df['epoch'], train_df['acc'], label='Train')
+    ax.plot(test_df['epoch'], test_df['acc'], label='Test')
+    ax.set_xlabel('Epoch')
+    ax.set_ylabel('Accuracy')
+    ax.set_ylim(70, 100)
+    fig.legend(ncol=2, loc='lower right')
+    fig.tight_layout()
+    fname = 'maml-accs.png'
+    print(f'--- Plotting accuracy to {fname}')
+    fig.savefig(fname)
+    plt.close(fig)
+
+
+# Won't need this after this PR is merged in:
+# https://github.com/pytorch/pytorch/pull/22245
+class Flatten(nn.Module):
+    def forward(self, input):
+        return input.view(input.size(0), -1)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/functorch/examples/maml_omniglot/support/omniglot_loaders.py b/functorch/examples/maml_omniglot/support/omniglot_loaders.py
new file mode 100644
index 0000000..31118a0
--- /dev/null
+++ b/functorch/examples/maml_omniglot/support/omniglot_loaders.py
@@ -0,0 +1,303 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation:
+#     https://github.com/dragen1860/MAML-Pytorch
+#     https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
+#     https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py
+
+import  torchvision.transforms as transforms
+from    PIL import Image
+import  numpy as np
+
+import torch
+import  torch.utils.data as data
+import  os
+import  os.path
+import  errno
+
+
+class Omniglot(data.Dataset):
+    urls = [
+        'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
+        'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
+    ]
+    raw_folder = 'raw'
+    processed_folder = 'processed'
+    training_file = 'training.pt'
+    test_file = 'test.pt'
+
+    '''
+    The items are (filename,category). The index of all the categories can be found in self.idx_classes
+    Args:
+    - root: the directory where the dataset will be stored
+    - transform: how to transform the input
+    - target_transform: how to transform the target
+    - download: need to download the dataset
+    '''
+
+    def __init__(self, root, transform=None, target_transform=None,
+                 download=False):
+        self.root = root
+        self.transform = transform
+        self.target_transform = target_transform
+
+        if not self._check_exists():
+            if download:
+                self.download()
+            else:
+                raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
+
+        self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
+        self.idx_classes = index_classes(self.all_items)
+
+    def __getitem__(self, index):
+        filename = self.all_items[index][0]
+        img = str.join('/', [self.all_items[index][2], filename])
+
+        target = self.idx_classes[self.all_items[index][1]]
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self):
+        return len(self.all_items)
+
+    def _check_exists(self):
+        return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
+               os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
+
+    def download(self):
+        from six.moves import urllib
+        import zipfile
+
+        if self._check_exists():
+            return
+
+        # download files
+        try:
+            os.makedirs(os.path.join(self.root, self.raw_folder))
+            os.makedirs(os.path.join(self.root, self.processed_folder))
+        except OSError as e:
+            if e.errno == errno.EEXIST:
+                pass
+            else:
+                raise
+
+        for url in self.urls:
+            print('== Downloading ' + url)
+            data = urllib.request.urlopen(url)
+            filename = url.rpartition('/')[2]
+            file_path = os.path.join(self.root, self.raw_folder, filename)
+            with open(file_path, 'wb') as f:
+                f.write(data.read())
+            file_processed = os.path.join(self.root, self.processed_folder)
+            print("== Unzip from " + file_path + " to " + file_processed)
+            zip_ref = zipfile.ZipFile(file_path, 'r')
+            zip_ref.extractall(file_processed)
+            zip_ref.close()
+        print("Download finished.")
+
+
+def find_classes(root_dir):
+    retour = []
+    for (root, dirs, files) in os.walk(root_dir):
+        for f in files:
+            if (f.endswith("png")):
+                r = root.split('/')
+                lr = len(r)
+                retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
+    print("== Found %d items " % len(retour))
+    return retour
+
+
+def index_classes(items):
+    idx = {}
+    for i in items:
+        if i[1] not in idx:
+            idx[i[1]] = len(idx)
+    print("== Found %d classes" % len(idx))
+    return idx
+
+
+class OmniglotNShot:
+
+    def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None):
+        """
+        Different from mnistNShot, the
+        :param root:
+        :param batchsz: task num
+        :param n_way:
+        :param k_shot:
+        :param k_qry:
+        :param imgsz:
+        """
+
+        self.resize = imgsz
+        self.device = device
+        if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
+            # if root/data.npy does not exist, just download it
+            self.x = Omniglot(
+                root, download=True,
+                transform=transforms.Compose(
+                    [lambda x: Image.open(x).convert('L'),
+                     lambda x: x.resize((imgsz, imgsz)),
+                     lambda x: np.reshape(x, (imgsz, imgsz, 1)),
+                     lambda x: np.transpose(x, [2, 0, 1]),
+                     lambda x: x/255.]),
+            )
+
+            temp = dict()  # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
+            for (img, label) in self.x:
+                if label in temp.keys():
+                    temp[label].append(img)
+                else:
+                    temp[label] = [img]
+
+            self.x = []
+            for label, imgs in temp.items():  # labels info deserted , each label contains 20imgs
+                self.x.append(np.array(imgs))
+
+            # as different class may have different number of imgs
+            self.x = np.array(self.x).astype(np.float)  # [[20 imgs],..., 1623 classes in total]
+            # each character contains 20 imgs
+            print('data shape:', self.x.shape)  # [1623, 20, 84, 84, 1]
+            temp = []  # Free memory
+            # save all dataset into npy file.
+            np.save(os.path.join(root, 'omniglot.npy'), self.x)
+            print('write into omniglot.npy.')
+        else:
+            # if data.npy exists, just load it.
+            self.x = np.load(os.path.join(root, 'omniglot.npy'))
+            print('load from omniglot.npy.')
+
+        # [1623, 20, 84, 84, 1]
+        # TODO: can not shuffle here, we must keep training and test set distinct!
+        self.x_train, self.x_test = self.x[:1200], self.x[1200:]
+
+        # self.normalization()
+
+        self.batchsz = batchsz
+        self.n_cls = self.x.shape[0]  # 1623
+        self.n_way = n_way  # n way
+        self.k_shot = k_shot  # k shot
+        self.k_query = k_query  # k query
+        assert (k_shot + k_query) <=20
+
+        # save pointer of current read batch in total cache
+        self.indexes = {"train": 0, "test": 0}
+        self.datasets = {"train": self.x_train, "test": self.x_test}  # original data cached
+        print("DB: train", self.x_train.shape, "test", self.x_test.shape)
+
+        self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]),  # current epoch data cached
+                               "test": self.load_data_cache(self.datasets["test"])}
+
+    def normalization(self):
+        """
+        Normalizes our data, to have a mean of 0 and sdt of 1
+        """
+        self.mean = np.mean(self.x_train)
+        self.std = np.std(self.x_train)
+        self.max = np.max(self.x_train)
+        self.min = np.min(self.x_train)
+        # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
+        self.x_train = (self.x_train - self.mean) / self.std
+        self.x_test = (self.x_test - self.mean) / self.std
+
+        self.mean = np.mean(self.x_train)
+        self.std = np.std(self.x_train)
+        self.max = np.max(self.x_train)
+        self.min = np.min(self.x_train)
+
+    # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
+
+    def load_data_cache(self, data_pack):
+        """
+        Collects several batches data for N-shot learning
+        :param data_pack: [cls_num, 20, 84, 84, 1]
+        :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
+        """
+        #  take 5 way 1 shot as example: 5 * 1
+        setsz = self.k_shot * self.n_way
+        querysz = self.k_query * self.n_way
+        data_cache = []
+
+        # print('preload next 50 caches of batchsz of batch.')
+        for sample in range(10):  # num of episodes
+
+            x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
+            for i in range(self.batchsz):  # one batch means one set
+
+                x_spt, y_spt, x_qry, y_qry = [], [], [], []
+                selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)
+
+                for j, cur_class in enumerate(selected_cls):
+
+                    selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
+
+                    # meta-training and meta-test
+                    x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]])
+                    x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
+                    y_spt.append([j for _ in range(self.k_shot)])
+                    y_qry.append([j for _ in range(self.k_query)])
+
+                # shuffle inside a batch
+                perm = np.random.permutation(self.n_way * self.k_shot)
+                x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
+                y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
+                perm = np.random.permutation(self.n_way * self.k_query)
+                x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
+                y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
+
+                # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
+                x_spts.append(x_spt)
+                y_spts.append(y_spt)
+                x_qrys.append(x_qry)
+                y_qrys.append(y_qry)
+
+
+            # [b, setsz, 1, 84, 84]
+            x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
+            y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz)
+            # [b, qrysz, 1, 84, 84]
+            x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
+            y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz)
+
+            x_spts, y_spts, x_qrys, y_qrys = [
+                torch.from_numpy(z).to(self.device) for z in
+                [x_spts, y_spts, x_qrys, y_qrys]
+            ]
+
+            data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
+
+        return data_cache
+
+    def next(self, mode='train'):
+        """
+        Gets next batch from the dataset with name.
+        :param mode: The name of the splitting (one of "train", "val", "test")
+        :return:
+        """
+        # update cache if indexes is larger cached num
+        if self.indexes[mode] >= len(self.datasets_cache[mode]):
+            self.indexes[mode] = 0
+            self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
+
+        next_batch = self.datasets_cache[mode][self.indexes[mode]]
+        self.indexes[mode] += 1
+
+        return next_batch
diff --git a/functorch/examples/maml_regression/evjang.py b/functorch/examples/maml_regression/evjang.py
new file mode 100644
index 0000000..cae6975
--- /dev/null
+++ b/functorch/examples/maml_regression/evjang.py
@@ -0,0 +1,113 @@
+import math
+import random
+import torch
+import numpy as np
+from torch import nn
+from torch.nn import functional as F
+import matplotlib as mpl
+mpl.use('Agg')
+import matplotlib.pyplot as plt
+
+def net(x, params):
+    x = F.linear(x, params[0], params[1])
+    x = F.relu(x)
+
+    x = F.linear(x, params[2], params[3])
+    x = F.relu(x)
+
+    x = F.linear(x, params[4], params[5])
+    return x
+
+params = [
+    torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
+    torch.Tensor(40).zero_().requires_grad_(),
+
+    torch.Tensor(40, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
+    torch.Tensor(40).zero_().requires_grad_(),
+
+    torch.Tensor(1, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
+    torch.Tensor(1).zero_().requires_grad_(),
+]
+
+opt = torch.optim.Adam(params, lr=1e-3)
+alpha = 0.1
+
+K = 20
+losses = []
+num_tasks = 4
+def sample_tasks(outer_batch_size, inner_batch_size):
+    # Select amplitude and phase for the task
+    As = []
+    phases = []
+    for _ in range(outer_batch_size):
+        As.append(np.random.uniform(low=0.1, high=.5))
+        phases.append(np.random.uniform(low=0., high=np.pi))
+    def get_batch():
+        xs, ys = [], []
+        for A, phase in zip(As, phases):
+            x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
+            y = A * np.sin(x + phase)
+            xs.append(x)
+            ys.append(y)
+        return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
+    x1, y1 = get_batch()
+    x2, y2 = get_batch()
+    return x1, y1, x2, y2
+
+for it in range(20000):
+    loss2 = 0.0
+    opt.zero_grad()
+    def get_loss_for_task(x1, y1, x2, y2):
+        f = net(x1, params)
+        loss = F.mse_loss(f, y1)
+
+        # create_graph=True because computing grads here is part of the forward pass.
+        # We want to differentiate through the SGD update steps and get higher order
+        # derivatives in the backward pass.
+        grads = torch.autograd.grad(loss, params, create_graph=True)
+        new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
+
+        v_f = net(x2, new_params)
+        return F.mse_loss(v_f, y2)
+
+    task = sample_tasks(num_tasks, K)
+    inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)]
+    loss2 = sum(inner_losses)/len(inner_losses)
+    loss2.backward()
+
+    opt.step()
+
+    if it % 100 == 0:
+        print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
+    losses.append(loss2)
+
+t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
+t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
+
+t_x = torch.empty(4, 1).uniform_(-5, 5)
+t_y = t_A*torch.sin(t_x + t_b)
+
+opt.zero_grad()
+
+t_params = params
+for k in range(5):
+    t_f = net(t_x, t_params)
+    t_loss = F.l1_loss(t_f, t_y)
+
+    grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
+    t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
+
+
+test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
+test_y = t_A*torch.sin(test_x + t_b)
+
+test_f = net(test_x, t_params)
+
+plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
+plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
+plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
+plt.legend()
+plt.savefig('maml-sine.png')
+plt.figure()
+plt.plot(np.convolve(losses, [.05]*20))
+plt.savefig('losses.png')
\ No newline at end of file
diff --git a/functorch/examples/maml_regression/evjang_transforms.py b/functorch/examples/maml_regression/evjang_transforms.py
new file mode 100644
index 0000000..f8c3472
--- /dev/null
+++ b/functorch/examples/maml_regression/evjang_transforms.py
@@ -0,0 +1,118 @@
+import math
+import random
+import torch
+import numpy as np
+from torch import nn
+from torch.nn import functional as F
+import matplotlib as mpl
+mpl.use('Agg')
+import matplotlib.pyplot as plt
+
+from functorch import grad, vmap
+
+def net(params, x):
+    x = F.linear(x, params[0], params[1])
+    x = F.relu(x)
+
+    x = F.linear(x, params[2], params[3])
+    x = F.relu(x)
+
+    x = F.linear(x, params[4], params[5])
+    return x
+
+params = [
+    torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
+    torch.Tensor(40).zero_().requires_grad_(),
+
+    torch.Tensor(40, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
+    torch.Tensor(40).zero_().requires_grad_(),
+
+    torch.Tensor(1, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
+    torch.Tensor(1).zero_().requires_grad_(),
+]
+
+# The prototype doesn't like F.mse_loss.
+def mse_loss(x, y):
+    return torch.mean((x - y) ** 2)
+
+opt = torch.optim.Adam(params, lr=1e-3)
+alpha = 0.1
+
+K = 20
+losses = []
+num_tasks = 4
+def sample_tasks(outer_batch_size, inner_batch_size):
+    # Select amplitude and phase for the task
+    As = []
+    phases = []
+    for _ in range(outer_batch_size):
+        As.append(np.random.uniform(low=0.1, high=.5))
+        phases.append(np.random.uniform(low=0., high=np.pi))
+    def get_batch():
+        xs, ys = [], []
+        for A, phase in zip(As, phases):
+            x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
+            y = A * np.sin(x + phase)
+            xs.append(x)
+            ys.append(y)
+        return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
+    x1, y1 = get_batch()
+    x2, y2 = get_batch()
+    return x1, y1, x2, y2
+
+for it in range(20000):
+    loss2 = 0.0
+    opt.zero_grad()
+    def get_loss_for_task(x1, y1, x2, y2):
+        def inner_loss(params, x1, y1):
+            f = net(params, x1)
+            loss = mse_loss(f, y1)
+            return loss
+
+        grads = grad(inner_loss)(tuple(params), x1, y1)
+        new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
+
+        v_f = net(new_params, x2)
+        return mse_loss(v_f, y2)
+
+    task = sample_tasks(num_tasks, K)
+    inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3])
+    loss2 = sum(inner_losses)/len(inner_losses)
+    loss2.backward()
+
+    opt.step()
+
+    if it % 100 == 0:
+        print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
+    losses.append(loss2)
+
+t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
+t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
+
+t_x = torch.empty(4, 1).uniform_(-5, 5)
+t_y = t_A*torch.sin(t_x + t_b)
+
+opt.zero_grad()
+
+t_params = params
+for k in range(5):
+    t_f = net(t_x, t_params)
+    t_loss = F.l1_loss(t_f, t_y)
+
+    grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
+    t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
+
+
+test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
+test_y = t_A*torch.sin(test_x + t_b)
+
+test_f = net(test_x, t_params)
+
+plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
+plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
+plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
+plt.legend()
+plt.savefig('maml-sine.png')
+plt.figure()
+plt.plot(np.convolve(losses, [.05]*20))
+plt.savefig('losses.png')
diff --git a/functorch/examples/maml_regression/evjang_transforms_module.py b/functorch/examples/maml_regression/evjang_transforms_module.py
new file mode 100644
index 0000000..0ddb48a
--- /dev/null
+++ b/functorch/examples/maml_regression/evjang_transforms_module.py
@@ -0,0 +1,115 @@
+import math
+import random
+import torch
+import numpy as np
+from torch import nn
+from torch.nn import functional as F
+import matplotlib as mpl
+mpl.use('Agg')
+import matplotlib.pyplot as plt
+
+from functorch import grad, vmap, make_functional
+
+class ThreeLayerNet(nn.Module):
+    def __init__(self):
+        super(ThreeLayerNet, self).__init__()
+        self.fc1 = nn.Linear(1, 40)
+        self.relu1 = nn.ReLU()
+        self.fc2 = nn.Linear(40, 40)
+        self.relu2 = nn.ReLU()
+        self.fc3 = nn.Linear(40, 1)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu1(x)
+        x = self.fc2(x)
+        x = self.relu2(x)
+        x = self.fc3(x)
+        return x
+
+# The prototype doesn't like F.mse_loss.
+def mse_loss(x, y):
+    return torch.mean((x - y) ** 2)
+
+params, net, _ = make_functional(ThreeLayerNet())
+opt = torch.optim.Adam(params, lr=1e-3)
+alpha = 0.1
+
+K = 20
+losses = []
+num_tasks = 4
+def sample_tasks(outer_batch_size, inner_batch_size):
+    # Select amplitude and phase for the task
+    As = []
+    phases = []
+    for _ in range(outer_batch_size):
+        As.append(np.random.uniform(low=0.1, high=.5))
+        phases.append(np.random.uniform(low=0., high=np.pi))
+    def get_batch():
+        xs, ys = [], []
+        for A, phase in zip(As, phases):
+            x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
+            y = A * np.sin(x + phase)
+            xs.append(x)
+            ys.append(y)
+        return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
+    x1, y1 = get_batch()
+    x2, y2 = get_batch()
+    return x1, y1, x2, y2
+
+for it in range(20000):
+    loss2 = 0.0
+    opt.zero_grad()
+    def get_loss_for_task(x1, y1, x2, y2):
+        def inner_loss(params, x1, y1):
+            f = net(params, (x1,))
+            loss = mse_loss(f, y1)
+            return loss
+
+        grads = grad(inner_loss)(params, x1, y1)
+        new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
+
+        v_f = net(new_params, (x2,))
+        return mse_loss(v_f, y2)
+
+    task = sample_tasks(num_tasks, K)
+    inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3])
+    loss2 = sum(inner_losses)/len(inner_losses)
+    loss2.backward()
+
+    opt.step()
+
+    if it % 100 == 0:
+        print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
+    losses.append(loss2)
+
+t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
+t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
+
+t_x = torch.empty(4, 1).uniform_(-5, 5)
+t_y = t_A*torch.sin(t_x + t_b)
+
+opt.zero_grad()
+
+t_params = params
+for k in range(5):
+    t_f = net(t_params, (t_x,))
+    t_loss = F.l1_loss(t_f, t_y)
+
+    grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
+    t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
+
+
+test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
+test_y = t_A*torch.sin(test_x + t_b)
+
+test_f = net(t_params, (test_x,))
+
+plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
+plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
+plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
+plt.legend()
+plt.savefig('maml-sine.png')
+plt.figure()
+plt.plot(np.convolve(losses, [.05]*20))
+plt.savefig('losses.png')
diff --git a/functorch/functorch/__init__.py b/functorch/functorch/__init__.py
new file mode 100644
index 0000000..c62bfc7
--- /dev/null
+++ b/functorch/functorch/__init__.py
@@ -0,0 +1,6 @@
+import torch
+from . import _C
+
+from ._src.vmap import vmap
+from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
+from ._src.make_functional import make_functional, make_functional_with_buffers
diff --git a/functorch/functorch/_src/__init__.py b/functorch/functorch/_src/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/functorch/functorch/_src/__init__.py
diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py
new file mode 100644
index 0000000..756081f
--- /dev/null
+++ b/functorch/functorch/_src/eager_transforms.py
@@ -0,0 +1,185 @@
+import torch
+from functools import partial
+import collections
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.make_functional import make_functional, make_functional_with_buffers
+import gc
+
+from .vmap import vmap
+
+from functorch._C import (
+    _wrap_for_grad,
+    _unwrap_for_grad,
+    _grad_increment_nesting,
+    _grad_decrement_nesting,
+)
+
+# x = torch.ones(2, 3)
+# y = torch.ones(2, 3)
+# # result = vmap(torch.add)(x, y)
+# result = vmap(vmap(torch.add))(x, y)
+
+# assert torch.allclose(result, x + y)
+
+# TODO: replace all of these with pytrees
+def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
+    if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
+        tensor = tensor_or_tuple_of_tensors
+        aliased = tensor
+        return aliased.requires_grad_()
+    if isinstance(tensor_or_tuple_of_tensors, tuple):
+        return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
+    if isinstance(tensor_or_tuple_of_tensors, list):
+        return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
+    raise ValueError(f'Thing passed to transform API must be Tensor, List or Tuple, '
+                     f'got {type(tensor_or_tuple_of_tensors)}')
+
+def _undo_create_differentiable(tensor_or_tuple_of_tensors, level=None):
+    if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
+        tensor = tensor_or_tuple_of_tensors
+        return _unwrap_for_grad(tensor, level)
+    if isinstance(tensor_or_tuple_of_tensors, tuple):
+        return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
+    if isinstance(tensor_or_tuple_of_tensors, list):
+        return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
+    assert False
+
+def _any_differentiable(tensor_or_tuple_of_tensors):
+    if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
+        tensor = tensor_or_tuple_of_tensors
+        return tensor.requires_grad
+    if isinstance(tensor_or_tuple_of_tensors, tuple):
+        return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
+    if isinstance(tensor_or_tuple_of_tensors, list):
+        return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
+    return False
+
+def _wrap_all_tensors(tensor_or_tuple_of_tensors, level):
+    if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
+        tensor = tensor_or_tuple_of_tensors
+        return _wrap_for_grad(tensor, level)
+    if isinstance(tensor_or_tuple_of_tensors, tuple):
+        return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
+    if isinstance(tensor_or_tuple_of_tensors, list):
+        return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
+    return tensor_or_tuple_of_tensors
+
+# How do we increment and decrement the nesting? I don't think we can.
+def vjp(f, *primals):
+    level = _grad_increment_nesting()
+    try:
+        primals = _wrap_all_tensors(primals, level)
+        diff_primals = _create_differentiable(primals, level)
+        primals_out = f(*diff_primals)
+        results = _undo_create_differentiable(primals_out, level)
+
+        def wrapper(*cotangents, retain_graph=True, create_graph=True):
+            result = torch.autograd.grad(primals_out, diff_primals, cotangents,
+                                         retain_graph=retain_graph, create_graph=create_graph)
+            return result
+
+    finally:
+        _grad_decrement_nesting()
+
+    return results, wrapper
+
+def jacrev(f):
+    def wrapper_fn(primal):
+        output, vjp_fn = vjp(f, primal)
+        assert isinstance(output, torch.Tensor)
+        # TODO: does jacrev compose with vmap...? the eye call should make it so that it doesn't
+        basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device) \
+                     .view(output.numel(), *output.shape)
+        result, = vmap(vjp_fn)(basis)
+        result = result.view(*output.shape, *primal.shape)
+        return result
+    return wrapper_fn
+
+# 
+# 
+# def jacrev(f, diff_argnums=(0,)):
+#     def wrapper(*args):
+#         torch._C._grad_increment_nesting()
+#         output = None
+#         grad_outputs = None
+#         try:
+#             args = [_create_differentiable(arg) if i in diff_argnums else arg
+#                     for i, arg in enumerate(args)]
+#             output = f(*args)
+#             # Only support single tensor output for now
+#             assert isinstance(output, torch.Tensor)
+#             output_numel = output.numel()
+#             if output_numel != 0:
+#                 grad_output = torch.eye(output_numel).view(output_numel, *output.shape)
+# 
+#             diff_args = [args[i] for i in diff_argnums]
+#             single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
+#             # TODO: quick hack...
+#             if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
+#                 diff_args = diff_args[0]
+#             # NB: need create_graph so that backward pass isn't run in no_grad mode
+# 
+#             def compute_vjp(v):
+#                 return torch.autograd.grad(output, diff_args, v, create_graph=True)
+# 
+#             if output_numel == 0:
+#                 grad_input = compute_vjp(grad_output)
+#             else:
+#                 grad_input = vmap(compute_vjp)(grad_output)
+# 
+#             if single_diff_arg:
+#                 grad_input = grad_input[0]
+#         finally:
+#             _undo_create_differentiable(args)
+#             torch._C._grad_decrement_nesting()
+#         return grad_input, output
+#     return wrapper
+
+def grad_with_value(f, diff_argnums=(0,), has_aux=False):
+    def wrapper(*args):
+        level = _grad_increment_nesting()
+        output, aux, grad_input = None, None, None
+        try:
+            args = _wrap_all_tensors(args, level)
+            args = [_create_differentiable(arg, level) if i in diff_argnums else arg
+                    for i, arg in enumerate(args)]
+            # print("calling f(*args)")
+            output = f(*args)
+            # print("done with f(*args)")
+            if has_aux:
+                output, aux = output
+            # print("calling output.dim()")
+            assert output.dim() == 0
+            diff_args = [args[i] for i in diff_argnums]
+            single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
+            # TODO: quick hack...
+            if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
+                diff_args = diff_args[0]
+            # NB: need create_graph so that backward pass isn't run in no_grad mode
+            # import torchviz; import graphviz
+            # graph = torchviz.make_dot(output)
+            # graph.save("inner.dot")
+            # print("calling autograd.grad")
+            grad_input = torch.autograd.grad(
+                output, diff_args, create_graph=True)
+            # print("done-ish!")
+            if single_diff_arg:
+                grad_input = grad_input[0]
+        finally:
+            if grad_input is not None:
+                grad_input = _undo_create_differentiable(grad_input, level)
+            _grad_decrement_nesting()
+        if has_aux:
+            return grad_input, output, aux
+        return grad_input, output
+    return wrapper
+
+def grad(f, diff_argnums=(0,), has_aux=False):
+    def wrapper(*args):
+        results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args)
+        if has_aux:
+            return results[0], results[2]
+        return results[0]
+    return wrapper
+
diff --git a/functorch/functorch/_src/make_functional.py b/functorch/functorch/_src/make_functional.py
new file mode 100644
index 0000000..ba6176e
--- /dev/null
+++ b/functorch/functorch/_src/make_functional.py
@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import List, Tuple
+import copy
+
+# Utilities to make nn.Module "functional"
+# In particular the goal is to be able to provide a function that takes as input
+# the parameters and evaluate the nn.Module using fixed inputs.
+def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
+    """
+    Deletes the attribute specified by the given list of names.
+    For example, to delete the attribute obj.conv.weight,
+    use _del_nested_attr(obj, ['conv', 'weight'])
+    """
+    if len(names) == 1:
+        delattr(obj, names[0])
+    else:
+        _del_nested_attr(getattr(obj, names[0]), names[1:])
+
+def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
+    """
+    Set the attribute specified by the given list of names to value.
+    For example, to set the attribute obj.conv.weight,
+    use _del_nested_attr(obj, ['conv', 'weight'], value)
+    """
+    if len(names) == 1:
+        setattr(obj, names[0], value)
+    else:
+        _set_nested_attr(getattr(obj, names[0]), names[1:], value)
+
+def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
+    """
+    This function removes all the Parameters from the model and
+    return them as a tuple as well as their original attribute names.
+    The weights must be re-loaded with `load_weights` before the model
+    can be used again.
+    Note that this function modifies the model in place and after this
+    call, mod.parameters() will be empty.
+    """
+    orig_params = tuple(mod.parameters())
+    # Remove all the parameters in the model
+    names = []
+    for name, p in list(mod.named_parameters()):
+        _del_nested_attr(mod, name.split("."))
+        names.append(name)
+
+    # Make params regular Tensors instead of nn.Parameter
+    params = tuple(p for p in orig_params)
+    return params, names
+
+def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
+    """
+    Reload a set of weights so that `mod` can be used again to perform a forward pass.
+    Note that the `params` are regular Tensors (that can have history) and so are left
+    as Tensors. This means that mod.parameters() will still be empty after this call.
+    """
+    for name, p in zip(names, params):
+        if as_params:
+            p = nn.Parameter(p)
+        _set_nested_attr(mod, name.split("."), p)
+
+def extract_buffers(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
+    orig_params = tuple(mod.buffers())
+    # Remove all the parameters in the model
+    names = []
+    for name, p in list(mod.named_buffers()):
+        _del_nested_attr(mod, name.split("."))
+        names.append(name)
+
+    # Make params regular Tensors instead of nn.Parameter
+    params = tuple(p for p in orig_params)
+    return params, names
+
+def load_buffers(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
+    for name, p in zip(names, params):
+        _set_nested_attr(mod, name.split("."), p)
+
+def make_functional(model: nn.Module):
+    weights, descriptors = extract_weights(model)
+
+    def fun(weights, data):
+        mutable_model = copy.deepcopy(model)
+        load_weights(mutable_model, descriptors, weights)
+        return mutable_model(*data)
+
+    return weights, fun, descriptors
+
+def make_functional_with_buffers(model: nn.Module):
+    weights, weight_descriptors = extract_weights(model)
+    buffers, buf_descriptors = extract_buffers(model)
+
+    def fun(weights, buffers, data):
+        mutable_model = copy.deepcopy(model)
+        load_weights(mutable_model, weight_descriptors, weights)
+        load_buffers(mutable_model, buf_descriptors, buffers)
+        return mutable_model(*data)
+
+    return weights, buffers, fun, weight_descriptors, buf_descriptors
diff --git a/functorch/functorch/_src/vmap.py b/functorch/functorch/_src/vmap.py
new file mode 100644
index 0000000..a836b08
--- /dev/null
+++ b/functorch/functorch/_src/vmap.py
@@ -0,0 +1,270 @@
+import torch
+import functools
+from torch import Tensor
+from typing import Any, Callable, Optional, Tuple, Union, List
+from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten
+import warnings
+
+from functorch._C import (
+    _add_batch_dim,
+    _remove_batch_dim,
+    _vmapmode_decrement_nesting,
+    _vmapmode_increment_nesting,
+)
+
+in_dims_t = Union[int, Tuple]
+out_dims_t = Union[int, Tuple[int, ...]]
+
+# Checks that all args-to-be-batched have the same batch dim size
+def _validate_and_get_batch_size(
+        flat_in_dims: List[Optional[int]],
+        flat_args: List) -> int:
+    batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)
+                   if in_dim is not None]
+    if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]):
+        raise ValueError(
+            f'vmap: Expected all tensors to have the same size in the mapped '
+            f'dimension, got sizes {batch_sizes} for the mapped dimension')
+    return batch_sizes[0]
+
+def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
+    if isinstance(batched_outputs, tuple):
+        return len(batched_outputs)
+    return 1
+
+# If value is a tuple, check it has length `num_elements`.
+# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
+def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple:
+    if not isinstance(value, tuple):
+        return (value,) * num_elements
+    if len(value) != num_elements:
+        raise ValueError(error_message_lambda())
+    return value
+
+# Creates BatchedTensors for every Tensor in arg that should be batched.
+# Returns the (potentially) batched arguments and the batch_size.
+def _create_batched_inputs(
+        in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable) -> Tuple[Tuple, int]:
+    if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
+        raise ValueError(
+            f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
+            f'expected `in_dims` to be int or a (potentially nested) tuple '
+            f'matching the structure of inputs, got: {type(in_dims)}.')
+    if len(args) == 0:
+        raise ValueError(
+            f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '
+            f'inputs, or you are trying to vmap over a function with no inputs. '
+            f'The latter is unsupported.')
+
+    flat_args, args_spec = tree_flatten(args)
+    flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
+    if flat_in_dims is None:
+        raise ValueError(
+            f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
+            f'in_dims is not compatible with the structure of `inputs`. '
+            f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
+            f'has structure {args_spec}.')
+
+    for arg, in_dim in zip(flat_args, flat_in_dims):
+        if not isinstance(in_dim, int) and in_dim is not None:
+            raise ValueError(
+                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
+                f'Got in_dim={in_dim} for an input but in_dim must be either '
+                f'an integer dimension or None.')
+        if isinstance(in_dim, int) and not isinstance(arg, Tensor):
+            raise ValueError(
+                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
+                f'Got in_dim={in_dim} for an input but the input is of type '
+                f'{type(arg)}. We cannot vmap over non-Tensor arguments, '
+                f'please use None as the respective in_dim')
+        if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
+            raise ValueError(
+                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
+                f'Got in_dim={in_dim} for some input, but that input is a Tensor '
+                f'of dimensionality {arg.dim()} so expected in_dim to satisfy '
+                f'0 <= in_dim < {arg.dim()}.')
+
+    batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
+    # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
+    batched_inputs = [arg if in_dim is None else
+                      _add_batch_dim(arg, in_dim, vmap_level)  # type: ignore
+                      for in_dim, arg in zip(flat_in_dims, flat_args)]
+    return tree_unflatten(batched_inputs, args_spec), batch_size
+
+# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
+def _unwrap_batched(
+        batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
+        out_dims: out_dims_t,
+        vmap_level: int, batch_size: int, func: Callable) -> Tuple:
+    num_outputs = _num_outputs(batched_outputs)
+    out_dims_as_tuple = _as_tuple(
+        out_dims, num_outputs,
+        lambda: f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must '
+                f'have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.')
+
+    # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
+    # There is something wrong with our type bindings for functions that begin
+    # with '_', see #40397.
+    if isinstance(batched_outputs, Tensor):
+        out_dim = out_dims_as_tuple[0]
+        return _remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim)  # type: ignore
+    return tuple(_remove_batch_dim(out, vmap_level, batch_size, out_dim)  # type: ignore
+                 for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
+
+# Checks that `fn` returned one or more Tensors and nothing else.
+# NB: A python function that return multiple arguments returns a single tuple,
+# so we are effectively checking that `outputs` is a single Tensor or a tuple of
+# Tensors.
+def _validate_outputs(outputs: Any, func: Callable) -> None:
+    if isinstance(outputs, Tensor):
+        return
+    if not isinstance(outputs, tuple):
+        raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
+                         f'Tensors, got type {type(outputs)} as the return.')
+    for idx, output in enumerate(outputs):
+        if isinstance(output, Tensor):
+            continue
+        raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
+                         f'Tensors, got type {type(output)} for return {idx}.')
+
+def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
+    if isinstance(out_dims, int):
+        return
+    if not isinstance(out_dims, tuple) or \
+            not all([isinstance(out_dim, int) for out_dim in out_dims]):
+        raise ValueError(
+            f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
+            f'an int or a tuple of int representing where in the outputs the '
+            f'vmapped dimension should appear.')
+
+def _get_name(func: Callable):
+    if hasattr(func, '__name__'):
+        return func.__name__
+
+    # Not all callables have __name__, in fact, only static functions/methods do.
+    # A callable created via functools.partial or an nn.Module, to name some
+    # examples, don't have a __name__.
+    return repr(func)
+
+# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
+# sends those into func, and then unwraps the output BatchedTensors. Operations
+# on BatchedTensors perform the batched operations that the user is asking for.
+def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
+    """
+    vmap is the vectorizing map. Returns a new function that maps `func` over some
+    dimension of the inputs. Semantically, vmap pushes the map into PyTorch
+    operations called by `func`, effectively vectorizing those operations.
+
+    vmap is useful for handling batch dimensions: one can write a function `func`
+    that runs on examples and then lift it to a function that can take batches of
+    examples with `vmap(func)`. vmap can also be used to compute batched
+    gradients when composed with autograd.
+
+    .. warning::
+        functorch.vmap is an experimental prototype that is subject to
+        change and/or deletion. Please use at your own risk.
+
+    .. note::
+        If you're interested in using vmap for your use case, please
+        `contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
+        We're interested in gathering feedback from early adopters to inform
+        the design.
+
+    Args:
+        func (function): A Python function that takes one or more arguments.
+            Must return one or more Tensors.
+        in_dims (int or nested structure): Specifies which dimension of the
+            inputs should be mapped over. `in_dims` should have a structure
+            like the inputs. If the `in_dim` for a particular input is None,
+            then that indicates there is no map dimension. Default: 0.
+        out_dims (int or Tuple[int]): Specifies where the mapped dimension
+            should appear in the outputs. If `out_dims` is a Tuple, then it should
+            have one element per output. Default: 0.
+
+    Returns:
+        Returns a new "batched" function. It takes the same inputs as `func`,
+        except each input has an extra dimension at the index specified by `in_dims`.
+        It takes returns the same outputs as `func`, except each output has
+        an extra dimension at the index specified by `out_dims`.
+
+    .. warning:
+        vmap works best with functional-style code. Please do not perform any
+        side-effects in `func`, with the exception of in-place PyTorch operations.
+        Examples of side-effects include mutating Python data structures and
+        assigning values to variables not captured in `func`.
+
+    One example of using `vmap` is to compute batched dot products. PyTorch
+    doesn't provide a batched `torch.dot` API; instead of unsuccessfully
+    rummaging through docs, use `vmap` to construct a new function.
+
+        >>> torch.dot                            # [D], [D] -> []
+        >>> batched_dot = functorch.vmap(torch.dot)  # [N, D], [N, D] -> [N]
+        >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
+        >>> batched_dot(x, y)
+
+    `vmap` can be helpful in hiding batch dimensions, leading to a simpler
+    model authoring experience.
+
+        >>> batch_size, feature_size = 3, 5
+        >>> weights = torch.randn(feature_size, requires_grad=True)
+        >>>
+        >>> def model(feature_vec):
+        >>>     # Very simple linear model with activation
+        >>>     return feature_vec.dot(weights).relu()
+        >>>
+        >>> examples = torch.randn(batch_size, feature_size)
+        >>> result = functorch.vmap(model)(examples)
+
+    `vmap` can also help vectorize computations that were previously difficult
+    or impossible to batch. One example is higher-order gradient computation.
+    The PyTorch autograd engine computes vjps (vector-Jacobian products).
+    Computing a full Jacobian matrix for some function f: R^N -> R^N usually
+    requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
+    we can vectorize the whole computation, computing the Jacobian in a single
+    call to `autograd.grad`.
+
+        >>> # Setup
+        >>> N = 5
+        >>> f = lambda x: x ** 2
+        >>> x = torch.randn(N, requires_grad=True)
+        >>> y = f(x)
+        >>> I_N = torch.eye(N)
+        >>>
+        >>> # Sequential approach
+        >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
+        >>>                  for v in I_N.unbind()]
+        >>> jacobian = torch.stack(jacobian_rows)
+        >>>
+        >>> # vectorized gradient computation
+        >>> def get_vjp(v):
+        >>>     return torch.autograd.grad(y, x, v)
+        >>> jacobian = functorch.vmap(get_vjp)(I_N)
+
+    .. note::
+        vmap does not provide general autobatching or handle variable-length
+        sequences out of the box.
+    """
+    warnings.warn(
+        'functorch.vmap is an experimental prototype that is subject to '
+        'change and/or deletion. Please use at your own risk. There may be '
+        'unexpected performance cliffs due to certain operators not being '
+        'implemented. To see detailed performance warnings please use '
+        '`torch._C._debug_only_display_vmap_fallback_warnings(True) '
+        'before the call to `vmap`.',
+        stacklevel=2)
+    return _vmap(func, in_dims, out_dims)
+
+# A version of vmap but without the initial "experimental prototype" warning
+def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
+    @functools.wraps(func)
+    def wrapped(*args):
+        _check_out_dims_is_int_or_int_tuple(out_dims, func)
+        vmap_level = _vmapmode_increment_nesting()
+        try:
+            batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
+            batched_outputs = func(*batched_inputs)
+            _validate_outputs(batched_outputs, func)
+            return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
+        finally:
+            _vmapmode_decrement_nesting()
+    return wrapped
diff --git a/functorch/functorch/csrc/BatchedFallback.cpp b/functorch/functorch/csrc/BatchedFallback.cpp
new file mode 100644
index 0000000..6e54a5e
--- /dev/null
+++ b/functorch/functorch/csrc/BatchedFallback.cpp
@@ -0,0 +1,416 @@
+#include <functorch/csrc/BatchedFallback.h>
+#include <functorch/csrc/VmapTransforms.h>
+#include <functorch/csrc/Constants.h>
+#include <functorch/csrc/TensorWrapper.h>
+#include <functorch/csrc/DynamicLayer.h>
+
+#include <ATen/Context.h>
+#include <ATen/MatrixRef.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <c10/util/accumulate.h>
+#include <c10/util/llvmMathExtras.h>
+
+namespace at {
+namespace functorch {
+
+// Given a linear index, return the actual index.
+// Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
+static at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
+computeIndex(int64_t linear_idx, IntArrayRef sizes) {
+  at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result;
+  result.reserve(sizes.size());
+  for (auto it = sizes.rbegin(); it != sizes.rend(); it++) {
+    auto remainder = linear_idx % *it;
+    result.push_back(remainder);
+    linear_idx -= remainder;
+    linear_idx /= *it;
+  }
+  std::reverse(std::begin(result), std::end(result));
+  return result;
+}
+
+static bool areAllReturnsTensors(const at::FunctionSchema& schema) {
+  return std::all_of(
+      schema.returns().begin(),
+      schema.returns().end(),
+      [] (const Argument& arg) { return arg.type() == TensorType::get(); });
+}
+
+static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) {
+  return std::any_of(
+      schema.arguments().begin(),
+      schema.arguments().end(),
+      [] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); });
+}
+
+// Returns if an operator is in-place. An operator is inplace if:
+// 1. The first argument is a Tensor and it is being written to
+// 2. The first argument is being returned
+// 3. No other arguments are aliased
+// Here is an example of an in-place operator:
+// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+static bool isInplaceOp(const FunctionSchema& schema) {
+  if (!schema.is_mutable() || schema.returns().size() != 1) {
+    return false;
+  }
+  // Check that the first argument is being written to
+  const auto& first_arg_alias_info = schema.arguments().begin()->alias_info();
+  if (!first_arg_alias_info || !first_arg_alias_info.value().isWrite()) {
+    return false;
+  }
+  // Check that none of the other args are being aliased
+  for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
+    const auto& alias_info = it->alias_info();
+    if (alias_info) {
+      return false;
+    }
+  }
+  // Check that the first tensor is being returned (i.e., output has a (a!))
+  const auto& return_alias_info = schema.returns()[0].alias_info();
+  return return_alias_info && return_alias_info.value().isWrite();
+}
+
+static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
+  if (!globalContext().areVmapFallbackWarningsEnabled()) {
+    return;
+  }
+  auto uses_stack = is_inplace ? "" : " and stack";
+  TORCH_WARN("Batching rule not implemented for ", schema.operator_name(), " falling back "
+             "to slow (for loop", uses_stack, ") implementation");
+}
+
+// The general flow of the algorithm is as follows.
+// - First, we figure out which arguments are BatchedTensors and save them
+//   to a vector. We also store a vector of which index of the arguments list
+//   each BatchedTensor appears in. This will be useful for bookkeeping later.
+// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
+//   This returns a vector of VmapPhysicalView that hold tensors that contain
+//   all of the collective batch dimensions at the front of the tensors.
+// - Then, we attempt to call `op` once per slice of the inputs. To do this,
+//   we repeatedly we slice the input arguments (if they are BatchedTensors),
+//   put the sliced (or a not-sliced) version of the input onto the stack, invoke
+//   the operator, and then pop the results off the stack.
+void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  const auto& schema = op.schema();
+  const auto num_returns = schema.returns().size();
+  warnFallback(schema, /*in_place*/true);
+
+  const auto num_arguments = schema.arguments().size();
+  const auto arguments = torch::jit::last(stack, num_arguments);
+  const auto arguments_begin = stack->size() - num_arguments;
+
+  // `self` is the Tensor being modified in-place
+  Tensor self = arguments[0].toTensor();
+  const auto* self_impl = maybeGetBatchedImpl(self);
+  std::bitset<kVmapMaxTensorDims> self_vmap_levels;
+  if (self_impl) {
+    self_vmap_levels = createVmapLevelsBitset(self_impl->bdims());
+  }
+
+  // Figure out which arguments are BatchedTensor. Save them to a vector.
+  // For each BatchedTensor, also record what position of `arguments` they came from.
+  at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
+  VmapDimVector batched_tensor_inputs_position;
+  for (int64_t idx = 0; idx < arguments.size(); ++idx) {
+    const auto& ivalue = arguments[idx];
+    if (!ivalue.isTensor()) {
+      continue;
+    }
+    const auto& tensor = ivalue.toTensor();
+    if (!tensor.defined()) {
+      continue;
+    }
+    const auto* batched = maybeGetBatchedImpl(tensor);
+    if (!batched) {
+      continue;
+    }
+
+    // NOTE: [vmap-incompatible in-place operations]
+    // In-place operations on `self` are not possible if there exists some vmap
+    // level `l` such that `self` is not being vmapped on that level but another
+    // argument is. For example, let B0 be a batch dim inside vmap and consider
+    // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
+    // - self is torch.ones(3) and does not participate in this vmap
+    // - other is BatchedTensor(torch.ones(B0, 3))
+    // There's no way to do self.add_(other) because `other` has more elements
+    // elements than `self` due to being vmapped over.
+    //
+    // In the vmap fallback, we should error out when we detect this.
+    auto other_vmap_levels = createVmapLevelsBitset(batched->bdims());
+    if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
+      // Find one vmap level to complain about
+      auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
+      auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
+      // The following prints out "vmap: aten::add_(tensor, ...) is not possible",
+      // but it would be better to print out "tensor.add_(...) is not possible".
+      // Afaict there's no official way to get the add_ and there is no way to
+      // tell if an operator has method or function variants.
+      TORCH_CHECK(false,
+        "vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
+        "there exists a Tensor `other` in extra_args that has more elements ",
+        "than `self`. This happened due to `other` being vmapped over but ",
+        "`self` not being vmapped over at level ", offending_level, ". ",
+        "Please try to use out-of-place operators instead of ", schema.name(), ". ",
+        "If said operator is being called inside the PyTorch framework, ",
+        "please file a bug report instead.");
+    }
+    batched_tensor_inputs.push_back(tensor);
+    batched_tensor_inputs_position.push_back(idx);
+  }
+  TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0);
+
+  // MultiBatchVmapTransform the BatchedTensor arguments. This returns
+  // VmapPhysicalViews that contain all of the batch dimensions.
+  const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
+      batched_tensor_inputs);
+
+  // Compute the total number of batches
+  auto num_batch_dims = input_physical_views.front().numBatchDims();
+  auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
+  auto batch_sizes = ArrayRef<int64_t>(
+      first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
+  const auto num_batches = c10::multiply_integers(batch_sizes);
+  // Without a shape-checking API, we're unable to compute the correct shape of
+  // the output so we just error out.
+  TORCH_CHECK(num_batches > 0,
+      "Batching rule not implemented for ", schema.operator_name(), ". ",
+      "The fallback path does not support vmap over dims of size 0.");
+
+  // Strategy: For each batch, we are going to push slices (where applicable)
+  // of the arguments onto `stack`, and call `op`.
+  for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
+    auto index = computeIndex(linear_idx, batch_sizes);
+    auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
+    auto input_physical_views_iter = input_physical_views.begin();
+    for (int64_t arg_idx = 0; arg_idx < num_arguments; ++arg_idx) {
+      // We assume that torch::jit::Stack is backed by vector<IValue> for
+      // simplicity. When that is not the case, this code should be updated.
+      const auto& argument = (*stack)[arguments_begin + arg_idx];
+      if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
+          || arg_idx != *batched_tensor_inputs_pos_iter) {
+        // argument isn't a BatchedTensor
+        torch::jit::push(stack, argument);
+        continue;
+      }
+      // argument is a BatchedTensor
+      TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
+      const auto& physical_view_for_argument = *input_physical_views_iter;
+      auto thing = physical_view_for_argument.tensor().index(index);
+      torch::jit::push(stack, thing);
+      batched_tensor_inputs_pos_iter++;
+      input_physical_views_iter++;
+    }
+
+    op.callBoxed(stack);
+    torch::jit::drop(stack, 1);
+  }
+
+  // Return the tensor that was written to in-place
+  torch::jit::drop(stack, num_arguments);
+  torch::jit::push(stack, self);
+}
+
+static Tensor safeStack(TensorList tensors) {
+  auto is_defined = [](const Tensor& t) { return t.defined(); };
+  if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
+    return at::stack(tensors);
+  }
+  // NOTE [vmap through backward and undefined grad]
+  // While vmapping through backward functions (to compute batched grad), it
+  // is possible for the backward function to return an undefined grad for some
+  // grad_input for each example. In that case, we return an undefined grad.
+  //
+  // It is theoretically posssible for *some* of the examples to produce an
+  // undefined grad (a kernel could peek at the gradient values and return an
+  // undefined tensor if it determines the gradient is full of zeros). We
+  // could handle this by treating the undefined grad as a zero-filled tensor
+  // of the correct shape while stacking the tensors together. However I expect
+  // this to happen very rarely (I have not been able to find an example in our
+  // codebase) so we just error out in this case.
+  if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
+    return Tensor();
+  }
+  TORCH_CHECK(false,
+      "vmap: slow fallback received a mix of undefined and defined tensors ",
+      "as the result of an operation. This is not supported, please file us ",
+      "an issue on github.");
+}
+
+// TODO: dedup
+static bool participatesInCurrentLevel(const Tensor& self) {
+  auto maybe_level = maybeCurrentDynamicLayer();
+  TORCH_INTERNAL_ASSERT(maybe_level.has_value());
+  auto current_level = maybe_level->layerId();
+  auto* maybe_batched_impl = maybeGetBatchedImpl(self);
+  if (!maybe_batched_impl) {
+    return false;
+  }
+  const auto& bdims = maybe_batched_impl->bdims();
+  TORCH_INTERNAL_ASSERT(bdims.size() == 1);
+  auto self_level = bdims.back().level();
+  TORCH_INTERNAL_ASSERT(self_level <= current_level);
+  return self_level == current_level;
+}
+
+static bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
+  if (!ivalue.isTensor()) {
+    return false;
+  }
+  return participatesInCurrentLevel(ivalue.toTensor());
+}
+
+// The general flow of the algorithm is as follows.
+// - First, we figure out which arguments are BatchedTensors and save them
+//   to a vector. We also store a vector of which index of the arguments list
+//   each BatchedTensor appears in. This will be useful for bookkeeping later.
+// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
+//   This returns a vector of VmapPhysicalView that hold tensors that contain
+//   all of the collective batch dimensions at the front of the tensors.
+// - Then, we attempt to call `op` once per slice of the inputs. To do this,
+//   we repeatedly we slice the input arguments (if they are BatchedTensors),
+//   put the sliced (or a not-sliced) version of the input onto the stack, invoke
+//   the operator, and then pop the results off the stack.
+// - Each result obtained from the previous step is a slice of the total result,
+//   so we stack those tensors together to form the final result.
+void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  const auto& schema = op.schema();
+  const auto num_returns = schema.returns().size();
+  const auto num_arguments = schema.arguments().size();
+  const auto arguments = torch::jit::last(stack, num_arguments);
+
+  TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
+              "Batching rule not implemented for ", schema.operator_name(), ". ",
+              "We could not generate a fallback.");
+
+  if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    op.callBoxed(stack); 
+    return;
+  }
+
+  if (isInplaceOp(schema)) {
+    batchedTensorInplaceForLoopFallback(op, stack);
+    return;
+  }
+  TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
+              "Batching rule not implemented for ", schema.operator_name(), "; ",
+              "the fallback path doesn't work on out= or view ops.");
+  TORCH_CHECK(num_returns >= 1,
+              "Batching rule not implemented for ", schema.operator_name(), ". ",
+              "The fallback path does not support operations with no returns.");
+  warnFallback(schema, /*in_place*/false);
+
+  const auto arguments_begin = stack->size() - num_arguments;
+
+  // Figure out which arguments are BatchedTensor. Save them to a vector.
+  // For each BatchedTensor, also record what position of `arguments` they came from.
+  at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
+  VmapDimVector batched_tensor_inputs_position;
+  for (int64_t idx = 0; idx < arguments.size(); ++idx) {
+    const auto& ivalue = arguments[idx];
+    if (!ivalue.isTensor()) {
+      continue;
+    }
+    const auto& tensor = ivalue.toTensor();
+    if (!tensor.defined()) {
+      continue;
+    }
+    const auto* batched = maybeGetBatchedImpl(tensor);
+    if (!batched) {
+      continue;
+    }
+    batched_tensor_inputs.push_back(tensor);
+    batched_tensor_inputs_position.push_back(idx);
+  }
+  TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0);
+
+  // MultiBatchVmapTransform the BatchedTensor arguments. This returns
+  // VmapPhysicalViews that contain all of the batch dimensions.
+  const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
+      batched_tensor_inputs);
+
+  // Compute the total number of batches
+  auto num_batch_dims = input_physical_views.front().numBatchDims();
+  auto some_sizes = input_physical_views.front().tensor().sizes();
+  auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
+  const auto num_batches = c10::multiply_integers(batch_sizes);
+  // Without a shape-checking API, we're unable to compute the correct shape of
+  // the output so we just error out.
+  TORCH_CHECK(num_batches > 0,
+      "Batching rule not implemented for ", schema.operator_name(), ". ",
+      "The fallback path does not support vmap over dims of size 0.");
+
+  // Strategy: For each batch, we are going to push slices (where applicable)
+  // of the arguments onto `stack`, call `op`, and store the result in
+  // `output_shards`.
+  //
+  // NOTE: [Output shards layout]
+  // Assume that the operator has three outputs: a, b, c.
+  // The layout of output_shards is as follows:
+  // [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
+  // This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
+  // more easily in the next step.
+  std::vector<Tensor> output_shards(num_batches * num_returns);
+
+  for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
+    auto index = computeIndex(linear_idx, batch_sizes);
+    auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
+    auto input_physical_views_iter = input_physical_views.begin();
+    for (int64_t arg_idx = 0; arg_idx < num_arguments; ++arg_idx) {
+      // We assume that torch::jit::Stack is backed by vector<IValue> for
+      // simplicity. When that is not the case, this code should be updated.
+      const auto& argument = (*stack)[arguments_begin + arg_idx];
+      if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
+          || arg_idx != *batched_tensor_inputs_pos_iter) {
+        // argument isn't a BatchedTensor
+        torch::jit::push(stack, argument);
+        continue;
+      }
+      // argument is a BatchedTensor
+      TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
+      const auto& physical_view_for_argument = *input_physical_views_iter;
+      c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+      torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
+      batched_tensor_inputs_pos_iter++;
+      input_physical_views_iter++;
+    }
+
+    // std::cout << "[Fallback]: ";
+    // at::dump_tensor((*stack)[stack->size() - 1].toTensor());
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    op.callBoxed(stack);
+
+    // Store the result into `output_shards`. See NOTE: [Output shards layout]
+    // to learn about the details of how we store the shards.
+    const auto returns = torch::jit::last(stack, num_returns);
+    for (int64_t return_idx = 0; return_idx < returns.size(); ++return_idx) {
+      output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
+    }
+    torch::jit::drop(stack, num_returns);
+  }
+
+  // For each output Tensor, stack the shards of the tensor together to form a return
+  torch::jit::drop(stack, num_arguments);
+  auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
+  for (int64_t return_idx = 0; return_idx < num_returns; ++return_idx) {
+    auto shards = output_shards_chunks[return_idx];
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    auto flat_output = safeStack(shards);
+    // See NOTE [vmap through backward and undefined grad]
+    if (!flat_output.defined()) {
+      torch::jit::push(stack, flat_output);
+      continue;
+    }
+    VmapDimVector output_sizes(batch_sizes);
+    output_sizes.insert(
+        output_sizes.end(),
+        flat_output.sizes().begin() + 1,
+        flat_output.sizes().end());
+    torch::jit::push(
+        stack,
+        input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
+  }
+}
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/BatchedFallback.h b/functorch/functorch/csrc/BatchedFallback.h
new file mode 100644
index 0000000..1b84a3d
--- /dev/null
+++ b/functorch/functorch/csrc/BatchedFallback.h
@@ -0,0 +1,26 @@
+#pragma once
+#include <ATen/ATen.h>
+#include <ATen/core/op_registration/op_registration.h>
+#include <torch/library.h>
+
+namespace at {
+namespace functorch {
+
+// If an operator doesn't have a batching rule implemented then we fallback
+// to this implementation. The fallback only works on out-of-place operators
+// that return only tensors with new memory. (e.g., no in-place operators, no
+// view operations).
+//
+// The fallback effectively takes all of the BatchedTensors in `stack`, slices
+// them, and runs `op` on all of the corresponding slices to produce slices
+// of the outputs. The output slices then get `torch.stack`ed to create the
+// final returns.
+//
+// The performance of the fallback is not very good because it introduces an
+// extra copy from stacking the sliced outputs. Because of this, we prefer to
+// write batching rules for operators whenever possible.
+void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/BatchedTensorImpl.cpp b/functorch/functorch/csrc/BatchedTensorImpl.cpp
new file mode 100644
index 0000000..4a3058f
--- /dev/null
+++ b/functorch/functorch/csrc/BatchedTensorImpl.cpp
@@ -0,0 +1,169 @@
+#include <functorch/csrc/BatchedTensorImpl.h>
+
+#include <ATen/WrapDimUtils.h>
+#include <c10/util/Exception.h>
+
+#include <functorch/csrc/Constants.h>
+
+namespace at {
+namespace functorch {
+
+BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
+  : TensorImpl(
+      c10::DispatchKeySet(kBatchedKey),
+      value.dtype(),
+      value.device()
+    )
+  , value_(std::move(value))
+  , bdims_(std::move(bdims))
+{
+  TORCH_INTERNAL_ASSERT(value_.defined());
+  set_storage_access_should_throw();
+  checkInvariants();
+
+  const auto public_dims = value_.dim() - bdims_.size();
+  const auto value_sizes = value_.sizes();
+  const auto value_strides = value_.strides();
+  sizes_and_strides_.resize(public_dims);
+  for (int64_t dim = 0; dim < public_dims; dim++) {
+    auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
+    sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
+    sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
+  }
+  refresh_numel();
+  refresh_contiguous();
+}
+
+BatchedTensorImpl::BatchedTensorImpl(DispatchKeySet key_set, Tensor value, BatchDims bdims)
+  : TensorImpl(
+      key_set.add(kBatchedKey),
+      value.dtype(),
+      value.device()
+    )
+  , value_(std::move(value))
+  , bdims_(std::move(bdims))
+{
+  TORCH_INTERNAL_ASSERT(value_.defined());
+  checkInvariants();
+
+  TORCH_INTERNAL_ASSERT(bdims_.size() == 1);
+  refreshSizesAndStrides();
+}
+
+void BatchedTensorImpl::refreshSizesAndStrides() {
+  const auto public_dims = value_.dim() - bdims_.size();
+  const auto value_sizes = value_.sizes();
+  const auto value_strides = value_.strides();
+  sizes_and_strides_.resize(public_dims);
+  for (int64_t dim = 0; dim < public_dims; dim++) {
+    auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
+    sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
+    sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
+  }
+  refresh_numel();
+  refresh_contiguous();
+}
+
+int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const {
+  if (wrap_dim) {
+    const auto ndim = sizes_and_strides_.size();
+    dim = maybe_wrap_dim(dim, ndim);
+  }
+  auto is_bdim = createBatchDimBitset(bdims_);
+
+  // Example: assume dim = 3, and is_bdim = 10010011000...
+  // The 1's are batch dims and 0's are normal dims of the underlying value_ Tensor.
+  // actualDim gives us the index of `dim` in the `value_` Tensor, which is equivalent
+  // to asking "where does the 3rd (0-indexed) zero occur in the bitset?".
+  // The answer to that is index 5.
+  //
+  // TODO(rzou): the PDEP instruction does exactly this
+  // (https://stackoverflow.com/questions/7669057/find-nth-set-bit-in-an-int)
+  // but it might require newer (>= ~2015) CPUs. We should clean this up
+  // if/when we have dropped support for older CPUs.
+  int64_t non_bdim_count = 0;
+  for (int64_t actual_dim = 0; actual_dim < kVmapMaxTensorDims; actual_dim++) {
+    if (is_bdim[actual_dim]) {
+      continue;
+    }
+    if (non_bdim_count == dim) {
+      return actual_dim;
+    }
+    non_bdim_count++;
+  }
+  // If we hit this assert, then that means
+  // `non_bdim_count` + #num_bdims > kVmapMaxTensorDims. We restrict the number
+  // of dims a BatchedTensorImpl can have to kVmapMaxTensorDims so this should
+  // never be hit.
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+void BatchedTensorImpl::checkInvariants() const {
+  int64_t prev_level = -1;
+  for (const auto& bdim : bdims_) {
+    TORCH_INTERNAL_ASSERT(bdim.level() > prev_level);
+    prev_level = bdim.level();
+  }
+}
+
+// The following are publically exposed as methods of Tensor
+bool BatchedTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
+  TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
+      "NYI: querying is_contiguous inside of vmap for memory_format ",
+      "other than torch.contiguous_format");
+  return is_contiguous_;
+}
+
+// The following are some internal inherited methods that we do not support.
+// They should never get called.
+void BatchedTensorImpl::set_size(int64_t dim, int64_t new_size) {
+  TORCH_INTERNAL_ASSERT(false, "Can't set_size for BatchedTensorImpl");
+}
+void BatchedTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
+  TORCH_INTERNAL_ASSERT(false, "Can't set_stride for BatchedTensorImpl");
+}
+void BatchedTensorImpl::set_storage_offset(int64_t storage_offset) {
+  TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for BatchedTensorImpl");
+}
+#ifdef DEBUG
+bool BatchedTensorImpl::has_storage() const {
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "BatchedTensorImpl assumes that storage_ is never set");
+  return false;
+}
+#endif
+
+const char* BatchedTensorImpl::tensorimpl_type_name() const {
+  return "BatchedTensorImpl";
+}
+
+Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
+  DispatchKeySet key_set;
+  if (tensor.is_cuda()) {
+    key_set = key_set.add(DispatchKey::CUDA);
+  }
+  return at::detail::make_tensor<BatchedTensorImpl>(key_set, tensor, std::move(bdims));
+}
+
+Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
+  BatchDims new_bdims = { { level, dim } };
+  TORCH_INTERNAL_ASSERT(new_bdims.size() == 1);
+  return makeBatched(tensor, std::move(new_bdims));
+}
+
+bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other) {
+  const auto* other_batched = maybeGetBatchedImpl(other);
+  if (!other_batched) {
+    return true;
+  }
+  const auto* self_batched = maybeGetBatchedImpl(self);
+  if (!self_batched) {
+    // self is not batched but other is batched
+    return false;
+  }
+  auto self_levels = createVmapLevelsBitset(self_batched->bdims());
+  auto other_levels = createVmapLevelsBitset(other_batched->bdims());
+  return self_levels == (self_levels | other_levels);
+}
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/BatchedTensorImpl.h b/functorch/functorch/csrc/BatchedTensorImpl.h
new file mode 100644
index 0000000..e8a3be0
--- /dev/null
+++ b/functorch/functorch/csrc/BatchedTensorImpl.h
@@ -0,0 +1,158 @@
+#pragma once
+
+#include <bitset>
+
+#include <ATen/ArrayRef.h>
+#include <ATen/SmallVector.h>
+#include <ATen/Tensor.h>
+
+#include <functorch/csrc/Constants.h>
+
+namespace at {
+namespace functorch {
+
+using Tensor = at::Tensor;
+
+// We assume this in a few other places in the codebase,
+// but there isn't a centralized definition.
+constexpr int64_t kVmapMaxTensorDims = 64;
+
+// The valid vmap levels range from [0, 64). This effectively means that we
+// support a maximum of 64 nested vmaps.
+constexpr int64_t kVmapNumLevels = 64;
+
+// Store this number of elements of BatchDims on the stack. Most people will
+// probably use <= 5 nested vmaps, but adjust this number as necessary.
+constexpr int64_t kBatchDimsStackSize = 5;
+
+// a BatchDim represents a "private" dimension on a Tensor created inside of
+// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
+// is being vmap'ed over and the `level` being an identifier for which vmap
+// said dimension was created inside. The `dim` corresponds to a "physical
+// dim" - it is a dimension index on the underlying physical tensor that is being
+// vmapped over.
+struct BatchDim {
+  BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
+  int64_t dim() const {
+    return dim_;
+  }
+  int64_t level() const {
+    return level_;
+  }
+ private:
+  int64_t dim_;
+  int64_t level_;
+};
+
+using BatchDims = at::SmallVector<BatchDim, kBatchDimsStackSize>;
+using BatchDimsRef = at::ArrayRef<BatchDim>;
+
+// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
+// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
+// BatchedTensorImpl.
+//
+// The batch dimensions are treated as being "private"; they are not user-visible.
+// For example, in the following Tensor,
+//    bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
+// dimensions 0 and 1 are batch dimensions.
+//
+// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
+// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
+struct BatchedTensorImpl : public c10::TensorImpl {
+  explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
+  explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, BatchDims bdims);
+
+  // Returns a reference to BatchDims that represent which dimensions of this
+  // tensor are private.
+  BatchDimsRef bdims() const { return bdims_; }
+
+  // BatchedTensorImpl wraps a Tensor
+  const Tensor& value() const { return value_; };
+
+  // Given a public dimension index, return the dimension index in the underlying
+  // value() tensor.
+  // For example, if we have
+  //    bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=2)])
+  // bt.actualDim(0) -> 1
+  // bt.actualDim(1) -> 3
+  // bt.actualDim(2) -> Error
+  int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
+
+  // Override a bunch of methods inherited from TensorImpl to return error messages.
+  bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
+  void set_size(int64_t dim, int64_t new_size) override;
+  void set_stride(int64_t dim, int64_t new_stride) override;
+  void set_storage_offset(int64_t storage_offset) override;
+#ifdef DEBUG
+  bool has_storage() const override;
+#endif
+
+  void refreshSizesAndStrides();
+
+ private:
+  // see NOTE: [BatchedTensorImpl levels invariant]
+  void checkInvariants() const;
+  const char* tensorimpl_type_name() const override;
+
+  Tensor value_;
+
+  // Note: [BatchedTensorImpl levels invariant]
+  // There is an invariant that the BatchDims must be stored in increasing `level`
+  // order. That is, for i < j, bdims_[i].level must be less than bdims_[j].level.
+  BatchDims bdims_;
+};
+
+// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
+// BatchedTensorImpl.
+inline bool isBatchedTensor(const Tensor& tensor) {
+  return tensor.unsafeGetTensorImpl()->key_set().has(kBatchedKey);
+}
+
+// It is unsafe to call this on a Tensor that is not backed by a
+// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
+inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
+  return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
+}
+
+inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
+  if (!isBatchedTensor(tensor)) {
+    return nullptr;
+  }
+  return unsafeGetBatchedImpl(tensor);
+}
+
+// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
+inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(BatchDimsRef bdims) {
+  std::bitset<kVmapMaxTensorDims> is_bdim;
+  for (const auto& bdim : bdims) {
+    is_bdim.set(bdim.dim());
+  }
+  return is_bdim;
+}
+
+// Creates a bitset for all of the levels present in `bdims`
+inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
+  std::bitset<kVmapNumLevels> result;
+  for (const auto& bdim : bdims) {
+    result.set(bdim.level());
+  }
+  return result;
+}
+
+inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
+  out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
+  return out;
+}
+
+// Use this to construct a BatchedTensor from a regular Tensor
+TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
+
+// Adds a batch dim to `tensor`, returning a BatchedTensor
+TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
+
+// Checks if an inplace operation on self and other is "vmap compatible".
+// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
+TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
+
+}
+}
diff --git a/functorch/functorch/csrc/BatchingMetaprogramming.h b/functorch/functorch/csrc/BatchingMetaprogramming.h
new file mode 100644
index 0000000..3cb90f4
--- /dev/null
+++ b/functorch/functorch/csrc/BatchingMetaprogramming.h
@@ -0,0 +1,155 @@
+#pragma once
+#include <ATen/Tensor.h>
+
+namespace at {
+namespace functorch {
+
+// Metaprogramming things
+template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
+template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
+template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
+template <typename T> class debug_t;
+
+// tail operation
+template<class TypeList>
+struct tail final {
+    static_assert(c10::guts::false_t<TypeList>::value,
+                  "In typelist::tail<T>, the T argument must be typelist<...>.");
+};
+template<class Head, class... Tail>
+struct tail<typelist<Head, Tail...>> final {
+  using type = typelist<Tail...>;
+};
+template<class TypeList> using tail_t = typename tail<TypeList>::type;
+
+template <class First, class Second, class Next, class Tail>
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
+  using type = Next;
+};
+template <class Next, class Tail>
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, optional<int64_t>, Next, Tail> {
+  using type = Tail;
+};
+template <class Next, class Tail>
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, optional<int64_t>, Next, Tail> {
+  using type = Tail;
+};
+template <class Next, class Tail>
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, optional<int64_t>, Next, Tail> {
+  using type = Tail;
+};
+template <class TypeList> struct RemoveBatchDimAfterTensor {
+  using first = head_t<TypeList>;
+  using next = tail_t<TypeList>;
+  using second = head_t<next>;
+  using tail = tail_t<next>;
+
+  using type = concat_t<
+    typelist<first>,
+    typename RemoveBatchDimAfterTensor<
+      typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
+    >::type
+  >;
+};
+template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
+  using type = typelist<Type>;
+};
+template <> struct RemoveBatchDimAfterTensor<typelist<>> {
+  using type = typelist<>;
+};
+template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
+
+// TODO: get rid of these
+// Do I need templates on templates now?
+// template <typename func_t> struct LowerToNextLayer {};
+// template <typename Return, typename... Args> struct LowerToNextLayer<Return(Args...)> {
+//   // How to pass in batch_rule directly?
+//   static Return apply(Args... args);
+// };
+
+template <typename batch_rule_t, typename Result, typename... Args>
+Result lowerToNextLayer(batch_rule_t batch_rule, Args... args);
+
+//# Tensor lowerToNextLayer(
+//#     std::function<std::tuple<Tensor,optional<int64_t>>(const Tensor&, optional<int64_t>)> batch_rule,
+//#     const Tensor& tensor);
+std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, optional<int64_t> batch_dim);
+
+template<typename F, F Func, typename Return, typename TupleArgs> struct TORCH_API Dummy {};
+
+template<typename F, F Func, typename Return, typename...T> struct Dummy<F, Func, Return, std::tuple<T...>> {
+  static Return apply(T... args) {
+    return lowerToNextLayer(abs_batch_rule, std::forward<T>(args)...);
+  }
+};
+
+template <typename T> struct UnpackSingleItemTuple {
+  using type = T;
+};
+template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
+  using type = T;
+};
+template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
+
+template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
+template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
+  using type = Return(Args...);
+};
+template <typename Return, typename TL>
+struct BuildFunction {
+  using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
+};
+template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
+
+
+// std::tuple<Tensor,optional<int64_t>> (*kAbsBatchRule)(const Tensor& Tensor, optional<int64_t>)
+//  = &abs_batch_rule;
+template <typename batch_rule_t> struct ToOperatorType {
+  using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
+  using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
+
+  using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
+  using operator_return_type =
+    unpack_single_item_tuple_t<
+      c10::guts::typelist::to_tuple_t<
+        remove_batch_dim_after_tensor_t<
+          c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
+
+  using type = build_function_t<operator_return_type, operator_parameter_types>;
+};
+template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
+
+template <typename F, F Func> struct TORCH_API PrimBatchRule3 {
+  using func_t = to_operator_t<typename std::remove_pointer<F>::type>;
+  using result_type = typename c10::guts::function_traits<func_t>::return_type;
+  using parameter_types = c10::guts::typelist::to_tuple_t<typename c10::guts::function_traits<func_t>::parameter_types>;
+  static auto apply = Dummy<F, Func, result_type, parameter_types>::apply;
+};
+
+template<typename Return, typename TypeList> struct TORCH_API PrimBatchRule5 {};
+template<typename Return, typename... T> struct PrimBatchRule5<Return, typelist<T...>> {
+  static inline Return apply(T... args) {
+    return lowerToNextLayer(abs_batch_rule, std::forward<T>(args)...);
+  }
+};
+
+template<typename func_t> struct PrimBatchRule6 {};
+template<typename Return, typename... Args> struct PrimBatchRule6<Return (Args...)> {
+  static inline Return apply(Args... args) {
+    return lowerToNextLayer(abs_batch_rule, std::forward<Args>(args)...);
+  }
+};
+
+// template<typename batch_rule_t, batch_rule_t BatchRule> struct PrimBatchRule7 {};
+// template<typename batch_rule_t, batch_rule_t BatchRule, typename BRReturn, typename... BRArgs>
+// struct PrimBatchRule7<BRReturn(*)(BRArgs...), BatchRule> {
+template<typename br_t, br_t BatchRule, typename func_t> struct PrimBatchRule7 {};
+template<typename br_t, br_t BatchRule, typename Return, typename... Args> struct PrimBatchRule7<
+br_t, BatchRule, Return (Args...)> {
+  static inline Return apply(Args... args) {
+    return lowerToNextLayer<br_t, Return, Args...>(BatchRule, std::forward<Args>(args)...);
+  }
+};
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/BatchingRegistrations.cpp b/functorch/functorch/csrc/BatchingRegistrations.cpp
new file mode 100644
index 0000000..a053c87
--- /dev/null
+++ b/functorch/functorch/csrc/BatchingRegistrations.cpp
@@ -0,0 +1,1700 @@
+#include <torch/library.h>
+#include <ATen/native/ResizeCommon.h>
+#include <ATen/ATen.h>
+#include <torch/csrc/autograd/variable.h>
+
+#include <functorch/csrc/DynamicLayer.h>
+#include <functorch/csrc/TensorWrapper.h>
+#include <functorch/csrc/BatchingMetaprogramming.h>
+#include <functorch/csrc/VmapTransforms.h>
+#include <functorch/csrc/BatchedFallback.h>
+#include <functorch/csrc/Constants.h>
+
+namespace at {
+namespace functorch {
+
+
+// NOTE: [What is a batching rule?]
+//
+// A *batching rule* implements the logic of how to call an operator on inputs
+// that have zero or more additional batch dimensions. When one does a vmap, the
+// dimension(s) being vmap'ed over get recorded as batch dimensions.
+//
+// For example, vmap(torch.add)(x, y)
+// 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
+// 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
+// 3. and then runs `torch.add(batched_x, batched_y)`.
+
+// NOTE: [When should I add a batching rule?]
+// When you are adding a new operator, you'll need to add a batching rule so
+// that vmap can work efficiently with said operator. If you do not, we'll attempt
+// to generate a slow fallback for the batching rule.
+
+// NOTE: [How to write batching rules?]
+// The signature of a batching rule should look like exactly like the C++ signature
+// of its operator.
+//
+// First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
+//
+// At a high level, what a batching rule does is the following:
+// 1. Converts (logical) BatchedTensors to views on physical tensors.
+// 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
+//    arguments that correspond to the physical tensors.
+// 3. Calls at:: operations on the physical tensors and arguments to produce
+//    some physical results.
+// 4. Converts physical results back to BatchedTensors.
+//
+// Steps 1, 2, and 4 differ for operators with different batching behaviors. When
+// writing a new batching rule, please select a VmapTransform that matches the
+// batching behavior of your operation. The VmapTransform provides helper functions
+// to do steps (1), (2), and (4).
+// (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
+
+// Note: [Future plans]
+// The API for writing a batching rule isn't stable. In the future, we'd like
+// to think about the problem of translating these batching rules to TorchScript.
+// Ideally batching rules in eager mode vs TorchScript would look pretty similar,
+// if not use the same mechanism. In order to accomplish that we might have to
+// do some refactoring.
+
+// PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
+static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
+  return dim == 0 || dim == -1;
+}
+
+// This check should probably go into the dispatcher...
+static bool participatesInCurrentLevel(const Tensor& self) {
+  auto maybe_level = maybeCurrentDynamicLayer();
+  TORCH_INTERNAL_ASSERT(maybe_level.has_value());
+  auto current_level = maybe_level->layerId();
+  auto* maybe_batched_impl = maybeGetBatchedImpl(self);
+  if (!maybe_batched_impl) {
+    return false;
+  }
+  const auto& bdims = maybe_batched_impl->bdims();
+  TORCH_INTERNAL_ASSERT(bdims.size() == 1);
+  auto self_level = bdims.back().level();
+  TORCH_INTERNAL_ASSERT(self_level <= current_level);
+  return self_level == current_level;
+}
+static bool participatesInCurrentLevel(const Tensor& self, const Tensor& other) {
+  return participatesInCurrentLevel(self) || participatesInCurrentLevel(other);
+}
+
+static bool participatesInCurrentLevel(TensorList self) {
+  for (const Tensor& tensor : self) {
+    if (participatesInCurrentLevel(tensor)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+Tensor mean_batching_rule(const Tensor& self, optional<ScalarType> dtype) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.mean(dtype);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  VmapDimVector dims;
+  for (int64_t i = 1; i < self_physical.tensor().dim(); i++) {
+    dims.push_back(i);
+  }
+  auto result = at::mean(self_physical.tensor(), dims, /*keepdim*/false, dtype);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+} 
+
+Tensor log_softmax_batching_rule(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::log_softmax(self, dim, dtype);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = at::log_softmax(self_physical.tensor(), dim_physical, dtype);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor _log_softmax_batching_rule(const Tensor& self, int64_t dim, bool half_to_float) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::_log_softmax(self, dim, half_to_float);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = at::_log_softmax(self_physical.tensor(), dim_physical, half_to_float);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+std::tuple<Tensor,Tensor> max_pool2d_with_indices_batching_rule(
+    const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride,
+    IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::max_pool2d_with_indices(
+        self, kernel_size, stride, padding, dilation, ceil_mode);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  TORCH_INTERNAL_ASSERT(self_physical.tensor().dim() == 5);
+
+  auto N = self_physical.tensor().size(0);
+  auto M = self_physical.tensor().size(1);
+  auto physical = self_physical.tensor().flatten(0, 1);
+
+  auto result = max_pool2d_with_indices_batching_rule(physical,
+      kernel_size, stride, padding, dilation, ceil_mode);
+
+  auto first = std::get<0>(result).unflatten(0, {N, M});
+  auto second = std::get<1>(result).unflatten(0, {N, M});
+
+  first = self_physical.getPhysicalToLogicalMap().apply(first);
+  second = self_physical.getPhysicalToLogicalMap().apply(second);
+  return std::make_tuple<Tensor, Tensor>(std::move(first), std::move(second));
+}
+
+Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional<ScalarType> dtype) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.sum(dims, keepdim, dtype);
+  }
+  // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
+  // and instead returns a new scalar tensor (this also happens for dim=-1)
+  // If the following happens:
+  // >>> x = torch.randn(B0)  # the per-examples are all scalars
+  // >>> vmap(partial(torch.sum, dim=0), x)
+  // then we replicate the behavior of sum(scalar_tensor, dim=0).
+  if (/*logical*/self.dim() == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])) {
+    return self.clone();
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dims_physical = self_physical.getPhysicalDims(dims);
+  auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
+  if (logical_tensor.dim() > 0) {
+    return false;
+  }
+  auto* batched = maybeGetBatchedImpl(logical_tensor);
+  if (batched) {
+    return false;
+  }
+  return true;
+}
+
+template <typename F, F Func, typename... ExtraArgs>
+Tensor binary_pointwise_batching_rule(
+    const Tensor& self, const Tensor& other, ExtraArgs... args) {
+  if (!participatesInCurrentLevel(self, other)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return Func(self, other, std::forward<ExtraArgs>(args)...);
+  }
+  if (self.dim() > 0 && other.dim() > 0) {
+    auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
+    auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
+    return physical_args[0].getPhysicalToLogicalMap().apply(result);
+  }
+  if (isPhysicalScalarTensor(self)) {
+    auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
+    auto result = Func(self, other_physical.tensor(), args...);
+    return other_physical.getPhysicalToLogicalMap().apply(result);
+  }
+  if (isPhysicalScalarTensor(other)) {
+    auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+    auto result = Func(self_physical.tensor(), other, args...);
+    return self_physical.getPhysicalToLogicalMap().apply(result);
+  }
+
+  // At this point, we know at least one of the operands is a logical Scalar tensor.
+  // Here we must emulate TensorIterator's special behavior on Scalars.
+  //
+  // As a motivating example, consider the following:
+  //   x = torch.randn(3, 10)
+  //   y = torch.randn(3, dtype=torch.double)
+  //   vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
+  //
+  // At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
+  // Type Promotion dictates that the result should be FloatTensor[10].
+  // This means we cannot directly pass the physical tensors (x and y) to
+  // TensorIterator (if we did, it would promote them to DoubleTensor).
+  //
+  // FIXME(rzou): I didn't want to go down the slippery slope of emulating
+  // everything TensorIterator does (it would be better to refactor out the
+  // TensorIterator logic). The one thing that this code doesn't handle
+  // is cross-device logical scalar tensors.
+  //   cpu_tensor = torch.randn(3)
+  //   cuda_tensor = torch.randn(3, 10, device='cuda')
+  //   vmap(torch.mul)(cpu_tensor, cuda_tensor)
+  //
+  // At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
+  // TensorIterator allows for this cross-device operation because one of the
+  // tensors is a Scalar CPU tensor. However, the following code will throw an
+  // error in that case. I don't expect to see many use cases for this, so
+  // this is probably fine as-is.
+  auto logical_self = self;
+  auto logical_other = other;
+  auto result_type = at::native::result_type(logical_self, logical_other);
+  if (logical_self.scalar_type() != result_type) {
+    logical_self = logical_self.to(result_type);
+  }
+  if (logical_other.scalar_type() != result_type) {
+    logical_other = logical_other.to(result_type);
+  }
+  auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
+      {logical_self, logical_other});
+  auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
+  return physical_args[0].getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.expand(size, implicit);
+  }
+
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto size_physical = self_physical.getPhysicalShape(size);
+  auto self_physical_dim = self_physical.tensor().dim();
+
+  TORCH_CHECK(self_physical_dim <= size_physical.size(),
+       "expand: the number of sizes provided (", /*logical*/size.size(), ") ",
+       "must be greater or equal to the number of dimensions in the tensor (",
+       /*logical dim*/self.dim(), ")");
+
+  if (self_physical_dim == size_physical.size()) {
+    auto result = self_physical.tensor().expand(size_physical, implicit);
+    return self_physical.getPhysicalToLogicalMap().apply(result);
+  }
+
+  TORCH_INTERNAL_ASSERT(self_physical_dim < size_physical.size());
+  // Here, we know we are expanding a (logical) tensor to a larger number
+  // of dimensions. We have to be careful because we can't call expand directly
+  // due to the presence of batch dimensions.
+  //
+  // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
+  // The result should be a tensor of size [B0, 2, 3].
+  // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
+  // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
+  // then expand.
+  auto self_physical_size = self_physical.tensor().sizes();
+  auto extra_dims = size_physical.size() - self_physical_dim;
+  VmapDimVector view_shape(size_physical.size(), 1);
+  std::copy(self_physical_size.begin(),
+            self_physical_size.begin() + self_physical.numBatchDims(),
+            view_shape.begin());
+  std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
+            self_physical_size.end(),
+            view_shape.begin() + self_physical.numBatchDims() + extra_dims);
+  auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.chunk(chunks, dim);
+  }
+
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = at::chunk(self_physical.tensor(), chunks, dim_physical);
+  self_physical.getPhysicalToLogicalMap().applyInplace(result);
+  return result;
+}
+
+Tensor clamp_batching_rule(const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.clamp(min, max);
+  }
+
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto result = at::clamp(self_physical.tensor(), min, max);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor clamp_min_batching_rule(const Tensor& self, Scalar min) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::clamp_min(self, min);
+  }
+
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto result = at::clamp_min(self_physical.tensor(), min);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor clamp_max_batching_rule(const Tensor& self, Scalar max) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::clamp_max(self, max);
+  }
+
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto result = at::clamp_max(self_physical.tensor(), max);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::tensor_split(self, sections, dim);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
+  self_physical.getPhysicalToLogicalMap().applyInplace(result);
+  return result;
+}
+
+std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::tensor_split(self, indices, dim);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
+  self_physical.getPhysicalToLogicalMap().applyInplace(result);
+  return result;
+}
+
+Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::unsqueeze(self, dim);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  // NB: unsqueeze has some special handling of its `dim` argument so we can't call
+  // self_physical.getPhysicalDim directly. In particular, native::unsqueeze
+  // wraps the dim to (the logical dimension) + 1, so we need to do that here too.
+  // https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
+  auto dim_physical =
+      self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
+  auto result = self_physical.tensor().unsqueeze(dim_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+// Checks if the batch dims in `bdims` appear at the front of the tensor.
+static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
+  for (int64_t idx = 0; idx < bdims.size(); idx++) {
+    if (bdims[idx].dim() != idx) {
+      return false;
+    }
+  }
+  return true;
+}
+
+Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.squeeze_(dim);
+  }
+  auto* batched = maybeGetBatchedImpl(self);
+  TORCH_CHECK(areBdimsAtFrontInOrder(batched->bdims()), "NYI: squeeze_ with bdims not at front");
+  auto num_bdims = batched->bdims().size();
+  auto logical_dim = self.dim();
+  auto dim_physical = num_bdims + maybe_wrap_dim(dim, logical_dim);
+  batched->value().squeeze_(dim_physical);
+
+  // Also need to change some metadata...
+  batched->refreshSizesAndStrides();
+  return self;
+}
+
+Tensor& unsqueeze__batching_rule(Tensor& self, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.unsqueeze_(dim);
+  }
+  auto* batched = maybeGetBatchedImpl(self);
+  TORCH_CHECK(areBdimsAtFrontInOrder(batched->bdims()), "NYI: unsqueeze_ with bdims not at front");
+  auto num_bdims = batched->bdims().size();
+  auto logical_dim = self.dim();
+  auto dim_physical = num_bdims + maybe_wrap_dim(dim, logical_dim + 1);
+  batched->value().unsqueeze_(dim_physical);
+
+  // Also need to change some metadata...
+  batched->refreshSizesAndStrides();
+  return self;
+}
+
+Tensor& fill_inplace_scalar_batching_rule(Tensor& self, Scalar value) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.fill_(value);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  self_physical.tensor().fill_(value);
+  return self;
+}
+
+Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) {
+  auto value_batched = isBatchedTensor(value);
+
+  if (value_batched) {
+    auto physical_args =
+      BroadcastingVmapTransform::logicalToPhysical({self, value});
+    physical_args[0].tensor().copy_(physical_args[1].tensor());
+  } else {
+    auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+    self_physical.tensor().fill_(value);
+  }
+  return self;
+}
+
+Tensor& zero_inplace_batching_rule(Tensor &self) {
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  self_physical.tensor().zero_();
+  return self;
+}
+
+Tensor squeeze_batching_rule(const Tensor& self) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.squeeze();
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto physical_sizes = self_physical.tensor().sizes();
+
+  // Don't squeeze the batch dims!
+  VmapDimVector squeezed_sizes;
+  int64_t num_batch_dims = self_physical.numBatchDims();
+  squeezed_sizes.insert(
+      squeezed_sizes.end(),
+      physical_sizes.begin(),
+      physical_sizes.begin() + num_batch_dims);
+  for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
+    if (*it != 1) {
+      squeezed_sizes.push_back(*it);
+    }
+  }
+
+  auto result = self_physical.tensor().view(squeezed_sizes);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.squeeze(dim);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = self_physical.tensor().squeeze(dim_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor trace_batching_rule(const Tensor& self) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.trace();
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  // Batched Diagonal View
+  auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
+  auto result =  at::sum(self_diag, -1);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
+  if (!participatesInCurrentLevel(grad)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::trace_backward(grad, input_sizes);
+  }
+  auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
+  auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
+  // Batched Diagonal View
+  auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
+  // Append a dimension of size one to the grad output 
+  auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
+  grad_input_diag.copy_(grad_physical_tensor);
+  return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
+}
+
+Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::transpose(self, dim0, dim1);
+  }
+  // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
+  // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
+  // >>> x = torch.randn(B0)  # the per-examples are all scalars
+  // >>> vmap(lambda x: x.transpose(0, -1), x)
+  // then we replicate this behavior.
+  if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) &&
+      is_allowed_dim_on_scalar_tensor(dim1)) {
+    return self;
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim0_physical = self_physical.getPhysicalDim(dim0);
+  auto dim1_physical = self_physical.getPhysicalDim(dim1);
+  auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.permute(dims);
+  }
+
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dims_physical = self_physical.getPhysicalDims(dims);
+
+  VmapDimVector all_dims_physical;
+  all_dims_physical.reserve(self_physical.tensor().dim());
+  for (int64_t bdim = 0; bdim < self_physical.numBatchDims(); bdim++) {
+    all_dims_physical.push_back(bdim);
+  }
+  all_dims_physical.insert(
+      all_dims_physical.end(),
+      dims_physical.begin(),
+      dims_physical.end());
+  auto result = self_physical.tensor().permute(all_dims_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::select(self, dim, index);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = self_physical.tensor().select(dim_physical, index);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
+  return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
+}
+
+Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
+  if (!participatesInCurrentLevel(grad)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::select_backward(grad, input_sizes, dim, index);
+  }
+  auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
+  auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
+  auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
+  grad_input.select(physical_dim, index).copy_(grad_physical.tensor());
+  return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
+}
+
+Tensor slice_batching_rule(
+    const Tensor& self,
+    int64_t dim,
+    c10::optional<int64_t> start,
+    c10::optional<int64_t> end,
+    int64_t step) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::slice(self, dim, start, end, step);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = self_physical.tensor().slice(dim_physical, start, end, step);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+  if (!participatesInCurrentLevel(grad)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::slice_backward(grad, input_sizes, dim, start, end, step);
+  }
+  auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
+  auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
+  auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
+  grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor());
+  return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
+}
+
+Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::diagonal(self, offset, dim1, dim2);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim1_physical = self_physical.getPhysicalDim(dim1);
+  auto dim2_physical = self_physical.getPhysicalDim(dim2);
+  auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
+  if (!participatesInCurrentLevel(grad)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::diagonal_backward(grad, input_sizes, offset, dim1, dim2);
+  }
+  auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
+  auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
+  auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims());
+  auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims());
+  grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor());
+  return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
+}
+
+Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::movedim(self, source, destination);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto source_physical = self_physical.getPhysicalDims(source);
+  auto destination_physical = self_physical.getPhysicalDims(destination);
+  auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::reshape(self, shape);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto shape_physical = self_physical.getPhysicalShape(shape);
+  auto result = self_physical.tensor().reshape(shape_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::split(self, split_size, dim);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = at::split(self_physical.tensor(), split_size, dim_physical);
+  self_physical.getPhysicalToLogicalMap().applyInplace(result);
+  return result;
+}
+
+std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::split_with_sizes(self, split_sizes, dim);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
+  self_physical.getPhysicalToLogicalMap().applyInplace(result);
+  return result;
+}
+
+std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::unbind(self, dim);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = at::unbind(self_physical.tensor(), dim_physical);
+  self_physical.getPhysicalToLogicalMap().applyInplace(result);
+  return result;
+}
+
+Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.unfold(dim, size, step);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dim_physical = self_physical.getPhysicalDim(dim);
+  auto result = self_physical.tensor().unfold(dim_physical, size, step);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.contiguous(memory_format);
+  }
+  TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
+      "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ",
+      "than torch.contiguous_format");
+  auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto result = physical_view.tensor().contiguous(memory_format);
+  return physical_view.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return self.view(size);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto size_physical = self_physical.getPhysicalShape(size);
+  auto result = self_physical.tensor().view(size_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor flatten_batching_rule(const Tensor& self, int64_t start_dim, int64_t end_dim) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::flatten(self, start_dim, end_dim);
+  }
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto start_dim_physical = self_physical.getPhysicalDim(start_dim);
+  auto end_dim_physical = self_physical.getPhysicalDim(end_dim);
+  auto result = self_physical.tensor().flatten(start_dim, end_dim);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor view_as_complex_batching_rule(const Tensor& self) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::view_as_complex(self);
+  }
+  // guard against the user passing in a batch of scalar tensors with batch
+  // size equal to 2.
+  TORCH_CHECK(self.sizes().size() != 0, "Input tensor must have one or more dimensions");
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto result = at::view_as_complex(self_physical.tensor());
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+// Checks that the smallest batch stride is greater than the largest example
+// stride. This is something we can support but we choose not to because it's
+// potentially error prone.
+static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) {
+  auto smallest_batch_stride = std::min_element(
+      physical_strides.begin(), physical_strides.begin() + num_batch_dims);
+  auto largest_example_stride = std::max_element(
+      physical_strides.begin() + num_batch_dims, physical_strides.end());
+  if (largest_example_stride == physical_strides.end()) {
+    // No example dimensions
+    return;
+  }
+  TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride,
+    "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ",
+    "vmapped over are at the front of the tensor (in memory layout). When they are ",
+    "not at the front of the tensor this operation can be error prone so we "
+    "actively discourage it; please file us a bug report and/or try to ",
+    "express the as_strided operation in terms of PyTorch view operations");
+}
+
+// given (sizes, strides, storage_offset) returns the maximum location that
+// can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
+// with zero-size dims).
+static optional<int64_t> maximum_indexable_location(
+    IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
+  auto result = native::storage_size_for(sizes, strides);
+  if (result == 0) {
+    return nullopt;
+  }
+  return result + storage_offset;
+}
+
+// Let x be the "first slice" of physical_tensor.
+// This checks that the range of possible memory locations accessible by
+// x.as_strided(sizes, strides, maybe_storage_offset)
+// are within the bounds of possible memory locations accessible by x.
+static void checkBasicAsStridedValidForSlice(
+    const Tensor& physical_tensor,
+    int64_t num_batch_dims,
+    IntArrayRef sizes,
+    IntArrayRef strides,
+    optional<int64_t> maybe_storage_offset) {
+  auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims);
+  auto slice_strides = physical_tensor.strides().slice(num_batch_dims);
+  auto base_offset = physical_tensor.storage_offset();
+
+  auto storage_offset = maybe_storage_offset.value_or(base_offset);
+
+  auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
+  auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
+
+  if (!max_as_strided_loc.has_value()) {
+    return;
+  }
+  if (!max_slice_loc.has_value()) {
+    TORCH_CHECK(false,
+        "result = tensor.as_strided(", sizes, ",",  strides, ",", storage_offset, ")",
+        "can access memory outside of `tensor`. `tensor` has no storage but the ",
+        "passed-in (size, stride, storage_offset) imply a result with some storage. ",
+        "This is not supported inside of vmap, please try to rewrite the ",
+        "`as_strided` call as a sequence of PyTorch view operations");
+  }
+
+  TORCH_CHECK(
+      *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
+      "result = tensor.as_strided(", sizes, ",",  strides, ",", storage_offset, ")",
+      "can access memory outside of `tensor`. `result` can access some",
+      "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
+      "`tensor` can only access some memory in range [", base_offset, ", ",
+      *max_slice_loc, "]. This is not supported inside of vmap, please try to",
+      "rewrite the `as_strided` call as a sequence of PyTorch view operations");
+}
+
+// What are the semantics of as_strided inside of vmap?
+// y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
+// This returns a view on `x`, `y`, such that each y[i] has:
+// - sizes: `sizes`
+// - strides: `strides`
+// - storage_offset: offset + i * x.stride(batch_dim)
+//
+// In other words, it is as if we had treated each x[i] as having storage
+// offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
+// (that is equivalent to x[i].as_strided(
+//    sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
+//
+// Note that this *may* be different from actually running as_strided
+// in a for-loop. This is due to how as_strided takes in `offset` to be
+// an *absolute* offset. As an example, consider:
+// >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
+// >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
+// Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
+// However, we consider the above for-loop comprehension to be a user error:
+// a user should have written the following if they wanted to use as_strided
+// in a per-sample way:
+// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
+Tensor as_strided_batching_rule(
+    const Tensor& tensor,
+    IntArrayRef sizes,
+    IntArrayRef strides,
+    optional<int64_t> storage_offset) {
+  if (!participatesInCurrentLevel(tensor)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::as_strided(tensor, sizes, strides, storage_offset);
+  }
+  auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor);
+  auto num_batch_dims = physical_view.numBatchDims();
+  auto physical_sizes = physical_view.getPhysicalShape(sizes);
+  const auto& physical_tensor = physical_view.tensor();
+
+  // We can't rely on the physical as_strided call to do this for us because
+  // we do some sanity checks on the size/strides before calling into as_strided.
+  TORCH_CHECK(sizes.size() == strides.size(),
+      "Tensor.as_strided(size, stride, ...): size and stride must have the ",
+      "same length! Got size ", sizes, " and stride ", strides);
+
+  // Sanity checks:
+  // 1. All batch dims are at the front in memory layout (not necessary for
+  // correctness, but we are worried the user might be doing crazy things)
+  // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
+  // is valid for a slice of the input tensor.
+  // See Note: [When will the as_strided batching rule fail?] for details.
+  checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims);
+  checkBasicAsStridedValidForSlice(
+      physical_tensor, num_batch_dims, sizes, strides, storage_offset);
+
+  // physical_strides = physical tensor's batch strides + (logical) strides
+  auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
+  VmapDimVector physical_strides;
+  physical_strides.reserve(num_batch_dims + strides.size());
+  physical_strides.insert(
+      physical_strides.end(), batch_strides.begin(), batch_strides.end());
+  physical_strides.insert(
+      physical_strides.end(), strides.begin(), strides.end());
+
+  // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
+  // is valid for all i, then it turns out that
+  // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
+  // and creates a tensor y such that each y[i] references the same memory
+  // locations as zi. See NOTE: [When will the as_strided batching rule fail?]
+  auto result = physical_view.tensor().as_strided(
+      physical_sizes, physical_strides, storage_offset);
+  return physical_view.getPhysicalToLogicalMap().apply(result);
+}
+
+// NOTE: [When will the as_strided batching rule fail?]
+// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
+// is valid for all i, then it turns out that
+// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
+// creates a tensor y such that each y[i] refers to the same memory as zi.
+//
+// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
+// Furthermore, let's say that as a part of being "valid" this as_strided call
+// does not return a result that can index memory not indexable by xs[i].
+//
+// WLOG, assume that there's only one batch dim and it is at the front of the
+// `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
+// - If the batch dim isn't at the front of the tensor, then we can just move it
+// to the front with movedim/permute. This is always valid because it just swaps
+// some strides around.
+// - This proof also works for tensors with multiple batch dims. We just have to
+// do a little accounting:
+//   - instead of [B], we'd have [B0, B1, ..., Bk].
+//   - instead of [S], we'd have [S0, S1, ..., Sk].
+//   - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
+//   - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
+//
+// [Equation 1]
+// xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
+// - sizes: sizes
+// - strides: strides
+// - offset: offset + S * i
+//
+// x.as_strided itself checks that:
+// - (sizes, strides, offset) are in bounds for `x`'s storage.
+// - strides are positive
+// - offset is positive
+//
+// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
+// is valid, then
+// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
+//
+// If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
+// won't error out. So all we need to check is that the memory locations are
+// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
+//
+// xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
+// xs.as_strided([B] + sizes, [S] + strides, offset)
+//
+// xs.as_strided([B] + sizes, [S] + strides, offset) has:
+// - sizes: [B] + sizes
+// - strides: [S] + strides
+// - offset: offset
+//
+// xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
+// - sizes: sizes
+// - strides: strides
+// - offset: offset + S * i
+// These memory locations are exactly the same as what we got for [Equation 1],
+// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
+//
+// [Hand-wavy proof of Claim 1]
+// Part of our definition of being valid is that xs[i].as_strided(...)
+// must return a tensor that only uses memory indexable by xs[i].
+// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
+//    offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
+//    <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
+// (the largest-index memory location of xs[i].as_strided(...) must be \leq
+// the largest-index memory location of xs[i])
+//
+// Fiddling that inequality gives us:
+//    offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
+//    <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
+//
+//    offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
+//    <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
+//
+//    offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
+//    <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
+//
+//    offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
+//    <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
+// (the largest-index memory location of xs.as_strided(size, stride, offset)
+// is \leq than the largest-index memory location of xs)
+// Under the assumptions we've made, the lower bound (lowest indexed memory)
+// is trivially within the storage.
+//
+// Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
+// `xs`'s storage.
+
+template <typename F, F Func, typename... ExtraArgs>
+Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
+  if (!participatesInCurrentLevel(input)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return Func(input, args...);
+  }
+  // guard against the user passing in a batch of scalar tensors with batch
+  auto* input_batched = unsafeGetBatchedImpl(input);
+  auto output_physical = Func(input_batched->value(), args...);
+  auto old_bdims = input_batched->bdims();
+  return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+}
+
+template <typename F, F Func, typename... ExtraArgs>
+Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
+  if (!participatesInCurrentLevel(input)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return (input.*Func)(extra_args...);
+  }
+  auto* input_batched = unsafeGetBatchedImpl(input);
+  auto output_physical = (input_batched->value().*Func)(extra_args...);
+  auto old_bdims = input_batched->bdims();
+  return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+}
+
+Tensor pow_scalar_Tensor_batching_rule(Scalar other, const Tensor& self) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::pow(other, self);
+  }
+  auto* self_batched = unsafeGetBatchedImpl(self);
+  auto output_physical = at::pow(other, self_batched->value());
+  auto old_bdims = self_batched->bdims();
+  return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+}
+
+// Tensor ones_like_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) {
+//   if (!participatesInCurrentLevel(self)) {
+//     c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+//     return at::ones_like(self, memory_format);
+//   }
+// 
+//   TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
+//       || memory_format == MemoryFormat::Contiguous,
+//       "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
+//       "memory_format torch.preserve_format or torch.contiguous_format (got ",
+//       *memory_format, ")");
+// 
+//   if (memory_format == MemoryFormat::Contiguous) {
+//     auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
+//     auto output_physical = at::clone(physical_view.tensor(), memory_format);
+//     return physical_view.getPhysicalToLogicalMap().apply(output_physical);
+//   }
+// 
+//   TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
+//   auto* self_batched = unsafeGetBatchedImpl(self);
+//   auto output_physical = at::clone(self_batched->value(), memory_format);
+//   auto old_bdims = self_batched->bdims();
+//   return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+// }
+
+Tensor clone_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::clone(self, memory_format);
+  }
+  // Memory format support is a little tricky because vmap is allowed to move
+  // around batch dimensions and some memory formats are rank-dependent.
+  // Another weird case is:
+  // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
+  //   allow the user to clone a Tensor with 3 logical dimensions and 1 batch
+  //   dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
+  //   and N>1 batch dims?
+  TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
+      || memory_format == MemoryFormat::Contiguous,
+      "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
+      "memory_format torch.preserve_format or torch.contiguous_format (got ",
+      *memory_format, ")");
+
+  if (memory_format == MemoryFormat::Contiguous) {
+    // There is an ambiguity here when the batch dims are not at the front of
+    // the tensor.
+    // >>> x = torch.randn(3, B0, 5)
+    // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
+    // >>> y[0].is_contiguous()
+    // ???
+    // Should we make the whole tensor contiguous, or should we
+    // make the non-batch dims contiguous? We've chosen the latter because
+    // philosophically vmap hides the batch dims and operates on a per-sample level.
+    auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
+    auto output_physical = at::clone(physical_view.tensor(), memory_format);
+    return physical_view.getPhysicalToLogicalMap().apply(output_physical);
+  }
+
+  TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
+  auto* self_batched = unsafeGetBatchedImpl(self);
+  auto output_physical = at::clone(self_batched->value(), memory_format);
+  auto old_bdims = self_batched->bdims();
+  return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+}
+
+// Note [Batching rules for matmul-like operators]
+// at::matmul doesn't "de-expand" arguments to get better performance (maybe
+// it should). In the batching rules for matmul-like operators (dot, mv, mm),
+// we should be careful not to expand any unnecessary dimensions. e.g., if
+// only one of the two arguments is a BatchedTensor, then we should try
+// not to expand batch dimensions onto the other arg.
+Tensor mv_batching_rule(const Tensor& self, const Tensor& other) {
+  auto self_batched = isBatchedTensor(self);
+  auto other_batched = isBatchedTensor(other);
+
+  // A shape checking API would be nice...
+  TORCH_CHECK(self.dim() == 2 && other.dim() == 1,
+      "mv(self, other): Shape mismatch: expected matrix "
+      "(got `self` of size ", self.sizes(), ") ",
+      "and vector (got `other` of size ", other.sizes(), ")");
+
+  // See Note [Batching rules for matmul-like operators] for why we have cases
+  if (self_batched && !other_batched) {
+    auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+    auto result = at::matmul(self_physical.tensor(), other);
+    return self_physical.getPhysicalToLogicalMap().apply(result);
+  }
+  if (!self_batched && other_batched) {
+    // self_physical: [L, K], other_physical: [..., K]
+    // We view the tensors as [L, K], [..., K, 1], perform matmul to get
+    // a tensor of size [..., L, 1], and unsqueeze the last dim.
+    auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
+    auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
+    return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
+  }
+  if (self_batched && other_batched) {
+    // self_physical: [..., L, K], other_physical: [..., K]
+    // We view the tensors as [..., L, K], [..., K, 1], perform matmul to get
+    // a tensor of size [..., L, 1], and unsqueeze the last dim.
+    auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
+    auto result = at::matmul(
+        physical_args[0].tensor(),
+        physical_args[1].tensor().unsqueeze(-1));
+    return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1));
+  }
+  TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
+}
+
+Tensor dot_batching_rule(const Tensor& self, const Tensor& other) {
+  auto self_batched = isBatchedTensor(self);
+  auto other_batched = isBatchedTensor(other);
+
+  TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1,
+      "dot(self, other): Shape mismatch: vector "
+      "(got `self` of size ", self.sizes(), ") ",
+      "and vector (got `other` of size ", other.sizes(), ")");
+
+  // See Note [Batching rules for matmul-like operators] for why we have cases
+  if (self_batched && !other_batched) {
+    // self_physical: [..., K], other_physical: [K]
+    // View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze.
+    auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+    auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other);
+    return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
+  }
+  if (!self_batched && other_batched) {
+    // self_physical: [K], other_physical: [..., K]
+    // View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze.
+    auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
+    auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
+    return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
+  }
+  if (self_batched && other_batched) {
+    // self_physical: [..., K], other_physical: [..., K]
+    // View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze.
+    auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
+    auto result = at::matmul(
+        physical_args[0].tensor().unsqueeze(-2),
+        physical_args[1].tensor().unsqueeze(-1));
+    return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
+  }
+  TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
+}
+
+Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) {
+  TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3,
+      "bmm(self, other): Shape mismatch: expected 3D `self` "
+      "(got `self` of size ", self.sizes(), ") ",
+      "and 3D `other` (got `other` of size ", other.sizes(), ")");
+
+  auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
+  auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
+  return physical_args[0].getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
+  if (!participatesInCurrentLevel(self, other)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::mm(self, other);
+  }
+
+  auto self_batched = participatesInCurrentLevel(self);
+  auto other_batched = participatesInCurrentLevel(other);
+
+  TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2,
+      "mm(self, other): Shape mismatch: expected matrix "
+      "(got `self` of size ", self.sizes(), ") ",
+      "and matrix (got `other` of size ", other.sizes(), ")");
+
+  // See Note [Batching rules for matmul-like operators] for why we have cases
+  if (self_batched && !other_batched) {
+    auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    auto result = at::matmul(self_physical.tensor(), other);
+    result = self_physical.getPhysicalToLogicalMap().apply(result);
+    TORCH_INTERNAL_ASSERT(result.dim() == 2);
+    return result;
+  }
+  if (!self_batched && other_batched) {
+    auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    auto result = at::matmul(self, other_physical.tensor());
+    result = other_physical.getPhysicalToLogicalMap().apply(result);
+    TORCH_INTERNAL_ASSERT(result.dim() == 2);
+    return result;
+  }
+  if (self_batched && other_batched) {
+    auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
+    TORCH_INTERNAL_ASSERT(result.dim() == 3);
+    result = physical_args[0].getPhysicalToLogicalMap().apply(result);
+    TORCH_INTERNAL_ASSERT(result.dim() == 2);
+    return result;
+  }
+  TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
+}
+
+Tensor cat_batching_rule(TensorList tensors, int64_t dim) {
+  if (!participatesInCurrentLevel(tensors)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::cat(tensors, dim);
+  }
+  auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
+  auto physical_tensors = fmap(
+      physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
+  TORCH_INTERNAL_ASSERT(
+      tensors.size() > 0, "The dispatcher should not have dispatched here otherwise.");
+  auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
+  return physical_views[0].getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
+  if (!participatesInCurrentLevel(tensors)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    return at::stack(tensors, dim);
+  }
+  auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
+  auto physical_tensors = fmap(
+      physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
+  TORCH_INTERNAL_ASSERT(
+      tensors.size() > 0, "The dispatcher should not have dispatched here otherwise.");
+  // NB: stack wraps the dimensionality to (logical dim + 1), so we have to
+  // manually handle that here.
+  auto dim_physical =
+      physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
+  auto result = at::stack(physical_tensors, dim_physical);
+  return physical_views[0].getPhysicalToLogicalMap().apply(result);
+}
+
+// I am quite sad that we need to register operators with exploded TensorOptions,
+// even though the native:: implementations can use TensorOptions&.
+// This also makes it hard to metaprogram: i.e., we can't use
+// unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!)
+Tensor to_dtype_layout_batching_rule(
+    const Tensor& self,
+    optional<ScalarType> dtype,
+    optional<Layout> layout,
+    optional<Device> device,
+    optional<bool> pin_memory,
+    bool non_blocking, bool copy,
+    optional<MemoryFormat> memory_format) {
+  auto options = TensorOptions()
+    .dtype(dtype)
+    .layout(layout)
+    .device(device)
+    .pinned_memory(pin_memory);
+  auto* input_batched = unsafeGetBatchedImpl(self);
+  auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format);
+  auto old_bdims = input_batched->bdims();
+  return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+}
+
+Tensor new_zeros_batching_rule(
+    const Tensor& self,
+    IntArrayRef size,
+    optional<ScalarType> dtype,
+    optional<Layout> layout,
+    optional<Device> device,
+    optional<bool> pin_memory) {
+  auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto physical_size = physical_view.getPhysicalShape(size);
+  auto options = TensorOptions()
+    .dtype(dtype)
+    .layout(layout)
+    .device(device)
+    .pinned_memory(pin_memory);
+  auto result = physical_view.tensor().new_zeros(physical_size, options);
+  return physical_view.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor new_empty_batching_rule(
+    const Tensor& self,
+    IntArrayRef size,
+    c10::optional<ScalarType> dtype,
+    c10::optional<Layout> layout,
+    c10::optional<Device> device,
+    c10::optional<bool> pin_memory) {
+  auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto physical_size = physical_view.getPhysicalShape(size);
+  auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
+  return physical_view.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor addmm_batching_rule(const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
+  // Decomposition that is probably not very fast...
+  return at::add(self * beta, at::mm(mat1, mat2), alpha);
+}
+
+Tensor ones_like_batching_rule(
+    const Tensor& self,
+    optional<ScalarType> dtype,
+    optional<Layout> layout,
+    optional<Device> device,
+    optional<bool> pin_memory,
+    optional<MemoryFormat> memory_format) {
+  if (!participatesInCurrentLevel(self)) {
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    auto options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
+    return at::ones_like(self, options, memory_format);
+  }
+  auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
+  auto result = at::ones_like(physical_view.tensor(), options, memory_format);
+  return physical_view.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor new_empty_strided_batching_rule(
+    const Tensor& self,
+    IntArrayRef size,
+    IntArrayRef stride,
+    optional<ScalarType> dtype,
+    optional<Layout> layout,
+    optional<Device> device,
+    optional<bool> pin_memory) {
+  auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto physical_size = physical_view.getPhysicalShape(size);
+
+  // Let [B0, B1, B2] be the shape of the batch dims. We're going to create
+  // the batch dimensions at the front of the tensor (in memory layout),
+  // irrespective of whether or not they are actually at the front (in memory layout)
+  // in the original `self` tensor. This is because when a user calls
+  // `new_empty_strided` in general, the `strides` they provide are for a new
+  // tensor and have no relation to the strides of the original tensor.
+  //
+  // So, the physical shape of the result should be ([B0, B1, B2] + size),
+  // but what about the physical strides?
+  //
+  // We're actually free to pick whatever stride we want:
+  // e.g., for size=[5, 3], stride=[0, 1], we could decide to
+  // use
+  // - physical size: [B0, B1, B2, 5, 3]
+  // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
+  //
+  // Let's select some reasonable strides such that:
+  // - The batch dims are "contiguous" with respect to each other
+  // - if empty_strided(size, stride) would have created a contiguous Tensor,
+  // then this new physical Tensor (with batch dims) is also contiguous
+  //
+  // Let S be the size of the storage if one were to construct a tensor
+  // with `size` and `stride` via empty_strided(size, stride).
+  // Then the physical sizes/strides should be:
+  // - physical size: [B0, B1, B2, 5, 3]
+  // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
+  auto batch_shape = IntArrayRef(
+      physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
+
+  // physical_strides = [B1 * B2 * S, B2 * S, S]
+  auto physical_strides = at::detail::defaultStrides(batch_shape);
+  TORCH_CHECK(size.size() == stride.size(),
+        "new_empty_strided(sizes, strides): dimensionality of sizes (",
+        size.size(), ") must match dimensionality of strides (",
+        stride.size(), ")");
+  auto storage_size = native::storage_size_for(size, stride);
+  for (auto& physical_stride : physical_strides) {
+    physical_stride *= storage_size;
+  }
+
+  // physical_strides = [B1 * B2 * S, B2 * S, S] + strides
+  physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
+
+  auto result = physical_view.tensor().new_empty_strided(
+      physical_size, physical_strides, dtype, layout, device, pin_memory);
+  return physical_view.getPhysicalToLogicalMap().apply(result);
+}
+
+template <typename F, F Func>
+Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) {
+  auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
+  auto result = Func(physical_args[0].tensor(), physical_args[1].tensor());
+  return physical_args[0].getPhysicalToLogicalMap().apply(result);
+}
+
+bool BatchedTensor_is_leaf(const Tensor& self) {
+  if (torch::autograd::impl::get_autograd_meta(self)) {
+    return torch::autograd::impl::get_autograd_meta(self)->grad_fn_ == nullptr;
+  } else {
+    return true;
+  }
+}
+
+Tensor& BatchedTensor_requires_grad_(Tensor& self, bool requires_grad) {
+  self.set_requires_grad(requires_grad);
+  return self;
+}
+
+
+TORCH_LIBRARY_IMPL(_, BatchedOutOfTree, m) {
+  m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
+}
+
+// // debug_t<tail_t<tail_t<typelist<Tensor, optional<int64_t>>>>> dt;
+// debug_t<remove_batch_dim_after_tensor_t<typelist<Tensor, optional<int64_t>>>> dt;
+
+static Tensor makeBatched(const Tensor& tensor, optional<int64_t> bdim, int64_t level) {
+  if (bdim.has_value()) {
+    return makeBatched(tensor, {{level, bdim.value()}});
+  }
+  return tensor;
+}
+
+
+std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, optional<int64_t> batch_dim) {
+  return {tensor.abs(), batch_dim};
+}
+
+std::tuple<Tensor, optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level) {
+  auto* batched = maybeGetBatchedImpl(tensor);
+  if (!batched) {
+    return {tensor, nullopt};
+  }
+  TORCH_INTERNAL_ASSERT(batched->bdims().size() == 1);
+  auto batched_level = batched->bdims().back().level();
+  if (batched_level == level) {
+    auto bdim = batched->bdims().back().dim();
+    return {batched->value(), bdim};
+  }
+  return {tensor, nullopt};
+}
+
+typedef std::tuple<Tensor, optional<int64_t>> (*something_t)(const Tensor&, optional<int64_t>);
+
+template <>
+Tensor lowerToNextLayer<something_t, Tensor, const Tensor&>(
+    something_t batch_rule,
+    const Tensor& tensor) {
+  c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
+  int64_t cur_level = maybe_layer->layerId();
+  auto unwrapped = unwrapTensorAtLevel(tensor, cur_level);
+  auto unwrapped_result = batch_rule(std::get<0>(unwrapped), std::get<1>(unwrapped));
+  return makeBatched(std::get<0>(unwrapped_result), std::get<1>(unwrapped_result), cur_level);
+}
+
+// Tensor absBatchRule(const Tensor& tensor) {
+//   return lowerToNextLayer(abs_batch_rule, tensor);
+// }
+// 
+//template <typename Result, typename... Args>
+//Result primBatchRule(Args... args) {
+//  return lowerToNextLayer(abs_batch_rule, std::forward<Args>(args)...);
+//}
+
+// template <typename func_t>
+// typename c10::guts::function_traits<func_t>::result_type primBatchRule(const Tensor& tensor) {
+//   return lowerToNextLayer(abs_batch_rule, tensor);
+// }
+
+// std::function<Tensor(const Tensor&)> kFunc = primBatchRule<Tensor, const Tensor&>;
+// // std::function<Tensor(const Tensor&)> kFunc3 = PrimBatchRule3<decltype(&abs_batch_rule), &abs_batch_rule>::apply;
+// 
+// std::function<Tensor(const Tensor&)> kFunc3 = PrimBatchRule3<decltype(&abs_batch_rule), &abs_batch_rule>::apply;
+//
+// debug_t<decltype(abs_batch_rule)> foobar;
+
+TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
+  // m.impl("abs", kFunc3);
+  //m.impl("abs", PrimBatchRule6<Tensor(const Tensor&)>::apply);
+  // m.impl("abs", PrimBatchRule7<decltype(&abs_batch_rule), &abs_batch_rule, Tensor(const Tensor&)>::apply);
+  m.impl("abs", PrimBatchRule7<decltype(&abs_batch_rule), &abs_batch_rule, to_operator_t<decltype(abs_batch_rule)>>::apply);
+
+  // NB: Ideally we would like some operators, like size.int, to "fallthrough"
+  // to the underlying implementation. However, because a BatchedTensor is a
+  // Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution
+  // here is to just directly call the underlying implementation.
+  m.impl("size.int", static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size));
+  m.impl("_add_batch_dim", native::_add_batch_dim);
+  m.impl("_remove_batch_dim", native::_remove_batch_dim);
+
+  m.impl("max_pool2d", at::native::max_pool2d); // composite
+  m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);
+
+  m.impl("mean", mean_batching_rule);
+  m.impl("sum.dim_IntList", sum_batching_rule);
+  m.impl("log_softmax.int", log_softmax_batching_rule);
+  m.impl("_log_softmax", _log_softmax_batching_rule);
+  m.impl("is_complex", native::is_complex);
+  m.impl("conj", native::conj);
+  m.impl("flatten.using_ints", flatten_batching_rule);
+  m.impl("cross_entropy_loss", native::cross_entropy_loss);
+// 
+//   // inplace operations
+//   m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
+//   m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule);
+//   m.impl("zero_", zero_inplace_batching_rule);
+
+//   // autograd things...
+//   m.impl("is_leaf", BatchedTensor_is_leaf);
+//   m.impl("requires_grad_", BatchedTensor_requires_grad_);
+
+  // view operations
+  m.impl("as_strided", as_strided_batching_rule);
+  m.impl("chunk", chunk_batching_rule);
+  m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
+  m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
+  m.impl("diagonal", diagonal_batching_rule);
+  m.impl("expand", expand_batching_rule);
+  m.impl("expand_as", native::expand_as); // composite wrt autograd
+  m.impl("movedim.intlist", movedim_batching_rule);
+  m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
+  // NB: static_cast because there's another variant of narrow. However, we don't
+  // want to support the other variant yet bc it isn't documented...
+  m.impl("narrow", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t,int64_t)>(native::narrow)); // composite wrt autograd
+  m.impl("numpy_T", native::numpy_T); // composite wrt autograd
+  m.impl("permute", permute_batching_rule);
+  m.impl("reshape", reshape_batching_rule);
+  m.impl("reshape_as", native::reshape_as); // composite wrt autograd
+  m.impl("select.int", select_batching_rule);
+  m.impl("slice.Tensor", slice_batching_rule);
+  m.impl("split.Tensor", split_batching_rule);
+  m.impl("split_with_sizes", split_with_sizes_batching_rule);
+  m.impl("squeeze", squeeze_batching_rule);
+  m.impl("squeeze.dim", squeeze_dim_batching_rule);
+  m.impl("squeeze_.dim", squeeze_dim__batching_rule);
+  m.impl("t", native::t); // composite wrt autograd
+  m.impl("trace", trace_batching_rule);
+  m.impl("transpose.int", transpose_int_batching_rule);
+  m.impl("unbind.int", unbind_batching_rule);
+  m.impl("unfold", unfold_batching_rule);
+  m.impl("unsqueeze", unsqueeze_batching_rule);
+  m.impl("unsqueeze_", unsqueeze__batching_rule);
+  m.impl("view", view_batching_rule);
+  m.impl("view_as", native::view_as); // composite wrt autograd
+
+//   m.impl("addmm", addmm_batching_rule);
+// 
+  // clamp operations
+//   m.impl("clamp", clamp_batching_rule);
+//   m.impl("clamp_min", clamp_min_batching_rule);
+//   m.impl("clamp_max", clamp_max_batching_rule);
+
+// unary pointwise, out-of-place, no additional arguments.
+#define UNARY_POINTWISE(op) m.impl(#op, \
+    unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
+  // m.impl("abs", PrimBatchRule3<decltype(&abs_batch_rule), &abs_batch_rule>::apply);
+  // UNARY_POINTWISE(abs);
+  UNARY_POINTWISE(acos);
+  UNARY_POINTWISE(asin);
+  UNARY_POINTWISE(atan);
+  UNARY_POINTWISE(ceil);
+  UNARY_POINTWISE(cos);
+  UNARY_POINTWISE(cosh);
+  UNARY_POINTWISE(_conj);
+  UNARY_POINTWISE(digamma);
+  UNARY_POINTWISE(exp);
+  UNARY_POINTWISE(expm1);
+  UNARY_POINTWISE(floor);
+  UNARY_POINTWISE(frac);
+  UNARY_POINTWISE(lgamma);
+  UNARY_POINTWISE(log);
+  UNARY_POINTWISE(log10);
+  UNARY_POINTWISE(log1p);
+  UNARY_POINTWISE(log2);
+  UNARY_POINTWISE(neg);
+  UNARY_POINTWISE(reciprocal);
+  UNARY_POINTWISE(relu);
+  UNARY_POINTWISE(round);
+  UNARY_POINTWISE(rsqrt);
+  UNARY_POINTWISE(sigmoid);
+  UNARY_POINTWISE(sign);
+  UNARY_POINTWISE(sin);
+  UNARY_POINTWISE(sinh);
+  UNARY_POINTWISE(sqrt);
+  UNARY_POINTWISE(tan);
+  UNARY_POINTWISE(tanh);
+  UNARY_POINTWISE(trunc);
+#undef UNARY_POINTWISE
+#define TO_BATCHING_RULE(name, ...) \
+  { \
+    using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \
+    m.impl(name, unwrap_and_call_method< \
+        to_type, &Tensor::to, __VA_ARGS__>);\
+  }
+  TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, optional<MemoryFormat>)
+  TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, optional<MemoryFormat>)
+  TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, optional<MemoryFormat>)
+  m.impl("to.dtype_layout", to_dtype_layout_batching_rule);
+#undef TO_BATCHING_RULE
+  m.impl("clone", clone_batching_rule);
+  // m.impl("ones_like", ones_like_batching_rule);
+
+  using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
+  using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
+  using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
+
+#define BINARY_POINTWISE(op) \
+  m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
+  m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
+#define BINARY_POINTWISE_VA(op, ...) \
+  { \
+    using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
+    using Unop = Tensor (*)(const Tensor&, Scalar, __VA_ARGS__); \
+    m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
+    m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, Scalar, __VA_ARGS__>); \
+  }
+
+//   BINARY_POINTWISE_VA(add, Scalar);
+//   BINARY_POINTWISE_VA(sub, Scalar);
+//   BINARY_POINTWISE_VA(rsub, Scalar);
+//   BINARY_POINTWISE(mul);
+//   BINARY_POINTWISE(div);
+// 
+//   // at::pow has three out-of-place overloads
+//   m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
+//   m.impl("pow.Tensor_Scalar", unwrap_and_call<TensorScalarType, at::pow, Scalar>);
+//   m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
+// 
+//   m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
+//   m.impl(
+//       "threshold_backward",
+//       binary_pointwise_batching_rule<
+//           TensorTensorScalarType,
+//           at::threshold_backward,
+//           Scalar>);
+// 
+  // for at::result_type, call the native::result_type implementation.
+  // We don't have to do anything special because native::result_type operates
+  // on the logical shape of the tensors.
+  m.impl("result_type.Tensor", static_cast<ScalarType (*)(const Tensor&, const Tensor&)>(native::result_type));
+  m.impl("result_type.Scalar", static_cast<ScalarType (*)(const Tensor&, const Scalar&)>(native::result_type));
+  m.impl("result_type.Scalar_Tensor", static_cast<ScalarType (*)(const Scalar&, const Tensor&)>(native::result_type));
+  m.impl("result_type.Scalar_Scalar", static_cast<ScalarType (*)(const Scalar&, const Scalar&)>(native::result_type));
+// 
+// #undef BINARY_POINTWISE_VA
+// #undef BINARY_POINTWISE
+// 
+// 
+#define TRIVIAL_OP(op) m.impl(#op, \
+    unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
+  // complex number view operators
+  TRIVIAL_OP(imag)
+  TRIVIAL_OP(real);
+  TRIVIAL_OP(view_as_real);
+  m.impl("view_as_complex", view_as_complex_batching_rule);
+// #undef TRIVIAL
+// // 
+// //   // matmul-like operators
+// //   m.impl("mv", mv_batching_rule);
+// //   m.impl("dot", dot_batching_rule);
+// //   m.impl("bmm", bmm_batching_rule);
+  m.impl("mm", mm_batching_rule);
+// // 
+  // cat/stack
+  m.impl("cat", cat_batching_rule);
+  m.impl("stack", stack_batching_rule);
+// // 
+// //   // backward operators
+// //   m.impl("select_backward", select_backward_batching_rule);
+// //   m.impl("slice_backward", slice_backward_batching_rule);
+// //   m.impl("trace_backward", trace_backward_batching_rule);
+// //   m.impl("diagonal_backward", diagonal_backward_batching_rule);
+// // 
+// //   // Tensor.new_* operators
+//   m.impl("ones_like", ones_like_batching_rule);
+// //   m.impl("new_empty", new_empty_batching_rule);
+//   m.impl("new_empty_strided", new_empty_strided_batching_rule);
+// //   m.impl("new_zeros", new_zeros_batching_rule);
+// // 
+// //   m.impl("contiguous", contiguous_batching_rule);
+// // 
+// //   // Comparison ops
+// // #define COMPARISON_POINTWISE(op) \
+// //   m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \
+// //   m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
+// // 
+// //   COMPARISON_POINTWISE(eq);
+// //   COMPARISON_POINTWISE(gt);
+// //   COMPARISON_POINTWISE(ge);
+// //   COMPARISON_POINTWISE(le);
+// //   COMPARISON_POINTWISE(lt);
+// //   COMPARISON_POINTWISE(ne);
+// // 
+// #undef COMPARISON_POINTWISE
+}
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/Constants.h b/functorch/functorch/csrc/Constants.h
new file mode 100644
index 0000000..b0edd68
--- /dev/null
+++ b/functorch/functorch/csrc/Constants.h
@@ -0,0 +1,9 @@
+#pragma once
+#include <c10/core/DispatchKey.h>
+
+namespace at {
+namespace functorch {
+
+constexpr auto kBatchedKey = c10::DispatchKey::BatchedOutOfTree;
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/DynamicLayer.cpp b/functorch/functorch/csrc/DynamicLayer.cpp
new file mode 100644
index 0000000..6b9283e
--- /dev/null
+++ b/functorch/functorch/csrc/DynamicLayer.cpp
@@ -0,0 +1,377 @@
+#include <functorch/csrc/DynamicLayer.h>
+#include <functorch/csrc/TensorWrapper.h>
+
+#include <torch/library.h>
+#include <c10/core/impl/LocalDispatchKeySet.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <torch/csrc/autograd/variable.h>
+#include <c10/util/ThreadLocalDebugInfo.h>
+
+namespace at {
+namespace functorch {
+
+// Initial autograd layer, because autograd is always "on"
+// thread_local std::vector<DynamicLayer> dynamicLayerStack = { DynamicLayer(DispatchKey::Autograd, 1) };
+
+using DynmetaData = std::unordered_map<int64_t, std::shared_ptr<bool>>;
+DynmetaData kDynMetaDataSingleton;
+
+static DynmetaData& getGlobalDynmetaData() {
+  return kDynMetaDataSingleton;
+}
+
+class DynamicLayerStackHolder : public c10::DebugInfoBase {
+ public:
+  DynamicLayerStackHolder() {}
+  virtual ~DynamicLayerStackHolder() {}
+
+  std::vector<DynamicLayer> dynamicLayerStack = { DynamicLayer(DispatchKey::Autograd, 1) };
+};
+
+thread_local std::shared_ptr<DynamicLayerStackHolder> kDynamicLayerStack;
+
+static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
+  if (kDynamicLayerStack == nullptr) {
+    kDynamicLayerStack = std::make_shared<DynamicLayerStackHolder>();
+    c10::ThreadLocalDebugInfo::_push(
+        // TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use
+        c10::DebugInfoKind::PRODUCER_INFO,
+        kDynamicLayerStack);
+  }
+  TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr);
+  // TODO: can figure out how to memoize this. std::call_once with thread_local?
+  return kDynamicLayerStack->dynamicLayerStack;
+}
+
+std::shared_ptr<bool> getLifeHandleForLevel(int64_t level) {
+  auto it = getGlobalDynmetaData().find(level);
+  TORCH_INTERNAL_ASSERT(it != kDynMetaDataSingleton.end(), "level should be alive");
+  return it->second;
+}
+
+optional<DynamicLayer> maybeCurrentDynamicLayer() {
+  auto& dynamicLayerStack = dynamicLayerStackAccessor();
+  // NB: Exception for regular autograd, maybe tweak this
+  if (dynamicLayerStack.size() <= 1) {
+    return {};
+  }
+  return dynamicLayerStack.back();
+}
+
+const std::vector<DynamicLayer>& getDynamicLayerStack() {
+  return dynamicLayerStackAccessor();
+}
+
+void setDynamicLayerStack(const std::vector<DynamicLayer>& stack) {
+  dynamicLayerStackAccessor() = stack;
+}
+
+static DynamicLayer popDynamicLayer() {
+  auto& dynamicLayerStack = dynamicLayerStackAccessor();
+  TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
+  auto result = dynamicLayerStack.back();
+  TORCH_INTERNAL_ASSERT(result.key() != DispatchKey::Undefined);
+  dynamicLayerStack.pop_back();
+
+  if (dynamicLayerStack.size() == 0) {
+    // std::cout << "DynamicLayer off" << std::endl;
+    c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, false);
+    c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, false);
+  }
+
+  return result;
+}
+
+static int64_t pushDynamicLayer(DispatchKey key) {
+  auto& dynamicLayerStack = dynamicLayerStackAccessor();
+  TORCH_INTERNAL_ASSERT(key != DispatchKey::Undefined);
+  TORCH_INTERNAL_ASSERT(key != DispatchKey::Batched);
+  auto layerId = 1 + dynamicLayerStack.size();
+  dynamicLayerStack.emplace_back(key, layerId);
+
+  if (layerId == 2) {
+    // std::cout << "DynamicLayer on" << std::endl;
+    c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
+    c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
+  }
+
+  return layerId;
+}
+
+int64_t initAndPushDynamicLayer(DispatchKey key) {
+  auto layerId = pushDynamicLayer(key);
+  auto& data = getGlobalDynmetaData();
+  TORCH_INTERNAL_ASSERT(data.find(layerId) == data.end());
+  data[layerId] = std::make_shared<bool>(true);
+  return layerId;
+}
+
+DynamicLayer popDynamicLayerAndDeleteMetadata() {
+  auto result = popDynamicLayer();
+  auto level = result.layerId();
+
+  // TODO: is this lock safe? No one else should be writing to the same bucket
+  if (c10::show_dispatch_trace_enabled()) {
+    std::cout << "deleting metadata" << std::endl;
+  }
+  auto& data = getGlobalDynmetaData();
+  auto it = data.find(level);
+  if (it == data.end()) {
+    return result;
+  }
+  if (c10::show_dispatch_trace_enabled()) {
+    std::cout << "deleted metadata for level " << level << std::endl;
+  }
+  // invalidate the thing
+  *(it->second) = false;
+  data.erase(level);
+  return result;
+}
+
+static Tensor materializeGradWrappers(const Tensor& tensor, const std::vector<DynamicLayer>& dynlayerStack) {
+  if (!tensor.defined()) {
+    return tensor;
+  }
+  // TODO: First entry in the stack is a default autograd key.
+  // We should clean up the logic
+  if (dynlayerStack.size() <= 1) {
+    return tensor;
+  }
+  if (dynlayerStack.back().key() != DispatchKey::Autograd) {
+    return tensor;
+  }
+  auto cur_level = dynlayerStack.back().layerId();
+  auto* wrapper = maybeGetTensorWrapper(tensor);
+  if (!wrapper) {
+    return makeTensorWrapper(tensor, cur_level);
+  }
+  TORCH_INTERNAL_ASSERT(wrapper->level().value() <= cur_level, "escaped?");
+  if (wrapper->level().value() == cur_level) {
+    TORCH_INTERNAL_ASSERT(tensor.defined());
+    return tensor;
+  }
+  return makeTensorWrapper(tensor, cur_level);
+}
+
+static Tensor unwrapIfDead(const Tensor& tensor) {
+  auto* wrapped = maybeGetTensorWrapper(tensor);
+  if (!wrapped) {
+    return tensor;
+  }
+  if (wrapped->is_alive()) {
+    return tensor;
+  }
+  return wrapped->value();
+}
+
+static void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
+    std::function<Tensor(const Tensor&)> func) {
+  TORCH_INTERNAL_ASSERT(begin >= 0);
+  TORCH_INTERNAL_ASSERT(end >= 0);
+  TORCH_INTERNAL_ASSERT(begin <= end);
+  for (int64_t idx = begin; idx < end; idx++) {
+    auto ivalue = args[idx];
+    if (ivalue.isTensorList()) {
+      auto list = ivalue.toTensorList();
+      for (int64_t list_idx = 0; list_idx < list.size(); list_idx++) {
+        list[list_idx] = func(list[list_idx]);
+      }
+      args[idx] = list;
+    }
+    TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict");
+    if (!ivalue.isTensor()) {
+      continue;
+    }
+    Tensor value = ivalue.toTensor();
+    Tensor replacement = func(value);
+    args[idx] = std::move(replacement);
+    // sanity checks
+    if (ivalue.toTensor().defined()) {
+      TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
+    }
+  }
+}
+
+constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
+  DispatchKey::DynamicLayerFront,
+  DispatchKey::DynamicLayerBack,
+  DispatchKey::TensorWrapper,
+  // DispatchKey::Batched,
+  DispatchKey::BatchedOutOfTree,
+  DispatchKey::InplaceOrView
+}) | autograd_dispatch_keyset;
+
+static void sanityCheckStack(torch::jit::Stack* stack) {
+  if (stack->size() > 0) {
+    auto last_ivalue = (*stack)[stack->size() - 1];
+    if (last_ivalue.isTensor()) {
+      auto tensor = last_ivalue.toTensor();
+      auto* wrapper = maybeGetTensorWrapper(tensor);
+      TORCH_INTERNAL_ASSERT(wrapper == nullptr);
+      TORCH_INTERNAL_ASSERT(tensor.has_storage());
+    }
+  }
+}
+
+void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  auto& dynamicLayerStack = dynamicLayerStackAccessor();
+  if (c10::show_dispatch_trace_enabled()) {
+    std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
+  }
+  if (dynamicLayerStack.size() == 0) {
+    sanityCheckStack(stack);
+    c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);
+    op.callBoxed(stack);
+    return;
+  }
+
+  // Unwrap dead GradWrappers, materialize live ones
+  auto maybeTransformGradWrappers = [](const Tensor& tensor) {
+    auto result = unwrapIfDead(tensor);
+    return materializeGradWrappers(result, getDynamicLayerStack());
+  };
+  auto num_args = op.schema().arguments().size();
+  foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), maybeTransformGradWrappers);
+
+  auto layer = dynamicLayerStack.back();
+
+  DispatchKeySet exclude = DispatchKeySet::FULL;
+  exclude = exclude.remove(DispatchKey::DynamicLayerBack);
+  if (layer.key() == DispatchKey::Autograd) {
+    exclude = exclude - autograd_dispatch_keyset;
+    exclude = exclude.remove(DispatchKey::InplaceOrView);
+  // } else if (layer.key() == DispatchKey::Batched) {
+  //   exclude = exclude.remove(DispatchKey::Batched);
+  } else if (layer.key() == DispatchKey::BatchedOutOfTree) {
+    exclude = exclude.remove(DispatchKey::BatchedOutOfTree);
+  } else {
+    TORCH_INTERNAL_ASSERT(false);
+  }
+  c10::impl::ExcludeDispatchKeyGuard guard(exclude);
+
+  // Re-dispatch
+  op.callBoxed(stack);
+}
+
+struct WithoutTop {
+  WithoutTop(): layer_(popDynamicLayer()) {
+  }
+  ~WithoutTop() {
+    pushDynamicLayer(layer_.key());
+  }
+
+  bool prev_grad_enabled_;
+  DynamicLayer layer_;
+};
+
+void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  auto cur_level = getDynamicLayerStack().back().layerId();
+  auto cur_key = getDynamicLayerStack().back().key();
+
+  auto unwrap = [&](const Tensor& tensor) {
+    if (!tensor.defined()) {
+      return tensor;
+    }
+    auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
+    if (!maybe_tensor_wrapper) {
+      return tensor;
+    }
+    if (maybe_tensor_wrapper->level().value() == cur_level) {
+      return maybe_tensor_wrapper->value();
+    }
+    if (c10::show_dispatch_trace_enabled()) {
+      std::cout << "unwrap " << cur_level << std::endl;
+    }
+    return tensor;
+  };
+  auto wrap = [&](const Tensor& tensor) {
+    if (!tensor.defined()) {
+      return tensor;
+    }
+    if (cur_level == 1) {
+      return tensor;
+    }
+    if (c10::show_dispatch_trace_enabled()) {
+      std::cout << "wrap " << cur_level << std::endl;
+    }
+    return makeTensorWrapper(tensor, cur_level);
+  };
+
+  // TODO: we only need to do the following (marked with !) on in-place functions
+  // that modify sizes or strides. There aren't many of them.
+  // If autograd dispatch key:
+  // 1. (!) Put a copy of all of the args onto the stack
+  // 2. Unwrap all the args in the copy set
+  // 3. Call the operator
+  // 4. Wrap the output
+  // 5. (!) refreshSizesAndStrides for all the args in the original set
+  // 6. (!) Pop those args off.
+
+  // Step 1 & 2
+  if (cur_key == DispatchKey::Autograd) {
+    auto args_size = op.schema().arguments().size();
+    // Step 1
+    auto front = stack->size() - args_size;
+    for (int64_t arg_idx = 0; arg_idx < args_size; arg_idx++) {
+      stack->push_back((*stack)[front + arg_idx]);
+    }
+    // Step 2
+    foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
+  }
+
+  // pop the top layer. Put it back on dtor.
+  WithoutTop guard;
+
+  // "reset exclude set"
+  // TODO: Still a problem with composabiilty and AutoNonVariableTypeGuard.
+  // Users cannot do torch.no_grad otherwise there will be problems.
+  auto keyset = c10::impl::PODLocalDispatchKeySet();
+  c10::impl::_force_tls_local_dispatch_key_set(keyset);
+  c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
+  c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
+
+  // Re-dispatch
+  op.callBoxed(stack);
+
+  // Step 4, 5, 6
+  if (cur_key == DispatchKey::Autograd) {
+    // Step 4
+    auto ret_size = op.schema().returns().size();
+    foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap);
+
+    // Step 5
+    auto args_size = op.schema().arguments().size();
+    auto args_front = stack->size() - args_size - ret_size;
+    for (int64_t arg_idx = 0; arg_idx < args_size; arg_idx++) {
+      auto& ivalue = (*stack)[args_front + arg_idx];
+      if (!ivalue.isTensor()) {
+        continue; 
+      }
+      auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
+      if (!maybe_tensor_wrapper) {
+        continue;
+      }
+      maybe_tensor_wrapper->refreshSizesAndStrides();
+    }
+
+    // Step 6
+    stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size);
+  }
+}
+
+TORCH_LIBRARY_IMPL(_, DynamicLayerFront, m) {
+  m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
+}
+
+TORCH_LIBRARY_IMPL(_, DynamicLayerBack, m) {
+  m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
+}
+
+// TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
+//   m.impl("_unwrap_for_grad", native::_unwrap_for_grad);
+//   m.impl("dump_tensor", native::dump_tensor);
+//   m.impl("dlevel", native::dlevel);
+// }
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/DynamicLayer.h b/functorch/functorch/csrc/DynamicLayer.h
new file mode 100644
index 0000000..037494f
--- /dev/null
+++ b/functorch/functorch/csrc/DynamicLayer.h
@@ -0,0 +1,33 @@
+#pragma once
+#include <c10/core/DispatchKey.h>
+#include <c10/util/Optional.h>
+#include <unordered_map>
+#include <mutex>
+
+// Forward declared bc I am lazy
+namespace c10 { struct AutogradMetaInterface; }
+
+namespace at {
+namespace functorch {
+
+struct TORCH_API DynamicLayer {
+  DynamicLayer(DispatchKey key, int64_t layerId): key_(key), layerId_(layerId) {}
+
+  DispatchKey key() const { return key_; }
+  int64_t layerId() const { return layerId_; }
+ private:
+  DispatchKey key_;
+  int64_t layerId_;
+};
+
+TORCH_API int64_t initAndPushDynamicLayer(DispatchKey key);
+TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
+TORCH_API c10::optional<DynamicLayer> maybeCurrentDynamicLayer();
+TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
+TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
+
+// NB: not lock safe. TODO: does it need a lock?
+TORCH_API std::shared_ptr<bool> getLifeHandleForLevel(int64_t level);
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/TensorWrapper.cpp b/functorch/functorch/csrc/TensorWrapper.cpp
new file mode 100644
index 0000000..984a4aa
--- /dev/null
+++ b/functorch/functorch/csrc/TensorWrapper.cpp
@@ -0,0 +1,232 @@
+#include <functorch/csrc/TensorWrapper.h>
+#include <functorch/csrc/DynamicLayer.h>
+#include <functorch/csrc/BatchedTensorImpl.h>
+
+#include <torch/library.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+
+namespace at {
+namespace functorch {
+
+void dumpTensor(std::ostream& ss, const Tensor& tensor) {
+  auto* wrapped = maybeGetTensorWrapper(tensor);
+  if (!wrapped) {
+    auto* batched = maybeGetBatchedImpl(tensor);
+    if (batched) {
+      ss << "Batched[" << batched->bdims() << ", ";
+      dumpTensor(ss, batched->value());
+      ss << "]";
+      return;
+    }
+    ss << "Tensor" << tensor.sizes();
+    return;
+  }
+  if (wrapped->is_alive()) {
+    ss << "Wrapper[";
+  } else {
+    ss << "Wrapper[";
+  }
+  if (wrapped->level().has_value()) {
+    ss << wrapped->level().value() << ", ";
+  } else {
+    ss << "dead, ";
+  }
+  dumpTensor(ss, wrapped->value());
+  ss << "]";
+}
+
+void TensorWrapper::refreshSizesAndStrides() {
+  auto dim = value_.dim();
+  auto sizes = value_.sizes();
+  auto strides = value_.strides();
+  sizes_and_strides_.resize(value_.dim());
+  for (int64_t i = 0; i < dim; i++) {
+    sizes_and_strides_.size_at_unchecked(i) = sizes[i];
+    sizes_and_strides_.stride_at_unchecked(i) = strides[i];
+  }
+
+  refresh_numel();
+  refresh_contiguous();
+}
+
+void dumpTensorCout(const Tensor& tensor) {
+  dumpTensor(std::cout, tensor);
+  std::cout << std::endl;
+}
+
+c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, bool should_be_alive) {
+  // TODO: denylist non-cuda/cpu backends to avoid funny business
+  DispatchKeySet key_set;
+  if (tensor.is_cuda()) {
+    key_set = key_set.add(DispatchKey::CUDA);
+    key_set = key_set.add(DispatchKey::AutogradCUDA);
+  } else {
+    key_set = key_set.add(DispatchKey::CPU);
+    key_set = key_set.add(DispatchKey::AutogradCPU);
+  }
+  key_set = key_set.add(DispatchKey::TensorWrapper);
+  if (should_be_alive) {
+    return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel(level));
+  } else {
+    return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, std::make_shared<bool>(false));
+  }
+}
+
+Tensor makeTensorWrapper(const Tensor& tensor, int64_t level) {
+  auto wrapped = maybeGetTensorWrapper(tensor);
+  if (wrapped) {
+    TORCH_INTERNAL_ASSERT(wrapped->level() < level);
+  }
+
+  // TODO: denylist non-cuda/cpu backends to avoid funny business
+  DispatchKeySet key_set;
+  if (tensor.is_cuda()) {
+    key_set = key_set.add(DispatchKey::CUDA);
+    key_set = key_set.add(DispatchKey::AutogradCUDA);
+  } else {
+    key_set = key_set.add(DispatchKey::CPU);
+    key_set = key_set.add(DispatchKey::AutogradCPU);
+  }
+  key_set = key_set.add(DispatchKey::TensorWrapper);
+  auto life_handle = getLifeHandleForLevel(level);
+  auto result = at::detail::make_tensor<TensorWrapper>(key_set, tensor, level, std::move(life_handle));
+  TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::TensorWrapper));
+  return result;
+}
+
+bool TensorWrapper::is_alive() const {
+  return *is_alive_;
+}
+
+c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
+    const c10::VariableVersion& version_counter,
+    bool allow_tensor_metadata_change) const {
+  auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
+  dest_impl->set_version_counter(version_counter);
+
+  // TODO: is this even right?
+  dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+  return dest_impl;
+}
+
+c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
+    c10::VariableVersion&& version_counter,
+    bool allow_tensor_metadata_change) const {
+  auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
+  dest_impl->set_version_counter(version_counter);
+
+  // TODO: is this even right?
+  dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+  return dest_impl;
+}
+
+void TensorWrapper::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
+  TORCH_INTERNAL_ASSERT(false, "NYI");
+}
+
+TensorWrapper::TensorWrapper(
+    c10::DispatchKeySet key_set,
+    Tensor value,
+    int64_t level,
+    std::shared_ptr<bool> is_alive,
+    bool use_value_sizes_strides)
+  : TensorImpl(key_set, value.dtype(), value.device())
+  , value_(std::move(value))
+  , level_(level)
+  , is_alive_(std::move(is_alive))
+{
+  TORCH_INTERNAL_ASSERT(value_.defined());
+  set_storage_access_should_throw();
+
+  // TODO: need to reset sizes/strides on mutation
+  TORCH_INTERNAL_ASSERT(use_value_sizes_strides);
+  refreshSizesAndStrides();
+}
+
+// The following are some internal inherited methods that we do not support.
+// They should never get called.
+void TensorWrapper::set_size(int64_t dim, int64_t new_size) {
+  TORCH_INTERNAL_ASSERT(false, "Can't set_size for TensorWrapper");
+}
+void TensorWrapper::set_stride(int64_t dim, int64_t new_stride) {
+  TORCH_INTERNAL_ASSERT(false, "Can't set_stride for TensorWrapper");
+}
+void TensorWrapper::set_storage_offset(int64_t storage_offset) {
+  TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for TensorWrapper");
+}
+
+const char* TensorWrapper::tensorimpl_type_name() const {
+  return "TensorWrapper";
+}
+
+
+TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor) {
+  if (!tensor.key_set().has(DispatchKey::TensorWrapper)) {
+    return nullptr;
+  }
+  return (TensorWrapper*)(tensor.unsafeGetTensorImpl());
+}
+
+static void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
+    std::function<Tensor(const Tensor&)> func) {
+  TORCH_INTERNAL_ASSERT(begin >= 0);
+  TORCH_INTERNAL_ASSERT(end >= 0);
+  TORCH_INTERNAL_ASSERT(begin <= end);
+  for (int64_t idx = begin; idx < end; idx++) {
+    auto ivalue = args[idx];
+    if (ivalue.isTensorList()) {
+      TORCH_INTERNAL_ASSERT(false, "NYI: TensorList");
+    }
+    if (!ivalue.isTensor()) {
+      continue;
+    }
+    Tensor value = ivalue.toTensor();
+    Tensor replacement = func(value);
+    args[idx] = replacement; // TODO: std::move?
+    if (ivalue.toTensor().defined()) {
+      TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
+    }
+  }
+}
+
+static Tensor unwrapIfDead(const Tensor& tensor) {
+  auto* wrapped = maybeGetTensorWrapper(tensor);
+  if (!wrapped) {
+    return tensor;
+  }
+  if (wrapped->is_alive()) {
+    return tensor;
+  }
+  return wrapped->value();
+}
+
+void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  auto args_size = op.schema().arguments().size();
+  int64_t unwrapped_count = 0;
+  auto unwrapIfDeadAndIncrement = [&](const Tensor& tensor) {
+    auto* wrapped = maybeGetTensorWrapper(tensor);
+    if (!wrapped) {
+      return tensor;
+    }
+    if (wrapped->is_alive()) {
+      return tensor;
+    }
+    unwrapped_count++;
+    return wrapped->value();
+  };
+
+  foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrapIfDeadAndIncrement);
+  TORCH_INTERNAL_ASSERT(unwrapped_count > 0, "Should have at least one dead wrapper");
+
+  // re-dispatch
+  op.callBoxed(stack);
+}
+
+// TensorWrapper backend fallback: Unwrap and fallthrough.
+
+TORCH_LIBRARY_IMPL(_, TensorWrapper, m) {
+  m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>());
+}
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/TensorWrapper.h b/functorch/functorch/csrc/TensorWrapper.h
new file mode 100644
index 0000000..e4533e5
--- /dev/null
+++ b/functorch/functorch/csrc/TensorWrapper.h
@@ -0,0 +1,61 @@
+#pragma once
+
+#include <ATen/Tensor.h>
+
+namespace at {
+namespace functorch {
+
+struct TORCH_API TensorWrapper : public c10::TensorImpl {
+  explicit TensorWrapper(
+      c10::DispatchKeySet key_set,
+      Tensor value,
+      int64_t level,
+      std::shared_ptr<bool> is_alive,
+      bool use_value_sizes_strides = true);
+
+  // Override a bunch of methods inherited from TensorImpl to return error messages
+  void set_size(int64_t dim, int64_t new_size) override;
+  void set_stride(int64_t dim, int64_t new_stride) override;
+  void set_storage_offset(int64_t storage_offset) override;
+
+  void refreshSizesAndStrides();
+
+  const Tensor& value() const {
+    return value_;
+  }
+  optional<int64_t> level() const {
+    if (is_alive()) {
+      return level_;
+    }
+    return {};
+  }
+  bool is_alive() const;
+
+  // Overrides necessary for autograd
+  c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
+    const c10::VariableVersion& version_counter,
+    bool allow_tensor_metadata_change) const override;
+  c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
+      c10::VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) const override;
+  void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
+
+ private:
+  const char* tensorimpl_type_name() const override;
+  Tensor value_;
+  int64_t level_;
+
+  // When we exit the level, this wrapper may be marked as "not alive".
+  // Wrappers that are not alive:
+  // 1) May still have autograd metadata on them
+  // 2) Forward dispatches to the underlying value()
+  std::shared_ptr<bool> is_alive_;
+};
+
+TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level);
+TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
+TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
+TORCH_API void dumpTensorCout(const Tensor& tensor);
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/VmapMode.cpp b/functorch/functorch/csrc/VmapMode.cpp
new file mode 100644
index 0000000..8d61b01
--- /dev/null
+++ b/functorch/functorch/csrc/VmapMode.cpp
@@ -0,0 +1,60 @@
+#include <functorch/csrc/DynamicLayer.h>
+#include <functorch/csrc/VmapMode.h>
+#include <functorch/csrc/Constants.h>
+
+namespace at {
+namespace functorch {
+namespace impl {
+
+/// thread_local is a feature that is not enabled by Caffe2 mobile
+/// build (e.g. iOS). Therefore, we only provide `at::VmapMode`
+/// when we are not in mobile build or when FEATURE_TORCH_MOBILE
+/// is on.
+#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
+
+thread_local int64_t VmapMode_current_vmap_level = 0;
+
+int64_t VmapMode::current_vmap_level() {
+  return VmapMode_current_vmap_level;
+}
+
+int64_t VmapMode::increment_nesting() {
+  VmapMode_current_vmap_level++;
+
+  auto level = initAndPushDynamicLayer(kBatchedKey);
+  if (VmapMode_current_vmap_level == 1) {
+    c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true);
+  }
+  return level;
+}
+
+int64_t VmapMode::decrement_nesting() {
+  VmapMode_current_vmap_level--;
+  auto layer = popDynamicLayerAndDeleteMetadata();
+  TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey);
+  if (VmapMode_current_vmap_level == 0) {
+    c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false);
+  }
+  // TODO: this return value should never be used
+  return VmapMode_current_vmap_level;
+}
+
+#else
+
+int64_t VmapMode::current_nesting_level() {
+  TORCH_CHECK(false, "VmapMode is not supported on mobile");
+}
+
+int64_t VmapMode::increment_nesting() {
+  TORCH_CHECK(false, "VmapMode is not supported on mobile");
+}
+
+int64_t VmapMode::decrement_nesting() {
+  TORCH_CHECK(false, "VmapMode is not supported on mobile");
+}
+
+#endif
+
+} // namespace impl
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/VmapMode.h b/functorch/functorch/csrc/VmapMode.h
new file mode 100644
index 0000000..c1f86d3
--- /dev/null
+++ b/functorch/functorch/csrc/VmapMode.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <c10/core/impl/LocalDispatchKeySet.h>
+
+namespace at {
+namespace functorch {
+namespace impl {
+
+// VmapMode contains a thread local count of how many nested vmaps
+// we are currently inside. That number is known as the `vmap level`.
+// VmapMode is used in the implementation of the Python `torch.vmap` API.
+//
+// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
+
+struct TORCH_API VmapMode {
+  // Returns the vmap level, aka the count of how many nested vmaps we're in.
+  static int64_t current_vmap_level();
+
+  // Increment the count of nested vmaps. If this causes the vmap level to be
+  // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
+  static int64_t increment_nesting();
+
+  // Decrements the count of nested vmaps. If this causes the vmap level to be
+  // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
+  static int64_t decrement_nesting();
+};
+
+} // namespace impl
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/VmapTransforms.cpp b/functorch/functorch/csrc/VmapTransforms.cpp
new file mode 100644
index 0000000..40831e9
--- /dev/null
+++ b/functorch/functorch/csrc/VmapTransforms.cpp
@@ -0,0 +1,323 @@
+#include <functorch/csrc/VmapTransforms.h>
+#include <functorch/csrc/DynamicLayer.h>
+
+#include <ATen/ATen.h>
+
+namespace at {
+namespace functorch {
+
+// Checks if the batch dims in `bdims` appear at the front of the tensor.
+static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
+  for (int64_t idx = 0; idx < bdims.size(); idx++) {
+    if (bdims[idx].dim() != idx) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
+// and then returns a physical version of the Tensor.
+static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
+  auto bdims = batched->bdims();
+  const Tensor& physical_tensor = batched->value();
+  if (areBdimsAtFrontInOrder(bdims)) {
+    return physical_tensor;
+  }
+  const auto sizes = physical_tensor.sizes();
+  VmapDimVector permutation(sizes.size(), 0);
+  permutation.reserve(sizes.size());
+  const auto is_bdim = createBatchDimBitset(bdims);
+  int64_t idx = 0;
+  for (const auto& bdim : bdims) {
+    permutation[idx++] = bdim.dim();
+  }
+  for (int64_t ptr = 0; idx < sizes.size(); ptr++) {
+    if (is_bdim[ptr]) {
+      continue;
+    }
+    permutation[idx++] = ptr;
+  }
+  return physical_tensor.permute(permutation);
+}
+
+VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
+  auto* batched = maybeGetBatchedImpl(logical_tensor);
+  TORCH_INTERNAL_ASSERT(
+      batched,
+      "logicalToPhysical(tensor) should only be passed a BatchedTensor");
+  return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
+}
+
+int64_t VmapPhysicalView::numBatchDims() const {
+  return levels_.count();
+}
+
+int64_t VmapPhysicalView::numLogicalDims() const {
+  return /*physical*/tensor_.dim() - numBatchDims();
+}
+
+VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
+  auto logical_ndim = numLogicalDims();
+  // NB: fmap doesn't have a SmallVector variant, so we don't use it here.
+  VmapDimVector result;
+  result.reserve(logical_ndim);
+  for (auto dim : logical_dims) {
+    result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
+  }
+  return result;
+}
+
+int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
+  auto logical_ndim = numLogicalDims();
+  return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
+}
+
+VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
+  VmapDimVector result;
+  result.reserve(logical_shape.size() + numBatchDims());
+  auto tensor_sizes = tensor_.sizes();
+  result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
+  result.insert(result.end(), logical_shape.begin(), logical_shape.end());
+  return result;
+}
+
+static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
+  BatchDims bdims;
+  int64_t dim = 0;
+  for (int64_t level = 0; level < kVmapNumLevels; level++) {
+    if (!levels_bitset[level]) {
+      continue;
+    }
+    bdims.emplace_back(level, dim++);
+  }
+  return bdims;
+}
+
+// Given a Tensor or a BatchedTensor, returns the underlying physical tensor
+// with all vmapped dimensions permuted to the front, if they exist, and a
+// bitset of vmap levels that were present in the tensor.
+static std::pair<Tensor,std::bitset<kVmapNumLevels>>
+getPhysicalTensorAndLevels(const Tensor& self) {
+  auto* batched = maybeGetBatchedImpl(self);
+  if (batched) {
+    return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
+  }
+  return {self, 0};
+}
+
+// Given a Tensor or a BatchedTensor, creates a physical view of the tensor
+// such that it has a batch dimension for each level in `requested_levels`
+// and `requested_example_dim` number of non-batch-dimensions.
+//
+// This function is useful in preparing physical views on tensors that can
+// then be passed into broadcasting operations. For example, when adding
+// two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
+// batch dimensions, we must align the batch dimensions and non-batch-dimensions
+// (henceforth referred to as the "example" dimensions) separately to produce
+// tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
+//
+// Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
+//
+// 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
+// returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
+// level 1 and another extra dimension to pad the example dimensions to 2.
+//
+// 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
+// returns a physical view of size [B0, B1, 2, 3]
+static Tensor alignBatchDimsAtFront(
+    const Tensor& self,
+    std::bitset<kVmapNumLevels> requested_levels,
+    int64_t requested_example_dim) {
+  Tensor physical_tensor;
+  std::bitset<kVmapNumLevels> tensor_levels;
+  std::tie(physical_tensor, tensor_levels) = getPhysicalTensorAndLevels(self);
+
+  TORCH_INTERNAL_ASSERT(
+    (tensor_levels | requested_levels) == requested_levels,
+    "`requested_levels` must be a superset of `self`'s levels");
+
+  auto physical_sizes = physical_tensor.sizes();
+
+  auto tensor_example_dim = physical_sizes.size() - /*num_batch_dims*/tensor_levels.count();
+  TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
+
+  if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
+    // Optimization: no need to do another view if the physical tensor is
+    // already the correct shape
+    return physical_tensor;
+  }
+
+  VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
+
+  // align the example dims (non-bdims dims) first
+  // aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
+  std::copy(
+      physical_sizes.rbegin(),
+      physical_sizes.rbegin() + tensor_example_dim,
+      aligned_sizes.rbegin());
+
+  // align the bdims
+  int64_t level = 0;
+  int64_t tensor_dim = 0;
+  for (int64_t bdim = 0; bdim < requested_levels.count(); bdim++) {
+    // Determine the level of the bdim
+    while (!requested_levels[level]) level++;
+    if (tensor_levels[level]) {
+      aligned_sizes[bdim] = physical_sizes[tensor_dim++];
+    }
+    level++;
+  }
+  return physical_tensor.view(aligned_sizes);
+}
+
+static Tensor moveDimToFrontAndExpand(Tensor tensor, optional<int64_t> dim, int64_t size) {
+  if (dim) {
+    tensor = tensor.movedim(*dim, 0);
+  } else {
+    tensor = tensor.unsqueeze(0);
+    auto expanded_sizes = tensor.sizes().vec();
+    expanded_sizes[0] = size;
+    tensor = tensor.expand(expanded_sizes);
+  }
+  return tensor;
+}
+
+// The algorithm is as follows:
+// 1. Figure out what all of the collective levels in `logical_tensors` is.
+// 2. Move all batch dims to the front of the tensors and add extra dims
+//    of size 1. At this point, every tensor will have a dimension for
+//    each of the collective levels.
+// 3. Compute the batch_sizes.
+// 4. Expand each physical tensor so that they have output batch size equal
+//    to `batch_sizes`
+VmapPhysicalViewVec
+MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) {
+  auto cur_level = maybeCurrentDynamicLayer().value().layerId();
+  auto bdim_size = -1;
+
+  // Figure out the batch size first
+  for (const auto& logical_tensor : logical_tensors) {
+    auto* batched = maybeGetBatchedImpl(logical_tensor);
+    if (!batched) {
+      continue;
+    }
+    TORCH_INTERNAL_ASSERT(batched->bdims().size() == 1);
+    if (batched->bdims().back().level() != cur_level) {
+      continue;
+    }
+    bdim_size = batched->value().size(batched->bdims().back().dim());
+  }
+  TORCH_INTERNAL_ASSERT(bdim_size != -1);
+
+  std::bitset<kVmapNumLevels> levels;
+  levels[cur_level] = 1;
+
+  VmapPhysicalViewVec result;
+  for (const auto& logical_tensor : logical_tensors) {
+    auto* batched = maybeGetBatchedImpl(logical_tensor);
+    if (!batched || (batched->bdims().back().level() != cur_level)) {
+      // Unsqueeze dim 0, expand it to the correct shape
+      c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+      auto value = moveDimToFrontAndExpand(logical_tensor, {}, bdim_size);
+      result.emplace_back(std::move(value), levels);
+      continue;
+    }
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    auto physical = batched->value();
+    auto value = moveDimToFrontAndExpand(physical, batched->bdims().back().dim(), bdim_size);
+    result.emplace_back(std::move(value), levels);
+  }
+
+  return result;
+}
+
+static std::pair<std::bitset<kVmapNumLevels>,int64_t>
+getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
+  TORCH_INTERNAL_ASSERT(logical_tensors.size() > 0);
+  std::bitset<kVmapNumLevels> levels;
+  int64_t largest_logical_dim = -1;
+  for (const auto& tensor : logical_tensors) {
+    auto* batched = maybeGetBatchedImpl(tensor);
+    if (batched) {
+      levels = levels | createVmapLevelsBitset(batched->bdims());
+    }
+    auto tensor_logical_dim = /*logical dim*/tensor.dim();
+    if (tensor_logical_dim > largest_logical_dim) {
+      largest_logical_dim = tensor_logical_dim;
+    }
+  }
+  return { levels, largest_logical_dim };
+}
+
+static Tensor moveDimToFrontAndUnsqueeze(Tensor tensor, optional<int64_t> dim, int64_t example_ndim) {
+  if (dim) {
+    tensor = tensor.movedim(*dim, 0);
+  } else {
+    tensor = tensor.unsqueeze(0);
+  }
+  auto ndim = tensor.dim() - 1;
+  for (int64_t i = 0; i < example_ndim - ndim; i++) {
+    tensor = tensor.unsqueeze(1);
+  }
+  return tensor;
+}
+
+VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
+  auto cur_level = maybeCurrentDynamicLayer().value().layerId();
+  auto bdim_size = -1;
+
+  // Figure out the batch size first
+  for (const auto& logical_tensor : logical_tensors) {
+    auto* batched = maybeGetBatchedImpl(logical_tensor);
+    if (!batched || (batched->bdims().back().level() != cur_level)) {
+      continue;
+    }
+    bdim_size = batched->value().size(batched->bdims().back().dim());
+  }
+  TORCH_INTERNAL_ASSERT(bdim_size != -1);
+
+  std::bitset<kVmapNumLevels> levels;
+  levels[cur_level] = 1;
+
+  // figure out the example ndim
+  int64_t max_example_dim = -1;
+  for (const auto& logical_tensor : logical_tensors) {
+    max_example_dim = std::max(logical_tensor.dim(), max_example_dim);
+  }
+
+  VmapPhysicalViewVec result;
+  for (const auto& logical_tensor : logical_tensors) {
+    auto* batched = maybeGetBatchedImpl(logical_tensor);
+    if (!batched || (batched->bdims().back().level() != cur_level)) {
+      // Unsqueeze dim 0, expand it to the correct shape
+      c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+      auto value = moveDimToFrontAndUnsqueeze(logical_tensor, {}, max_example_dim);
+      result.emplace_back(std::move(value), levels);
+      continue;
+    }
+    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+    auto physical = batched->value();
+    auto value = moveDimToFrontAndUnsqueeze(physical, batched->bdims().back().dim(), max_example_dim);
+    result.emplace_back(std::move(value), levels);
+  }
+
+  return result;
+}
+
+VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
+  return VmapPhysicalToLogicalMap(levels_);
+}
+
+Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
+  return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
+}
+
+void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
+  for (int64_t idx = 0; idx < physical_tensors.size(); ++idx) {
+    physical_tensors[idx] = apply(physical_tensors[idx]);
+  }
+}
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/VmapTransforms.h b/functorch/functorch/csrc/VmapTransforms.h
new file mode 100644
index 0000000..9f7dfac
--- /dev/null
+++ b/functorch/functorch/csrc/VmapTransforms.h
@@ -0,0 +1,177 @@
+#pragma once
+
+#include <functorch/csrc/BatchedTensorImpl.h>
+
+namespace at {
+namespace functorch {
+
+// This file contains abstractions used for transforming *logical* vmap arguments
+// into *physical* arguments. (Keep reading for definitions of these terms).
+
+// NOTE: [Logical vs physical args]
+// Consider the following vmap.
+//   vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
+// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
+// with batch dims 0 and 2:
+//   BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
+//
+// We say the *logical* view of the tensor has size [3] -- tensors inside
+// `func` appear to have size [3].
+// However, the *physical* underlying tensor (the one passed to vmap) has size
+// [2, 3, 4].
+//
+// This notion of logical vs physical also extends to non-tensor arguments.
+// Consider the previous tensor; let's assume the user called
+// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
+// dimension they are reducing over is dim 0 but the physical dim is dim 1
+// (the first non-batch dimension)
+
+// Forward declared; see NOTE: [What is a VmapPhysicalView?]
+struct VmapPhysicalView;
+
+// Most PyTorch operators take 4 or fewer inputs.
+constexpr int64_t kVmapTransformStaticInputSize = 4;
+using VmapPhysicalViewVec = SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
+
+// Pytorch generally advertises good performance for <= 5 dims.
+// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
+// dimensions to get 8. Adjust this number as necessary
+constexpr int64_t kVmapStaticDimVecSize = 8;
+using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
+
+// NOTE: [What is an VmapTransform?]
+// An *VmapTransform* converts logical views of tensors to physical views.
+//
+// Batching rules use VmapTransforms to convert logical arguments to
+// physical arguments, then call one or more at:: operator that handles the
+// physical arguments, and then converts the physical result back to a logical
+// argument.
+
+// VmapTransform for operators that take tensors with multiple batch dims.
+// Given one or more logical views on Tensors, `logicalToPhysical` 
+// permutes all of the batch dims to the front of the tensor, aligns
+// and expands the batch dims to match each other (according to their `level`),
+// and returns a VmapPhysicalView on the tensor(s).
+struct TORCH_API MultiBatchVmapTransform {
+  static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
+  static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
+};
+
+// VmapTransform for operators that broadcast all inputs.
+// Given some logical views on Tensors, `logicalToPhysical`:
+// - permutes all of the batch dims to the front of the tensors
+// - aligns all the batch dims to the collective levels of all of the tensors.
+//   If a tensor does not have a batch dim for a vmap level, then it receives
+//   a size-one dimension for said level.
+// - aligns the non-batch dims to have the same dimensionality, adding extra
+//   size-1 dimensions in between the batch dimensions and the non-batch dimensions
+//   so that the batch dimensions are lined up from the right.
+//
+// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
+// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
+// of size (B, 1, 2) and (B, 3, 2).
+//
+// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
+// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
+// actually *need* to return a tensor of size (1, 2) for the second tensor
+// because the broadcasting operation takes care of that for us, but we do
+// it anyways to keep things simple.
+struct TORCH_API BroadcastingVmapTransform {
+  static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
+};
+
+// Forward declared, if you're reading this file head to toe, don't worry about
+// it yet.
+struct VmapPhysicalToLogicalMap;
+
+// NOTE: [What is a VmapPhysicalView?]
+// VmapPhysicalView represents a physical view on a Tensor.
+//
+// One can use it to further convert logical dimension indices, logical shapes,
+// and more to their physical variants, or convert a new (physical) tensor into
+// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
+//
+// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
+// the front and some levels that correspond to said batch dimensions.
+//
+// The levels bitset specifies which vmap levels correspond to the batch
+// dimensions at the front of the tensor. In particular, the number of set bits
+// corresponds to the number of batch dimensions on `tensor` and the rightmost
+// bit of `levels` specifies the maximum number of nested vmaps we are in at
+// this point in time.
+// For example, given:
+//   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
+//
+// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
+// than or equal to 3.
+//   bitset: 010100
+//              ^
+//              |
+//   levels: 012345
+struct TORCH_API VmapPhysicalView {
+  VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
+      : levels_(levels), tensor_(tensor) {
+    // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
+  }
+
+  Tensor& tensor() { return tensor_; }
+  const Tensor& tensor() const { return tensor_; }
+
+  // Maps logical dim indices to physical dim indices. Also does dim wrapping.
+  //
+  // For example, given:
+  //   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
+  //
+  // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
+  // This is because the size of levels tell us that the first two dimensions
+  // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
+  // a physical dim of `n + 2`.
+  VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
+  int64_t getPhysicalDim(int64_t logical_dim) const;
+
+  // Returns a VmapPhysicalToLogicalMap object. This can be used for
+  // mapping a physical tensor to a new logical tensor (BatchedTensor)
+  VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
+
+  // Maps a logical shape to a physical shape by pre-pending the batch
+  // sizes to the logical shape.
+  VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
+
+  int64_t numBatchDims() const;
+
+ private:
+  int64_t numLogicalDims() const;
+
+  std::bitset<kVmapNumLevels> levels_;
+  Tensor tensor_;
+};
+
+// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
+// to a logical one (BatchedTensor). It holds some levels that are used to do the
+// mapping and assumes that the batch dimensions in the physical tensor all
+// occur at the front of the tensor.
+struct TORCH_API VmapPhysicalToLogicalMap {
+  VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {}
+
+  // Maps a physical tensor to a new logical tensor (BatchedTensor).
+  // Assumes that all of the "batch dimensions" are at the front
+  // of the physical tensor. For example, given:
+  // - x = rank-4 Tensor with size 2, 3, 5, 7
+  // - levels = (2, 4)
+  // Returns:
+  // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
+  Tensor apply(const Tensor& physical_tensor) const;
+
+  // Given a vector of physical tensors,
+  // 1. maps each tensor to a new logical tensor. Assumes that all of the
+  //    "batch dimensions" are at the front of the physical tensors.
+  // 2. stores the new logical tensors back into the passed-in vector. This is
+  //    to avoid additional dynamic allocations.
+  void applyInplace(std::vector<Tensor>& physical_tensors) const;
+
+  std::bitset<kVmapNumLevels> levels_;
+};
+
+
+}
+} // namespace at
diff --git a/functorch/functorch/csrc/init.cpp b/functorch/functorch/csrc/init.cpp
new file mode 100644
index 0000000..38b7529
--- /dev/null
+++ b/functorch/functorch/csrc/init.cpp
@@ -0,0 +1,176 @@
+#include <torch/extension.h>
+#include <ATen/WrapDimUtils.h>
+
+#include <functorch/csrc/TensorWrapper.h>
+#include <functorch/csrc/DynamicLayer.h>
+#include <functorch/csrc/BatchedTensorImpl.h>
+#include <functorch/csrc/VmapTransforms.h>
+#include <functorch/csrc/VmapMode.h>
+
+namespace at {
+namespace functorch {
+
+static bool has_level(const Tensor& self, int64_t level) {
+  const auto* batched = maybeGetBatchedImpl(self);
+  if (!batched) {
+    return false;
+  }
+  auto bdims = batched->bdims();
+  auto* it = std::find_if(bdims.begin(), bdims.end(), [&](const BatchDim& bdim) {
+    return bdim.level() == level;
+  });
+  return it != bdims.end();
+}
+
+Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) {
+  return addBatchDim(self, level, batch_dim);
+}
+
+static std::pair<Tensor,int64_t> remove_existing_batch_dim(
+    const BatchedTensorImpl* batched, int64_t level) {
+  auto bdims = batched->bdims();
+  if (bdims.size() == 1) {
+    TORCH_INTERNAL_ASSERT(bdims[0].level() == level);
+    return std::make_pair(batched->value(), bdims[0].dim());
+  }
+  BatchDims new_bdims;
+  int64_t newly_exposed_physical_dim = -1;
+  new_bdims.reserve(bdims.size() - 1);
+  for (const auto& bdim : bdims) {
+    if (bdim.level() == level) {
+      newly_exposed_physical_dim = bdim.dim();
+    } else {
+      new_bdims.push_back(bdim);
+    }
+  }
+  // Because a BatchDim with level `level` must exist inside `batched,
+  // we should have found a `newly_exposed_logical_dim`.
+  TORCH_INTERNAL_ASSERT(newly_exposed_physical_dim != -1);
+  int64_t num_batch_dims_before_newly_exposed_physical_dim = std::count_if(
+      new_bdims.begin(), new_bdims.end(),
+      [&](const BatchDim& bdim) {
+        return bdim.dim() < newly_exposed_physical_dim;
+      });
+  int64_t newly_exposed_logical_dim =
+      newly_exposed_physical_dim - num_batch_dims_before_newly_exposed_physical_dim;
+  auto result_tensor = makeBatched(batched->value(), std::move(new_bdims));
+  return std::make_pair(std::move(result_tensor), newly_exposed_logical_dim);
+}
+
+// Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
+// while preserving the order of other existing dimensions.
+// We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
+// When we do, replace the following with it.
+static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) {
+  auto logical_dim = self.dim();
+  src = maybe_wrap_dim(src, logical_dim);
+  dst = maybe_wrap_dim(dst, logical_dim);
+  if (src == dst) {
+    return self;
+  }
+  VmapDimVector permutation;
+  permutation.reserve(logical_dim);
+  for (int64_t dim = 0; dim < logical_dim; dim++) {
+    if (dim == src) {
+      continue;
+    }
+    permutation.push_back(dim);
+  }
+  permutation.insert(permutation.begin() + dst, src);
+  return self.permute(permutation);
+}
+
+// Removes the batch dim with level `level` from `self`. If this causes the
+// last batch dim to be removed from a BatchedTensor, then this returns a
+// regular Tensor.
+//
+// If the `level` of the batch dim to remove does not exist in `self`, then we
+// add the batch dim in. This can happen if `self` didn't interact with a tensor
+// inside the vmap level, for example,
+//     self = torch.randn(3)
+//     y = torch.randn(5)
+//     out = vmap(lambda x: vmap(lambda y: x)(y))(self)
+//     assert out.shape == (3, 5)
+// Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension
+// corresponding to the *outer* vmap level and it doesn't have any dimensions that
+// correspond to the inner vmap level so we need to create one for the user.
+//
+// `out_dim` controls where we should put the batch dimension in the output tensor.
+Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) {
+  if (!has_level(self, level)) {
+    auto self_sizes = self.sizes();
+    VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
+    expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size);
+    return self.expand(expanded_sizes);
+  }
+
+  // Must be batched if has_level(self, /*any_level*/)
+  const auto* batched = maybeGetBatchedImpl(self);
+  TORCH_INTERNAL_ASSERT(batched != nullptr);
+
+  Tensor self_without_bdim;
+  int64_t newly_exposed_logical_dim;
+  std::tie(self_without_bdim, newly_exposed_logical_dim) = remove_existing_batch_dim(batched, level);
+  return _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
+}
+
+Tensor _wrap_for_grad(const Tensor& self, int64_t level) {
+  // NB: different behavior inside??
+  // return self;
+  // TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self));
+  // TORCH_INTERNAL_ASSERT(self.has_storage());
+  return makeTensorWrapper(self, level);
+}
+
+Tensor _unwrap_for_grad(const Tensor& self, int64_t level) {
+  auto* result = maybeGetTensorWrapper(self);
+  if (!result) {
+    return self;
+  }
+  TORCH_INTERNAL_ASSERT(result->level().has_value());
+  if (result->level() == level) {
+    return result->value();
+  }
+  return self;
+}
+
+int64_t dlevel(const Tensor& tensor) {
+  auto* wrapped = maybeGetTensorWrapper(tensor);
+  if (!wrapped) {
+    return 0;
+  }
+  if (!wrapped->is_alive()) {
+    return -1;
+  }
+  return wrapped->level().value();
+}
+
+bool dump_tensor(const Tensor& self) {
+  dumpTensorCout(self);
+  return true;
+}
+
+int64_t _grad_increment_nesting() {
+  return initAndPushDynamicLayer(at::DispatchKey::Autograd);
+}
+
+int64_t _grad_decrement_nesting() {
+  return popDynamicLayerAndDeleteMetadata().layerId();
+}
+
+
+} // namespace functorch
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim");
+  m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim");
+  m.def("_vmapmode_increment_nesting", &at::functorch::impl::VmapMode::increment_nesting, "add batch dim");
+  m.def("_vmapmode_decrement_nesting", &at::functorch::impl::VmapMode::decrement_nesting, "remove batch dim");
+  m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim");
+  m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim");
+  m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim");
+  m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "add batch dim");
+  m.def("dlevel", &at::functorch::dlevel, "add batch dim");
+  m.def("dump_tensor", &at::functorch::dump_tensor, "add batch dim");
+}
diff --git a/functorch/setup.py b/functorch/setup.py
new file mode 100644
index 0000000..e2929b9
--- /dev/null
+++ b/functorch/setup.py
@@ -0,0 +1,71 @@
+import distutils
+import shutil
+import glob
+import os
+from setuptools import setup, find_packages
+from torch.utils.cpp_extension import (
+    CppExtension,
+    BuildExtension,
+)
+
+
+# class clean(distutils.command.clean.clean):
+#     def run(self):
+#         with open(".gitignore", "r") as f:
+#             ignores = f.read()
+#             for wildcard in filter(None, ignores.split("\n")):
+#                 for filename in glob.glob(wildcard):
+#                     try:
+#                         os.remove(filename)
+#                     except OSError:
+#                         shutil.rmtree(filename, ignore_errors=True)
+# 
+#         # It's an old-style class in Python 2.7...
+#         distutils.command.clean.clean.run(self)
+
+
+def get_extensions():
+    extension = CppExtension
+
+    define_macros = []
+
+    extra_link_args = []
+    extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]}
+    if int(os.environ.get("DEBUG", 0)):
+        extra_compile_args = {
+            "cxx": ["-O0", "-fno-inline", "-g", "-std=c++14"]}
+        extra_link_args = ["-O0", "-g"]
+
+    this_dir = os.path.dirname(os.path.abspath(__file__))
+    extensions_dir = os.path.join(this_dir, "functorch", "csrc")
+
+    extension_sources = set(
+        os.path.join(extensions_dir, p)
+        for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
+    )
+    sources = list(extension_sources)
+    include_dirs = [extensions_dir]
+
+    ext_modules = [
+        extension(
+            "functorch._C",
+            sources,
+            include_dirs=[this_dir],
+            define_macros=define_macros,
+            extra_compile_args=extra_compile_args,
+            extra_link_args=extra_link_args,
+        )
+    ]
+
+    return ext_modules
+
+
+setup(
+    name='functorch',
+    url="https://github.com/zou3519/functorch",
+    packages=find_packages(),
+    ext_modules=get_extensions(),
+    cmdclass={
+        # "clean": clean,
+        "build_ext": BuildExtension
+    })
diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py
new file mode 100644
index 0000000..975c4eb
--- /dev/null
+++ b/functorch/test/test_eager_transforms.py
@@ -0,0 +1,438 @@
+from torch.testing._internal.common_utils import TestCase, run_tests
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import unittest
+import functools
+import itertools
+import warnings
+from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
+    skipCUDAIfNoMagma
+import types
+from functools import partial
+
+import functorch
+from functorch import grad, vjp, vmap, make_functional, jacrev
+
+
+class TestGradTransform(TestCase):
+    def test_primitive(self):
+        x = torch.randn([])
+        result = grad(torch.sin)(x)
+        self.assertEqual(result, torch.cos(x))
+
+    def test_composite_simple(self):
+        x = torch.randn(2, 3, 4)
+        result = grad(lambda x: torch.flatten(x).sum())(x)
+        self.assertEqual(result, torch.ones_like(x))
+
+    def test_composite_complicated(self):
+        x = torch.randn(3)
+        y = torch.randn(3, 5)
+
+        def foo(x, y):
+            result = x @ y
+            return result.sum()
+
+        result = grad(foo)(x, y)
+
+        x.requires_grad_()
+        out = foo(x, y)
+        expected, = torch.autograd.grad(out, x)
+
+        self.assertEqual(result, expected)
+
+    def test_composite_two_ops(self):
+        N, C = 2, 5
+        y = torch.randn(N, C)
+        targets = torch.randint(0, C, (N,))
+
+        def foo(y, targets):
+            return F.cross_entropy(y, targets)
+
+        result = grad(foo)(y, targets)
+
+        y.requires_grad_()
+        expected, = torch.autograd.grad(foo(y, targets), y)
+
+        self.assertEqual(result, expected)
+
+    def _test_attributes(self, get_attr_lambda):
+        x = torch.randn(2, 3, 5, dtype=torch.double)
+        expected = get_attr_lambda(x)
+
+        def foo(x):
+            self.assertEqual(get_attr_lambda(x), expected)
+            return x.sum()
+
+        grad(foo)(x)
+
+    def test_shape(self):
+        self._test_attributes(lambda x: x.shape)
+
+    def test_dtype(self):
+        self._test_attributes(lambda x: x.dtype)
+
+    def test_is_cuda(self):
+        self._test_attributes(lambda x: x.is_cuda)
+
+    def test_numel(self):
+        self._test_attributes(lambda x: x.numel())
+
+    def test_inplace(self):
+        x = torch.randn([])
+
+        def foo(x):
+            return x.clone().sin_()
+
+        result = grad(foo)(x)
+        self.assertEqual(result, x.cos())
+
+    def test_inplace_on_view(self):
+        x = torch.randn(3)
+
+        def foo(x):
+            y = x.clone()
+            y0 = y[0]
+            y0.sin_()
+            return y.sum()
+
+        result = grad(foo)(x)
+
+        x.requires_grad_()
+        out = foo(x)
+        expected, = torch.autograd.grad(out, x)
+
+        self.assertEqual(result, expected)
+
+    def test_inplace_on_view_base(self):
+        x = torch.randn(3)
+
+        def foo(x):
+            y = x.clone()
+            y0 = y[0]
+            y.sin_()
+            return y0
+
+        result = grad(foo)(x)
+
+        x.requires_grad_()
+        out = foo(x)
+        expected, = torch.autograd.grad(out, x)
+
+        self.assertEqual(result, expected)
+
+    def test_nesting_simple(self):
+        x = torch.randn([])
+        result = grad(grad(torch.sin))(x)
+        self.assertEqual(result, -torch.sin(x))
+
+    def test_escaped_wrappers_are_marked_as_dead(self):
+        x = torch.randn([])
+        escaped = []
+        def foo(x):
+            y = x.sin()
+            escaped.append(y)
+            return y
+
+        result = grad(foo)(x)
+        self.assertEqual(escaped[0].dlevel(), -1)
+
+    def test_escaped_wrappers_are_ignored(self):
+        x = torch.randn([])
+        escaped = []
+        def foo(x):
+            y = x.sin()
+            escaped.append(y)
+            return y
+
+        result = grad(foo)(x)
+
+        something = escaped[0].sum()
+        self.assertEqual(something.dlevel(), 0)
+        self.assertEqual(something, x.sin().sum())
+
+    def test_vjp(self):
+        x = torch.randn([])
+        out, vjp_fn = vjp(torch.sin, x)
+        self.assertEqual(out, x.sin())
+
+        v = torch.randn([])
+        result, = vjp_fn(v)
+        self.assertEqual(result, v * x.cos())
+
+    def test_composed_with_autograd(self):
+        x = torch.randn([], requires_grad=True)
+
+        y = grad(torch.sin)(x)
+        result, = torch.autograd.grad(y, x)
+        self.assertEqual(result, -x.sin())
+
+    def test_grad_of_vjp_composition(self):
+        x = torch.randn([])
+        y = torch.randn([])
+
+        def foo(x, y):
+            out, vjp_fn = vjp(torch.sin, x)
+            return grad(lambda y: vjp_fn(y)[0])(y)
+
+        result = foo(x, y)
+        expected = x.cos()
+        self.assertEqual(result, expected)
+
+    def test_vjp_of_grad_composition(self):
+        x = torch.randn([])
+        y = torch.randn([])
+
+        def foo(x, y):
+            out, vjp_fn = vjp(grad(torch.sin), x)
+            return vjp_fn(y)[0]
+
+        result = foo(x, y)
+        expected = -y * x.sin()
+        self.assertEqual(result, expected)
+
+    def test_grad_of_vjp_of_grad_composition(self):
+        x = torch.randn([])
+        y = torch.randn([])
+
+        def foo(x, y):
+            df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
+            return grad(lambda y: vjp_fn(y)[0])(y)
+
+        result = foo(x, y)
+        expected = x.cos()
+        self.assertEqual(result, expected)
+
+    def test_views(self):
+        x = torch.randn([], requires_grad=True)
+        y = torch.randn([], requires_grad=True)
+
+        def silly_sin(x):
+            x = x.view([])
+            x = x.sin()
+            return x
+
+        def foo(x, y):
+            z1 = grad(silly_sin)(x)
+            z2 = torch.cos(y)
+            return z1 + z2
+
+        result = foo(x, y)
+        grads = torch.autograd.grad(result, [x, y])
+        self.assertEqual(grads[0], -x.sin())
+        self.assertEqual(grads[1], -y.sin())
+
+    def test_view_inplace_simple(self):
+        def foo(x):
+            x = x.clone()
+            x.view([]).sin_()
+            return x
+
+        x = torch.randn([], requires_grad=True)
+        result = grad(foo)(x)
+        self.assertEqual(result, x.cos())
+
+
+class TestVmapOfGrad(TestCase):
+    def test_per_sample_grads_inplace_view(self):
+        def compute_loss(weight, x, t):
+            x = x.mm(weight)
+            y = x.squeeze_(0)
+            return (y - t).sum()
+
+        weight = torch.randn(16, 2)
+        x = torch.randn(64, 1, 16)
+        t = torch.randn(64, 2)
+        result = vmap(partial(grad(compute_loss), weight))(x, t)
+        expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
+        expected = torch.stack(expected)
+        # TODO: Check if the rtol is a problem
+        self.assertEqual(result, expected, atol=0, rtol=5e-4)
+
+    def test_new_zeros_materializes_tensor(self):
+        N = 3
+        C = 5
+
+        def foo(x, y):
+            result = x.new_zeros((C,))
+            result.copy_(y)
+            return result.sum()
+
+        x = torch.randn(N)
+        y = torch.randn(N, C)
+        result = vmap(grad(foo))(x, y)
+
+    def test_per_sample_grads_simple(self):
+        def compute_loss(weight, x, t):
+            y = x @ weight
+            return ((y - t) ** 2).sum()
+
+        weight = torch.randn(16, 2)
+        x = torch.randn(64, 16)
+        t = torch.randn(64, 2)
+        result = vmap(partial(grad(compute_loss), weight))(x, t)
+        expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
+        expected = torch.stack(expected)
+        # TODO: Check if the rtol is a problem
+        self.assertEqual(result, expected, atol=0, rtol=5e-4)
+
+    def test_per_sample_grads_embeddingnet(self):
+        class SampleNet(nn.Module):
+            def __init__(self, vocab_size: int):
+                super().__init__()
+                self.emb = nn.Embedding(vocab_size, 16)
+                self.fc1 = nn.Linear(16, 16)
+                self.fc2 = nn.Linear(16, 2)
+
+            def forward(self, x):
+                x = self.emb(x)
+                x = torch.transpose(x, -1, -2)
+                x = torch.mean(x, -1)
+                x = self.fc1(x)
+                x = F.relu(x)
+                x = self.fc2(x)
+                return x
+
+            def name(self):
+                return "SampleNet"
+
+        # Create our inputs...
+        vocab_size = 1000
+        batch_shape = [64]
+        words_per_sentence = 5
+        data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence))
+        targets = torch.randint(0, 1, (*batch_shape,))
+
+        # Construct our module
+        net = SampleNet(vocab_size)
+        criterion = nn.CrossEntropyLoss()
+
+        params = dict(net.named_parameters())
+        weights, net_func, _ = make_functional(net)
+
+        def compute_loss(weights, data, target):
+            output = net_func(weights, (data,))
+            result = criterion(output, target)
+            return result
+
+        expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)]
+        expected = zip(*expected)
+        expected = tuple(torch.stack(shards) for shards in expected)
+
+        result = vmap(partial(grad(compute_loss), weights))(data, targets)
+        for r, e in zip(result, expected):
+            # TODO: Check if the rtol is a problem
+            self.assertEqual(r, e, atol=0, rtol=1e-4)
+
+class TestJacrev(TestCase):
+    def test_simple(self):
+        x = torch.randn(3)
+        y = jacrev(torch.sin)(x)
+        expected = torch.diagflat(x.cos())
+        assert torch.allclose(y, expected)
+
+    def test_simple_not_flat(self):
+        x = torch.randn(2, 3)
+        y = jacrev(torch.sin)(x)
+        expected = torch.diagflat(x.view(-1).cos())
+        expected = expected.view(2, 3, 2, 3)
+        assert torch.allclose(y, expected)
+
+    def test_vmap_on_jacrev_simple(self):
+        x = torch.randn(2, 3)
+        y = vmap(jacrev(torch.sin))(x)
+        expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
+        assert torch.allclose(y, expected)
+
+    def test_hessian_simple(self):
+        def foo(x):
+            return x.sin().sum()
+
+        x = torch.randn(3)
+        y = jacrev(jacrev(foo))(x)
+        expected = torch.diagflat(-x.sin())
+        assert torch.allclose(y, expected)
+
+
+class TestComposability(TestCase):
+    def test_grad_grad(self):
+        x = torch.randn([])
+        y = grad(grad(torch.sin))(x)
+        self.assertEqual(y, -x.sin())
+
+    def test_grad_vmap(self):
+        def foo(x):
+            y = vmap(torch.sin)(x)
+            return y.sum()
+
+        x = torch.randn(3)
+        y = grad(foo)(x)
+        self.assertEqual(y, x.cos())
+
+    def test_grad_vjp(self):
+        x = torch.randn(3)
+
+        def foo(x):
+            _, vjp_fn = vjp(torch.sin, x)
+            return vjp_fn(x)[0].sum()
+
+        y = grad(foo)(x)
+        expected = grad(lambda x: (x * x.cos()).sum())(x)
+        self.assertEqual(y, expected)
+
+    def test_vmap_grad(self):
+        x = torch.randn(3)
+        y = vmap(grad(torch.sin))(x)
+        self.assertEqual(y, x.cos())
+
+    def test_vmap_vmap(self):
+        x = torch.randn(2, 3)
+        y = vmap(vmap(torch.sin))(x)
+        self.assertEqual(y, x.sin())
+
+    def test_vmap_vjp(self):
+        x = torch.randn(3)
+        _, vjp_fn = vjp(torch.sin, x)
+
+        def foo(x):
+            _, vjp_fn = vjp(torch.sin, x)
+            return vjp_fn(x)
+
+        y = vmap(foo)(x)
+        self.assertEqual(y, vjp_fn(x))
+
+        xs = torch.randn(5, 3)
+        expected = torch.stack([vjp_fn(x)[0] for x in xs])
+        self.assertEqual(vmap(lambda x: vjp_fn(x)[0])(xs), expected)
+
+    def test_vjp_grad(self):
+        x = torch.randn([])
+        y, vjp_fn = vjp(grad(torch.sin), x)
+        self.assertEqual(y, x.cos())
+
+        v = torch.randn([])
+        self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
+
+    def test_vjp_vmap(self):
+        x = torch.randn(3)
+        y, vjp_fn = vjp(vmap(torch.sin), x)
+        self.assertEqual(y, x.sin())
+
+        v = torch.randn(3)
+        self.assertEqual(vjp_fn(v)[0], x.cos() * v)
+
+    def test_vjp_vjp(self):
+        x = torch.randn(3)
+        y, vjp_fn = vjp(torch.sin, x)
+        self.assertEqual(y, x.sin())
+
+        y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x)
+        self.assertEqual(y, x * x.cos())
+
+        y = vjp_fn(x)[0]
+        # Honestly IDK what the result here is... but at least it runs
+
+
+if __name__ == '__main__':
+    run_tests()
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
new file mode 100644
index 0000000..7eaad7e
--- /dev/null
+++ b/functorch/test/test_vmap.py
@@ -0,0 +1,2516 @@
+from torch.testing._internal.common_utils import TestCase, run_tests
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+import functools
+import itertools
+import warnings
+import unittest
+from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
+    skipCUDAIfNoMagma
+import types
+
+from functorch import vmap
+
+
+FALLBACK_REGEX = r'falling back to slow \(for loop( and stack)?\) implementation'
+
+class EnableVmapFallbackWarnings:
+    def __enter__(self):
+        self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled()
+        torch._C._debug_only_display_vmap_fallback_warnings(True)
+
+    def __exit__(self, *ignored):
+        torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state)
+
+class TestVmapAPI(TestCase):
+    def test_non_tensor_output_raises(self):
+        with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
+            output = vmap(lambda x: 3.14)(torch.ones(3))
+
+        def multiple_outputs(x):
+            return x, 3
+
+        with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
+            vmap(multiple_outputs)(torch.ones(3))
+
+    def test_different_map_dim_size_raises(self):
+        x = torch.randn(2)
+        y = torch.randn(3)
+        expected_msg = 'Expected all tensors to have the same size in the mapped dimension'
+        with self.assertRaisesRegex(ValueError, expected_msg):
+            vmap(torch.mul)(x, y)
+        with self.assertRaisesRegex(ValueError, expected_msg):
+            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
+        with self.assertRaisesRegex(ValueError, expected_msg):
+            vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y})
+
+    def test_func_with_no_inputs(self):
+        expected_msg = 'got no inputs'
+
+        def foo():
+            return torch.randn(3)
+
+        def bar(x):
+            return torch.randn(3)
+
+        with self.assertRaisesRegex(ValueError, expected_msg):
+            vmap(foo)()
+
+        with self.assertRaisesRegex(ValueError, expected_msg):
+            vmap(bar)()
+
+    def test_constant_function(self):
+        output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
+        self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
+
+    def test_single_input(self):
+        x = torch.randn(2, 3)
+
+        def square(x):
+            return x * x
+
+        output = vmap(square)(x)
+        self.assertEqual(output, x * x)
+
+    def test_multiple_inputs(self):
+        x = torch.randn(2, 3)
+        y = torch.randn(2, 3)
+        output = vmap(torch.mul)(x, y)
+        self.assertEqual(output, x * y)
+
+    def test_multiple_outputs(self):
+        def foo(x):
+            return x * x, x * x * x
+
+        x = torch.randn(3)
+        outputs = vmap(foo)(x)
+        self.assertEqual(outputs[0], x * x)
+        self.assertEqual(outputs[1], x * x * x)
+
+    def test_multiple_outputs_error_cases(self):
+        # This is the same thing as
+        # def returns_tuple_of_tensors(x):
+        #     return x, x
+        def returns_tuple_of_tensors(x):
+            return (x, x)
+
+        def returns_list_of_two_tensors(x):
+            return [x, x]
+
+        def returns_list_of_one_tensor(x):
+            return [x]
+
+        x = torch.randn(3)
+
+        # should not throw
+        vmap(returns_tuple_of_tensors)(x)
+
+        # jax supports these, but we don't yet
+        msg = "must only return Tensors, got type <class 'list'>"
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(returns_list_of_two_tensors)(x)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(returns_list_of_one_tensor)(x)
+
+    def test_nested_with_same_map_dim(self):
+        x = torch.randn(2, 3, 5)
+        y = torch.randn(2, 3, 5)
+        output = vmap(vmap(torch.mul))(x, y)
+        self.assertEqual(output, x * y)
+
+        output = vmap(vmap(vmap(torch.mul)))(x, y)
+        self.assertEqual(output, x * y)
+
+    def test_nested_with_different_map_dim(self):
+        x = torch.randn(2, 3)
+        y = torch.randn(5, 3)
+        output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
+        self.assertEqual(output.shape, (2, 5, 3))
+        self.assertEqual(output, x.view(2, 1, 3) * y)
+
+        z = torch.randn(7, 3)
+        output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
+        self.assertEqual(output.shape, (2, 5, 7, 3))
+        self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
+
+    def test_noop_in_inner_vmap(self):
+        x = torch.randn(3)
+        y = torch.randn(5)
+        output = vmap(lambda x: vmap(lambda y: x)(y))(x)
+        self.assertEqual(output, x.view(3, 1).expand(3, 5))
+
+    def test_unsupported_op_err_msg(self):
+        # Unsupported view op
+        tensor = torch.randn(2, 3)
+        msg = (
+            r"Batching rule not implemented for aten::.+; the "
+            r"fallback path doesn't work on out= or view ops"
+        )
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(torch.ravel)(tensor)
+
+        def out_op(x, y):
+            return torch.abs(x, out=y)
+
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(out_op)(tensor, tensor)
+
+        tensor = torch.randn(2)
+        # The fallback doesn't support TensorList
+        with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
+            vmap(lambda t: torch.atleast_1d([t]))(tensor)
+
+        # Don't support non-tensor returns. This is a limitation of vmap;
+        # functions that don't return tensors must be special cased
+        with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
+            vmap(torch.Tensor.item)(tensor)
+
+    def test_nonzero_out_dims(self):
+        # Basic test
+        tensor = torch.randn(2, 3)
+        result = vmap(lambda x: x, out_dims=1)(tensor)
+        self.assertEqual(result, tensor.permute(1, 0))
+        self.assertEqual(result.data_ptr(), tensor.data_ptr())
+
+        # Test that the batch dimension gets permuted to dim 2
+        tensor = torch.randn(2, 3, 5, 7)
+        result = vmap(lambda x: x, out_dims=2)(tensor)
+        self.assertEqual(result, tensor.permute(1, 2, 0, 3))
+        self.assertEqual(result.data_ptr(), tensor.data_ptr())
+
+        # negative out_dim
+        tensor = torch.randn(2, 3, 5, 7)
+        result = vmap(lambda x: x, out_dims=-1)(tensor)
+        self.assertEqual(result, tensor.permute(1, 2, 3, 0))
+        self.assertEqual(result.data_ptr(), tensor.data_ptr())
+
+        # check that out_dims works on ALL outputs
+        tensor = torch.randn(2, 3, 5, 7)
+        other = torch.randn(2, 3, 5, 7)
+        result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
+        self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)))
+
+        # use out_dims with the maximum vmap-able tensor dims (64 dims)
+        ndims = 64
+        shape = [2] + [1] * (ndims - 1)
+        expected_shape = [1, 1, 2] + [1] * (ndims - 3)
+        tensor = torch.randn(shape)
+        result = vmap(lambda x: x, out_dims=2)(tensor)
+        self.assertEqual(result.shape, expected_shape)
+
+        # test something that is not the identity function
+        def foo(x, y):
+            return x, x * y, x * y * y
+        x = torch.randn(2, 3, 5)
+        y = torch.randn(2, 3, 5)
+        result = vmap(foo, out_dims=1)(x, y)
+        self.assertEqual(
+            result,
+            (x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2)))
+
+    def test_multiple_out_dims(self):
+        def foo(x):
+            return x, x
+
+        def bar(x, y):
+            return x, x, x, x * y
+
+        x = torch.randn(2, 3, 5)
+        y = torch.randn(2, 3, 5)
+        result = vmap(foo, out_dims=(0, 1))(x)
+        self.assertEqual(result, (x, x.permute(1, 0, 2)))
+
+        result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
+        expected = (
+            x.permute(1, 2, 0),
+            x,
+            x.permute(1, 0, 2),
+            (x * y).permute(1, 2, 0),
+        )
+        self.assertEqual(result, expected)
+
+    def test_nested_out_dims(self):
+        y = torch.randn(2, 3, 5, 7)
+
+        # Inner vmap has non-zero out_dim
+        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
+        self.assertEqual(result.shape, (2, 5, 3, 7))
+        self.assertEqual(result, y.permute(0, 2, 1, 3))
+
+        # all vmaps have non-zero out_dim
+        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
+        self.assertEqual(result.shape, (5, 2, 3, 7))
+        self.assertEqual(result, y.permute(2, 0, 1, 3))
+
+        # throwing in some negative out_dims
+        result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
+        self.assertEqual(result.shape, (5, 7, 3, 2))
+        self.assertEqual(result, y.permute(2, 3, 1, 0))
+
+        # testing fn that isn't the identity
+        x = torch.randn(2, 3)
+        y = torch.randn(5, 3)
+        result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
+        self.assertEqual(result.shape, (3, 2, 5))
+        self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
+
+    def test_out_dims_edge_case(self):
+        def foo(x):
+            return x
+
+        # Test that we accept out_dims=(1,) for a function with one output.
+        tensor = torch.randn(2, 3)
+        expected = vmap(foo, out_dims=1)(tensor)
+        result = vmap(foo, out_dims=(1,))(tensor)
+        self.assertEqual(result, expected)
+
+    def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
+        msg = '`out_dims` must be an int or a tuple of int'
+        tensor = torch.randn(2, 3)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda x: x, out_dims='lol')(tensor)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda x: x, out_dims=('lol',))(tensor)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda x: x, out_dims=None)(tensor)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda x: x, out_dims=(None,))(tensor)
+
+    def test_out_dims_and_num_outputs_mismatch_err_msg(self):
+        msg = '`out_dims` must have one dim per output'
+        x = torch.randn(2, 3, 5)
+
+        # Too many out_dims
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda x: x, out_dims=(0, 0))(x)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
+
+        # Too few out_dims
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda x: (x, x), out_dims=(0,))(x)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
+
+    def test_out_dim_out_of_bounds_err_msg(self):
+        # TODO(rzou): This error message isn't that great. It comes straight
+        # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
+        # the error message in the future in C++
+        msg = 'Dimension out of range'
+        x = torch.randn(2, 3, 5)
+        with self.assertRaisesRegex(IndexError, msg):
+            vmap(lambda x: x, out_dims=3)(x)
+        with self.assertRaisesRegex(IndexError, msg):
+            vmap(lambda x: x, out_dims=-4)(x)
+
+    def test_non_zero_in_dims(self):
+        tensor = torch.randn(2, 3, 5)
+
+        # Implicit out_dims = 0; vmap will move the batch dim to the front.
+        output = vmap(lambda x: x, (1,))(tensor)
+        self.assertEqual(output, tensor.permute(1, 0, 2))
+        self.assertEqual(output.data_ptr(), tensor.data_ptr())
+
+        x = torch.randn(2, 3)
+        y = torch.randn(3, 2)
+        output = vmap(torch.mul, (0, 1))(x, y)
+        self.assertEqual(output, x * y.t())
+        output = vmap(torch.mul, (1, 0))(x, y)
+        self.assertEqual(output, x.t() * y)
+
+    def test_none_in_dims(self):
+        x = torch.randn(2, 3)
+        y = torch.randn(2, 3)
+
+        # None in_dim for a Tensor means we don't map over it
+        output = vmap(torch.mul, (0, None))(x, y)
+        self.assertEqual(output.shape, (2, 2, 3))
+        self.assertEqual(output, x.view(2, 1, 3) * y)
+
+        # None in_dim for non-tensor arguments
+        output = vmap(torch.mul, (0, None))(x, 2)
+        self.assertEqual(output, x * 2)
+
+    def test_nested_non_default_in_dims(self):
+        x = torch.rand(5, 2, 3)
+        y = torch.rand(3, 5, 2)
+        result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
+        self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))
+
+    def test_non_default_in_dims_out_dims(self):
+        x = torch.randn(2, 3, 5)
+
+        # Same in_dim as out_dim, vmap over identity
+        result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
+        self.assertEqual(result, x)
+        self.assertEqual(result.data_ptr(), x.data_ptr())
+
+        # Different in_dim from out_dim, vmap over identity
+        result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
+        self.assertEqual(result.shape, (2, 5, 3))
+        self.assertEqual(result, x.transpose(1, 2))
+        self.assertEqual(result.data_ptr(), x.data_ptr())
+
+        def foo(x):
+            return x * 2
+
+        # Same in_dim as out_dim, vmap over operation
+        result = vmap(foo, in_dims=1, out_dims=1)(x)
+        self.assertEqual(result, x * 2)
+
+        # Different in_dim as out_dim, vmap over operation
+        result = vmap(foo, in_dims=2, out_dims=1)(x)
+        self.assertEqual(result.shape, (2, 5, 3))
+        self.assertEqual(result, (x * 2).transpose(1, 2))
+
+        # Basic nested test.
+        result = vmap(vmap(foo, 1, 1), 1, 1)(x)
+        self.assertEqual(result, x * 2)
+
+    def test_accepts_nested_inputs(self):
+        B0 = 2
+        x = torch.randn(2, 3)
+        y = torch.randn(2, 3)
+
+        # Single layer of nesting
+        out = vmap(lambda z: z[0] + z[1])((x, y))
+        self.assertEqual(out, x + y)
+        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y))
+        self.assertEqual(out, x + y)
+        out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
+        self.assertEqual(out, x + y)
+
+        out = vmap(lambda z: z[0] + z[1])([x, y])
+        self.assertEqual(out, x + y)
+        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y])
+        self.assertEqual(out, x + y)
+        out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y])
+        self.assertEqual(out, x + y)
+
+        out = vmap(lambda z: z['x'] + z['y'])({'x': x, 'y': y})
+        self.assertEqual(out, x + y)
+        out = vmap(lambda z: z['x'] + z['y'], in_dims=(0,))({'x': x, 'y': y})
+        self.assertEqual(out, x + y)
+        out = vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y})
+        self.assertEqual(out, x + y)
+
+        # Multiple layers of nesting
+        out_fn = vmap(lambda z: z['x'][0] + z['x'][1][0] + z['y'][0] + z['y'][1])
+        out = out_fn({'x': [x, (x,)], 'y': [y, y]})
+        self.assertEqual(out, x + x + y + y)
+
+    def test_in_dims_wrong_type_err_msg(self):
+        x = torch.randn(3)
+        y = torch.randn(3)
+        msg = r'expected `in_dims` to be int or a \(potentially nested\) tuple'
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(torch.mul, [0, 0])(x, y)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(torch.mul, set({0, 0}))(x, y)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(torch.mul, 'lol')(x, y)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y])
+        # The following should not throw
+        vmap(torch.mul, (0, 0))(x, y)
+
+    def test_not_enough_in_dims_err_msg(self):
+        x = torch.randn(3)
+        y = torch.randn(3)
+        msg = r'in_dims is not compatible with the structure of `inputs`'
+
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(torch.mul, (0,))(x, y)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(torch.mul, (0, 0, 0))(x, y)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y])
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y])
+        # The following should not throw
+        vmap(torch.mul, (0, 0))(x, y)
+
+    def test_integer_in_dim_but_not_tensor_input_err_msg(self):
+        def foo(xy):
+            return xy[0] * xy[1]
+
+        def bar(x, yz):
+            return x * yz[0] * yz[1]
+
+        x = torch.randn(2, 3)
+        y = torch.randn(2, 3)
+
+        # the following are errors in jax (and will always be errors)
+        msg = 'Got in_dim=0 for an input but the input is of type'
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(torch.sum)(x, 0)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(torch.sum, (0, 0))(x, 0)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1])
+        # The following should not throw
+        vmap(torch.sum, (0, None))(x, 0)
+
+    def test_in_dim_not_in_tensor_err_msg(self):
+        def foo(x):
+            return x * x
+
+        x = torch.randn(2, 3)
+        y = torch.randn(2, 3)
+
+        msg = r'Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w'
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(foo)(torch.randn([]))
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(foo, in_dims=(0,))(torch.randn([]))
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(foo, in_dims=(-1,))(x)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(foo, in_dims=(2,))(y)
+        with self.assertRaisesRegex(ValueError, msg):
+            vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y])
+        # the following should not throw
+        vmap(foo, in_dims=(0,))(torch.randn(2, 3))
+        vmap(foo, in_dims=(1,))(torch.randn(2, 3))
+
+    def test_fallback_does_not_warn_by_default(self):
+        # NB: One day we will implement a batching rule for torch.atan2.
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        op = torch.atan2
+        x = torch.randn(11)
+        y = torch.randn(11)
+        with warnings.catch_warnings(record=True) as wa:
+            result = vmap(op)(x, y)
+            # The single warning here is the "vmap is experimental"
+            # warning, not a warning from the vmap fallback path.
+            self.assertEqual(len(wa), 1)
+
+    def test_fallback_warns_when_warnings_are_enabled(self):
+        # NB: One day we will implement a batching rule for torch.atan2.
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        op = torch.atan2
+        x = torch.randn(11)
+        y = torch.randn(11)
+        with warnings.catch_warnings(record=True) as wa:
+            with EnableVmapFallbackWarnings():
+                result = vmap(op)(x, y)
+            self.assertEqual(len(wa), 2)
+            self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
+
+    def _assert_uses_vmap_fallback(self, vmap_args, inputs):
+        return
+        with warnings.catch_warnings(record=True) as wa:
+            with EnableVmapFallbackWarnings():
+                result = vmap(*vmap_args)(*inputs)
+            self.assertEqual(len(wa), 2)
+            self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
+
+    def test_fallback_zero_dim(self):
+        # NB: One day we will implement a batching rule for torch.atan2.
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        op = torch.atan2
+        x = torch.randn(11)
+        y = torch.randn(11)
+        self._assert_uses_vmap_fallback((op,), (x, y))
+
+        B0, B1 = 0, 3
+        x = torch.randn(B0, 11)
+        y = torch.randn(11)
+
+        msg = 'The fallback path does not support vmap over dims of size 0'
+
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, (0, None))(x, y)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, (None, 0))(y, x)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op)(x, x)
+
+        x = torch.randn(B0, B1, 11)
+        y = torch.randn(B1, 11)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, (0, None))(x, y)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, (None, 0))(y, x)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op)(x, x)
+
+    def test_fallback_atan2(self):
+        # NB: One day we will implement a batching rule for torch.atan2.
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        op = torch.atan2
+
+        x = torch.randn(5, 7, 11)
+        y = torch.randn(5, 7, 11)
+
+        self._assert_uses_vmap_fallback((op,), (x, y))
+
+        # fallback on torch.atan2
+        x = torch.randn(7, 11, 5)
+        y = torch.randn(5, 7, 11)
+        result = vmap(op, (2, 0))(x, y)
+        self.assertEqual(result, op(x.permute(2, 0, 1), y))
+
+        # fallback on torch.atan2, nested vmap
+        x = torch.randn(7, 11, 5)
+        y = torch.randn(5, 7, 11)
+        result = vmap(vmap(op), (2, 0))(x, y)
+        self.assertEqual(result, op(x.permute(2, 0, 1), y))
+
+        # big batch size (total 10000)
+        x = torch.randn(100, 10, 10, 5)
+        y = torch.randn(100, 10, 10)
+        result = vmap(vmap(vmap(op)))(x, y)
+        self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))
+
+    def test_fallback_masked_fill(self):
+        # NB: One day we will implement a batching rule for masked_fill
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        def run_test(batch_size):
+            B0 = batch_size
+            x = torch.randn(B0, 7, 11, 13)
+            dim = 0
+            index = torch.tensor([0, 4, 2])
+            values = torch.randn(B0, 3, 13)
+
+            self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values))
+
+            result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
+            expected = torch.index_add(
+                x, dim + 1, index, values.view(B0, 3, 1, 13))
+            self.assertEqual(result, expected)
+
+        run_test(batch_size=5)
+        run_test(batch_size=1237)
+
+    def test_fallback_multiple_returns(self):
+        # NB: One day we will implement a batching rule for torch.var_mean
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        B0, B1, B2 = 2, 3, 1237
+        tensor = torch.randn(B0, 10)
+
+        self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))
+
+        # fallback correctness on torch.var_mean
+        result = vmap(torch.var_mean)(tensor)
+        expected = torch.var_mean(tensor, dim=1)
+        self.assertEqual(result, expected)
+
+        # nested vmap
+        tensor = torch.randn(B0, B1, 10)
+        result = vmap(vmap(torch.var_mean))(tensor)
+        expected = torch.var_mean(tensor, dim=2)
+        self.assertEqual(result, expected)
+
+        # big batch size, nested vmap
+        tensor = torch.randn(B0, B1, B2, 10)
+        result = vmap(vmap(vmap(torch.var_mean)))(tensor)
+        expected = torch.var_mean(tensor, dim=3)
+        self.assertEqual(result, expected)
+
+    def test_inplace_fallback_unary(self):
+        # Test the in-place fallback on an in-place method that takes no
+        # additional Tensor arguments. This is the simplest case of the fallback.
+        # NB: One day we will implement a batching rule for acos_.
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        op = Tensor.acos_
+        B0, B1, B2 = 2, 3, 10000
+
+        x = torch.randn(B0, 5)
+        self._assert_uses_vmap_fallback((op,), (x,))
+
+        # Single vmap
+        x_orig = torch.rand(B0, 5)
+        x = x_orig.clone()
+        result = vmap(op)(x)
+        self.assertTrue(result is x)
+        self.assertEqual(result, x_orig.acos())
+
+        # Single vmap + different out_dim produces a view(!)
+        x_orig = torch.rand(B0, 5)
+        x = x_orig.clone()
+        result = vmap(op, out_dims=(1,))(x)
+        self.assertTrue(result._base is x)
+        self.assertEqual(result, x_orig.t().acos())
+
+        # Nested vmap
+        x_orig = torch.randn(B0, B1, 5)
+        x = x_orig.clone()
+        result = vmap(vmap(op))(x)
+        self.assertTrue(result is x)
+        self.assertEqual(result, x_orig.acos())
+
+        # Nested vmap, large batch size
+        x_orig = torch.randn(B0, B1, B2, 5)
+        x = x_orig.clone()
+        result = vmap(vmap(vmap(op)))(x)
+        self.assertTrue(result is x)
+        self.assertEqual(result, x_orig.acos())
+
+    def test_inplace_fallback_nary_same_levels(self):
+        # NB: One day we will implement a batching rule for atan2_
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        op = Tensor.atan2_
+        outplace_op = torch.atan2
+
+        x = torch.randn(5, 7, 11)
+        y = torch.randn(5, 7, 11)
+        self._assert_uses_vmap_fallback((op,), (x, y))
+
+        # Single vmap
+        B0 = 5
+        x_orig = torch.randn(7, 11, B0)
+        x = x_orig.clone()
+        y = torch.randn(B0, 7, 11)
+        vmap(op, (2, 0))(x, y)
+        self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2)))
+
+        # Nested vmap
+        B0, B1 = 5, 7
+        x_orig = torch.randn(B1, 11, B0)
+        x = x_orig.clone()
+        y = torch.randn(B0, B1, 11)
+        vmap(vmap(op), (2, 0))(x, y)
+        self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0])))
+
+        # big batch size (total 10000)
+        B0, B1, B2 = 100, 10, 10
+        x_orig = torch.randn(B0, B1, B2, 5)
+        x = x_orig.clone()
+        y = torch.randn(B0, B1, B2)
+        result = vmap(vmap(vmap(op)))(x, y)
+        self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1)))
+
+    # ("Fallback isInplaceVmapCompatible check is broken")
+    @unittest.expectedFailure
+    def test_inplace_fallback_nary_different_levels(self):
+        # NB: One day we will implement a batching rule for atan2_
+        # If/when we do, this test should be replaced to test the fallback
+        # path on another operator to avoid bitrot.
+        op = Tensor.atan2_
+        outplace_op = torch.atan2
+        B0, B1, B2 = 2, 3, 5
+
+        x = torch.rand(B0, 7)
+        y = torch.rand(7)
+        self._assert_uses_vmap_fallback((op, (0, None)), (x, y))
+
+        # op(left, right): All of the levels in right are found in left
+        x_orig = torch.rand(B0, 7)
+        x = x_orig.clone()
+        y = torch.rand(7)
+        vmap(op, in_dims=(0, None))(x, y)
+        self.assertEqual(x, outplace_op(x_orig, y))
+
+        x_orig = torch.rand(B0, B1, 7)
+        x = x_orig.clone()
+        y = torch.rand(B0, 7)
+        vmap(vmap(op, in_dims=(0, None)))(x, y)
+        self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7)))
+
+        # op(left, right): Some of the levels in right are not found in left
+        msg = r'vmap: aten::atan2_\(self, \*extra_args\) is not possible'
+        x = torch.rand(7)
+        y = torch.rand(B0, 7)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(None, 0))(x, y)
+
+        x = torch.rand(B1, 7)
+        y = torch.rand(B0, 7)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y)
+
+        x = torch.rand(B1, 7)
+        y = torch.rand(7, B0)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y)
+
+        x = torch.rand(B0, 7)
+        y = torch.rand(B0, B1, 7)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(vmap(op, in_dims=(None, 0)))(x, y)
+
+    def test_backward_unsupported_interaction(self):
+        x = torch.randn(3, requires_grad=True)
+        y = torch.randn(5)
+        grad = torch.randn_like(x)
+        err_msg = r'backward\(\) called inside torch.vmap'
+
+        def backward_on_vmapped_tensor(x):
+            x.sum().backward()
+
+        with self.assertRaisesRegex(RuntimeError, err_msg):
+            vmap(backward_on_vmapped_tensor)(x)
+
+        def backward_with_vmapped_grad(x, grad):
+            x.backward(grad)
+
+        with self.assertRaisesRegex(RuntimeError, err_msg):
+            vmap(backward_with_vmapped_grad)(x, grad)
+
+        def completely_unrelated_backward(y):
+            x.sum().backward()
+
+        with self.assertRaisesRegex(RuntimeError, err_msg):
+            vmap(completely_unrelated_backward)(y)
+
+    def test_grad_unsupported_interaction(self):
+        input_tensor = torch.randn(3, requires_grad=True)
+        err_msg = 'autograd.grad.* called inside torch.vmap'
+
+        captured = torch.randn(3, requires_grad=True)
+
+        def output_to_grad_is_vmapped(input_tensor):
+            output = (captured * input_tensor).sum()
+            return torch.autograd.grad([output], [captured])[0]
+
+        with self.assertRaisesRegex(RuntimeError, err_msg):
+            vmap(output_to_grad_is_vmapped)(input_tensor)
+
+        output = (input_tensor ** 2).sum()
+
+        def input_to_grad_is_vmapped(input_tensor):
+            return torch.autograd.grad([output], [input_tensor])[0]
+
+        with self.assertRaisesRegex(RuntimeError, err_msg):
+            vmap(input_to_grad_is_vmapped)(input_tensor)
+
+    def test_batched_gradient_basic(self):
+        N = 3
+        x = torch.randn(N, requires_grad=True)
+        y = torch.randn(N)
+
+        def vjp_mul(v):
+            return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]
+
+        batched_v = torch.eye(N)
+        jacobian = vmap(vjp_mul)(batched_v)
+        self.assertEqual(jacobian, torch.diagflat(y))
+
+    def test_functools_partial(self):
+        x = torch.randn(3)
+        y = torch.randn(2, 3)
+        result = vmap(functools.partial(torch.mul, x))(y)
+        self.assertEqual(result, x * y)
+
+    def test_nn_module(self):
+        tensor = torch.randn(2, 3)
+        model = torch.nn.Linear(3, 3, bias=False)
+        result = vmap(model)(tensor)
+        self.assertEqual(result, model(tensor))
+
+    def test_fallback_with_undefined_grad(self):
+        B0 = 7
+        x = torch.randn(2, 3, 4, 5, requires_grad=True)
+        weight = torch.randn(3, 3, 1, 1)
+        v = torch.randn(B0, 2, 3, 4, 5)
+
+        def get_vjp(v):
+            result = torch.nn.functional.conv2d(x, weight)
+            grad_x, = torch.autograd.grad(result, x, v)
+            return grad_x
+
+        # Runs vmap(get_vjp)(v), which should not error out.
+        # The backward formula for convolution returns an undefined
+        # Tensor for grad_bias because the original bias does not exist.
+        #
+        # In the future we'll probably add a batching rule for convolution
+        # backward. When this happens, we should modify this test to use a
+        # different op (and/or create and use a dummy operator) to avoid bitrot.
+        self._assert_uses_vmap_fallback([get_vjp], [v])
+
+def slice_inputs(inputs, bdims, i):
+    result = []
+    for inp, bdim in zip(inputs, bdims):
+        if bdim is None:
+            result.append(inp)
+        else:
+            result.append(inp.select(bdim, i))
+    return tuple(result)
+
+
+def reference_vmap(op, inputs, in_dims=0, out_dims=0):
+    if isinstance(in_dims, int):
+        in_dims = (in_dims,) * len(inputs)
+    bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
+    assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
+    bdim_size = bdim_sizes[0]
+    results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
+
+    assert len(results) > 0
+    op_has_single_return = not isinstance(results[0], tuple)
+    if op_has_single_return:
+        assert all(isinstance(result, torch.Tensor) for result in results)
+        if isinstance(out_dims, int):
+            out_dims = (out_dims,) * 1
+        return torch.stack(results, dim=out_dims[0])
+
+    assert all(isinstance(result, tuple) for result in results)
+    num_returns = len(results[0])
+    assert all(len(result) == num_returns for result in results)
+    if isinstance(out_dims, int):
+        out_dims = (out_dims,) * num_returns
+    return tuple(torch.stack(result_shards, out_dim)
+                 for result_shards, out_dim in zip(zip(*results), out_dims))
+
+
+class TensorFactory:
+    @staticmethod
+    def rand(size, device='cpu', dtype=torch.float):
+        return torch.rand(size, device=device, dtype=dtype)
+
+    @staticmethod
+    def randn(size, device='cpu', dtype=torch.float):
+        return torch.randn(size, device=device, dtype=dtype)
+
+    @staticmethod
+    def randp1(size, device='cpu', dtype=torch.float):
+        return torch.rand(size, device=device, dtype=dtype) + 1
+
+# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
+# (slow) sequential map+stack fallback.
+#
+# check_view: Test if the first returned output is a view of the first input
+# check_propagates_grad: Test if the operation propagates gradients.
+def _vmap_test(self, op, inputs, in_dims=0, out_dims=0,
+               check_view=False, check_propagates_grad=True):
+    result = vmap(op, in_dims, out_dims)(*inputs)
+    reference_result = reference_vmap(op, inputs, in_dims, out_dims)
+    self.assertEqual(result, reference_result)
+    op_has_single_return = not isinstance(result, tuple)
+
+    if check_view:
+        result_as_tuple = (result,) if op_has_single_return else result
+        for output in result_as_tuple:
+            input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base
+            self.assertTrue(output._base is input0_base,
+                            msg="result was not a view of the first input!")
+
+    if not check_propagates_grad:
+        return
+    # Assuming input[0] is a floating-point tensor. Check if the vmap
+    # operation propagates the requires_grad flag to the zeroth output.
+    # Some vmap operators are implemented in a way that assumes that
+    # they are composite with respect to autograd. If the operator ever is
+    # changed to not be composite with respect to autograd, then the
+    # following check should fail.
+    inputs_clone = list(inputs)
+    inputs_clone[0] = inputs[0].clone().requires_grad_()
+    result = vmap(op, in_dims, out_dims)(*inputs_clone)
+    result_as_tuple = (result,) if op_has_single_return else result
+    self.assertTrue(result[0].requires_grad)
+
+def should_allow_vmap_fallback_usage(fn):
+    return getattr(fn, '_allow_vmap_fallback_usage', False)
+
+def allowVmapFallbackUsage(fn):
+    fn._allow_vmap_fallback_usage = True
+    return fn
+
+# All tests of TestVmapBase check that the slow vmap fallback is never invoked.
+# This is so that we can incrementally add batching rules for operators to
+# replace the slow vmap fallback path for said operators. To skip this check,
+# please use the allowVmapFallbackUsage decorator.
+#
+# NB: Don't add tests to TestVmapBase directly, unless you want them to run
+# on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators.
+#
+# NB: TestVmapBase is a nested class. This prevents test runners from picking
+# it up and running it.
+class Namespace:
+    class TestVmapBase(TestCase):
+        def __init__(self, method_name='runTest'):
+            super().__init__(method_name)
+
+            test_method = getattr(self, method_name, None)
+            if test_method is None:
+                return
+
+            if not should_allow_vmap_fallback_usage(test_method):
+                setattr(self, method_name,
+                        self._wrap_method_with_vmap_fallback_check(test_method))
+
+        def _wrap_method_with_vmap_fallback_check(self, method):
+            msg = (
+                'Expected the test to not invoke the vmap fallback path, i.e., '
+                'all of the operators being tested in this test should have batching '
+                'rules implemented. If you are intentionally testing something to '
+                'do with the fallback path, use allowVmapFallbackUsage. Otherwise, '
+                'please make sure that batching rules are implemented for the '
+                'operator(s) being tested.'
+            )
+
+            @functools.wraps(method)
+            def wrapper(self, *args, **kwargs):
+                with warnings.catch_warnings(record=True) as wa:
+                    warnings.simplefilter('always')
+                    with EnableVmapFallbackWarnings():
+                        method(*args, **kwargs)
+                    # for captured_warning in wa:
+                    #     self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg)
+            return types.MethodType(wrapper, self)
+
+        @allowVmapFallbackUsage
+        def test_vmap_fallback_check_ok(self):
+            # One day we'll implement a batching rule for torch.var_mean.
+            # When that happens, please change the example to use an
+            # operator that doesn't have a batching rule implemented.
+            op_using_fallback = torch.var_mean
+            vmap(op_using_fallback)(torch.rand(3))
+
+        @unittest.expectedFailure
+        def test_vmap_fallback_check(self):
+            @self._wrap_method_with_vmap_fallback_check
+            def no_fallback(self):
+                pass
+
+            # One day we'll implement a batching rule for torch.var_mean.
+            # When that happens, please change the example to use an
+            # operator that doesn't have a batching rule implemented.
+            op_using_fallback = torch.var_mean
+
+            @self._wrap_method_with_vmap_fallback_check
+            def uses_fallback(self):
+                vmap(op_using_fallback)(torch.rand(3))
+
+            no_fallback(self)
+
+            with self.assertRaises(AssertionError):
+                uses_fallback(self)
+
+
+class TestVmapOperators(Namespace.TestVmapBase):
+    def _vmap_test(self, *args, **kwargs):
+        return _vmap_test(self, *args, **kwargs)
+
+    def _vmap_view_test(self, *args, **kwargs):
+        self._vmap_test(*args, **kwargs, check_view=True)
+
+    def _test_unary(self, op, getter, device, *args, **kwargs):
+        test = functools.partial(self._vmap_test, *args, **kwargs)
+        B0, B1 = 7, 11
+
+        # Single vmap, various in_dims / out_dims
+        test(op, [getter([B0, 3], device)])
+        test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
+        test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)
+
+        # Doubly nested vmap
+        test(vmap(op), [getter([B0, B1], device)])
+        test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
+        test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)],
+             in_dims=2, out_dims=2)
+
+    def test_unary_pointwise_ops(self):
+        cases = [
+            (torch.abs, TensorFactory.randn),
+            (torch.acos, TensorFactory.rand),
+            (torch.asin, TensorFactory.rand),
+            (torch.atan, TensorFactory.rand),
+            (torch.ceil, TensorFactory.randn),
+            (torch.cos, TensorFactory.rand),
+            (torch.cosh, TensorFactory.rand),
+            (torch.digamma, TensorFactory.rand),
+            (torch.exp, TensorFactory.randn),
+            (torch.expm1, TensorFactory.randn),
+            (torch.floor, TensorFactory.randn),
+            (torch.frac, TensorFactory.randn),
+            (torch.lgamma, TensorFactory.rand),
+            (torch.log, TensorFactory.randp1),
+            (torch.log10, TensorFactory.randp1),
+            (torch.log1p, TensorFactory.randp1),
+            (torch.log2, TensorFactory.randp1),
+            (torch.neg, TensorFactory.randn),
+            (torch.reciprocal, TensorFactory.randp1),
+            (torch.relu, TensorFactory.randn),
+            (torch.round, TensorFactory.randn),
+            (torch.rsqrt, TensorFactory.randp1),
+            (torch.sigmoid, TensorFactory.randn),
+            (torch.sign, TensorFactory.randn),
+            (torch.sin, TensorFactory.rand),
+            (torch.sinh, TensorFactory.rand),
+            (torch.sqrt, TensorFactory.rand),
+            (torch.tan, TensorFactory.rand),
+            (torch.tanh, TensorFactory.rand),
+            (torch.trunc, TensorFactory.randn),
+        ]
+        for op, getter in cases:
+            self._test_unary(op, getter, 'cpu')
+
+    def test_clone(self):
+        # Some basic tests
+        self._test_unary(lambda x: x.clone(), TensorFactory.randn, 'cpu')
+        self._test_unary(lambda x: x.clone(memory_format=torch.preserve_format),
+                         TensorFactory.randn, 'cpu')
+        self._test_unary(lambda x: x.clone(memory_format=torch.contiguous_format),
+                         TensorFactory.randn, 'cpu')
+
+        # Test that the per-examples are contiguous when using torch.contiguous_format
+        def clone_contiguous(x):
+            return x.clone(memory_format=torch.contiguous_format)
+
+        B0, B1 = 3, 5
+        x = torch.randn(2, B0, 7)
+        y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x)
+        self.assertTrue(y.movedim(1, 0).is_contiguous())
+        self.assertTrue(y[:, 0, :].is_contiguous())
+
+        x = torch.randn(2, B0, 7, B1)
+        y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x)
+        self.assertTrue(y.is_contiguous())
+        self.assertTrue(y[0][0].is_contiguous())
+
+
+        msg = r'only supported with memory_format torch.preserve_format or torch.contiguous_format'
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0))
+
+    def test_binary_pointwise_ops(self):
+        def get_number(getter):
+            return getter([]).item()
+
+        def make_case(op, input_getter=TensorFactory.randn):
+            return (op, input_getter)
+
+        cases = [
+            # Basic arithmetic
+            make_case(torch.add),
+            make_case(lambda x, y: x + y),
+            make_case(torch.sub),
+            make_case(lambda x, y: x - y),
+            make_case(torch.mul),
+            make_case(lambda x, y: x * y),
+            make_case(torch.div, input_getter=TensorFactory.randp1),
+            make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
+            make_case(torch.pow, input_getter=TensorFactory.randp1),
+            make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1),
+        ]
+        test = self._vmap_test
+
+        for op, getter in cases:
+            device = 'cpu'
+            B0, B1 = 7, 11
+
+            # Single vmap: op(Tensor, Tensor)
+            test(op, (getter([B0, 3], device), getter([B0, 3], device)))
+            test(op, (getter([B0], device), getter([B0, 2, 3], device)))
+            test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
+            test(op, (getter([B0], device), getter([2, B0, 3], device)),
+                 in_dims=(0, 1), out_dims=1)
+            test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
+            test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
+
+            # Nested vmap: op(Tensor, Tensor)
+            test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
+            test(vmap(op, in_dims=(None, 0)),
+                 (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
+
+            # Python number overload: op(Tensor, Number) (and vice-versa)
+            number = get_number(getter)
+            self._test_unary(lambda t: op(t, number), getter, device)
+            number = get_number(getter)
+            self._test_unary(lambda t: op(number, t), getter, device)
+
+            # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
+            test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
+            test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
+            test(op, (getter([B0], device), getter([B0], device)))
+
+            # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
+            test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
+            test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
+
+            if not torch.cuda.is_available():
+                continue
+
+            # TODO(rzou): fix the following
+            # # Test cross-device scalars
+            # number = get_number(getter)
+            # self._test_unary(lambda t: op(t, number), getter, device='cuda')
+            # self._test_unary(lambda t: op(number, t), getter, device='cuda')
+            # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
+
+    # TODO: as_strided BR
+    @unittest.expectedFailure
+    def test_as_strided(self):
+        def _test(sizes, strides, offset, tensor, lambd):
+            result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor)
+            expected = vmap(lambd)(tensor)
+            self.assertTrue(result._base is expected._base)
+            self.assertEqual(result, expected)
+
+        # single vmap test
+        B0 = 5
+        tensors = [
+            # contiguous
+            torch.randn(B0, 2, 3),
+            # non-contiguous
+            torch.randn(B0, 3, 2).transpose(1, 2),
+            # non-zero storage offset
+            torch.randn(2, B0, 2, 3)[1],
+            # non-contiguous strides, zero storage offset
+            torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0],
+            # non-contiguous strides, non-zero storage offset
+            torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1],
+        ]
+
+        for x in tensors:
+            S0, S1 = x.stride()[1:]
+            offset = x.storage_offset()
+
+            # Broadcast
+            _test([5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3))
+            # transpose
+            _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1))
+            # select
+            _test([2], [S0], offset + S1, x, lambda x: x[:, 1])
+
+        # Nested vmap test
+        B1 = 7
+        x = torch.randn(B1, B0, 2, 3)
+        S0, S1 = x.stride()[2:]
+        result = vmap(vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1)(x)
+        expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x)
+        self.assertTrue(result._base is expected._base)
+        self.assertEqual(result, expected)
+
+        # Check that mal-formatted size/strides doesn't crash
+        with self.assertRaisesRegex(RuntimeError, 'size and stride must have the same length'):
+            x = torch.randn(B0, 2, 3).transpose(0, 1)
+            vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x)
+
+        # Sanity check #1: we require the batch dims to be at the front of the
+        # tensor (in memory layout).
+        msg = 'batch dims being vmapped over are at the front of the tensor'
+        with self.assertRaisesRegex(RuntimeError, msg):
+            x = torch.randn(2, B0, 3).transpose(0, 1)
+            vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            x = torch.randn(B0, 2, 3, B1).movedim(3, 1)
+            vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x)
+
+        # All the Sanity check #2{a,b,c} cases check that
+        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
+        # doesn't index memory that is out of bounds of xs[i]. This condition
+        # is important to the correctness of the as_strided batching rule
+        # (see NOTE: [When will the as_strided_batching_rule fail?])
+
+        # Sanity check #2a: The maximum indexable location of
+        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
+        # is less than or equal to the maximum indexable location of xs[i].
+        msg = 'This is not supported inside of vmap'
+        with self.assertRaisesRegex(RuntimeError, msg):
+            x = torch.randn(B0, 3)
+            vmap(lambda x: x.as_strided([3], [1], 1))(x)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            x = torch.randn(B0, 3, 5)
+            vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            x = torch.randn(B0, B1, 3, 5)
+            vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x)
+
+        # Sanity check #2b: The min indexable location of
+        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
+        # is greater than or equal to the min indexable location of xs[i].
+        with self.assertRaisesRegex(RuntimeError, msg):
+            x = torch.randn(2, B0, 3)[1]
+            vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x)
+
+        # Sanity check #2c:
+        # xs[i] is a zero-dim tensor, but
+        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
+        # is not
+        with self.assertRaisesRegex(RuntimeError, msg):
+            x = torch.randn(B0, 0, 3)
+            vmap(lambda x: x.as_strided([3], [1]))(x)
+
+    def test_bmm(self):
+        op = torch.bmm
+        test = self._vmap_test
+        B0, B1 = 7, 11
+
+        # shape mismatch
+        msg = ""
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
+
+        # left arg is vmapped
+        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
+        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
+             in_dims=(1, None))
+
+        # right arg is vmapped
+        test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
+        test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
+             in_dims=(None, 1))
+
+        # both args are vmapped
+        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
+        test(vmap(op), (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), in_dims=(1, 0))
+        test(vmap(op, in_dims=(0, None)),
+             (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), in_dims=(None, 0))
+
+    def test_cat(self):
+        test = self._vmap_test
+        B0, B1 = 5, 7
+
+        # Quick hack b/c vmap can't accept a list of tensors as an argument
+        def get_op(dim):
+            def op(*tensors):
+                return torch.cat(tensors, dim=dim)
+            return op
+
+        test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
+        test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
+        test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
+        test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
+        test(vmap(get_op(0), in_dims=(0, None)),
+             (torch.rand(B1, 2), torch.rand(B0, 3)), in_dims=(None, 0))
+        test(vmap(get_op(0), in_dims=(0, 0)),
+             (torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0))
+
+    def test_conj(self):
+        op = torch.conj
+
+        def run_test(dtype):
+            def get(shape):
+                return torch.randn(shape, dtype=dtype)
+            B0, B1 = 7, 11
+            test = self._vmap_test
+
+            # Single vmap, various in_dims / out_dims
+            test(op, [get([B0, 3])])
+            test(op, [get([2, 5, B0, 3])], in_dims=2)
+            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
+
+            # Doubly nested vmap
+            test(vmap(op), [get([B0, B1])])
+            test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
+            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])],
+                 in_dims=2, out_dims=2)
+
+        # correctness tests
+        run_test(torch.float)
+        run_test(torch.cfloat)
+
+        # check that torch.conj on a non-complex tensor returns the same tensor
+        real_tensor = torch.randn(3)
+        result = vmap(op)(real_tensor)
+        self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
+
+    @unittest.expectedFailure
+    def test_contiguous(self):
+        op = Tensor.contiguous
+
+        self._test_unary(op, TensorFactory.randn, 'cpu')
+
+        # check that contiguous returns the original tensor if the per-examples
+        # are already contiguous
+        B0 = 3
+        x = torch.randn(B0, 2, 5, 7)
+        x = x.movedim(0, 2)
+        result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x)
+        self.assertTrue(result is x)
+
+        msg = 'NYI: querying is_contiguous inside of vmap for memory_format'
+        tensor = torch.randn(B0, 3)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(functools.partial(op, memory_format=torch.channels_last))(tensor)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor)
+
+    def test_stride(self):
+        B0 = 3
+
+        x = torch.randn(B0, 2, 5, 7)
+
+        def foo(x):
+            assert x.stride() == (7 * 5, 7, 1)
+            return x
+
+        vmap(foo)(x)
+
+        x = torch.randn(2, B0, 5, 7).movedim(1, 0)
+
+        def bar(x):
+            assert x.stride() == (7 * 5 * B0, 7, 1)
+            return x
+
+        vmap(bar)(x)
+
+    def test_chunk(self):
+        test = self._vmap_view_test
+        op = torch.chunk
+        B0, B1, B2 = 7, 11, 13
+
+        # tests for torch.split(self, split_size: int, dim)
+        test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
+        test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
+        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0),
+             in_dims=(2, None, None))
+        test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
+             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
+
+    def test_clamp(self):
+        clamp_cases = (
+            (lambda t: t.clamp(min=-0.5), TensorFactory.randn),
+            (lambda t: t.clamp(max=0.5), TensorFactory.randn),
+            (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn),
+            (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn),
+            (lambda t: t.clamp_max(max=0.5), TensorFactory.randn),
+        )
+        for op, getter in clamp_cases:
+            self._test_unary(op, getter, 'cpu')
+
+    def test_comparison_ops(self):
+        test = functools.partial(self._vmap_test, check_propagates_grad=False)
+
+        getter = TensorFactory.randn
+        B0, B1 = 7, 11
+
+        ops = (
+            torch.eq, lambda x, y: x == y,
+            torch.gt, lambda x, y: x > y,
+            torch.ge, lambda x, y: x >= y,
+            torch.le, lambda x, y: x <= y,
+            torch.lt, lambda x, y: x < y,
+            torch.ne, lambda x, y: x != y,
+        )
+
+        for op in ops:
+            # Single vmap: op(Tensor, Tensor)
+            test(op, (getter([B0, 3]), getter([B0, 3])))
+            test(op, (getter([B0]), getter([B0, 2, 3])))
+            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1))
+            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1)
+            test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None))
+            test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None))
+
+            # Nested vmap: op(Tensor, Tensor)
+            test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3])))
+            test(vmap(op, in_dims=(None, 0)),
+                 (getter([B0, 2, 3]), getter([B1, 3])), in_dims=(0, None))
+
+            # test number as inputs
+            number = getter([]).item()
+            self._test_unary(lambda t: op(t, number), getter, 'cpu', check_propagates_grad=False)
+
+    def test_diagonal(self):
+        tensor = torch.randn(3, 5, 7, 11, 13)
+        test = self._vmap_view_test
+        op = torch.diagonal
+        test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
+        test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
+        test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
+        test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
+        test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
+        test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
+             (tensor,), in_dims=1, out_dims=1)
+
+    def test_dot(self):
+        op = torch.dot
+        test = self._vmap_test
+        B0, B1 = 7, 11
+
+        # shape mismatch
+        msg = ""
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))
+
+        # left arg is vmapped
+        test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
+        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 5), torch.rand(5)),
+             in_dims=(1, None))
+
+        # right arg is vmapped
+        test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
+        test(vmap(op, in_dims=(None, 0)), (torch.rand(5), torch.rand(B1, B0, 5)),
+             in_dims=(None, 1))
+
+        # both args are vmapped
+        test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
+        test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
+        test(vmap(op, in_dims=(0, None)),
+             (torch.rand(B1, 5), torch.rand(B0, 5)), in_dims=(None, 0))
+
+    def test_expand_as(self):
+        op = torch.Tensor.expand_as
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+        test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
+        test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
+        test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
+        test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
+        test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1))
+        test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
+        test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
+
+    def test_fill_and_zero_inplace(self):
+        test = functools.partial(self._vmap_test, check_propagates_grad=False)
+        B0, B1 = 7, 11
+        ops = (
+            lambda t: t.fill_(0.1),
+            lambda t: t.fill_(torch.tensor(0.2)),
+            lambda t: t.zero_(),
+        )
+
+        for op in ops:
+            # Single vmap, various in_dims / out_dims
+            test(op, [TensorFactory.randn([B0, 3])])
+            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2)
+            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
+
+            # Doubly nested vmap
+            test(vmap(op), [TensorFactory.randn([B0, B1])])
+            test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2)
+            test(vmap(op, in_dims=2), [TensorFactory.randn([2, 5, B0, B1, 3])],
+                 in_dims=2, out_dims=2)
+
+        # test when value is a batched tensor for fill_ operator
+        B0, B1 = 3, 5
+        test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)])
+
+        with self.assertRaisesRegex(RuntimeError,
+                                    ""):
+            # Runtime Error is thrown when the tensor being written to isn't being vmapped over
+            vmap(Tensor.fill_, (None, 0))(TensorFactory.randn([B0, B1]),
+                                          TensorFactory.randn([B0]))
+
+    def _test_complex_views(self, op, dtypes):
+        test = self._vmap_view_test
+
+        def run_test(op, dtype):
+            def get(shape):
+                return torch.randn(shape, dtype=dtype)
+
+            B0, B1 = 7, 11
+
+            # Single vmap, various in_dims / out_dims
+            test(op, [get([B0, 3])])
+            test(op, [get([3, B0])], in_dims=1)
+            test(op, [get([2, 5, B0, 3])], in_dims=2)
+            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
+
+            # Doubly nested vmap
+            test(vmap(op), [get([B0, B1])])
+            test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4)
+            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])],
+                 in_dims=2, out_dims=2)
+
+        for dtype in dtypes:
+            run_test(op, dtype)
+
+    def test_real(self):
+        self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble])
+
+    def test_imag(self):
+        self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble])
+
+    def test_view_as_real(self):
+        self._test_complex_views(torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble])
+
+    def test_view_as_complex(self):
+        def run_test(dtype):
+            def get(shape):
+                return torch.randn(shape, dtype=dtype)
+
+            op = torch.view_as_complex
+            test = self._vmap_view_test
+            B0, B1 = 7, 11
+
+            # Single vmap, various in_dims / out_dims
+            test(op, [get([B0, 3, 2])])
+            test(op, [get([2, 5, B0, 3, 2])], in_dims=2)
+            test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2)
+
+            # Doubly nested vmap
+            test(vmap(op), [get([B0, B1, 2])])
+            test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2)
+            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])],
+                 in_dims=2, out_dims=2)
+
+            # Interesting case #1: Batch dim directly before dim of size 2
+            test(op, [get([3, B0, 2])], in_dims=1)
+            test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2)
+
+            # Interesting case #2: Batch dim at end of tensor, success cases
+            # view_as_complex requires that the dim with size 2 have stride 1
+            # in order for the view to function propertly
+            test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1)
+            test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)])
+            test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)])
+
+            # Interesting case #3: Batch dim at end of tensor, failure cases
+            msg = "Tensor must have a last dimension with stride 1"
+            with self.assertRaisesRegex(RuntimeError, msg):
+                vmap(op, in_dims=1)(get([2, B0]))
+            with self.assertRaisesRegex(RuntimeError, msg):
+                vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1]))
+
+            # Invalid input: no dimension of size 2
+            msg = 'Input tensor must have one or more dimensions'
+            with self.assertRaisesRegex(RuntimeError, msg):
+                vmap(op)(get([B0]))
+            with self.assertRaisesRegex(RuntimeError, msg):
+                vmap(vmap(op))(get([B0, B1]))
+
+            # Invalid input: Batch dim has size 2, but the logical last dim does
+            # not have size 2
+            msg = 'Tensor must have a last dimension of size 2'
+            with self.assertRaisesRegex(RuntimeError, msg):
+                vmap(op, in_dims=1)(get([3, 2]))
+
+        for dtype in [torch.float, torch.double]:
+            run_test(dtype)
+
+    def test_is_complex(self):
+        ctensor = torch.randn(3, dtype=torch.cfloat)
+        tensor = torch.randn(3)
+
+        def foo(x):
+            if x.is_complex():
+                return torch.tensor(1)
+            else:
+                return torch.tensor(0)
+
+        self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
+        self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
+
+    def test_is_floating_point(self):
+        float_tensor = torch.tensor([1., 2., 3.])
+        long_tensor = torch.tensor([1, 2, 3])
+
+        def foo(x):
+            if x.is_floating_point():
+                return torch.tensor(1)
+            else:
+                return torch.tensor(0)
+
+        self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
+        self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
+
+    def test_is_contiguous(self):
+        def foo(x):
+            if x.is_contiguous():
+                return torch.tensor(1.)
+            else:
+                return torch.tensor(0.)
+
+        B0, B1 = 3, 5
+
+        # Single batch dim
+        contig = torch.randn(B0, 2, 7)
+        self.assertEqual(vmap(foo)(contig), torch.ones(B0))
+
+        noncontig = torch.randn(2, B0, 7)
+        self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0))
+
+        noncontig = torch.randn(2, B0, 7).movedim(1, 0)
+        self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0))
+
+        noncontig = torch.randn(2, 7, B0)
+        self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0))
+
+        # Multiple batch dims
+        contig = torch.randn(B0, B1, 3)
+        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
+
+        contig = torch.randn(B1, B0, 3)
+        self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1))
+
+        contig = torch.randn(B1, B0, 3).movedim(0, 1)
+        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
+
+        noncontig = torch.randn(B0, 3, B1)
+        self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1))
+
+        # is_contiguous on empty tensor is True
+        def bar(x):
+            assert x.is_contiguous()
+            return x
+
+        vmap(bar)(torch.randn(B0, 0, 3))
+        vmap(bar, in_dims=1)(torch.randn(0, B0, 3))
+        vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2))
+
+        # is_contiguous with other memory formats
+        def baz(x, memory_format):
+            x.is_contiguous(memory_format=memory_format)
+            return x
+
+        msg = 'NYI: querying is_contiguous inside of vmap for memory_format'
+        tensor = torch.randn(B0, 2, 7, 3)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
+
+    def test_movedim(self):
+        op = torch.movedim
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+
+        # movedim(tensor, int, int) variant
+        test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
+        test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
+        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 2, B0, 5), 0, 1), in_dims=(2, None, None))
+        test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
+             (torch.rand(B1, 2, B0, 5, B2), 0, 1), in_dims=(2, None, None))
+
+        # movedim(tensor, intlist, intlist) variant
+        test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
+        test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
+        test(vmap(op, in_dims=(0, None, None)),
+             (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), in_dims=(2, None, None))
+        test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
+             (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), in_dims=(2, None, None))
+
+    def test_mm(self):
+        op = torch.mm
+        test = self._vmap_test
+        B0, B1 = 7, 11
+
+        # shape mismatch
+        msg = "Shape mismatch"
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
+
+        # left arg is vmapped
+        test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
+        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
+             in_dims=(1, None))
+
+        # right arg is vmapped
+        test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
+        test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
+             in_dims=(None, 1))
+
+        # both args are vmapped
+        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
+        test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), in_dims=(1, 0))
+        test(vmap(op, in_dims=(0, None)),
+             (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
+
+    def test_mv(self):
+        op = torch.mv
+        test = self._vmap_test
+        B0, B1 = 7, 11
+
+        # shape mismatch
+        msg = ""
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2))
+        with self.assertRaisesRegex(RuntimeError, msg):
+            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2))
+
+        # left arg is vmapped
+        test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None))
+        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5)),
+             in_dims=(1, None))
+
+        # right arg is vmapped
+        test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
+        test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5)),
+             in_dims=(None, 1))
+
+        # both args are vmapped
+        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5)))
+        test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
+        test(vmap(op, in_dims=(0, None)),
+             (torch.rand(B1, 2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
+
+    def test_narrow(self):
+        op = torch.narrow
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+
+        test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
+        test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
+        test(vmap(op, in_dims=(0, None, None, None)),
+             (torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None))
+        test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
+             (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None))
+
+    def test_new_empty(self):
+        # Empty is non-deterministic so we just check that the shape of the
+        # output tensor is what we expect and that the vmap fallback isn't used.
+        op = Tensor.new_empty
+
+        B0, B1 = 7, 11
+
+        result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0))
+        self.assertEqual(result.shape, [B0, 2, 3])
+
+        result = vmap(lambda x: op(x, []))(torch.randn(B0))
+        self.assertEqual(result.shape, [B0])
+
+        result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
+        self.assertEqual(result.shape, [B0, B1, 2, 3])
+
+    # TODO: new_empty_strided BR
+    @unittest.expectedFailure
+    def test_new_empty_strided(self):
+        # Empty is non-deterministic so we just check that the size and shape
+        # of the output are what we expect and that the vmap fallback isn't used
+        B0, B1 = 7, 11
+
+        def _test_single_vmap(size, stride, B0):
+            x = torch.randn(B0)
+            result = vmap(lambda x: x.new_empty_strided(size, stride))(x)
+            S = torch.empty_strided(size, stride).storage().size()
+            self.assertEqual(result.shape, [B0] + size)
+            self.assertEqual(result.stride(), [S] + stride)
+
+        def _test_double_vmap(size, stride, B0, B1):
+            x = torch.randn(B0, B1)
+            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x)
+            S = torch.empty_strided(size, stride).storage().size()
+            self.assertEqual(result.shape, [B0, B1] + size)
+            self.assertEqual(result.stride(), [B1 * S, S] + stride)
+
+            x = torch.randn(B1, B0)
+            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(x)
+            S = x.new_empty_strided(size, stride).storage().size()
+            self.assertEqual(result.shape, [B0, B1] + size)
+            self.assertEqual(result.stride(), [B1 * S, S] + stride)
+
+        # contiguous case
+        _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0)
+        _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1)
+
+        # expanded
+        _test_single_vmap([2, 3, 5], [0, 5, 1], B0)
+        _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1)
+
+        # some of these cases are pretty strange, just verifying that if
+        # empty_strided allows them then BatchedTensor.new_empty_strided
+        # can as well
+        for shape in [[2, 3, 4], [0, 2, 0]]:
+            for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]:
+                _test_single_vmap(shape, strides, B0)
+                _test_double_vmap(shape, strides, B0, B1)
+
+    def test_new_zeros(self):
+        op = Tensor.new_zeros
+        test = functools.partial(self._vmap_test, check_propagates_grad=False)
+        B0, B1 = 7, 11
+
+        test(lambda x: op(x, 2, 3), (torch.rand(B0),))
+        test(lambda x: op(x, []), (torch.rand(B0),))
+        test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),))
+
+    def test_select(self):
+        op = torch.select
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+        test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
+        test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
+        test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
+        test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)
+
+    def test_stack(self):
+        test = self._vmap_test
+        B0, B1 = 5, 7
+
+        # Quick hack b/c vmap can't accept a list of tensors as an argument
+        def get_op(dim):
+            def op(*tensors):
+                return torch.stack(tensors, dim=dim)
+            return op
+
+        test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
+        test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
+        test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
+        test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
+        test(vmap(get_op(0), in_dims=(0, None)),
+             (torch.rand(B1, 2), torch.rand(B0, 2)), in_dims=(None, 0))
+        test(vmap(get_op(0), in_dims=(0, 0)),
+             (torch.rand(B1, 2), torch.rand(B0, B1, 2)), in_dims=(None, 0))
+
+
+    def test_slice(self):
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+        test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
+        test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
+        test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2)
+        test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
+             (torch.rand(3, 5, B0, B1, B2),), in_dims=2)
+
+    def test_squeeze(self):
+        test = self._vmap_view_test
+        op = torch.squeeze
+        B0, B1 = 1, 11
+        test(op, (torch.rand(B0),))
+        test(op, (torch.rand(B0, 3, 5),))
+        test(op, (torch.rand(1, B0, 5),), in_dims=1)
+        test(op, (torch.rand(B0, 0, 1, 5, 1),))
+        test(op, (torch.rand(B0, 1, 1, 1, 1),))
+        test(vmap(op), (torch.rand(B0, B1, 1),))
+        test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
+
+    def test_sum_dim(self):
+        test = self._vmap_test
+        B0, B1 = 5, 7
+
+        # Single vmap, various in_dims / out_dims
+        test(lambda x: x.sum(0), [torch.randn([B0])])
+        test(lambda x: x.sum(-1), [torch.randn([B0])])
+        test(lambda x: x.sum(0), [torch.randn([B0, 3])])
+        test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2)
+        test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
+
+        # Doubly nested vmap
+        test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])])
+        test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])])
+        test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
+        test(vmap(lambda x: x.sum(2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])],
+             in_dims=2, out_dims=2)
+
+    def test_reshape(self):
+        test = self._vmap_test
+        B0, B1, B2 = 7, 11, 13
+        op = torch.reshape
+        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
+        test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False)
+        test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True)
+        test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
+             (torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False)
+
+    def test_reshape_as(self):
+        test = self._vmap_test
+        B0, B1, B2 = 7, 11, 13
+        op = torch.Tensor.reshape_as
+        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
+        test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True)
+        test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True)
+
+        test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False)
+
+        test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True)
+        test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
+             (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
+             in_dims=(5, 0), check_view=False)
+
+    def test_result_type(self):
+        def scalar_tensor_with_dtype(op):
+            def wrapped(*args, **kwargs):
+                dtype = op(*args, **kwargs)
+                return torch.ones([], dtype=dtype)
+            return wrapped
+
+        test = self._vmap_test
+        op = scalar_tensor_with_dtype(torch.result_type)
+
+        B0 = 2
+
+        test(op, (torch.randn(B0), torch.randn(B0, dtype=torch.float64)),
+             check_propagates_grad=False)
+        test(op, (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)),
+             check_propagates_grad=False)
+
+        test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False)
+        test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False)
+
+        test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0),),
+             check_propagates_grad=False)
+        test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
+             (torch.randn(B0),), check_propagates_grad=False)
+
+        test(op, (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)),
+             check_propagates_grad=False)
+        test(op, (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)),
+             check_propagates_grad=False)
+
+        test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False)
+        test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False)
+
+        test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0, 2),),
+             check_propagates_grad=False)
+        test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
+             (torch.randn(B0, 2),), check_propagates_grad=False)
+
+        test(op, (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)),
+             check_propagates_grad=False)
+        test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
+             check_propagates_grad=False)
+
+    def test_tensor_split(self):
+        test = self._vmap_view_test
+        op = torch.tensor_split
+        B0, B1, B2 = 7, 11, 13
+
+        # tests for torch.tensor_split(self, indices_or_sections: int, dim)
+        test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
+        test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
+        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
+             in_dims=(2, None, None))
+        test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
+             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
+
+        # tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
+        test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None))
+        test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None))
+        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
+             in_dims=(2, None, None))
+        test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
+             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
+
+    def test_split(self):
+        test = self._vmap_view_test
+        op = torch.split
+        B0, B1, B2 = 7, 11, 13
+
+        # tests for torch.split(self, split_size: int, dim)
+        test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
+        test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
+        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
+             in_dims=(2, None, None))
+        test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
+             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
+
+        # tests for torch.split(self, split_size: List[int], dim)
+        test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
+        test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None))
+        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
+             in_dims=(2, None, None))
+        test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
+             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
+
+    def test_trace(self):
+        op = torch.trace
+        test = self._vmap_test
+        B0, B1, B2 = 7, 11, 13
+
+        test(op, (torch.rand(B0, 2, 5),))
+        test(op, (torch.rand(2, B0, 5),), in_dims=1)
+        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
+        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
+
+    def test_transpose(self):
+        op = torch.transpose
+        test = self._vmap_view_test
+
+        B0, B1, B2 = 7, 11, 13
+        test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),))
+        test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),))
+        test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),))
+        test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1)
+        test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
+        test(vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)),
+             (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
+
+        # Special case: scalar tensor
+        for dim1, dim2 in itertools.product([0, -1], [0, -1]):
+            x = torch.rand(B0)
+            result = vmap(lambda x: op(x, dim1, dim2))(x)
+            self.assertTrue(result is x)
+
+    def test_t(self):
+        op = torch.t
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+        test(op, (torch.rand(B0, 2, 5),))
+        test(op, (torch.rand(2, B0, 5),), in_dims=1)
+        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
+        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
+
+    def test_T_numpy(self):
+        def op(t):
+            return t.T
+
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+        test(op, (torch.rand(B0, 2, 3, 5),))
+        test(op, (torch.rand(B0),))
+        test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
+        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
+        test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
+        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
+
+    def test_to(self):
+        test = self._vmap_test
+        B0, B1 = 7, 11
+
+        test(lambda t: t.to('cpu'), (torch.rand(B0),))
+        test(lambda t: t.to(torch.double), (torch.rand(B0),))
+        test(lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64)))
+        test(lambda t, o: t.to(o),
+             (torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
+             in_dims=(0, None))
+        test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))
+
+        # also test some casting methods
+        test(lambda t: t.double(), (torch.rand(B0),))
+        test(lambda t: t.float(), (torch.rand(B0),))
+        test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
+        test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)
+
+    def test_unfold(self):
+        op = torch.Tensor.unfold
+        test = self._vmap_view_test
+        B0, B1, B2 = 3, 2, 5
+
+        test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
+        test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
+        test(vmap(op, in_dims=(0, None, None, None)),
+             (torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None))
+        test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
+             (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None))
+
+    def test_unbind(self):
+        test = self._vmap_view_test
+        op = torch.unbind
+        B0, B1, B2 = 7, 11, 13
+
+        test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
+        test(op, (torch.rand(B0, 2, 0),))
+        test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
+        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1),
+             in_dims=(2, None))
+        test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
+             (torch.rand(B1, 2, B0, 32, B2),), in_dims=2)
+
+    def test_view(self):
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+        op = torch.Tensor.view
+
+        # We should error out if the view would produce an incorrect result
+        with self.assertRaises(RuntimeError):
+            vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])
+
+        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
+        test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
+        test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
+        test(vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
+             (torch.rand(B2, B0, B1, 3, 2, 5),), in_dims=1)
+
+    def test_view_as(self):
+        test = self._vmap_view_test
+        B0, B1, B2 = 7, 11, 13
+        op = torch.Tensor.view_as
+
+        # We should error out if the view would produce an incorrect result
+        with self.assertRaises(RuntimeError):
+            vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10))
+
+        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)))
+        test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0))
+        test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None))
+
+        test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None))
+
+        test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)))
+        test(vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)),
+             (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)),
+             in_dims=(2, 0))
+
+    # TODO: reenable the random op failures
+    @unittest.expectedFailure
+    def test_no_random_op_support(self):
+        B0 = 2
+
+        captured = torch.rand(3)
+
+        random_ops = [
+            # out-of-place on BatchedTensor
+            (torch.bernoulli, (torch.rand(B0, 1),)),
+            (lambda t: torch.bernoulli(t, p=0.5), (torch.rand(B0, 1),)),
+            (lambda t: torch.multinomial(t, 2), (torch.rand(B0, 3),)),
+            (torch.normal, (torch.randn(B0, 1), torch.randn(B0, 1))),
+            (lambda t: torch.normal(t, 1.), (torch.randn(B0, 1),)),
+            (lambda t: torch.normal(0., t), (torch.randn(B0, 1),)),
+            (torch.poisson, (torch.rand(B0, 1),)),
+            (torch.rand_like, (torch.rand(B0, 1),)),
+            (torch.randn_like, (torch.rand(B0, 1),)),
+            (lambda t: torch.randint_like(t, 2), (torch.rand(B0, 1),)),
+            (lambda t: torch.randint_like(t, 0, 2), (torch.rand(B0, 1),)),
+
+            # out-of-place on captured tensor
+            (lambda t: torch.bernoulli(captured), (torch.rand(B0),)),
+            (lambda t: torch.bernoulli(captured, p=0.5), (torch.rand(B0),)),
+            (lambda t: torch.multinomial(captured, 2), (torch.rand(B0),)),
+            (lambda t: torch.normal(captured, captured), (torch.randn(B0),)),
+            (lambda t: torch.normal(captured, 1.), (torch.randn(B0),)),
+            (lambda t: torch.normal(0., captured), (torch.randn(B0),)),
+            (lambda t: torch.poisson(captured), (torch.rand(B0),)),
+            (lambda t: torch.rand_like(captured), (torch.rand(B0),)),
+            (lambda t: torch.randn_like(captured) , (torch.rand(B0),)),
+            (lambda t: torch.randint_like(captured, 2), (torch.rand(B0),)),
+            (lambda t: torch.randint_like(captured, 0, 2), (torch.rand(B0),)),
+
+            # in-place on BatchedTensor
+            (lambda t: t.bernoulli_(), (torch.randn(B0, 1),)),
+            (lambda t: t.cauchy_(), (torch.randn(B0, 1),)),
+            (lambda t: t.exponential_(), (torch.randn(B0, 1),)),
+            (lambda t: t.geometric_(0.5), (torch.randn(B0, 1),)),
+            (lambda t: t.log_normal_(), (torch.randn(B0, 1),)),
+            (lambda t: t.normal_(), (torch.randn(B0, 1),)),
+            (lambda t: t.random_(), (torch.randn(B0, 1),)),
+            (lambda t: t.random_(0, 2), (torch.randn(B0, 1),)),
+            (lambda t: t.random_(2), (torch.randn(B0, 1),)),
+            (lambda t: t.uniform_(), (torch.randn(B0, 1),)),
+
+            # in-place on captured tensor
+            (lambda t: captured.bernoulli_(), (torch.randn(B0),)),
+            (lambda t: captured.cauchy_(), (torch.randn(B0),)),
+            (lambda t: captured.exponential_(), (torch.randn(B0),)),
+            (lambda t: captured.geometric_(0.5), (torch.randn(B0),)),
+            (lambda t: captured.log_normal_(), (torch.randn(B0),)),
+            (lambda t: captured.normal_(), (torch.randn(B0),)),
+            (lambda t: captured.random_(), (torch.randn(B0),)),
+            (lambda t: captured.random_(0, 2), (torch.randn(B0),)),
+            (lambda t: captured.random_(2), (torch.randn(B0),)),
+            (lambda t: captured.uniform_(), (torch.randn(B0),)),
+
+            # factory functions
+            (lambda t: torch.rand(1), (torch.randn(B0),)),
+            (lambda t: torch.randn(1), (torch.randn(B0),)),
+            (lambda t: torch.randint(5, [1]), (torch.randn(B0),)),
+            (lambda t: torch.randperm(5), (torch.randn(B0),)),
+        ]
+        for op, args in random_ops:
+            with self.assertRaisesRegex(RuntimeError,
+                                        'vmap: We do not yet support calling random operations'):
+                vmap(op)(*args)
+
+def construct_v(output, batch_size):
+    return torch.randn(batch_size, *output.shape,
+                       dtype=output.dtype, device=output.device)
+
+def as_tuple(x):
+    if isinstance(x, tuple):
+        return x
+    elif isinstance(x, list):
+        return tuple(x)
+    else:
+        return x,
+
+def differentiable(args):
+    return tuple(arg for arg in as_tuple(args)
+                 if isinstance(arg, torch.Tensor) and arg.requires_grad)
+
+def _get_rand_no_zeros(*args, **kwargs):
+    requires_grad = kwargs.get('requires_grad', False)
+    kwargs_without_requires_grad = kwargs.copy()
+    kwargs_without_requires_grad['requires_grad'] = False
+    result = torch.rand(*args, **kwargs_without_requires_grad)
+    return result.clamp_min_(0.1).requires_grad_(requires_grad)
+
+class TestVmapBatchedGradient(Namespace.TestVmapBase):
+    def _vmap_test(self, *args, **kwargs):
+        return _vmap_test(self, *args, **kwargs)
+
+    # Tests batched gradient computation of outputs = op(*args, **kwargs)
+    # by comparing it to a sequential map+stack fallback.
+    #
+    # output_process_fn: a function that maps the outputs to the part
+    #       that should be differentiated.
+    # batch_size: the batch dim size for the batched grad
+    def _batched_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3):
+        if kwargs is None:
+            kwargs = {}
+        outputs = op(*args, **kwargs)
+        outputs = differentiable(output_process_fn(outputs))
+        batched_vectors = tuple(construct_v(out, batch_size) for out in outputs)
+
+        def vector_jacobian_product(*vectors):
+            return torch.autograd.grad(outputs, differentiable(args), vectors,
+                                       retain_graph=True)
+        self._vmap_test(vector_jacobian_product, batched_vectors,
+                        check_propagates_grad=False)
+
+    # Tests batched second grad computation of outputs = op(*args, **kwargs).
+    # by comparing it to a sequential map+stack fallback.
+    #
+    # output_process_fn: a function that maps the outputs to the part
+    #       that should be differentiated.
+    # batch_size: the batch dim size for the batched grad
+    #
+    # NB: we only test computing batched gradients in the second gradient
+    # computation. One specific use case that does this is computing the hessian
+    # matrix of a scalar-valued function; this is useful in Bayesian Logistic
+    # Regression.
+    # It might be useful to have a test that computes batched first gradients and
+    # then uses those to compute batched second gradients in the future.
+    def _batched_grad_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3):
+        if kwargs is None:
+            kwargs = {}
+        outputs = op(*args, **kwargs)
+        outputs = differentiable(output_process_fn(outputs))
+        ones = tuple(torch.ones_like(out) for out in outputs)
+        # Same thing as summing together all of the outputs and calling .backward()
+        first_grads = torch.autograd.grad(outputs, differentiable(args), ones,
+                                          create_graph=True)
+        first_grads = differentiable(first_grads)
+        self.assertNotEqual(
+            len(first_grads), 0, "None of the first grads depend on the input!")
+
+        batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads)
+
+        def vector_hessian_product(*vectors):
+            outputs = torch.autograd.grad(first_grads, differentiable(args), vectors,
+                                          retain_graph=True, allow_unused=True)
+            outputs = tuple(out for out in outputs if out is not None)
+            assert len(outputs) > 0
+            return outputs
+
+        self._vmap_test(vector_hessian_product, batched_vectors,
+                        check_propagates_grad=False)
+
+    def _test_arithmetic(self, op, device, test_grad_grad=True):
+        x = torch.randn(2, 3, requires_grad=True, device=device)
+        y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
+        scalar = 3.14
+        self._batched_grad_test(op, (x, y))
+        self._batched_grad_test(op, (scalar, y))
+        self._batched_grad_test(op, (x, scalar))
+
+        if test_grad_grad:
+            self._batched_grad_grad_test(op, (x, y))
+
+    def test_add(self, device):
+        self._test_arithmetic(torch.add, device, test_grad_grad=False)
+        self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False)
+
+    def test_sub(self, device):
+        self._test_arithmetic(torch.sub, device, test_grad_grad=False)
+        self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False)
+
+    def test_mul(self, device):
+        self._test_arithmetic(torch.mul, device)
+        self._test_arithmetic(lambda x, y: x * y, device)
+
+    def test_div(self, device):
+        self._test_arithmetic(torch.div, device)
+        self._test_arithmetic(lambda x, y: x / y, device)
+
+    @allowVmapFallbackUsage
+    def test_binary_cross_entropy(self, device):
+        x = F.sigmoid(torch.randn(3, 2, device=device, requires_grad=True))
+        target = torch.rand(3, 2, device=device)
+
+        op = functools.partial(F.binary_cross_entropy, target=target)
+
+        self._batched_grad_test(op, (x,), {})
+        self._batched_grad_grad_test(op, (x,), {})
+
+    def test_expand(self, device):
+        x = torch.randn(2, 3, device=device, requires_grad=True)
+
+        def op(x):
+            return x.expand(5, 5, 2, 3)
+        self._batched_grad_test(op, (x,))
+
+    @allowVmapFallbackUsage
+    def test_index(self, device):
+        x = torch.randn(2, 3, requires_grad=True, device=device)
+        index = torch.tensor([[0, 0], [1, 1]], device=device)
+
+        def op(x):
+            y = x * x
+            return y[index]
+
+        self._batched_grad_test(op, (x,))
+        self._batched_grad_grad_test(op, (x,))
+
+    def test_lgamma(self, device):
+        x = torch.randn(2, 3, requires_grad=True, device=device)
+        self._batched_grad_test(Tensor.lgamma, (x,))
+        self._batched_grad_grad_test(Tensor.lgamma, (x,))
+
+    def test_log(self, device):
+        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
+        self._batched_grad_test(torch.log, (x,))
+        self._batched_grad_grad_test(torch.log, (x,))
+
+    def test_logsumexp(self, device):
+        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
+
+        def op(x):
+            return torch.logsumexp(x, -1)
+
+        self._batched_grad_test(op, (x,))
+        self._batched_grad_grad_test(op, (x,))
+
+    def test_log1p(self, device):
+        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
+        self._batched_grad_test(torch.log1p, (x,))
+        self._batched_grad_grad_test(torch.log1p, (x,))
+
+    @allowVmapFallbackUsage
+    def test_max(self, device):
+        x = torch.randn(2, 3, requires_grad=True, device=device)
+        self._batched_grad_test(torch.max, (x,))
+
+    @allowVmapFallbackUsage
+    def test_median(self, device):
+        x = torch.randn(2, 3, requires_grad=True, device=device)
+        self._batched_grad_test(torch.median, (x,))
+
+    @allowVmapFallbackUsage
+    def test_min(self, device):
+        x = torch.randn(2, 3, requires_grad=True, device=device)
+        self._batched_grad_test(torch.min, (x,))
+
+    def test_permute(self, device):
+        x = torch.randn(2, 3, 5, requires_grad=True, device=device)
+
+        def op(x):
+            return x.permute(2, 0, 1)
+
+        self._batched_grad_test(op, (x,))
+
+    def test_reshape(self, device):
+        x = torch.randn(2, 3, 5, requires_grad=True, device=device)
+
+        def op(x):
+            return x.reshape([2 * 3, 5])
+
+        self._batched_grad_test(op, (x,))
+
+    def test_sigmoid(self, device):
+        x = torch.randn(2, 3, requires_grad=True, device=device)
+        self._batched_grad_test(Tensor.sigmoid, (x,))
+        self._batched_grad_grad_test(Tensor.sigmoid, (x,))
+
+    def test_stack(self, device):
+        x = torch.randn(2, 3, device=device, requires_grad=True)
+        y = torch.randn(2, 3, device=device, requires_grad=True)
+
+        def op(x, y):
+            return torch.stack([x, y])
+        self._batched_grad_test(op, (x, y))
+
+    def test_select(self, device):
+        x = torch.randn(2, 3, device=device, requires_grad=True)
+        self._batched_grad_test(lambda x: x[1], (x,))
+        self._batched_grad_test(lambda x: x.select(1, 2), (x,))
+        self._batched_grad_test(lambda x: x.select(-1, 0), (x,))
+
+    def test_slice(self, device):
+        x = torch.randn(2, 3, 5, device=device, requires_grad=True)
+        self._batched_grad_test(lambda x: x[0:1], (x,))
+        self._batched_grad_test(lambda x: x[:, 1:3], (x,))
+        self._batched_grad_test(lambda x: x[..., 1:3], (x,))
+
+    def test_trace(self, device):
+        x = torch.randn(2, 3, device=device, requires_grad=True)
+        self._batched_grad_test(Tensor.trace, (x,))
+
+    @skipCUDAIfNoMagma
+    @allowVmapFallbackUsage
+    def test_symeig(self, device):
+        def op(x):
+            return torch.symeig(x, eigenvectors=True)[0]
+
+        x = torch.randn(3, 3, device=device, requires_grad=True)
+        self._batched_grad_test(op, (x,), {})
+        self._batched_grad_grad_test(op, (x,), {})
+
+    def test_threshold(self, device):
+        x = torch.randn(2, 3, device=device, requires_grad=True)
+        self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
+
+
+    @allowVmapFallbackUsage
+    def test_inplace_view(self, device):
+        leaf = torch.randn(4, 5, requires_grad=True)
+
+        def func(leaf):
+            # Make sure the function is non-trivially twice differentiable
+            base = leaf * leaf
+            view = base[0]
+            view.cos_()
+            return view
+
+        self._batched_grad_test(func, (leaf,), {})
+        self._batched_grad_grad_test(func, (leaf,), {})
+
+    @allowVmapFallbackUsage
+    def test_inplace_manyview(self, device):
+        leaf = torch.randn(4, 4, 5, requires_grad=True)
+
+        def func(leaf):
+            # Make sure the function is non-trivially twice differentiable
+            base = leaf * leaf
+            view = base.transpose(0, 2)
+            view = view[1]
+            view = view.diagonal()
+            view = view[::2]
+            view.cos_()
+            return view
+
+        self._batched_grad_test(func, (leaf,), {})
+        self._batched_grad_grad_test(func, (leaf,), {})
+
+    def test_diagonal(self, device):
+        x = torch.randn(4, 5, device=device, requires_grad=True)
+        self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,))
+
+        x = torch.randn(3, 4, 5, device=device, requires_grad=True)
+        self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,))
+
+    @allowVmapFallbackUsage
+    def test_unrelated_output(self, device):
+        B0 = 3
+        x = torch.randn([], requires_grad=True)
+        y = torch.randn([], requires_grad=True)
+        gy = torch.randn(B0, requires_grad=True)
+
+        def vjp(v):
+            res, = torch.autograd.grad(y, x, v, allow_unused=True)
+            return torch.zeros_like(x) if res is None else res
+
+        result = vmap(vjp)(gy)
+        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
+
+    @allowVmapFallbackUsage
+    def test_unrelated_output_multiple_grad(self, device):
+        B0 = 3
+        x = torch.randn([], requires_grad=True)
+        y = torch.randn([], requires_grad=True)
+        gy = torch.randn(B0, requires_grad=True)
+
+        def vjp(v):
+            res, = torch.autograd.grad(y, x, v, allow_unused=True)
+            return torch.zeros_like(x) if res is None else res
+
+        _ = vjp(gy[0])
+        result = vmap(vjp)(gy)
+        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
+
+instantiate_device_type_tests(
+    TestVmapBatchedGradient,
+    globals(),
+    None,
+)
+
+if __name__ == '__main__':
+    run_tests()