blob: 4478c251e0202b8331c1e5fb3d542743373e2155 [file] [log] [blame]
import torch
from .Module import Module
class Transpose(Module):
# transpose dimensions:
# n = nn.Transpose({1, 4}, {1, 3})
# will transpose dims 1 and 4,: 1 and 3...
def __init__(self, *args):
super(Transpose, self).__init__()
self.permutations = args
def updateOutput(self, input):
for perm in self.permutations:
input = input.transpose(*perm)
self.output.resize_as_(input).copy_(input)
return self.output
def updateGradInput(self, input, gradOutput):
for perm in self.permutations[::-1]:
gradOutput = gradOutput.transpose(*perm)
self.gradInput.resize_as_(gradOutput).copy_(gradOutput)
return self.gradInput