|  | from typing import Optional | 
|  | import warnings | 
|  |  | 
|  | # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h | 
|  |  | 
|  |  | 
|  | def get_enum(reduction: str) -> int: | 
|  | if reduction == 'none': | 
|  | ret = 0 | 
|  | elif reduction == 'mean': | 
|  | ret = 1 | 
|  | elif reduction == 'elementwise_mean': | 
|  | warnings.warn("reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.") | 
|  | ret = 1 | 
|  | elif reduction == 'sum': | 
|  | ret = 2 | 
|  | else: | 
|  | ret = -1  # TODO: remove once JIT exceptions support control flow | 
|  | raise ValueError("{} is not a valid value for reduction".format(reduction)) | 
|  | return ret | 
|  |  | 
|  | # In order to support previous versions, accept boolean size_average and reduce | 
|  | # and convert them into the new constants for now | 
|  |  | 
|  |  | 
|  | # We use these functions in torch/legacy as well, in which case we'll silence the warning | 
|  | def legacy_get_string(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> str: | 
|  | warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." | 
|  |  | 
|  | if size_average is None: | 
|  | size_average = True | 
|  | if reduce is None: | 
|  | reduce = True | 
|  |  | 
|  | if size_average and reduce: | 
|  | ret = 'mean' | 
|  | elif reduce: | 
|  | ret = 'sum' | 
|  | else: | 
|  | ret = 'none' | 
|  | if emit_warning: | 
|  | warnings.warn(warning.format(ret)) | 
|  | return ret | 
|  |  | 
|  |  | 
|  | def legacy_get_enum(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> int: | 
|  | return get_enum(legacy_get_string(size_average, reduce, emit_warning)) |