blob: 014658a2657a81063d32601de7a46228524f187a [file] [log] [blame]
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch._thnn import type2backend
from .thnn.auto import function_by_name
import torch.backends.cudnn as cudnn
def affine_grid_generator(theta, size):
if theta.data.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta.data) and len(size) == 4:
N, C, H, W = size
return torch.cudnn_affine_grid_generator(theta, N, C, H, W)
else:
return AffineGridGenerator.apply(theta, size)
# TODO: Port these completely into C++
class AffineGridGenerator(Function):
@staticmethod
def forward(ctx, theta, size):
assert type(size) == torch.Size
ctx.size = size
ctx.is_cuda = theta.is_cuda
if len(size) == 5:
N, C, D, H, W = size
base_grid = theta.new(N, D, H, W, 4)
base_grid[:, :, :, :, 0] = (torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1]))
base_grid[:, :, :, :, 1] = (torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1]))\
.unsqueeze(-1)
base_grid[:, :, :, :, 2] = (torch.linspace(-1, 1, D) if D > 1 else torch.Tensor([-1]))\
.unsqueeze(-1).unsqueeze(-1)
base_grid[:, :, :, :, 3] = 1
grid = torch.bmm(base_grid.view(N, D * H * W, 4), theta.transpose(1, 2))
grid = grid.view(N, D, H, W, 3)
elif len(size) == 4:
N, C, H, W = size
base_grid = theta.new(N, H, W, 3)
base_grid[:, :, :, 0] = (torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1]))
base_grid[:, :, :, 1] = (torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1]))\
.unsqueeze(-1)
base_grid[:, :, :, 2] = 1
grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
grid = grid.view(N, H, W, 2)
else:
raise RuntimeError("AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.")
ctx.base_grid = base_grid
return grid
@staticmethod
@once_differentiable
def backward(ctx, grad_grid):
assert ctx.is_cuda == grad_grid.is_cuda
base_grid = ctx.base_grid
if len(ctx.size) == 5:
N, C, D, H, W = ctx.size
assert grad_grid.size() == torch.Size([N, D, H, W, 3])
grad_theta = torch.bmm(
base_grid.view(N, D * H * W, 4).transpose(1, 2),
grad_grid.view(N, D * H * W, 3))
elif len(ctx.size) == 4:
N, C, H, W = ctx.size
assert grad_grid.size() == torch.Size([N, H, W, 2])
grad_theta = torch.bmm(
base_grid.view(N, H * W, 3).transpose(1, 2),
grad_grid.view(N, H * W, 2))
else:
assert False
grad_theta = grad_theta.transpose(1, 2)
return grad_theta, None