blob: b3ed622e8b2963bc19d1b33f7ef46f8484968b93 [file]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import typing
from typing import Callable, Union
import numpy as np
import torch
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def distance(fn: Callable[[np.ndarray, np.ndarray], float]) -> Callable[
[
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
typing.Union[np.ndarray, torch._tensor.Tensor],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
typing.Union[np.ndarray, torch._tensor.Tensor],
],
float,
]:
# A distance decorator that performs all the necessary checkes before calculating
# the distance between two N-D tensors given a function. This can be a RMS
# function, maximum abs diff, or any kind of distance function.
def wrapper(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
a: Union[np.ndarray, torch.Tensor],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
b: Union[np.ndarray, torch.Tensor],
) -> float:
# convert a and b to np.ndarray type fp64
a = to_np_arr_fp64(a)
b = to_np_arr_fp64(b)
# return NaN if shape mismatches
if a.shape != b.shape:
return np.nan
# After we make sure shape matches, check if it's empty. If yes, return 0
if a.size == 0:
return 0
# np.isinf and np.isnan returns a Boolean mask. Check if Inf or NaN occur at
# the same places in a and b. If not, return NaN
if np.any(np.isinf(a) != np.isinf(b)) or np.any(np.isnan(a) != np.isnan(b)):
return np.nan
# mask out all the values that are either Inf or NaN
mask = np.isinf(a) | np.isnan(a)
if np.any(mask):
logging.warning("Found inf/nan in tensor when calculating the distance")
a_masked = a[~mask]
b_masked = b[~mask]
# after masking, the resulting tensor might be empty. If yes, return 0
if a_masked.size == 0:
return 0
# only compare the rest (those that are actually numbers) using the metric
return fn(a_masked, b_masked)
return wrapper
@distance
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def rms(a: np.ndarray, b: np.ndarray) -> float:
return ((a - b) ** 2).mean() ** 0.5
@distance
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def max_abs_diff(a: np.ndarray, b: np.ndarray) -> float:
return np.abs(a - b).max()
@distance
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def max_rel_diff(x: np.ndarray, x_ref: np.ndarray) -> float:
return np.abs((x - x_ref) / x_ref).max()
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def to_np_arr_fp64(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if isinstance(x, np.ndarray):
x = x.astype(np.float64)
return x
# pyre-fixme[3]: Return type must be annotated.
def normalized_rms(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
predicted: Union[np.ndarray, torch.Tensor],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
ground_truth: Union[np.ndarray, torch.Tensor],
):
num = rms(predicted, ground_truth)
if num == 0:
return 0
den = np.linalg.norm(to_np_arr_fp64(ground_truth))
return np.float64(num) / np.float64(den)