blob: 78a174cd6d831674792bc03a743c6c20d103b5fb [file] [log] [blame]
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_helper import _black_list_in_opset
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 11
black_listed_operators = [
"eq", "ne", "scatter", "clamp", "clamp_min", "clamp_max", "sort", "topk", "hardtanh"
]
for black_listed_op in black_listed_operators:
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
@parse_args('v', 'i')
def pixel_shuffle(g, self, upscale_factor):
dims = self.type().sizes()
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
@parse_args('v', 'i', 'none')
def cumsum(g, self, dim, dtype=None):
dim_tensor = g.op("Constant", value_t=torch.tensor(dim))
csum = g.op("CumSum", self, dim_tensor)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
csum = g.op("Cast", csum, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return csum