|  | """This file exports ONNX ops for opset 15. | 
|  |  | 
|  | Note [ONNX operators that are added/updated in opset 15] | 
|  | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | 
|  | https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set | 
|  | New operators: | 
|  | Bernoulli | 
|  | CastLike | 
|  | Optional | 
|  | OptionalGetElement | 
|  | OptionalHasElement | 
|  |  | 
|  | Updated operators: | 
|  | BatchNormalization https://github.com/onnx/onnx/pull/3545 | 
|  | Backwards compatible | 
|  | TODO: test coverage for mixed types inputs. | 
|  | Pow                https://github.com/onnx/onnx/pull/3412 | 
|  | Backwards compatible | 
|  | TODO: bfloat16 support. | 
|  | Shape              https://github.com/onnx/onnx/pull/3580 | 
|  | Backwards compatible | 
|  | TODO: optional start/end attribute. | 
|  | """ | 
|  |  | 
|  | # EDITING THIS FILE? READ THIS FIRST! | 
|  | # see Note [Edit Symbolic Files] in README.md | 
|  |  | 
|  | import functools | 
|  |  | 
|  | import torch | 
|  | from torch import _C | 
|  | from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 | 
|  | from torch.onnx._internal import _beartype, jit_utils, registration | 
|  |  | 
|  | _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::__is_") | 
|  | @_beartype.beartype | 
|  | def aten__is_(g: jit_utils.GraphContext, self, other): | 
|  | if symbolic_helper._is_none(other): | 
|  | if isinstance(self.type(), _C.OptionalType): | 
|  | none = g.op("OptionalHasElement", self) | 
|  | return g.op("Not", none) | 
|  | else: | 
|  | return g.op("Constant", value_t=torch.BoolTensor([0])) | 
|  | return opset9.eq(g, self, other) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::__isnot_") | 
|  | @opset9.wrap_logical_op_with_negation  # type: ignore[has-type] | 
|  | @_beartype.beartype | 
|  | def aten__isnot_(g: jit_utils.GraphContext, self, other): | 
|  | return aten__is_(g, self, other) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::bernoulli") | 
|  | @_beartype.beartype | 
|  | def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): | 
|  | if out is not None and not symbolic_helper._is_none(out): | 
|  | symbolic_helper._unimplemented( | 
|  | "Bernoulli", "out parameter is not supported for bernoulli", input | 
|  | ) | 
|  | if generator is not None and not symbolic_helper._is_none(generator): | 
|  | symbolic_helper._unimplemented( | 
|  | "Bernoulli", "generator is not supported for bernoulli", input | 
|  | ) | 
|  | if p is None or symbolic_helper._is_none(p): | 
|  | return g.op("Bernoulli", input) | 
|  | return opset9.bernoulli(g, input, p, generator, out) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("prim::unchecked_cast") | 
|  | @_beartype.beartype | 
|  | def prim_unchecked_cast(g: jit_utils.GraphContext, self): | 
|  | # exists to refine the type of the Value | 
|  | # if x is Optional[Tensor], unchecked_cast will cast | 
|  | # x to Tensor, so the rest of the graph knows that x is a Tensor. | 
|  | if isinstance(self.type(), _C.OptionalType): | 
|  | return g.op("OptionalGetElement", self) | 
|  |  | 
|  | return self |