Add inf norm support for _foreach_norm (#118441)
Fixes #117803
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118441
Approved by: https://github.com/mlazos
diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
index d8af951..5e0a9d8 100644
--- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
@@ -20,16 +20,16 @@
namespace at::native {
+// _foreach_norm supports only L1, L2, and inf norm
+enum class NormType { L1, L2, LInf };
+
template <
typename T,
- int NormType,
+ NormType norm_type,
int depth = 1,
int r_args_depth = 1,
int res_arg_index = 0>
struct LpNormFunctor {
- static_assert(
- NormType == 1 || NormType == 2,
- "foreach_norm supports only L1 and L2 norm");
using opmath_t = typename at::opmath_type<T>;
__device__ __forceinline__ void operator()(
int chunk_size,
@@ -61,7 +61,11 @@
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
opmath_t next = static_cast<opmath_t>(r_x[ii]);
- vals[ii] += NormType == 1 ? ::abs(next) : next * next;
+ if constexpr (norm_type == NormType::LInf) {
+ vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
+ } else {
+ vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next;
+ }
}
}
} else {
@@ -72,7 +76,11 @@
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
opmath_t next = static_cast<opmath_t>(x[i]);
- vals[ii] += NormType == 1 ? ::abs(next) : next * next;
+ if constexpr (norm_type == NormType::LInf) {
+ vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
+ } else {
+ vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next;
+ }
}
}
}
@@ -80,19 +88,28 @@
auto val = opmath_t(0);
for (int i = 0; i < kILP; i++) {
- val += vals[i];
+ if constexpr (norm_type == NormType::LInf) {
+ val = max_propagate_nan(val, vals[i]);
+ } else {
+ val += vals[i];
+ }
}
- auto final = at::native::cuda_utils::BlockReduceSum(val, s_vals);
+ auto final_val = norm_type == NormType::L1 || norm_type == NormType::L2
+ ? at::native::cuda_utils::BlockReduceSum(val, s_vals)
+ : at::native::cuda_utils::BlockReduceMax(val, s_vals);
if (threadIdx.x == 0) {
output_per_tensor
[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor +
- chunk_idx] = final;
+ chunk_idx] = final_val;
}
}
};
-template <typename T, int NormType, typename opmath_t = at::opmath_type<T>>
+template <
+ typename T,
+ NormType norm_type,
+ typename opmath_t = at::opmath_type<T>>
__global__ void lpnorm_cleanup(
const opmath_t* output_per_tensor,
T* ret_per_tensor,
@@ -103,11 +120,20 @@
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
opmath_t val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
- val += output_this_tensor[i];
+ if constexpr (norm_type == NormType::LInf) {
+ val = max_propagate_nan(val, output_this_tensor[i]);
+ } else {
+ val += output_this_tensor[i];
+ }
}
- opmath_t final = at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals);
+ opmath_t final_val = norm_type == NormType::L1 || norm_type == NormType::L2
+ ? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
+ : at::native::cuda_utils::BlockReduceMax(val, vals);
if (threadIdx.x == 0) {
- ret_per_tensor[blockIdx.x] = NormType == 1 ? final : ::sqrt(final);
+ ret_per_tensor[blockIdx.x] =
+ norm_type == NormType::L1 || norm_type == NormType::LInf
+ ? final_val
+ : ::sqrt(final_val);
}
}
@@ -135,7 +161,8 @@
at::isComplexType(scalar_type);
});
if (!can_use_fast_route(tensors) || has_int_or_complex ||
- !(p == static_cast<double>(1) || p == static_cast<double>(2))) {
+ !(p == static_cast<double>(1) || p == static_cast<double>(2) ||
+ p == std::numeric_limits<double>::infinity())) {
return foreach_tensor_norm_slow(tensors, ord);
}
@@ -166,14 +193,14 @@
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<1>(
tensor_lists,
- LpNormFunctor<scalar_t, 1>(),
+ LpNormFunctor<scalar_t, NormType::L1>(),
output_per_tensor.mutable_data_ptr<opmath_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
- lpnorm_cleanup<scalar_t, 1><<<ntensors, 512, 0, stream>>>(
+ lpnorm_cleanup<scalar_t, NormType::L1><<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
max_chunks_per_tensor);
@@ -189,19 +216,43 @@
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<1>(
tensor_lists,
- LpNormFunctor<scalar_t, 2>(),
+ LpNormFunctor<scalar_t, NormType::L2>(),
output_per_tensor.mutable_data_ptr<opmath_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
- lpnorm_cleanup<scalar_t, 2><<<ntensors, 512, 0, stream>>>(
+ lpnorm_cleanup<scalar_t, NormType::L2><<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
+ } else if (p == std::numeric_limits<double>::infinity()) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ kHalf,
+ kBFloat16,
+ tensor_lists[0][0].scalar_type(),
+ "foreach_tensor_norm_cuda",
+ [&]() {
+ using opmath_t = typename at::opmath_type<scalar_t>;
+ multi_tensor_apply<1>(
+ tensor_lists,
+ LpNormFunctor<scalar_t, NormType::LInf>(),
+ output_per_tensor.mutable_data_ptr<opmath_t>(),
+ max_chunks_per_tensor);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ const at::cuda::OptionalCUDAGuard device_guard(
+ device_of(output_per_tensor));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ lpnorm_cleanup<scalar_t, NormType::LInf>
+ <<<ntensors, 512, 0, stream>>>(
+ output_per_tensor.const_data_ptr<opmath_t>(),
+ ret_per_tensor.mutable_data_ptr<scalar_t>(),
+ max_chunks_per_tensor);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
} else {
TORCH_CHECK(
false,
diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh
index fa75c71..a21588a 100644
--- a/aten/src/ATen/native/cuda/block_reduce.cuh
+++ b/aten/src/ATen/native/cuda/block_reduce.cuh
@@ -29,6 +29,19 @@
return val;
}
+// Picks the maximum `val` accross all threads in a warp.
+//
+// Assumptions:
+// - The size of each block should be a multiple of `C10_WARP_SIZE`
+template <typename T>
+__inline__ __device__ T WarpReduceMax(T val) {
+#pragma unroll
+ for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
+ val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset));
+ }
+ return val;
+}
+
struct Block1D {
static __forceinline__ __device__ int Tid() { return threadIdx.x; }
@@ -72,6 +85,31 @@
return val;
}
+// Picks out the maximum `val` across all threads in a block.
+//
+// Warning: the return value is only valid for thread 0.
+// Assumptions:
+// - The size of each block should be a multiple of `C10_WARP_SIZE`
+// - `shared` should be a pointer to shared memory with size of, at least,
+// `sizeof(T) * number_of_warps`
+template <typename T, typename B = Block1D>
+__inline__ __device__ T BlockReduceMax(T val, T* shared) {
+ const int tid = B::Tid();
+ const int lid = tid % C10_WARP_SIZE;
+ const int wid = tid / C10_WARP_SIZE;
+ val = WarpReduceMax(val);
+ __syncthreads(); // prevent races when BlockReduces are called in a row.
+ if (lid == 0) {
+ shared[wid] = val;
+ }
+ __syncthreads();
+ val = (tid < B::Warps()) ? shared[lid] : T(0);
+ if (wid == 0) {
+ val = WarpReduceMax(val);
+ }
+ return val;
+}
+
template <typename T, class ReduceOp>
__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
#pragma unroll
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 8bba46d..3607db5 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9146,10 +9146,10 @@
assert "num_input_tensors" not in kwargs
_foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
_foreach_inputs_kwargs["requires_grad"] = requires_grad
- for ord in (0, 1, 2, -1, -2):
+ for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')):
input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
disable_fastpath = True
- if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
+ if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
@@ -9159,13 +9159,32 @@
_foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
_foreach_inputs_kwargs["requires_grad"] = requires_grad
- for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2)):
+ for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2, float('inf'), float('-inf'))):
input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
disable_fastpath = True
- if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
+ if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
+ # Also test nan propagation with a single tensor, but skip autograd testing
+ if not requires_grad:
+ nan_inputs = [
+ [float('nan')],
+ [float('nan'), 1.0],
+ [1.0, float('nan')],
+ [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0],
+ [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0],
+ [3.0, float('nan'), float('nan'), -1.5, 6.0],
+ ]
+ for input in nan_inputs:
+ x = torch.tensor(input, device=device)
+ disable_fastpath = True
+ if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
+ disable_fastpath = False
+ yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath)
+
+
+
class foreach_lerp_sample_func(foreach_inputs_sample_func):
def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs):