| # Copyright (c) Meta Platforms, Inc. and affiliates |
| |
| import torch |
| |
| from .core import _map_mt_args_kwargs, _wrap_result |
| |
| __all__ = [] # type: ignore[var-annotated] |
| |
| |
| UNARY_NAMES = [ |
| "abs", |
| "absolute", |
| "acos", |
| "arccos", |
| "acosh", |
| "arccosh", |
| "angle", |
| "asin", |
| "arcsin", |
| "asinh", |
| "arcsinh", |
| "atan", |
| "arctan", |
| "atanh", |
| "arctanh", |
| "bitwise_not", |
| "ceil", |
| "clamp", |
| "clip", |
| "conj_physical", |
| "cos", |
| "cosh", |
| "deg2rad", |
| "digamma", |
| "erf", |
| "erfc", |
| "erfinv", |
| "exp", |
| "exp2", |
| "expm1", |
| "fix", |
| "floor", |
| "frac", |
| "lgamma", |
| "log", |
| "log10", |
| "log1p", |
| "log2", |
| "logit", |
| "i0", |
| "isnan", |
| "nan_to_num", |
| "neg", |
| "negative", |
| "positive", |
| "pow", |
| "rad2deg", |
| "reciprocal", |
| "round", |
| "rsqrt", |
| "sigmoid", |
| "sign", |
| "sgn", |
| "signbit", |
| "sin", |
| "sinc", |
| "sinh", |
| "sqrt", |
| "square", |
| "tan", |
| "tanh", |
| "trunc", |
| ] |
| |
| INPLACE_UNARY_NAMES = [ |
| n + "_" |
| for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"})) |
| ] |
| |
| # Explicitly tracking functions we know are currently not supported |
| # This might be due to missing code gen or because of complex semantics |
| UNARY_NAMES_UNSUPPORTED = [ |
| "atan2", |
| "arctan2", |
| "bitwise_left_shift", |
| "bitwise_right_shift", |
| "copysign", |
| "float_power", |
| "fmod", |
| "frexp", |
| "gradient", |
| "imag", |
| "ldexp", |
| "lerp", |
| "logical_not", |
| "hypot", |
| "igamma", |
| "igammac", |
| "mvlgamma", |
| "nextafter", |
| "polygamma", |
| "real", |
| "remainder", |
| "true_divide", |
| "xlogy", |
| ] |
| |
| |
| def _unary_helper(fn, args, kwargs, inplace): |
| if len(kwargs) != 0: |
| raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. " |
| "If you need support for this, please open an issue on Github.") |
| for a in args[1:]: |
| if torch.is_tensor(a): |
| raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments") |
| |
| mask_args, mask_kwargs = _map_mt_args_kwargs( |
| args, kwargs, lambda x: x._masked_mask |
| ) |
| data_args, data_kwargs = _map_mt_args_kwargs( |
| args, kwargs, lambda x: x._masked_data |
| ) |
| |
| if args[0].layout == torch.sparse_coo: |
| data_args[0] = data_args[0].coalesce() |
| s = data_args[0].size() |
| i = data_args[0].indices() |
| data_args[0] = data_args[0].coalesce().values() |
| v = fn(*data_args) |
| result_data = torch.sparse_coo_tensor(i, v, size=s) |
| |
| elif args[0].layout == torch.sparse_csr: |
| crow = data_args[0].crow_indices() |
| col = data_args[0].col_indices() |
| data_args[0] = data_args[0].values() |
| v = fn(*data_args) |
| result_data = torch.sparse_csr_tensor(crow, col, v) |
| |
| else: |
| result_data = fn(*data_args) |
| |
| if inplace: |
| args[0]._set_data_mask(result_data, mask_args[0]) |
| return args[0] |
| else: |
| return _wrap_result(result_data, mask_args[0]) |
| |
| |
| def _torch_unary(fn_name): |
| fn = getattr(torch.ops.aten, fn_name) |
| |
| def unary_fn(*args, **kwargs): |
| return _unary_helper(fn, args, kwargs, inplace=False) |
| |
| return unary_fn |
| |
| |
| def _torch_inplace_unary(fn_name): |
| fn = getattr(torch.ops.aten, fn_name) |
| |
| def unary_fn(*args, **kwargs): |
| return _unary_helper(fn, args, kwargs, inplace=True) |
| |
| return unary_fn |
| |
| |
| NATIVE_UNARY_MAP = { |
| getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES |
| } |
| NATIVE_INPLACE_UNARY_MAP = { |
| getattr(torch.ops.aten, name): _torch_inplace_unary(name) |
| for name in INPLACE_UNARY_NAMES |
| } |
| |
| NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys()) |
| NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys()) |
| |
| |
| def _is_native_unary(fn): |
| return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS |
| |
| |
| def _apply_native_unary(fn, *args, **kwargs): |
| if fn in NATIVE_UNARY_FNS: |
| return NATIVE_UNARY_MAP[fn](*args, **kwargs) |
| if fn in NATIVE_INPLACE_UNARY_FNS: |
| return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs) |
| return NotImplemented |