Add support for stochastic functions in autograd (#294)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index cefe8d7..b578874 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -457,6 +457,54 @@
y.sum().backward()
self.assertEqual(x.grad, x.data.clone().fill_(1))
+ def test_stochastic(self):
+ x = Variable(torch.rand(10), requires_grad=True)
+ stddevs = Variable(torch.rand(10) * 5, requires_grad=True)
+ y = (x * 2).clamp(0, 1)
+ y = y / y.sum().expand_as(y)
+ samples_multi = y.multinomial(5)
+ samples_bernoulli = y.bernoulli()
+ samples_norm = torch.normal(y)
+ samples_norm_std = torch.normal(y, stddevs)
+ z = samples_multi * 2 + 4
+ z = torch.cat([z, z])
+ z = z.double()
+ z = z + samples_bernoulli + samples_norm + samples_norm_std
+ last_sample = torch.normal(z, 4)
+ z = last_sample + 2
+ self.assertFalse(z.requires_grad)
+
+ self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
+ samples_multi.reinforce(torch.randn(5))
+ self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
+ samples_bernoulli.reinforce(torch.randn(10))
+ self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
+ samples_norm.reinforce(torch.randn(10))
+ self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
+ samples_norm_std.reinforce(torch.randn(10))
+ self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
+ last_sample.reinforce(torch.randn(10))
+
+ last_sample.backward(retain_variables=True)
+ z.backward()
+
+ self.assertGreater(x.grad.abs().sum(), 0)
+
+ def test_stochastic_sequence(self):
+ x = Variable(torch.rand(10), requires_grad=True)
+ b = x.bernoulli()
+ n1 = torch.normal(b, x)
+ n2 = torch.normal(n1, 2)
+
+ b.reinforce(torch.randn(10))
+ n1.reinforce(torch.randn(10))
+ n2.reinforce(torch.randn(10))
+
+ n2.backward()
+
+ self.assertGreater(x.grad.abs().sum(), 0)
+
+
def index_variable(num_indices, max_indices):
index = torch.randperm(max_indices)[:num_indices].long()
diff --git a/test/test_torch.py b/test/test_torch.py
index 63649f8..2532ab3 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2403,6 +2403,47 @@
t.bernoulli_(p)
self.assertTrue(isBinary(t))
+ q = torch.rand(5, 5)
+ self.assertTrue(isBinary(q.bernoulli()))
+
+ def test_normal(self):
+ q = torch.Tensor(50, 50)
+ q.normal_()
+ self.assertEqual(q.mean(), 0, 0.1)
+ self.assertEqual(q.std(), 1, 0.1)
+
+ q.normal_(2, 3)
+ self.assertEqual(q.mean(), 2, 0.1)
+ self.assertEqual(q.std(), 3, 0.1)
+
+ mean = torch.Tensor(100, 100)
+ std = torch.Tensor(100, 100)
+ mean[:50] = 0
+ mean[50:] = 1
+ std[:,:50] = 4
+ std[:,50:] = 1
+
+ r = torch.normal(mean)
+ self.assertEqual(r[:50].mean(), 0, 0.2)
+ self.assertEqual(r[50:].mean(), 1, 0.2)
+ self.assertEqual(r.std(), 1, 0.2)
+
+ r = torch.normal(mean, 3)
+ self.assertEqual(r[:50].mean(), 0, 0.2)
+ self.assertEqual(r[50:].mean(), 1, 0.2)
+ self.assertEqual(r.std(), 3, 0.2)
+
+ r = torch.normal(2, std)
+ self.assertEqual(r.mean(), 2, 0.2)
+ self.assertEqual(r[:,:50].std(), 4, 0.2)
+ self.assertEqual(r[:,50:].std(), 1, 0.2)
+
+ r = torch.normal(mean, std)
+ self.assertEqual(r[:50].mean(), 0, 0.2)
+ self.assertEqual(r[50:].mean(), 1, 0.2)
+ self.assertEqual(r[:,:50].std(), 4, 0.2)
+ self.assertEqual(r[:,50:].std(), 1, 0.2)
+
def test_serialization(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py
index 638f54a..c681189 100644
--- a/torch/autograd/__init__.py
+++ b/torch/autograd/__init__.py
@@ -2,5 +2,6 @@
from .variable import Variable
from .function import Function, NestedIOFunction
+from .stochastic_function import StochasticFunction
assert torch._C._autograd_init()
diff --git a/torch/autograd/function.py b/torch/autograd/function.py
index 18fa27a..53014ee 100644
--- a/torch/autograd/function.py
+++ b/torch/autograd/function.py
@@ -44,6 +44,7 @@
super(InplaceFunction, self).__init__()
self.inplace = inplace
+
def _nested_map(condition, fn):
def _map(obj):
if condition(obj):
diff --git a/torch/autograd/functions/__init__.py b/torch/autograd/functions/__init__.py
index 9527f62..9f82344 100644
--- a/torch/autograd/functions/__init__.py
+++ b/torch/autograd/functions/__init__.py
@@ -4,4 +4,5 @@
from .reduce import *
from .linalg import *
from .blas import *
+from .stochastic import *
diff --git a/torch/autograd/functions/stochastic.py b/torch/autograd/functions/stochastic.py
new file mode 100644
index 0000000..c03307e
--- /dev/null
+++ b/torch/autograd/functions/stochastic.py
@@ -0,0 +1,83 @@
+from ..stochastic_function import StochasticFunction
+
+# Gradient formulas are based on Simple Statistical Gradient-Following
+# Algorithms for Connectionist Reinforcement Learning, available at
+# http://incompleteideas.net/sutton/williams-92.pdf
+
+
+class Multinomial(StochasticFunction):
+
+ def __init__(self, num_samples):
+ super(Multinomial, self).__init__()
+ self.num_samples = num_samples
+
+ def forward(self, probs):
+ samples = probs.multinomial(self.num_samples)
+ self.save_for_backward(probs, samples)
+ self.mark_non_differentiable(samples)
+ return samples
+
+ def backward(self, reward):
+ probs, samples = self.saved_tensors
+ grad_probs = probs.new().resize_as_(probs).zero_()
+ output_probs = probs.index_select(0, samples)
+ output_probs.add_(1e-6).cinv_()
+ output_probs.neg_().mul_(reward)
+ grad_probs.index_add_(0, samples, output_probs)
+ return grad_probs
+
+
+class Bernoulli(StochasticFunction):
+
+ def forward(self, probs):
+ samples = probs.new().resize_as_(probs).bernoulli_(probs)
+ self.save_for_backward(probs, samples)
+ self.mark_non_differentiable(samples)
+ return samples
+
+ def backward(self, reward):
+ probs, samples = self.saved_tensors
+ rev_probs = probs.neg().add_(1)
+ return (probs - samples) / (probs * rev_probs + 1e-6) * reward
+
+
+class Normal(StochasticFunction):
+
+ def __init__(self, stddev=None):
+ super(Normal, self).__init__()
+ self.stddev = stddev
+ assert stddev is None or stddev > 0
+
+ def forward(self, means, stddevs=None):
+ output = means.new().resize_as_(means)
+ output.normal_()
+ if self.stddev is not None:
+ output.mul_(self.stddev)
+ elif stddevs is not None:
+ output.mul_(stddevs)
+ else:
+ raise RuntimeError("Normal function requires specifying a common "
+ "stddev, or per-sample stddev")
+ output.add_(means)
+ self.save_for_backward(output, means, stddevs)
+ self.mark_non_differentiable(output)
+ return output
+
+ def backward(self, reward):
+ output, means, stddevs = self.saved_tensors
+ grad_stddevs = None
+ grad_means = means - output # == -(output - means)
+ assert self.stddev is not None or stddevs is not None
+ if self.stddev is not None:
+ grad_means /= 1e-6 + self.stddev ** 2
+ else:
+ stddevs_sq = stddevs * stddevs
+ stddevs_cb = stddevs_sq * stddevs
+ stddevs_sq += 1e-6
+ stddevs_cb += 1e-6
+ grad_stddevs = (grad_means * grad_means) / stddevs_cb
+ grad_stddevs = (stddevs - grad_stddevs) * reward
+ grad_means /= stddevs_sq
+ grad_means *= reward
+ return grad_means, grad_stddevs
+
diff --git a/torch/autograd/stochastic_function.py b/torch/autograd/stochastic_function.py
new file mode 100644
index 0000000..74d5982
--- /dev/null
+++ b/torch/autograd/stochastic_function.py
@@ -0,0 +1,21 @@
+from .function import Function
+
+_NOT_PROVIDED = object()
+
+class StochasticFunction(Function):
+
+ def __init__(self):
+ self.reward = _NOT_PROVIDED
+
+ def _do_backward(self, grad_output, retain_variables):
+ if self.reward is _NOT_PROVIDED:
+ raise RuntimeError("differentiating stochastic functions requires "
+ "providing a reward")
+ result = super(StochasticFunction, self)._do_backward((self.reward,), retain_variables)
+ if not retain_variables:
+ self.reward = None
+ return result
+
+ def _reinforce(self, reward):
+ self.reward = reward
+
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index b6d5d9cc..73ae8cf 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -83,9 +83,7 @@
def backward(self, gradient=None, retain_variables=False):
if self.volatile:
raise RuntimeError('calling backward on a volatile variable')
- if not self.requires_grad:
- raise RuntimeError("calling backward on a variable that doesn't require gradient")
- if gradient is None:
+ if gradient is None and self.requires_grad:
if self.data.numel() != 1:
raise RuntimeError('backward should be called only on a scalar (i.e. 1-element tensor) or with gradient w.r.t. the variable')
gradient = self.data.new(1).fill_(1)
@@ -127,6 +125,12 @@
self.grad.add_(unpacked_grad)
return tuple()
+ def reinforce(self, reward):
+ if not isinstance(self.creator, StochasticFunction):
+ raise RuntimeError("reinforce() can be only called on outputs "
+ "of stochastic functions")
+ self.creator._reinforce(reward)
+
def no_grad(self):
return NoGrad()(self)
@@ -555,6 +559,12 @@
def triu(self, diagonal_idx=0):
return Triu(diagonal_idx)(self)
+ def multinomial(self, num_samples=1):
+ return Multinomial(num_samples)(self)
+
+ def bernoulli(self):
+ return Bernoulli()(self)
+
def __add__(self, other):
return self.add(other)
__radd__ = __add__
@@ -608,6 +618,13 @@
return Concat(dim)(*iterable)
@staticmethod
+ def normal(means, stddev=1):
+ if isinstance(stddev, Variable):
+ return Normal()(means, stddev)
+ else:
+ return Normal(stddev)(means)
+
+ @staticmethod
def _blas(cls, args, inplace):
num_args = len(args)
alpha = beta = 1
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index f87dc81..1de0791 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -33,29 +33,36 @@
size_t buffer_id;
};
+// used for the queue of nodes ready for processing
+using ready_queue_type = std::deque<std::pair<THPFunction *, grad_buffer_type>>;
// Computes graph dependencies (using a super simple topological sort)
-dependencies_type THPEngine_compute_dependencies(THPFunction *function)
+void THPEngine_compute_dependencies(THPFunction *function,
+ dependencies_type& dependencies, ready_queue_type& ready)
{
- dependencies_type dependencies;
std::set<THPFunction *> seen;
std::vector<THPFunction *> queue = {function};
while (queue.size() > 0) {
THPFunction *fn = queue.back(); queue.pop_back();
for (int i = 0; i < fn->num_inputs; i++) {
THPFunction *prev_fn = (THPFunction*)fn->previous_functions[i].get();
+ // Stochastic functions are ready for backward immediately
// We can ignore variables (their backprop is called every time we have
- // gradient ready) and functions that don't require gradient.
- if (THPVariable_Check((PyObject*)prev_fn) || !prev_fn->requires_grad)
+ // gradient ready).
+ if (THPVariable_Check((PyObject*)prev_fn))
continue;
- dependencies[prev_fn] += 1;
+ if (PyObject_IsInstance((PyObject*)prev_fn, THPStochasticFunctionClass) &&
+ seen.count(prev_fn) == 0) {
+ ready.emplace_back(prev_fn, grad_buffer_type(0));
+ } else {
+ dependencies[prev_fn] += 1;
+ }
if (seen.count(prev_fn) == 0) {
seen.insert(prev_fn);
queue.push_back(prev_fn);
}
}
}
- return dependencies;
}
// Frees backward dependency and returns true if prev_fn is ready for backward
@@ -117,17 +124,24 @@
Py_RETURN_NONE;
}
- std::deque<std::pair<THPFunction *, grad_buffer_type>> ready;
+ ready_queue_type ready;
std::unordered_map<THPFunction *, grad_buffer_type> not_ready;
+ dependencies_type dependencies;
buffer_set_type need_copy;
// Initialize the queue
- grad_buffer_type buf(next_buf_id++, ((THPFunction*)variable->creator)->num_outputs);
- Py_INCREF(grad_variable);
- buf[variable->output_nr] = grad_variable;
- ready.emplace_front((THPFunction*)variable->creator, std::move(buf));
+ if (variable->requires_grad ||
+ PyObject_IsInstance(variable->creator, THPStochasticFunctionClass)) {
+ grad_buffer_type buf(next_buf_id++, ((THPFunction*)variable->creator)->num_outputs);
+ Py_INCREF(grad_variable);
+ buf[variable->output_nr] = grad_variable;
+ ready.emplace_front((THPFunction*)variable->creator, std::move(buf));
+ }
- dependencies_type dependencies = THPEngine_compute_dependencies((THPFunction*)variable->creator);
+ THPEngine_compute_dependencies((THPFunction*)variable->creator, dependencies, ready);
+
+ THPUtils_assert(ready.size() > 0, "there are no graph nodes that require "
+ "computing gradients");
while (ready.size() > 0) {
std::pair<THPFunction *, grad_buffer_type> ready_pair =
diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp
index 87fae15..269bd7f 100644
--- a/torch/csrc/autograd/function.cpp
+++ b/torch/csrc/autograd/function.cpp
@@ -19,6 +19,7 @@
PyObject *THPFunctionClass = NULL;
+PyObject *THPStochasticFunctionClass = NULL;
// Traverse and clear are required for supporting Python's GC cycle handling.
static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg)
@@ -392,7 +393,8 @@
_mark_dirty(self, t2var);
_wrap_outputs(self, t2var, raw_output, outputs);
_join_version_counters(self, t2var);
- if (self->requires_grad) {
+ if (self->requires_grad ||
+ PyObject_IsInstance((PyObject*)self, THPStochasticFunctionClass)) {
_save_variables(self, t2var);
_mark_non_differentiable(self, t2var);
}
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index a10253c..0800d4a 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -54,6 +54,7 @@
bool THPFunction_initModule(PyObject *module);
extern PyObject *THPFunctionClass;
+extern PyObject *THPStochasticFunctionClass;
#define THPFunction_Check(obj) PyObject_IsInstance(obj, THPFunctionClass)
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index 799415d..15442a3 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -12,10 +12,13 @@
THPVariableClass = PyMapping_GetItemString(autograd_dict,(char*)"Variable");
THPFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"Function");
+ THPStochasticFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"StochasticFunction");
THPUtils_assert(THPVariableClass, "couldn't find Variable class in "
"torch.autograd module");
THPUtils_assert(THPFunctionClass, "couldn't find Function class in "
"torch.autograd module");
+ THPUtils_assert(THPStochasticFunctionClass, "couldn't find "
+ "StochasticFunction class in torch.autograd module");
Py_RETURN_TRUE;
}
diff --git a/torch/csrc/generic/methods/TensorRandom.cwrap b/torch/csrc/generic/methods/TensorRandom.cwrap
index 4bd69ae..b204338 100644
--- a/torch/csrc/generic/methods/TensorRandom.cwrap
+++ b/torch/csrc/generic/methods/TensorRandom.cwrap
@@ -81,6 +81,65 @@
default: 1
]]
+#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
+static void THTensor_(normal_means)(THTensor *self, THGenerator *gen, THTensor *means, real stddev)
+{
+ THTensor_(resizeAs)(self, means);
+ THTensor_(normal)(self, gen, 0, stddev);
+ THTensor_(cadd)(self, self, 1, means);
+}
+
+static void THTensor_(normal_stddevs)(THTensor *self, THGenerator *gen, real mean, THTensor *stddevs)
+{
+ THTensor_(resizeAs)(self, stddevs);
+ THTensor_(normal)(self, gen, 0, 1);
+ THTensor_(cmul)(self, self, stddevs);
+ THTensor_(add)(self, self, mean);
+}
+
+static void THTensor_(normal_means_stddevs)(THTensor *self, THGenerator *gen, THTensor *means, THTensor *stddevs)
+{
+ THTensor_(resizeAs)(self, means);
+ THTensor_(normal)(self, gen, 0, 1);
+ THTensor_(cmul)(self, self, stddevs);
+ THTensor_(cadd)(self, self, 1, means);
+}
+#endif
+
+[[
+ name: normal
+ defined_if: defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
+ return: argument 0
+ only_stateless: True
+ options:
+ - cname: normal_means
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - arg: THGenerator* generator
+ default: THPDefaultGenerator->cdata
+ - THTensor* means
+ - arg: real stddev
+ default: 1
+ - cname: normal_stddevs
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - arg: THGenerator* generator
+ default: THPDefaultGenerator->cdata
+ - arg: real mean
+ default: 0
+ - THTensor* stddevs
+ - cname: normal_means_stddevs
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - arg: THGenerator* generator
+ default: THPDefaultGenerator->cdata
+ - THTensor* means
+ - THTensor* stddevs
+]]
+
[[
name: normal_
defined_if: defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
@@ -195,6 +254,67 @@
default: 1
]]
+#if CUDA_FLOAT || CUDA_DOUBLE || CUDA_HALF
+static void THTensor_(normal_means)(THCState *_, THTensor *self, THTensor *means, double stddev)
+{
+ THTensor_(resizeAs)(LIBRARY_STATE self, means);
+ THTensor_(normal)(LIBRARY_STATE self, 0, stddev);
+ THTensor_(cadd)(LIBRARY_STATE self, self, AS_REAL(1), means);
+}
+
+static void THTensor_(normal_stddevs)(THCState *_, THTensor *self, double mean, THTensor *stddevs)
+{
+ THTensor_(resizeAs)(LIBRARY_STATE self, stddevs);
+ THTensor_(normal)(LIBRARY_STATE self, 0, 1);
+ THTensor_(cmul)(LIBRARY_STATE self, self, stddevs);
+ THTensor_(add)(LIBRARY_STATE self, self, AS_REAL(mean));
+}
+
+static void THTensor_(normal_means_stddevs)(THCState *_, THTensor *self, THTensor *means, THTensor *stddevs)
+{
+ THTensor_(resizeAs)(LIBRARY_STATE self, means);
+ THTensor_(normal)(LIBRARY_STATE self, 0, 1);
+ THTensor_(cmul)(LIBRARY_STATE self, self, stddevs);
+ THTensor_(cadd)(LIBRARY_STATE self, self, AS_REAL(1), means);
+}
+#endif
+
+[[
+ name: normal
+ defined_if: CUDA_FLOAT || CUDA_DOUBLE || CUDA_HALF
+ return: argument 0
+ only_stateless: True
+ options:
+ - cname: normal
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - arg: double mean
+ default: 0
+ - arg: double stddev
+ default: 1
+ - cname: normal_means
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - THTensor* means
+ - arg: double stddev
+ default: 1
+ - cname: normal_stddevs
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - arg: double mean
+ default: 0
+ - THTensor* stddevs
+ - cname: normal_means_stddevs
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - THTensor* means
+ - THTensor* stddevs
+]]
+
[[
name: normal_
defined_if: CUDA_FLOAT || CUDA_DOUBLE || CUDA_HALF
@@ -282,6 +402,28 @@
- double p
]]
+#define THDoubleTensor_BERNOULLI_TENSOR THDoubleTensor_bernoulli_DoubleTensor
+#define THFloatTensor_BERNOULLI_TENSOR THFloatTensor_bernoulli_FloatTensor
+
+[[
+ name: bernoulli
+ defined_if: defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
+ return: argument 0
+ with_stateless: True
+ before_call:
+ THTensor_(resizeAs)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, ((THPTensor*)$arg2)->cdata);
+ cname: BERNOULLI_TENSOR
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - arg: THGenerator* generator
+ default: THPDefaultGenerator->cdata
+ - THTensor* self
+]]
+
+#undef THDoubleTensor_BERNOULLI_TENSOR
+#undef THFloatTensor_BERNOULLI_TENSOR
+
[[
name: bernoulli_
defined_if: "!IS_CUDA"
@@ -318,6 +460,26 @@
- double p
]]
+#define THCudaDoubleTensor_BERNOULLI_TENSOR THCudaDoubleTensor_bernoulli_DoubleTensor
+#define THCudaTensor_BERNOULLI_TENSOR THCudaTensor_bernoulli_FloatTensor
+
+[[
+ name: bernoulli
+ defined_if: CUDA_FLOAT || CUDA_DOUBLE
+ return: argument 0
+ with_stateless: True
+ cname: BERNOULLI_TENSOR
+ before_call:
+ THTensor_(resizeAs)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, ((THPTensor*)$arg1)->cdata);
+ arguments:
+ - arg: THTensor* output
+ allocate: True
+ - THTensor* self
+]]
+
+#undef THCudaDoubleTensor_BERNOULLI_TENSOR
+#undef THCudaTensor_BERNOULLI_TENSOR
+
[[
name: bernoulli_
defined_if: CUDA_FLOAT || CUDA_DOUBLE || CUDA_HALF