blob: 8ec2eefbd655dff407cc395e06e2d27bac5b3f84 [file] [log] [blame]
"""This file exports ONNX ops for opset 17.
Note [ONNX Operators that are added/updated in opset 17]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
New operators:
BlackmanWindow
DFT
HammingWindow
HannWindow
LayerNormalization
MelWeightMatrix
STFT
SequenceMap
"""
from typing import Sequence
from torch import _C
from torch.onnx import symbolic_helper
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def layer_norm(
g,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
cudnn_enable: bool,
):
# normalized_shape: input shape from an expected input of size
# axis: The first normalization dimension.
# layer_norm normalizes on the last D dimensions,
# where D is the size of normalized_shape
axis = -len(normalized_shape)
return g.op(
"LayerNormalization",
input,
weight,
bias,
epsilon_f=eps,
axis_i=axis,
)