blob: b7253af8dd109aab555fc056c64cd1af4f4a4f23 [file] [log] [blame]
import torch
from .Module import Module
class Select(Module):
def __init__(self, dimension, index):
super(Select, self).__init__()
self.dimension = dimension
self.index = index
def updateOutput(self, input):
index = self.index if self.index >= 0 else input.size(self.dimension) + self.index
output = input.select(self.dimension, index)
self.output.resizeAs_(output)
return self.output.copy_(output)
def updateGradInput(self, input, gradOutput):
index = self.index if self.index >= 0 else input.size(self.dimension) + self.index
self.gradInput.resizeAs_(input)
self.gradInput.zero_()
self.gradInput.select(self.dimension, index).copy_(gradOutput)
return self.gradInput