blob: 3f91bf80d6b29bf6dda0097fc7524ad056e329fa [file] [log] [blame]
import torch
from .Module import Module
class SoftMax(Module):
def __init__(self, dim=None):
super(SoftMax, self).__init__()
if dim is not None:
self.dim = dim
def _get_dim(self, input):
return getattr(self, 'dim', 0 if input.dim() == 1 or input.dim() == 3 else 1)
def updateOutput(self, input):
self._backend.SoftMax_updateOutput(
self._backend.library_state,
input,
self.output,
self._get_dim(input)
)
return self.output
def updateGradInput(self, input, gradOutput):
self._backend.SoftMax_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.output,
self._get_dim(input)
)
return self.gradInput