| # EDITING THIS FILE? READ THIS FIRST! |
| # see Note [Edit Symbolic Files] in symbolic_helper.py |
| |
| # This file exports ONNX ops for opset 16 |
| |
| # Note [ONNX Operators that are added/updated in opset 16] |
| # |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set |
| # New operators: |
| # GridSample https://github.com/onnx/onnx/pull/3557 |
| # |
| # Updated operators: |
| # Identity |
| # If |
| # LeakyRelu |
| # Loop |
| # PRelu |
| # RoiAlign |
| # Scan |
| # ScatterElemenets |
| # ScatterND |
| # Where |
| # GreaterOrEqual |
| # LessOrEqual |
| # SequenceMap |
| |
| from torch.nn.functional import ( |
| GRID_SAMPLE_INTERPOLATION_MODES, |
| GRID_SAMPLE_PADDING_MODES, |
| ) |
| from torch.onnx.symbolic_helper import parse_args |
| |
| |
| # note (mkozuki): Why `grid_sampler` instead of `grid_sample`? |
| # Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. |
| @parse_args("v", "v", "i", "i", "b") |
| def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners): |
| mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] |
| padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg] |
| return g.op( |
| "GridSample", |
| input, |
| grid, |
| align_corners_i=int(align_corners), |
| mode_s=mode_s, |
| padding_mode_s=padding_mode_s, |
| ) |