blob: 82219d173fadcc5acbed6fed7c3073983a353b96 [file] [log] [blame]
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 16
# Note [ONNX Operators that are added/updated in opset 16]
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
# New operators:
# GridSample https://github.com/onnx/onnx/pull/3557
#
# Updated operators:
# Identity
# If
# LeakyRelu
# Loop
# PRelu
# RoiAlign
# Scan
# ScatterElemenets
# ScatterND
# Where
# GreaterOrEqual
# LessOrEqual
# SequenceMap
from torch.nn.functional import (
GRID_SAMPLE_INTERPOLATION_MODES,
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx.symbolic_helper import parse_args
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@parse_args("v", "v", "i", "i", "b")
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)