| import torch | 
 | from typing import Optional | 
 |  | 
 |  | 
 | class SobolEngine: | 
 |     r""" | 
 |     The :class:`torch.quasirandom.SobolEngine` is an engine for generating | 
 |     (scrambled) Sobol sequences. Sobol sequences are an example of low | 
 |     discrepancy quasi-random sequences. | 
 |  | 
 |     This implementation of an engine for Sobol sequences is capable of | 
 |     sampling sequences up to a maximum dimension of 21201. It uses direction | 
 |     numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the | 
 |     search criterion D(6) up to the dimension 21201. This is the recommended | 
 |     choice by the authors. | 
 |  | 
 |     References: | 
 |       - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points. | 
 |         Journal of Complexity, 14(4):466-489, December 1998. | 
 |  | 
 |       - I. M. Sobol. The distribution of points in a cube and the accurate | 
 |         evaluation of integrals. | 
 |         Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967. | 
 |  | 
 |     Args: | 
 |         dimension (Int): The dimensionality of the sequence to be drawn | 
 |         scramble (bool, optional): Setting this to ``True`` will produce | 
 |                                    scrambled Sobol sequences. Scrambling is | 
 |                                    capable of producing better Sobol | 
 |                                    sequences. Default: ``False``. | 
 |         seed (Int, optional): This is the seed for the scrambling. The seed | 
 |                               of the random number generator is set to this, | 
 |                               if specified. Otherwise, it uses a random seed. | 
 |                               Default: ``None`` | 
 |  | 
 |     Examples:: | 
 |  | 
 |         >>> # xdoctest: +SKIP("unseeded random state") | 
 |         >>> soboleng = torch.quasirandom.SobolEngine(dimension=5) | 
 |         >>> soboleng.draw(3) | 
 |         tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], | 
 |                 [0.5000, 0.5000, 0.5000, 0.5000, 0.5000], | 
 |                 [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]]) | 
 |     """ | 
 |     MAXBIT = 30 | 
 |     MAXDIM = 21201 | 
 |  | 
 |     def __init__(self, dimension, scramble=False, seed=None): | 
 |         if dimension > self.MAXDIM or dimension < 1: | 
 |             raise ValueError("Supported range of dimensionality " | 
 |                              f"for SobolEngine is [1, {self.MAXDIM}]") | 
 |  | 
 |         self.seed = seed | 
 |         self.scramble = scramble | 
 |         self.dimension = dimension | 
 |  | 
 |         cpu = torch.device("cpu") | 
 |  | 
 |         self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long) | 
 |         torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) | 
 |  | 
 |         if not self.scramble: | 
 |             self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long) | 
 |         else: | 
 |             self._scramble() | 
 |  | 
 |         self.quasi = self.shift.clone(memory_format=torch.contiguous_format) | 
 |         self._first_point = (self.quasi / 2 ** self.MAXBIT).reshape(1, -1) | 
 |         self.num_generated = 0 | 
 |  | 
 |     def draw(self, n: int = 1, out: Optional[torch.Tensor] = None, | 
 |              dtype: torch.dtype = torch.float32) -> torch.Tensor: | 
 |         r""" | 
 |         Function to draw a sequence of :attr:`n` points from a Sobol sequence. | 
 |         Note that the samples are dependent on the previous samples. The size | 
 |         of the result is :math:`(n, dimension)`. | 
 |  | 
 |         Args: | 
 |             n (Int, optional): The length of sequence of points to draw. | 
 |                                Default: 1 | 
 |             out (Tensor, optional): The output tensor | 
 |             dtype (:class:`torch.dtype`, optional): the desired data type of the | 
 |                                                     returned tensor. | 
 |                                                     Default: ``torch.float32`` | 
 |         """ | 
 |         if self.num_generated == 0: | 
 |             if n == 1: | 
 |                 result = self._first_point.to(dtype) | 
 |             else: | 
 |                 result, self.quasi = torch._sobol_engine_draw( | 
 |                     self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype, | 
 |                 ) | 
 |                 result = torch.cat((self._first_point, result), dim=-2) | 
 |         else: | 
 |             result, self.quasi = torch._sobol_engine_draw( | 
 |                 self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype, | 
 |             ) | 
 |  | 
 |         self.num_generated += n | 
 |  | 
 |         if out is not None: | 
 |             out.resize_as_(result).copy_(result) | 
 |             return out | 
 |  | 
 |         return result | 
 |  | 
 |     def draw_base2(self, m: int, out: Optional[torch.Tensor] = None, | 
 |                    dtype: torch.dtype = torch.float32) -> torch.Tensor: | 
 |         r""" | 
 |         Function to draw a sequence of :attr:`2**m` points from a Sobol sequence. | 
 |         Note that the samples are dependent on the previous samples. The size | 
 |         of the result is :math:`(2**m, dimension)`. | 
 |  | 
 |         Args: | 
 |             m (Int): The (base2) exponent of the number of points to draw. | 
 |             out (Tensor, optional): The output tensor | 
 |             dtype (:class:`torch.dtype`, optional): the desired data type of the | 
 |                                                     returned tensor. | 
 |                                                     Default: ``torch.float32`` | 
 |         """ | 
 |         n = 2 ** m | 
 |         total_n = self.num_generated + n | 
 |         if not (total_n & (total_n - 1) == 0): | 
 |             raise ValueError("The balance properties of Sobol' points require " | 
 |                              f"n to be a power of 2. {self.num_generated} points have been " | 
 |                              f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. " | 
 |                              "If you still want to do this, please use " | 
 |                              "'SobolEngine.draw()' instead." | 
 |                              ) | 
 |         return self.draw(n=n, out=out, dtype=dtype) | 
 |  | 
 |     def reset(self): | 
 |         r""" | 
 |         Function to reset the ``SobolEngine`` to base state. | 
 |         """ | 
 |         self.quasi.copy_(self.shift) | 
 |         self.num_generated = 0 | 
 |         return self | 
 |  | 
 |     def fast_forward(self, n): | 
 |         r""" | 
 |         Function to fast-forward the state of the ``SobolEngine`` by | 
 |         :attr:`n` steps. This is equivalent to drawing :attr:`n` samples | 
 |         without using the samples. | 
 |  | 
 |         Args: | 
 |             n (Int): The number of steps to fast-forward by. | 
 |         """ | 
 |         if self.num_generated == 0: | 
 |             torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated) | 
 |         else: | 
 |             torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1) | 
 |         self.num_generated += n | 
 |         return self | 
 |  | 
 |     def _scramble(self): | 
 |         g: Optional[torch.Generator] = None | 
 |         if self.seed is not None: | 
 |             g = torch.Generator() | 
 |             g.manual_seed(self.seed) | 
 |  | 
 |         cpu = torch.device("cpu") | 
 |  | 
 |         # Generate shift vector | 
 |         shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g) | 
 |         self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))) | 
 |  | 
 |         # Generate lower triangular matrices (stacked across dimensions) | 
 |         ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT) | 
 |         ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril() | 
 |  | 
 |         torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension) | 
 |  | 
 |     def __repr__(self): | 
 |         fmt_string = [f'dimension={self.dimension}'] | 
 |         if self.scramble: | 
 |             fmt_string += ['scramble=True'] | 
 |         if self.seed is not None: | 
 |             fmt_string += [f'seed={self.seed}'] | 
 |         return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')' |