|  | import torch | 
|  | from torch.distributions.distribution import Distribution | 
|  |  | 
|  |  | 
|  | class ExponentialFamily(Distribution): | 
|  | r""" | 
|  | ExponentialFamily is the abstract base class for probability distributions belonging to an | 
|  | exponential family, whose probability mass/density function has the form is defined below | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) | 
|  |  | 
|  | where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic, | 
|  | :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier | 
|  | measure. | 
|  |  | 
|  | Note: | 
|  | This class is an intermediary between the `Distribution` class and distributions which belong | 
|  | to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL | 
|  | divergence methods. We use this class to compute the entropy and KL divergence using the AD | 
|  | framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and | 
|  | Cross-entropies of Exponential Families). | 
|  | """ | 
|  |  | 
|  | @property | 
|  | def _natural_params(self): | 
|  | """ | 
|  | Abstract method for natural parameters. Returns a tuple of Tensors based | 
|  | on the distribution | 
|  | """ | 
|  | raise NotImplementedError | 
|  |  | 
|  | def _log_normalizer(self, *natural_params): | 
|  | """ | 
|  | Abstract method for log normalizer function. Returns a log normalizer based on | 
|  | the distribution and input | 
|  | """ | 
|  | raise NotImplementedError | 
|  |  | 
|  | @property | 
|  | def _mean_carrier_measure(self): | 
|  | """ | 
|  | Abstract method for expected carrier measure, which is required for computing | 
|  | entropy. | 
|  | """ | 
|  | raise NotImplementedError | 
|  |  | 
|  | def entropy(self): | 
|  | """ | 
|  | Method to compute the entropy using Bregman divergence of the log normalizer. | 
|  | """ | 
|  | result = -self._mean_carrier_measure | 
|  | nparams = [p.detach().requires_grad_() for p in self._natural_params] | 
|  | lg_normal = self._log_normalizer(*nparams) | 
|  | gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) | 
|  | result += lg_normal | 
|  | for np, g in zip(nparams, gradients): | 
|  | result -= np * g | 
|  | return result |