|  | import copy | 
|  | from typing import Optional, Any, Union, Callable | 
|  |  | 
|  | import torch | 
|  | import warnings | 
|  | from torch import Tensor | 
|  | from .. import functional as F | 
|  | from .module import Module | 
|  | from .activation import MultiheadAttention | 
|  | from .container import ModuleList | 
|  | from ..init import xavier_uniform_ | 
|  | from .dropout import Dropout | 
|  | from .linear import Linear | 
|  | from .normalization import LayerNorm | 
|  |  | 
|  | __all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer'] | 
|  |  | 
|  | def _generate_square_subsequent_mask( | 
|  | sz: int, | 
|  | device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'), | 
|  | dtype: torch.dtype = torch.get_default_dtype(), | 
|  | ) -> Tensor: | 
|  | r"""Generate a square causal mask for the sequence. | 
|  |  | 
|  | The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). | 
|  | """ | 
|  | return torch.triu( | 
|  | torch.full((sz, sz), float('-inf'), dtype=dtype, device=device), | 
|  | diagonal=1, | 
|  | ) | 
|  |  | 
|  |  | 
|  | def _get_seq_len( | 
|  | src: Tensor, | 
|  | batch_first: bool | 
|  | ) -> Optional[int]: | 
|  |  | 
|  | if src.is_nested: | 
|  | return None | 
|  | else: | 
|  | src_size = src.size() | 
|  | if len(src_size) == 2: | 
|  | # unbatched: S, E | 
|  | return src_size[0] | 
|  | else: | 
|  | # batched: B, S, E if batch_first else S, B, E | 
|  | seq_len_pos = 1 if batch_first else 0 | 
|  | return src_size[seq_len_pos] | 
|  |  | 
|  |  | 
|  | class Transformer(Module): | 
|  | r"""A transformer model. | 
|  |  | 
|  | User is able to modify the attributes as needed. The architecture | 
|  | is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, | 
|  | Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and | 
|  | Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information | 
|  | Processing Systems, pages 6000-6010. | 
|  |  | 
|  | Args: | 
|  | d_model: the number of expected features in the encoder/decoder inputs (default=512). | 
|  | nhead: the number of heads in the multiheadattention models (default=8). | 
|  | num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). | 
|  | num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). | 
|  | dim_feedforward: the dimension of the feedforward network model (default=2048). | 
|  | dropout: the dropout value (default=0.1). | 
|  | activation: the activation function of encoder/decoder intermediate layer, can be a string | 
|  | ("relu" or "gelu") or a unary callable. Default: relu | 
|  | custom_encoder: custom encoder (default=None). | 
|  | custom_decoder: custom decoder (default=None). | 
|  | layer_norm_eps: the eps value in layer normalization components (default=1e-5). | 
|  | batch_first: If ``True``, then the input and output tensors are provided | 
|  | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). | 
|  | norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before | 
|  | other attention and feedforward operations, otherwise after. Default: ``False`` (after). | 
|  | bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive | 
|  | bias. Default: ``True``. | 
|  |  | 
|  | Examples:: | 
|  | >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) | 
|  | >>> src = torch.rand((10, 32, 512)) | 
|  | >>> tgt = torch.rand((20, 32, 512)) | 
|  | >>> out = transformer_model(src, tgt) | 
|  |  | 
|  | Note: A full example to apply nn.Transformer module for the word language model is available in | 
|  | https://github.com/pytorch/examples/tree/master/word_language_model | 
|  | """ | 
|  |  | 
|  | def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, | 
|  | num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, | 
|  | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | 
|  | custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, | 
|  | layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, | 
|  | bias: bool = True, device=None, dtype=None) -> None: | 
|  | factory_kwargs = {'device': device, 'dtype': dtype} | 
|  | super().__init__() | 
|  | torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") | 
|  |  | 
|  | if custom_encoder is not None: | 
|  | self.encoder = custom_encoder | 
|  | else: | 
|  | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, | 
|  | activation, layer_norm_eps, batch_first, norm_first, | 
|  | bias, **factory_kwargs) | 
|  | encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | 
|  | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) | 
|  |  | 
|  | if custom_decoder is not None: | 
|  | self.decoder = custom_decoder | 
|  | else: | 
|  | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, | 
|  | activation, layer_norm_eps, batch_first, norm_first, | 
|  | bias, **factory_kwargs) | 
|  | decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | 
|  | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) | 
|  |  | 
|  | self._reset_parameters() | 
|  |  | 
|  | self.d_model = d_model | 
|  | self.nhead = nhead | 
|  |  | 
|  | self.batch_first = batch_first | 
|  |  | 
|  | def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, | 
|  | memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, | 
|  | tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, | 
|  | src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None, | 
|  | memory_is_causal: bool = False) -> Tensor: | 
|  | r"""Take in and process masked source/target sequences. | 
|  |  | 
|  | Args: | 
|  | src: the sequence to the encoder (required). | 
|  | tgt: the sequence to the decoder (required). | 
|  | src_mask: the additive mask for the src sequence (optional). | 
|  | tgt_mask: the additive mask for the tgt sequence (optional). | 
|  | memory_mask: the additive mask for the encoder output (optional). | 
|  | src_key_padding_mask: the Tensor mask for src keys per batch (optional). | 
|  | tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional). | 
|  | memory_key_padding_mask: the Tensor mask for memory keys per batch (optional). | 
|  | src_is_causal: If specified, applies a causal mask as ``src_mask``. | 
|  | Default: ``None``; try to detect a causal mask. | 
|  | Warning: | 
|  | ``src_is_causal`` provides a hint that ``src_mask`` is | 
|  | the causal mask. Providing incorrect hints can result in | 
|  | incorrect execution, including forward and backward | 
|  | compatibility. | 
|  | tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``. | 
|  | Default: ``None``; try to detect a causal mask. | 
|  | Warning: | 
|  | ``tgt_is_causal`` provides a hint that ``tgt_mask`` is | 
|  | the causal mask. Providing incorrect hints can result in | 
|  | incorrect execution, including forward and backward | 
|  | compatibility. | 
|  | memory_is_causal: If specified, applies a causal mask as | 
|  | ``memory_mask``. | 
|  | Default: ``False``. | 
|  | Warning: | 
|  | ``memory_is_causal`` provides a hint that | 
|  | ``memory_mask`` is the causal mask. Providing incorrect | 
|  | hints can result in incorrect execution, including | 
|  | forward and backward compatibility. | 
|  |  | 
|  | Shape: | 
|  | - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or | 
|  | `(N, S, E)` if `batch_first=True`. | 
|  | - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or | 
|  | `(N, T, E)` if `batch_first=True`. | 
|  | - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`. | 
|  | - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`. | 
|  | - memory_mask: :math:`(T, S)`. | 
|  | - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. | 
|  | - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`. | 
|  | - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. | 
|  |  | 
|  | Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked | 
|  | positions. If a BoolTensor is provided, positions with ``True`` | 
|  | are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor | 
|  | is provided, it will be added to the attention weight. | 
|  | [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by | 
|  | the attention. If a BoolTensor is provided, the positions with the | 
|  | value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. | 
|  |  | 
|  | - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or | 
|  | `(N, T, E)` if `batch_first=True`. | 
|  |  | 
|  | Note: Due to the multi-head attention architecture in the transformer model, | 
|  | the output sequence length of a transformer is same as the input sequence | 
|  | (i.e. target) length of the decoder. | 
|  |  | 
|  | where S is the source sequence length, T is the target sequence length, N is the | 
|  | batch size, E is the feature number | 
|  |  | 
|  | Examples: | 
|  | >>> # xdoctest: +SKIP | 
|  | >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) | 
|  | """ | 
|  | is_batched = src.dim() == 3 | 
|  | if not self.batch_first and src.size(1) != tgt.size(1) and is_batched: | 
|  | raise RuntimeError("the batch number of src and tgt must be equal") | 
|  | elif self.batch_first and src.size(0) != tgt.size(0) and is_batched: | 
|  | raise RuntimeError("the batch number of src and tgt must be equal") | 
|  |  | 
|  | if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model: | 
|  | raise RuntimeError("the feature number of src and tgt must be equal to d_model") | 
|  |  | 
|  | memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask, | 
|  | is_causal=src_is_causal) | 
|  | output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, | 
|  | tgt_key_padding_mask=tgt_key_padding_mask, | 
|  | memory_key_padding_mask=memory_key_padding_mask, | 
|  | tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal) | 
|  | return output | 
|  |  | 
|  | @staticmethod | 
|  | def generate_square_subsequent_mask( | 
|  | sz: int, | 
|  | device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'), | 
|  | dtype: torch.dtype = torch.get_default_dtype(), | 
|  | ) -> Tensor: | 
|  | r"""Generate a square causal mask for the sequence. | 
|  |  | 
|  | The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). | 
|  | """ | 
|  | return _generate_square_subsequent_mask(sz, dtype=dtype, device=device) | 
|  |  | 
|  | def _reset_parameters(self): | 
|  | r"""Initiate parameters in the transformer model.""" | 
|  | for p in self.parameters(): | 
|  | if p.dim() > 1: | 
|  | xavier_uniform_(p) | 
|  |  | 
|  |  | 
|  | class TransformerEncoder(Module): | 
|  | r"""TransformerEncoder is a stack of N encoder layers. | 
|  |  | 
|  | Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. | 
|  |  | 
|  | Args: | 
|  | encoder_layer: an instance of the TransformerEncoderLayer() class (required). | 
|  | num_layers: the number of sub-encoder-layers in the encoder (required). | 
|  | norm: the layer normalization component (optional). | 
|  | enable_nested_tensor: if True, input will automatically convert to nested tensor | 
|  | (and convert back on output). This will improve the overall performance of | 
|  | TransformerEncoder when padding rate is high. Default: ``True`` (enabled). | 
|  |  | 
|  | Examples:: | 
|  | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) | 
|  | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) | 
|  | >>> src = torch.rand(10, 32, 512) | 
|  | >>> out = transformer_encoder(src) | 
|  | """ | 
|  |  | 
|  | __constants__ = ['norm'] | 
|  |  | 
|  | def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True): | 
|  | super().__init__() | 
|  | torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") | 
|  | self.layers = _get_clones(encoder_layer, num_layers) | 
|  | self.num_layers = num_layers | 
|  | self.norm = norm | 
|  | # this attribute saves the value providedat object construction | 
|  | self.enable_nested_tensor = enable_nested_tensor | 
|  | # this attribute controls whether nested tensors are used | 
|  | self.use_nested_tensor = enable_nested_tensor | 
|  | self.mask_check = mask_check | 
|  |  | 
|  | enc_layer = "encoder_layer" | 
|  | why_not_sparsity_fast_path = '' | 
|  | if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer): | 
|  | why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer" | 
|  | elif encoder_layer.norm_first : | 
|  | why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True" | 
|  | elif not encoder_layer.self_attn.batch_first: | 
|  | why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" + | 
|  | "(use batch_first for better inference performance)") | 
|  | elif not encoder_layer.self_attn._qkv_same_embed_dim: | 
|  | why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True" | 
|  | elif not encoder_layer.activation_relu_or_gelu: | 
|  | why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True" | 
|  | elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) : | 
|  | why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps" | 
|  | elif encoder_layer.self_attn.num_heads % 2 == 1: | 
|  | why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd" | 
|  |  | 
|  | if enable_nested_tensor and why_not_sparsity_fast_path: | 
|  | warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") | 
|  | self.use_nested_tensor = False | 
|  |  | 
|  |  | 
|  | def forward( | 
|  | self, | 
|  | src: Tensor, | 
|  | mask: Optional[Tensor] = None, | 
|  | src_key_padding_mask: Optional[Tensor] = None, | 
|  | is_causal: Optional[bool] = None) -> Tensor: | 
|  | r"""Pass the input through the encoder layers in turn. | 
|  |  | 
|  | Args: | 
|  | src: the sequence to the encoder (required). | 
|  | mask: the mask for the src sequence (optional). | 
|  | src_key_padding_mask: the mask for the src keys per batch (optional). | 
|  | is_causal: If specified, applies a causal mask as ``mask``. | 
|  | Default: ``None``; try to detect a causal mask. | 
|  | Warning: | 
|  | ``is_causal`` provides a hint that ``mask`` is the | 
|  | causal mask. Providing incorrect hints can result in | 
|  | incorrect execution, including forward and backward | 
|  | compatibility. | 
|  |  | 
|  | Shape: | 
|  | see the docs in Transformer class. | 
|  | """ | 
|  | src_key_padding_mask = F._canonical_mask( | 
|  | mask=src_key_padding_mask, | 
|  | mask_name="src_key_padding_mask", | 
|  | other_type=F._none_or_dtype(mask), | 
|  | other_name="mask", | 
|  | target_type=src.dtype | 
|  | ) | 
|  |  | 
|  | mask = F._canonical_mask( | 
|  | mask=mask, | 
|  | mask_name="mask", | 
|  | other_type=None, | 
|  | other_name="", | 
|  | target_type=src.dtype, | 
|  | check_other=False, | 
|  | ) | 
|  |  | 
|  | output = src | 
|  | convert_to_nested = False | 
|  | first_layer = self.layers[0] | 
|  | src_key_padding_mask_for_layers = src_key_padding_mask | 
|  | why_not_sparsity_fast_path = '' | 
|  | str_first_layer = "self.layers[0]" | 
|  | batch_first = first_layer.self_attn.batch_first | 
|  | if not hasattr(self, "use_nested_tensor"): | 
|  | why_not_sparsity_fast_path = "use_nested_tensor attribute not present" | 
|  | elif not self.use_nested_tensor: | 
|  | why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True" | 
|  | elif first_layer.training: | 
|  | why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" | 
|  | elif not src.dim() == 3: | 
|  | why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" | 
|  | elif src_key_padding_mask is None: | 
|  | why_not_sparsity_fast_path = "src_key_padding_mask was None" | 
|  | elif (((not hasattr(self, "mask_check")) or self.mask_check) | 
|  | and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())): | 
|  | why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" | 
|  | elif output.is_nested: | 
|  | why_not_sparsity_fast_path = "NestedTensor input is not supported" | 
|  | elif mask is not None: | 
|  | why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied" | 
|  | elif torch.is_autocast_enabled(): | 
|  | why_not_sparsity_fast_path = "autocast is enabled" | 
|  |  | 
|  | if not why_not_sparsity_fast_path: | 
|  | tensor_args = ( | 
|  | src, | 
|  | first_layer.self_attn.in_proj_weight, | 
|  | first_layer.self_attn.in_proj_bias, | 
|  | first_layer.self_attn.out_proj.weight, | 
|  | first_layer.self_attn.out_proj.bias, | 
|  | first_layer.norm1.weight, | 
|  | first_layer.norm1.bias, | 
|  | first_layer.norm2.weight, | 
|  | first_layer.norm2.bias, | 
|  | first_layer.linear1.weight, | 
|  | first_layer.linear1.bias, | 
|  | first_layer.linear2.weight, | 
|  | first_layer.linear2.bias, | 
|  | ) | 
|  | _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name] | 
|  | if torch.overrides.has_torch_function(tensor_args): | 
|  | why_not_sparsity_fast_path = "some Tensor argument has_torch_function" | 
|  | elif src.device.type not in _supported_device_type: | 
|  | why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}" | 
|  | elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): | 
|  | why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " | 
|  | "input/output projection weights or biases requires_grad") | 
|  |  | 
|  | if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): | 
|  | convert_to_nested = True | 
|  | output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False) | 
|  | src_key_padding_mask_for_layers = None | 
|  |  | 
|  | seq_len = _get_seq_len(src, batch_first) | 
|  | is_causal = _detect_is_causal_mask(mask, is_causal, seq_len) | 
|  |  | 
|  | for mod in self.layers: | 
|  | output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers) | 
|  |  | 
|  | if convert_to_nested: | 
|  | output = output.to_padded_tensor(0., src.size()) | 
|  |  | 
|  | if self.norm is not None: | 
|  | output = self.norm(output) | 
|  |  | 
|  | return output | 
|  |  | 
|  |  | 
|  | class TransformerDecoder(Module): | 
|  | r"""TransformerDecoder is a stack of N decoder layers. | 
|  |  | 
|  | Args: | 
|  | decoder_layer: an instance of the TransformerDecoderLayer() class (required). | 
|  | num_layers: the number of sub-decoder-layers in the decoder (required). | 
|  | norm: the layer normalization component (optional). | 
|  |  | 
|  | Examples:: | 
|  | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) | 
|  | >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) | 
|  | >>> memory = torch.rand(10, 32, 512) | 
|  | >>> tgt = torch.rand(20, 32, 512) | 
|  | >>> out = transformer_decoder(tgt, memory) | 
|  | """ | 
|  |  | 
|  | __constants__ = ['norm'] | 
|  |  | 
|  | def __init__(self, decoder_layer, num_layers, norm=None): | 
|  | super().__init__() | 
|  | torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") | 
|  | self.layers = _get_clones(decoder_layer, num_layers) | 
|  | self.num_layers = num_layers | 
|  | self.norm = norm | 
|  |  | 
|  | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, | 
|  | memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, | 
|  | memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None, | 
|  | memory_is_causal: bool = False) -> Tensor: | 
|  | r"""Pass the inputs (and mask) through the decoder layer in turn. | 
|  |  | 
|  | Args: | 
|  | tgt: the sequence to the decoder (required). | 
|  | memory: the sequence from the last layer of the encoder (required). | 
|  | tgt_mask: the mask for the tgt sequence (optional). | 
|  | memory_mask: the mask for the memory sequence (optional). | 
|  | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). | 
|  | memory_key_padding_mask: the mask for the memory keys per batch (optional). | 
|  | tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. | 
|  | Default: ``None``; try to detect a causal mask. | 
|  | Warning: | 
|  | ``tgt_is_causal`` provides a hint that ``tgt_mask`` is | 
|  | the causal mask. Providing incorrect hints can result in | 
|  | incorrect execution, including forward and backward | 
|  | compatibility. | 
|  | memory_is_causal: If specified, applies a causal mask as | 
|  | ``memory mask``. | 
|  | Default: ``False``. | 
|  | Warning: | 
|  | ``memory_is_causal`` provides a hint that | 
|  | ``memory_mask`` is the causal mask. Providing incorrect | 
|  | hints can result in incorrect execution, including | 
|  | forward and backward compatibility. | 
|  |  | 
|  | Shape: | 
|  | see the docs in Transformer class. | 
|  | """ | 
|  | output = tgt | 
|  |  | 
|  | seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) | 
|  | tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) | 
|  |  | 
|  | for mod in self.layers: | 
|  | output = mod(output, memory, tgt_mask=tgt_mask, | 
|  | memory_mask=memory_mask, | 
|  | tgt_key_padding_mask=tgt_key_padding_mask, | 
|  | memory_key_padding_mask=memory_key_padding_mask, | 
|  | tgt_is_causal=tgt_is_causal, | 
|  | memory_is_causal=memory_is_causal) | 
|  |  | 
|  | if self.norm is not None: | 
|  | output = self.norm(output) | 
|  |  | 
|  | return output | 
|  |  | 
|  | class TransformerEncoderLayer(Module): | 
|  | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. | 
|  |  | 
|  | This standard encoder layer is based on the paper "Attention Is All You Need". | 
|  | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, | 
|  | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in | 
|  | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement | 
|  | in a different way during application. | 
|  |  | 
|  | TransformerEncoderLayer can handle either traditional torch.tensor inputs, | 
|  | or Nested Tensor inputs.  Derived classes are expected to similarly accept | 
|  | both input formats.  (Not all combinations of inputs are currently | 
|  | supported by TransformerEncoderLayer while Nested Tensor is in prototype | 
|  | state.) | 
|  |  | 
|  | If you are implementing a custom layer, you may derive it either from | 
|  | the Module or TransformerEncoderLayer class.  If your custom layer | 
|  | supports both torch.Tensors and Nested Tensors inputs, make its | 
|  | implementation a derived class of TransformerEncoderLayer. If your custom | 
|  | Layer supports only torch.Tensor inputs, derive its implementation from | 
|  | Module. | 
|  |  | 
|  | Args: | 
|  | d_model: the number of expected features in the input (required). | 
|  | nhead: the number of heads in the multiheadattention models (required). | 
|  | dim_feedforward: the dimension of the feedforward network model (default=2048). | 
|  | dropout: the dropout value (default=0.1). | 
|  | activation: the activation function of the intermediate layer, can be a string | 
|  | ("relu" or "gelu") or a unary callable. Default: relu | 
|  | layer_norm_eps: the eps value in layer normalization components (default=1e-5). | 
|  | batch_first: If ``True``, then the input and output tensors are provided | 
|  | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). | 
|  | norm_first: if ``True``, layer norm is done prior to attention and feedforward | 
|  | operations, respectively. Otherwise it's done after. Default: ``False`` (after). | 
|  | bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive | 
|  | bias. Default: ``True``. | 
|  |  | 
|  | Examples:: | 
|  | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) | 
|  | >>> src = torch.rand(10, 32, 512) | 
|  | >>> out = encoder_layer(src) | 
|  |  | 
|  | Alternatively, when ``batch_first`` is ``True``: | 
|  | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) | 
|  | >>> src = torch.rand(32, 10, 512) | 
|  | >>> out = encoder_layer(src) | 
|  |  | 
|  | Fast path: | 
|  | forward() will use a special optimized implementation described in | 
|  | `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following | 
|  | conditions are met: | 
|  |  | 
|  | - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor | 
|  | argument ``requires_grad`` | 
|  | - training is disabled (using ``.eval()``) | 
|  | - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``) | 
|  | - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu`` | 
|  | - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed | 
|  | - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask`` | 
|  | nor ``src_key_padding_mask`` is passed | 
|  | - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case | 
|  | unless the caller has manually modified one without modifying the other) | 
|  |  | 
|  | If the optimized implementation is in use, a | 
|  | `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be | 
|  | passed for ``src`` to represent padding more efficiently than using a padding | 
|  | mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be | 
|  | returned, and an additional speedup proportional to the fraction of the input that | 
|  | is padding can be expected. | 
|  |  | 
|  | .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: | 
|  | https://arxiv.org/abs/2205.14135 | 
|  |  | 
|  | """ | 
|  |  | 
|  | __constants__ = ['norm_first'] | 
|  |  | 
|  | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, | 
|  | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | 
|  | layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, | 
|  | bias: bool = True, device=None, dtype=None) -> None: | 
|  | factory_kwargs = {'device': device, 'dtype': dtype} | 
|  | super().__init__() | 
|  | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, | 
|  | bias=bias, batch_first=batch_first, | 
|  | **factory_kwargs) | 
|  | # Implementation of Feedforward model | 
|  | self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) | 
|  | self.dropout = Dropout(dropout) | 
|  | self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) | 
|  |  | 
|  | self.norm_first = norm_first | 
|  | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | 
|  | self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | 
|  | self.dropout1 = Dropout(dropout) | 
|  | self.dropout2 = Dropout(dropout) | 
|  |  | 
|  | # Legacy string support for activation function. | 
|  | if isinstance(activation, str): | 
|  | activation = _get_activation_fn(activation) | 
|  |  | 
|  | # We can't test self.activation in forward() in TorchScript, | 
|  | # so stash some information about it instead. | 
|  | if activation is F.relu or isinstance(activation, torch.nn.ReLU): | 
|  | self.activation_relu_or_gelu = 1 | 
|  | elif activation is F.gelu or isinstance(activation, torch.nn.GELU): | 
|  | self.activation_relu_or_gelu = 2 | 
|  | else: | 
|  | self.activation_relu_or_gelu = 0 | 
|  | self.activation = activation | 
|  |  | 
|  | def __setstate__(self, state): | 
|  | super().__setstate__(state) | 
|  | if not hasattr(self, 'activation'): | 
|  | self.activation = F.relu | 
|  |  | 
|  |  | 
|  | def forward( | 
|  | self, | 
|  | src: Tensor, | 
|  | src_mask: Optional[Tensor] = None, | 
|  | src_key_padding_mask: Optional[Tensor] = None, | 
|  | is_causal: bool = False) -> Tensor: | 
|  | r"""Pass the input through the encoder layer. | 
|  |  | 
|  | Args: | 
|  | src: the sequence to the encoder layer (required). | 
|  | src_mask: the mask for the src sequence (optional). | 
|  | src_key_padding_mask: the mask for the src keys per batch (optional). | 
|  | is_causal: If specified, applies a causal mask as ``src mask``. | 
|  | Default: ``False``. | 
|  | Warning: | 
|  | ``is_causal`` provides a hint that ``src_mask`` is the | 
|  | causal mask. Providing incorrect hints can result in | 
|  | incorrect execution, including forward and backward | 
|  | compatibility. | 
|  |  | 
|  | Shape: | 
|  | see the docs in Transformer class. | 
|  | """ | 
|  | src_key_padding_mask = F._canonical_mask( | 
|  | mask=src_key_padding_mask, | 
|  | mask_name="src_key_padding_mask", | 
|  | other_type=F._none_or_dtype(src_mask), | 
|  | other_name="src_mask", | 
|  | target_type=src.dtype | 
|  | ) | 
|  |  | 
|  | src_mask = F._canonical_mask( | 
|  | mask=src_mask, | 
|  | mask_name="src_mask", | 
|  | other_type=None, | 
|  | other_name="", | 
|  | target_type=src.dtype, | 
|  | check_other=False, | 
|  | ) | 
|  |  | 
|  | # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf | 
|  | why_not_sparsity_fast_path = '' | 
|  | if not src.dim() == 3: | 
|  | why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" | 
|  | elif self.training: | 
|  | why_not_sparsity_fast_path = "training is enabled" | 
|  | elif not self.self_attn.batch_first : | 
|  | why_not_sparsity_fast_path = "self_attn.batch_first was not True" | 
|  | elif not self.self_attn._qkv_same_embed_dim : | 
|  | why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" | 
|  | elif not self.activation_relu_or_gelu: | 
|  | why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" | 
|  | elif not (self.norm1.eps == self.norm2.eps): | 
|  | why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" | 
|  | elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None): | 
|  | why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" | 
|  | elif self.self_attn.num_heads % 2 == 1: | 
|  | why_not_sparsity_fast_path = "num_head is odd" | 
|  | elif torch.is_autocast_enabled(): | 
|  | why_not_sparsity_fast_path = "autocast is enabled" | 
|  | if not why_not_sparsity_fast_path: | 
|  | tensor_args = ( | 
|  | src, | 
|  | self.self_attn.in_proj_weight, | 
|  | self.self_attn.in_proj_bias, | 
|  | self.self_attn.out_proj.weight, | 
|  | self.self_attn.out_proj.bias, | 
|  | self.norm1.weight, | 
|  | self.norm1.bias, | 
|  | self.norm2.weight, | 
|  | self.norm2.bias, | 
|  | self.linear1.weight, | 
|  | self.linear1.bias, | 
|  | self.linear2.weight, | 
|  | self.linear2.bias, | 
|  | ) | 
|  |  | 
|  | # We have to use list comprehensions below because TorchScript does not support | 
|  | # generator expressions. | 
|  | _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name] | 
|  | if torch.overrides.has_torch_function(tensor_args): | 
|  | why_not_sparsity_fast_path = "some Tensor argument has_torch_function" | 
|  | elif not all((x.device.type in _supported_device_type) for x in tensor_args): | 
|  | why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of " | 
|  | f"{_supported_device_type}") | 
|  | elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): | 
|  | why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " | 
|  | "input/output projection weights or biases requires_grad") | 
|  |  | 
|  | if not why_not_sparsity_fast_path: | 
|  | merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src) | 
|  | return torch._transformer_encoder_layer_fwd( | 
|  | src, | 
|  | self.self_attn.embed_dim, | 
|  | self.self_attn.num_heads, | 
|  | self.self_attn.in_proj_weight, | 
|  | self.self_attn.in_proj_bias, | 
|  | self.self_attn.out_proj.weight, | 
|  | self.self_attn.out_proj.bias, | 
|  | self.activation_relu_or_gelu == 2, | 
|  | self.norm_first, | 
|  | self.norm1.eps, | 
|  | self.norm1.weight, | 
|  | self.norm1.bias, | 
|  | self.norm2.weight, | 
|  | self.norm2.bias, | 
|  | self.linear1.weight, | 
|  | self.linear1.bias, | 
|  | self.linear2.weight, | 
|  | self.linear2.bias, | 
|  | merged_mask, | 
|  | mask_type, | 
|  | ) | 
|  |  | 
|  |  | 
|  | x = src | 
|  | if self.norm_first: | 
|  | x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal) | 
|  | x = x + self._ff_block(self.norm2(x)) | 
|  | else: | 
|  | x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)) | 
|  | x = self.norm2(x + self._ff_block(x)) | 
|  |  | 
|  | return x | 
|  |  | 
|  | # self-attention block | 
|  | def _sa_block(self, x: Tensor, | 
|  | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: | 
|  | x = self.self_attn(x, x, x, | 
|  | attn_mask=attn_mask, | 
|  | key_padding_mask=key_padding_mask, | 
|  | need_weights=False, is_causal=is_causal)[0] | 
|  | return self.dropout1(x) | 
|  |  | 
|  | # feed forward block | 
|  | def _ff_block(self, x: Tensor) -> Tensor: | 
|  | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | 
|  | return self.dropout2(x) | 
|  |  | 
|  |  | 
|  | class TransformerDecoderLayer(Module): | 
|  | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. | 
|  |  | 
|  | This standard decoder layer is based on the paper "Attention Is All You Need". | 
|  | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, | 
|  | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in | 
|  | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement | 
|  | in a different way during application. | 
|  |  | 
|  | Args: | 
|  | d_model: the number of expected features in the input (required). | 
|  | nhead: the number of heads in the multiheadattention models (required). | 
|  | dim_feedforward: the dimension of the feedforward network model (default=2048). | 
|  | dropout: the dropout value (default=0.1). | 
|  | activation: the activation function of the intermediate layer, can be a string | 
|  | ("relu" or "gelu") or a unary callable. Default: relu | 
|  | layer_norm_eps: the eps value in layer normalization components (default=1e-5). | 
|  | batch_first: If ``True``, then the input and output tensors are provided | 
|  | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). | 
|  | norm_first: if ``True``, layer norm is done prior to self attention, multihead | 
|  | attention and feedforward operations, respectively. Otherwise it's done after. | 
|  | Default: ``False`` (after). | 
|  | bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive | 
|  | bias. Default: ``True``. | 
|  |  | 
|  | Examples:: | 
|  | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) | 
|  | >>> memory = torch.rand(10, 32, 512) | 
|  | >>> tgt = torch.rand(20, 32, 512) | 
|  | >>> out = decoder_layer(tgt, memory) | 
|  |  | 
|  | Alternatively, when ``batch_first`` is ``True``: | 
|  | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) | 
|  | >>> memory = torch.rand(32, 10, 512) | 
|  | >>> tgt = torch.rand(32, 20, 512) | 
|  | >>> out = decoder_layer(tgt, memory) | 
|  | """ | 
|  |  | 
|  | __constants__ = ['norm_first'] | 
|  |  | 
|  | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, | 
|  | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | 
|  | layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, | 
|  | bias: bool = True, device=None, dtype=None) -> None: | 
|  | factory_kwargs = {'device': device, 'dtype': dtype} | 
|  | super().__init__() | 
|  | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, | 
|  | bias=bias, **factory_kwargs) | 
|  | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, | 
|  | bias=bias, **factory_kwargs) | 
|  | # Implementation of Feedforward model | 
|  | self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) | 
|  | self.dropout = Dropout(dropout) | 
|  | self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) | 
|  |  | 
|  | self.norm_first = norm_first | 
|  | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | 
|  | self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | 
|  | self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | 
|  | self.dropout1 = Dropout(dropout) | 
|  | self.dropout2 = Dropout(dropout) | 
|  | self.dropout3 = Dropout(dropout) | 
|  |  | 
|  | # Legacy string support for activation function. | 
|  | if isinstance(activation, str): | 
|  | self.activation = _get_activation_fn(activation) | 
|  | else: | 
|  | self.activation = activation | 
|  |  | 
|  | def __setstate__(self, state): | 
|  | if 'activation' not in state: | 
|  | state['activation'] = F.relu | 
|  | super().__setstate__(state) | 
|  |  | 
|  | def forward( | 
|  | self, | 
|  | tgt: Tensor, | 
|  | memory: Tensor, | 
|  | tgt_mask: Optional[Tensor] = None, | 
|  | memory_mask: Optional[Tensor] = None, | 
|  | tgt_key_padding_mask: Optional[Tensor] = None, | 
|  | memory_key_padding_mask: Optional[Tensor] = None, | 
|  | tgt_is_causal: bool = False, | 
|  | memory_is_causal: bool = False, | 
|  | ) -> Tensor: | 
|  | r"""Pass the inputs (and mask) through the decoder layer. | 
|  |  | 
|  | Args: | 
|  | tgt: the sequence to the decoder layer (required). | 
|  | memory: the sequence from the last layer of the encoder (required). | 
|  | tgt_mask: the mask for the tgt sequence (optional). | 
|  | memory_mask: the mask for the memory sequence (optional). | 
|  | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). | 
|  | memory_key_padding_mask: the mask for the memory keys per batch (optional). | 
|  | tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. | 
|  | Default: ``False``. | 
|  | Warning: | 
|  | ``tgt_is_causal`` provides a hint that ``tgt_mask`` is | 
|  | the causal mask. Providing incorrect hints can result in | 
|  | incorrect execution, including forward and backward | 
|  | compatibility. | 
|  | memory_is_causal: If specified, applies a causal mask as | 
|  | ``memory mask``. | 
|  | Default: ``False``. | 
|  | Warning: | 
|  | ``memory_is_causal`` provides a hint that | 
|  | ``memory_mask`` is the causal mask. Providing incorrect | 
|  | hints can result in incorrect execution, including | 
|  | forward and backward compatibility. | 
|  |  | 
|  | Shape: | 
|  | see the docs in Transformer class. | 
|  | """ | 
|  | # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf | 
|  |  | 
|  | x = tgt | 
|  | if self.norm_first: | 
|  | x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal) | 
|  | x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal) | 
|  | x = x + self._ff_block(self.norm3(x)) | 
|  | else: | 
|  | x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)) | 
|  | x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal)) | 
|  | x = self.norm3(x + self._ff_block(x)) | 
|  |  | 
|  | return x | 
|  |  | 
|  | # self-attention block | 
|  | def _sa_block(self, x: Tensor, | 
|  | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: | 
|  | x = self.self_attn(x, x, x, | 
|  | attn_mask=attn_mask, | 
|  | key_padding_mask=key_padding_mask, | 
|  | is_causal=is_causal, | 
|  | need_weights=False)[0] | 
|  | return self.dropout1(x) | 
|  |  | 
|  | # multihead attention block | 
|  | def _mha_block(self, x: Tensor, mem: Tensor, | 
|  | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: | 
|  | x = self.multihead_attn(x, mem, mem, | 
|  | attn_mask=attn_mask, | 
|  | key_padding_mask=key_padding_mask, | 
|  | is_causal=is_causal, | 
|  | need_weights=False)[0] | 
|  | return self.dropout2(x) | 
|  |  | 
|  | # feed forward block | 
|  | def _ff_block(self, x: Tensor) -> Tensor: | 
|  | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | 
|  | return self.dropout3(x) | 
|  |  | 
|  |  | 
|  | def _get_clones(module, N): | 
|  | # FIXME: copy.deepcopy() is not defined on nn.module | 
|  | return ModuleList([copy.deepcopy(module) for i in range(N)]) | 
|  |  | 
|  |  | 
|  | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: | 
|  | if activation == "relu": | 
|  | return F.relu | 
|  | elif activation == "gelu": | 
|  | return F.gelu | 
|  |  | 
|  | raise RuntimeError(f"activation should be relu/gelu, not {activation}") | 
|  |  | 
|  |  | 
|  | def _detect_is_causal_mask( | 
|  | mask: Optional[Tensor], | 
|  | is_causal: Optional[bool] = None, | 
|  | size: Optional[int] = None, | 
|  | ) -> bool: | 
|  | """Return whether the given attention mask is causal. | 
|  |  | 
|  | Warning: | 
|  | If ``is_causal`` is not ``None``, its value will be returned as is.  If a | 
|  | user supplies an incorrect ``is_causal`` hint, | 
|  |  | 
|  | ``is_causal=False`` when the mask is in fact a causal attention.mask | 
|  | may lead to reduced performance relative to what would be achievable | 
|  | with ``is_causal=True``; | 
|  | ``is_causal=True`` when the mask is in fact not a causal attention.mask | 
|  | may lead to incorrect and unpredictable execution - in some scenarios, | 
|  | a causal mask may be applied based on the hint, in other execution | 
|  | scenarios the specified mask may be used.  The choice may not appear | 
|  | to be deterministic, in that a number of factors like alignment, | 
|  | hardware SKU, etc influence the decision whether to use a mask or | 
|  | rely on the hint. | 
|  | ``size`` if not None, check whether the mask is a causal mask of the provided size | 
|  | Otherwise, checks for any causal mask. | 
|  | """ | 
|  | # Prevent type refinement | 
|  | make_causal = (is_causal is True) | 
|  |  | 
|  | if is_causal is None and mask is not None: | 
|  | sz = size if size is not None else mask.size(-2) | 
|  | causal_comparison = _generate_square_subsequent_mask( | 
|  | sz, device=mask.device, dtype=mask.dtype) | 
|  |  | 
|  | # Do not use `torch.equal` so we handle batched masks by | 
|  | # broadcasting the comparison. | 
|  | if mask.size() == causal_comparison.size(): | 
|  | make_causal = bool((mask == causal_comparison).all()) | 
|  | else: | 
|  | make_causal = False | 
|  |  | 
|  | return make_causal |