blob: 229246b694a0a5cfa557e08b77600d7a9f6e65b4 [file] [log] [blame]
import torch
from ..function import Function
class Multinomial(Function):
@staticmethod
def forward(ctx, probs, num_samples, with_replacement):
samples = probs.multinomial(num_samples, with_replacement)
ctx.mark_non_differentiable(samples)
return samples
@staticmethod
def backward(ctx, grad_output):
return None, None, None
class Bernoulli(Function):
@staticmethod
def forward(ctx, probs):
samples = probs.new().resize_as_(probs).bernoulli_(probs)
ctx.mark_non_differentiable(samples)
return samples
@staticmethod
def backward(ctx, grad_output):
return None
class Normal(Function):
@staticmethod
def forward(ctx, means, stddevs=None):
samples = torch.normal(means, stddevs)
ctx.mark_non_differentiable(samples)
return samples
@staticmethod
def backward(ctx, grad_output):
return None, None