| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import torch |
| |
| # pointwise operators can go through a faster pathway |
| |
| tensor_magic_methods = ["add", ""] |
| pointwise_magic_methods_with_reverse = ( |
| "add", |
| "sub", |
| "mul", |
| "floordiv", |
| "div", |
| "truediv", |
| "mod", |
| "pow", |
| "lshift", |
| "rshift", |
| "and", |
| "or", |
| "xor", |
| ) |
| pointwise_magic_methods = ( |
| *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)), |
| "eq", |
| "gt", |
| "le", |
| "lt", |
| "ge", |
| "gt", |
| "ne", |
| "neg", |
| "pos", |
| "abs", |
| "invert", |
| "iadd", |
| "isub", |
| "imul", |
| "ifloordiv", |
| "idiv", |
| "itruediv", |
| "imod", |
| "ipow", |
| "ilshift", |
| "irshift", |
| "iand", |
| "ior", |
| "ixor", |
| "int", |
| "long", |
| "float", |
| "complex", |
| ) |
| |
| pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),) |
| |
| pointwise = ( |
| *(getattr(torch.Tensor, m) for m in pointwise_methods), |
| torch.nn.functional.dropout, |
| torch.where, |
| torch.Tensor.abs, |
| torch.abs, |
| torch.Tensor.acos, |
| torch.acos, |
| torch.Tensor.acosh, |
| torch.acosh, |
| torch.Tensor.add, |
| torch.add, |
| torch.Tensor.addcdiv, |
| torch.addcdiv, |
| torch.Tensor.addcmul, |
| torch.addcmul, |
| torch.Tensor.addr, |
| torch.addr, |
| torch.Tensor.angle, |
| torch.angle, |
| torch.Tensor.asin, |
| torch.asin, |
| torch.Tensor.asinh, |
| torch.asinh, |
| torch.Tensor.atan, |
| torch.atan, |
| torch.Tensor.atan2, |
| torch.atan2, |
| torch.Tensor.atanh, |
| torch.atanh, |
| torch.Tensor.bitwise_and, |
| torch.bitwise_and, |
| torch.Tensor.bitwise_left_shift, |
| torch.bitwise_left_shift, |
| torch.Tensor.bitwise_not, |
| torch.bitwise_not, |
| torch.Tensor.bitwise_or, |
| torch.bitwise_or, |
| torch.Tensor.bitwise_right_shift, |
| torch.bitwise_right_shift, |
| torch.Tensor.bitwise_xor, |
| torch.bitwise_xor, |
| torch.Tensor.ceil, |
| torch.ceil, |
| torch.celu, |
| torch.nn.functional.celu, |
| torch.Tensor.clamp, |
| torch.clamp, |
| torch.Tensor.clamp_max, |
| torch.clamp_max, |
| torch.Tensor.clamp_min, |
| torch.clamp_min, |
| torch.Tensor.copysign, |
| torch.copysign, |
| torch.Tensor.cos, |
| torch.cos, |
| torch.Tensor.cosh, |
| torch.cosh, |
| torch.Tensor.deg2rad, |
| torch.deg2rad, |
| torch.Tensor.digamma, |
| torch.digamma, |
| torch.Tensor.div, |
| torch.div, |
| torch.dropout, |
| torch.nn.functional.dropout, |
| torch.nn.functional.elu, |
| torch.Tensor.eq, |
| torch.eq, |
| torch.Tensor.erf, |
| torch.erf, |
| torch.Tensor.erfc, |
| torch.erfc, |
| torch.Tensor.erfinv, |
| torch.erfinv, |
| torch.Tensor.exp, |
| torch.exp, |
| torch.Tensor.exp2, |
| torch.exp2, |
| torch.Tensor.expm1, |
| torch.expm1, |
| torch.feature_dropout, |
| torch.Tensor.float_power, |
| torch.float_power, |
| torch.Tensor.floor, |
| torch.floor, |
| torch.Tensor.floor_divide, |
| torch.floor_divide, |
| torch.Tensor.fmod, |
| torch.fmod, |
| torch.Tensor.frac, |
| torch.frac, |
| torch.Tensor.frexp, |
| torch.frexp, |
| torch.Tensor.gcd, |
| torch.gcd, |
| torch.Tensor.ge, |
| torch.ge, |
| torch.nn.functional.gelu, |
| torch.nn.functional.glu, |
| torch.Tensor.gt, |
| torch.gt, |
| torch.Tensor.hardshrink, |
| torch.hardshrink, |
| torch.nn.functional.hardshrink, |
| torch.nn.functional.hardsigmoid, |
| torch.nn.functional.hardswish, |
| torch.nn.functional.hardtanh, |
| torch.Tensor.heaviside, |
| torch.heaviside, |
| torch.Tensor.hypot, |
| torch.hypot, |
| torch.Tensor.i0, |
| torch.i0, |
| torch.Tensor.igamma, |
| torch.igamma, |
| torch.Tensor.igammac, |
| torch.igammac, |
| torch.Tensor.isclose, |
| torch.isclose, |
| torch.Tensor.isfinite, |
| torch.isfinite, |
| torch.Tensor.isinf, |
| torch.isinf, |
| torch.Tensor.isnan, |
| torch.isnan, |
| torch.Tensor.isneginf, |
| torch.isneginf, |
| torch.Tensor.isposinf, |
| torch.isposinf, |
| torch.Tensor.isreal, |
| torch.isreal, |
| torch.Tensor.kron, |
| torch.kron, |
| torch.Tensor.lcm, |
| torch.lcm, |
| torch.Tensor.ldexp, |
| torch.ldexp, |
| torch.Tensor.le, |
| torch.le, |
| torch.nn.functional.leaky_relu, |
| torch.Tensor.lerp, |
| torch.lerp, |
| torch.Tensor.lgamma, |
| torch.lgamma, |
| torch.Tensor.log, |
| torch.log, |
| torch.Tensor.log10, |
| torch.log10, |
| torch.Tensor.log1p, |
| torch.log1p, |
| torch.Tensor.log2, |
| torch.log2, |
| torch.nn.functional.logsigmoid, |
| torch.Tensor.logical_and, |
| torch.logical_and, |
| torch.Tensor.logical_not, |
| torch.logical_not, |
| torch.Tensor.logical_or, |
| torch.logical_or, |
| torch.Tensor.logical_xor, |
| torch.logical_xor, |
| torch.Tensor.logit, |
| torch.logit, |
| torch.Tensor.lt, |
| torch.lt, |
| torch.Tensor.maximum, |
| torch.maximum, |
| torch.Tensor.minimum, |
| torch.minimum, |
| torch.nn.functional.mish, |
| torch.Tensor.mvlgamma, |
| torch.mvlgamma, |
| torch.Tensor.nan_to_num, |
| torch.nan_to_num, |
| torch.Tensor.ne, |
| torch.ne, |
| torch.Tensor.neg, |
| torch.neg, |
| torch.Tensor.nextafter, |
| torch.nextafter, |
| torch.Tensor.outer, |
| torch.outer, |
| torch.polar, |
| torch.Tensor.polygamma, |
| torch.polygamma, |
| torch.Tensor.positive, |
| torch.positive, |
| torch.Tensor.pow, |
| torch.pow, |
| torch.Tensor.prelu, |
| torch.prelu, |
| torch.nn.functional.prelu, |
| torch.Tensor.rad2deg, |
| torch.rad2deg, |
| torch.Tensor.reciprocal, |
| torch.reciprocal, |
| torch.Tensor.relu, |
| torch.relu, |
| torch.nn.functional.relu, |
| torch.nn.functional.relu6, |
| torch.Tensor.remainder, |
| torch.remainder, |
| torch.Tensor.round, |
| torch.round, |
| torch.rrelu, |
| torch.nn.functional.rrelu, |
| torch.Tensor.rsqrt, |
| torch.rsqrt, |
| torch.rsub, |
| torch.selu, |
| torch.nn.functional.selu, |
| torch.Tensor.sgn, |
| torch.sgn, |
| torch.Tensor.sigmoid, |
| torch.sigmoid, |
| torch.nn.functional.sigmoid, |
| torch.Tensor.sign, |
| torch.sign, |
| torch.Tensor.signbit, |
| torch.signbit, |
| torch.nn.functional.silu, |
| torch.Tensor.sin, |
| torch.sin, |
| torch.Tensor.sinc, |
| torch.sinc, |
| torch.Tensor.sinh, |
| torch.sinh, |
| torch.nn.functional.softplus, |
| torch.nn.functional.softshrink, |
| torch.Tensor.sqrt, |
| torch.sqrt, |
| torch.Tensor.square, |
| torch.square, |
| torch.Tensor.sub, |
| torch.sub, |
| torch.Tensor.tan, |
| torch.tan, |
| torch.Tensor.tanh, |
| torch.tanh, |
| torch.nn.functional.tanh, |
| torch.threshold, |
| torch.nn.functional.threshold, |
| torch.trapz, |
| torch.Tensor.true_divide, |
| torch.true_divide, |
| torch.Tensor.trunc, |
| torch.trunc, |
| torch.Tensor.xlogy, |
| torch.xlogy, |
| torch.rand_like, |
| ) |