|  | # Copyright (c) Meta Platforms, Inc. and affiliates | 
|  |  | 
|  | import torch | 
|  |  | 
|  | from .core import _map_mt_args_kwargs, _wrap_result | 
|  |  | 
|  | __all__ = []  # type: ignore[var-annotated] | 
|  |  | 
|  |  | 
|  | UNARY_NAMES = [ | 
|  | "abs", | 
|  | "absolute", | 
|  | "acos", | 
|  | "arccos", | 
|  | "acosh", | 
|  | "arccosh", | 
|  | "angle", | 
|  | "asin", | 
|  | "arcsin", | 
|  | "asinh", | 
|  | "arcsinh", | 
|  | "atan", | 
|  | "arctan", | 
|  | "atanh", | 
|  | "arctanh", | 
|  | "bitwise_not", | 
|  | "ceil", | 
|  | "clamp", | 
|  | "clip", | 
|  | "conj_physical", | 
|  | "cos", | 
|  | "cosh", | 
|  | "deg2rad", | 
|  | "digamma", | 
|  | "erf", | 
|  | "erfc", | 
|  | "erfinv", | 
|  | "exp", | 
|  | "exp2", | 
|  | "expm1", | 
|  | "fix", | 
|  | "floor", | 
|  | "frac", | 
|  | "lgamma", | 
|  | "log", | 
|  | "log10", | 
|  | "log1p", | 
|  | "log2", | 
|  | "logit", | 
|  | "i0", | 
|  | "isnan", | 
|  | "nan_to_num", | 
|  | "neg", | 
|  | "negative", | 
|  | "positive", | 
|  | "pow", | 
|  | "rad2deg", | 
|  | "reciprocal", | 
|  | "round", | 
|  | "rsqrt", | 
|  | "sigmoid", | 
|  | "sign", | 
|  | "sgn", | 
|  | "signbit", | 
|  | "sin", | 
|  | "sinc", | 
|  | "sinh", | 
|  | "sqrt", | 
|  | "square", | 
|  | "tan", | 
|  | "tanh", | 
|  | "trunc", | 
|  | ] | 
|  |  | 
|  | INPLACE_UNARY_NAMES = [ | 
|  | n + "_" | 
|  | for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"})) | 
|  | ] | 
|  |  | 
|  | # Explicitly tracking functions we know are currently not supported | 
|  | # This might be due to missing code gen or because of complex semantics | 
|  | UNARY_NAMES_UNSUPPORTED = [ | 
|  | "atan2", | 
|  | "arctan2", | 
|  | "bitwise_left_shift", | 
|  | "bitwise_right_shift", | 
|  | "copysign", | 
|  | "float_power", | 
|  | "fmod", | 
|  | "frexp", | 
|  | "gradient", | 
|  | "imag", | 
|  | "ldexp", | 
|  | "lerp", | 
|  | "logical_not", | 
|  | "hypot", | 
|  | "igamma", | 
|  | "igammac", | 
|  | "mvlgamma", | 
|  | "nextafter", | 
|  | "polygamma", | 
|  | "real", | 
|  | "remainder", | 
|  | "true_divide", | 
|  | "xlogy", | 
|  | ] | 
|  |  | 
|  |  | 
|  | def _unary_helper(fn, args, kwargs, inplace): | 
|  | if len(kwargs) != 0: | 
|  | raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. " | 
|  | "If you need support for this, please open an issue on Github.") | 
|  | for a in args[1:]: | 
|  | if torch.is_tensor(a): | 
|  | raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments") | 
|  |  | 
|  | mask_args, mask_kwargs = _map_mt_args_kwargs( | 
|  | args, kwargs, lambda x: x._masked_mask | 
|  | ) | 
|  | data_args, data_kwargs = _map_mt_args_kwargs( | 
|  | args, kwargs, lambda x: x._masked_data | 
|  | ) | 
|  |  | 
|  | if args[0].layout == torch.sparse_coo: | 
|  | data_args[0] = data_args[0].coalesce() | 
|  | s = data_args[0].size() | 
|  | i = data_args[0].indices() | 
|  | data_args[0] = data_args[0].coalesce().values() | 
|  | v = fn(*data_args) | 
|  | result_data = torch.sparse_coo_tensor(i, v, size=s) | 
|  |  | 
|  | elif args[0].layout == torch.sparse_csr: | 
|  | crow = data_args[0].crow_indices() | 
|  | col = data_args[0].col_indices() | 
|  | data_args[0] = data_args[0].values() | 
|  | v = fn(*data_args) | 
|  | result_data = torch.sparse_csr_tensor(crow, col, v) | 
|  |  | 
|  | else: | 
|  | result_data = fn(*data_args) | 
|  |  | 
|  | if inplace: | 
|  | args[0]._set_data_mask(result_data, mask_args[0]) | 
|  | return args[0] | 
|  | else: | 
|  | return _wrap_result(result_data, mask_args[0]) | 
|  |  | 
|  |  | 
|  | def _torch_unary(fn_name): | 
|  | fn = getattr(torch.ops.aten, fn_name) | 
|  |  | 
|  | def unary_fn(*args, **kwargs): | 
|  | return _unary_helper(fn, args, kwargs, inplace=False) | 
|  |  | 
|  | return unary_fn | 
|  |  | 
|  |  | 
|  | def _torch_inplace_unary(fn_name): | 
|  | fn = getattr(torch.ops.aten, fn_name) | 
|  |  | 
|  | def unary_fn(*args, **kwargs): | 
|  | return _unary_helper(fn, args, kwargs, inplace=True) | 
|  |  | 
|  | return unary_fn | 
|  |  | 
|  |  | 
|  | NATIVE_UNARY_MAP = { | 
|  | getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES | 
|  | } | 
|  | NATIVE_INPLACE_UNARY_MAP = { | 
|  | getattr(torch.ops.aten, name): _torch_inplace_unary(name) | 
|  | for name in INPLACE_UNARY_NAMES | 
|  | } | 
|  |  | 
|  | NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys()) | 
|  | NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys()) | 
|  |  | 
|  |  | 
|  | def _is_native_unary(fn): | 
|  | return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS | 
|  |  | 
|  |  | 
|  | def _apply_native_unary(fn, *args, **kwargs): | 
|  | if fn in NATIVE_UNARY_FNS: | 
|  | return NATIVE_UNARY_MAP[fn](*args, **kwargs) | 
|  | if fn in NATIVE_INPLACE_UNARY_FNS: | 
|  | return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs) | 
|  | return NotImplemented |