| """This file exports ONNX ops for opset 20. |
| |
| Note [ONNX Operators that are added/updated in opset 20] |
| |
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set |
| New operators: |
| AffineGrid |
| ConstantOfShape |
| DFT |
| Gelu |
| GridSample |
| ImageDecoder |
| IsInf |
| IsNaN |
| ReduceMax |
| ReduceMin |
| RegexFullMatch |
| StringConcat |
| StringSplit |
| """ |
| |
| import functools |
| |
| import torch.nn.functional as F |
| |
| from torch import _C |
| from torch.onnx import symbolic_helper |
| from torch.onnx._internal import _beartype, jit_utils, registration |
| |
| # EDITING THIS FILE? READ THIS FIRST! |
| # see Note [Edit Symbolic Files] in symbolic_helper.py |
| |
| __all__ = ["_grid_sampler", "_affine_grid_generator"] |
| |
| |
| def convert_grid_sample_mode(mode_s): |
| return ( |
| "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s |
| ) |
| |
| |
| _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) |
| |
| |
| @_onnx_symbolic("aten::grid_sampler") |
| @symbolic_helper.parse_args("v", "v", "i", "i", "b") |
| @_beartype.beartype |
| def _grid_sampler( |
| g: jit_utils.GraphContext, |
| input: _C.Value, |
| grid: _C.Value, |
| mode_enum: int, |
| padding_mode_enum: int, |
| align_corners: bool, |
| ): |
| mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] |
| # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html |
| mode_s = convert_grid_sample_mode(mode_s) |
| padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg, index] |
| return g.op( |
| "GridSample", |
| input, |
| grid, |
| align_corners_i=int(align_corners), |
| mode_s=mode_s, |
| padding_mode_s=padding_mode_s, |
| ) |
| |
| |
| @_onnx_symbolic("aten::affine_grid_generator") |
| @symbolic_helper.parse_args("v", "v", "b") |
| @_beartype.beartype |
| def _affine_grid_generator( |
| g: jit_utils.GraphContext, |
| theta: _C.Value, |
| size: _C.Value, |
| align_corners: bool, |
| ): |
| return g.op( |
| "AffineGrid", |
| theta, |
| size, |
| align_corners_i=int(align_corners), |
| ) |