| r""" |
| The ``distributions`` package contains parameterizable probability distributions |
| and sampling functions. This allows the construction of stochastic computation |
| graphs and stochastic gradient estimators for optimization. This package |
| generally follows the design of the `TensorFlow Distributions`_ package. |
| |
| .. _`TensorFlow Distributions`: |
| https://arxiv.org/abs/1711.10604 |
| |
| It is not possible to directly backpropagate through random samples. However, |
| there are two main methods for creating surrogate functions that can be |
| backpropagated through. These are the score function estimator/likelihood ratio |
| estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly |
| seen as the basis for policy gradient methods in reinforcement learning, and the |
| pathwise derivative estimator is commonly seen in the reparameterization trick |
| in variational autoencoders. Whilst the score function only requires the value |
| of samples :math:`f(x)`, the pathwise derivative requires the derivative |
| :math:`f'(x)`. The next sections discuss these two in a reinforcement learning |
| example. For more details see |
| `Gradient Estimation Using Stochastic Computation Graphs`_ . |
| |
| .. _`Gradient Estimation Using Stochastic Computation Graphs`: |
| https://arxiv.org/abs/1506.05254 |
| |
| Score function |
| ^^^^^^^^^^^^^^ |
| |
| When the probability density function is differentiable with respect to its |
| parameters, we only need :meth:`~torch.distributions.Distribution.sample` and |
| :meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE: |
| |
| .. math:: |
| |
| \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta} |
| |
| where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate, |
| :math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of |
| taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`. |
| |
| In practice we would sample an action from the output of a network, apply this |
| action in an environment, and then use ``log_prob`` to construct an equivalent |
| loss function. Note that we use a negative because optimizers use gradient |
| descent, whilst the rule above assumes gradient ascent. With a categorical |
| policy, the code for implementing REINFORCE would be as follows:: |
| |
| probs = policy_network(state) |
| # Note that this is equivalent to what used to be called multinomial |
| m = Categorical(probs) |
| action = m.sample() |
| next_state, reward = env.step(action) |
| loss = -m.log_prob(action) * reward |
| loss.backward() |
| |
| Pathwise derivative |
| ^^^^^^^^^^^^^^^^^^^ |
| |
| The other way to implement these stochastic/policy gradients would be to use the |
| reparameterization trick from the |
| :meth:`~torch.distributions.Distribution.rsample` method, where the |
| parameterized random variable can be constructed via a parameterized |
| deterministic function of a parameter-free random variable. The reparameterized |
| sample therefore becomes differentiable. The code for implementing the pathwise |
| derivative would be as follows:: |
| |
| params = policy_network(state) |
| m = Normal(*params) |
| # Any distribution with .has_rsample == True could work based on the application |
| action = m.rsample() |
| next_state, reward = env.step(action) # Assuming that reward is differentiable |
| loss = -reward |
| loss.backward() |
| """ |
| |
| from .bernoulli import Bernoulli |
| from .beta import Beta |
| from .binomial import Binomial |
| from .categorical import Categorical |
| from .cauchy import Cauchy |
| from .chi2 import Chi2 |
| from .constraint_registry import biject_to, transform_to |
| from .continuous_bernoulli import ContinuousBernoulli |
| from .dirichlet import Dirichlet |
| from .distribution import Distribution |
| from .exp_family import ExponentialFamily |
| from .exponential import Exponential |
| from .fishersnedecor import FisherSnedecor |
| from .gamma import Gamma |
| from .geometric import Geometric |
| from .gumbel import Gumbel |
| from .half_cauchy import HalfCauchy |
| from .half_normal import HalfNormal |
| from .independent import Independent |
| from .kl import kl_divergence, register_kl, _add_kl_info |
| from .kumaraswamy import Kumaraswamy |
| from .laplace import Laplace |
| from .lkj_cholesky import LKJCholesky |
| from .log_normal import LogNormal |
| from .logistic_normal import LogisticNormal |
| from .lowrank_multivariate_normal import LowRankMultivariateNormal |
| from .mixture_same_family import MixtureSameFamily |
| from .multinomial import Multinomial |
| from .multivariate_normal import MultivariateNormal |
| from .negative_binomial import NegativeBinomial |
| from .normal import Normal |
| from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough |
| from .pareto import Pareto |
| from .poisson import Poisson |
| from .relaxed_bernoulli import RelaxedBernoulli |
| from .relaxed_categorical import RelaxedOneHotCategorical |
| from .studentT import StudentT |
| from .transformed_distribution import TransformedDistribution |
| from .transforms import * # noqa: F403 |
| from .uniform import Uniform |
| from .von_mises import VonMises |
| from .weibull import Weibull |
| from .wishart import Wishart |
| from . import transforms |
| |
| _add_kl_info() |
| del _add_kl_info |
| |
| __all__ = [ |
| 'Bernoulli', |
| 'Beta', |
| 'Binomial', |
| 'Categorical', |
| 'Cauchy', |
| 'Chi2', |
| 'ContinuousBernoulli', |
| 'Dirichlet', |
| 'Distribution', |
| 'Exponential', |
| 'ExponentialFamily', |
| 'FisherSnedecor', |
| 'Gamma', |
| 'Geometric', |
| 'Gumbel', |
| 'HalfCauchy', |
| 'HalfNormal', |
| 'Independent', |
| 'Kumaraswamy', |
| 'LKJCholesky', |
| 'Laplace', |
| 'LogNormal', |
| 'LogisticNormal', |
| 'LowRankMultivariateNormal', |
| 'MixtureSameFamily', |
| 'Multinomial', |
| 'MultivariateNormal', |
| 'NegativeBinomial', |
| 'Normal', |
| 'OneHotCategorical', |
| 'OneHotCategoricalStraightThrough', |
| 'Pareto', |
| 'RelaxedBernoulli', |
| 'RelaxedOneHotCategorical', |
| 'StudentT', |
| 'Poisson', |
| 'Uniform', |
| 'VonMises', |
| 'Weibull', |
| 'Wishart', |
| 'TransformedDistribution', |
| 'biject_to', |
| 'kl_divergence', |
| 'register_kl', |
| 'transform_to', |
| ] |
| __all__.extend(transforms.__all__) |