Add inf norm support for _foreach_norm (#118441)

Fixes #117803

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118441
Approved by: https://github.com/mlazos
diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
index d8af951..5e0a9d8 100644
--- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
@@ -20,16 +20,16 @@
 
 namespace at::native {
 
+// _foreach_norm supports only L1, L2, and inf norm
+enum class NormType { L1, L2, LInf };
+
 template <
     typename T,
-    int NormType,
+    NormType norm_type,
     int depth = 1,
     int r_args_depth = 1,
     int res_arg_index = 0>
 struct LpNormFunctor {
-  static_assert(
-      NormType == 1 || NormType == 2,
-      "foreach_norm supports only L1 and L2 norm");
   using opmath_t = typename at::opmath_type<T>;
   __device__ __forceinline__ void operator()(
       int chunk_size,
@@ -61,7 +61,11 @@
 #pragma unroll
         for (int ii = 0; ii < kILP; ii++) {
           opmath_t next = static_cast<opmath_t>(r_x[ii]);
-          vals[ii] += NormType == 1 ? ::abs(next) : next * next;
+          if constexpr (norm_type == NormType::LInf) {
+            vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
+          } else {
+            vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next;
+          }
         }
       }
     } else {
@@ -72,7 +76,11 @@
           int i = i_start + threadIdx.x + ii * blockDim.x;
           if (i < n && i < chunk_size) {
             opmath_t next = static_cast<opmath_t>(x[i]);
-            vals[ii] += NormType == 1 ? ::abs(next) : next * next;
+            if constexpr (norm_type == NormType::LInf) {
+              vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
+            } else {
+              vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next;
+            }
           }
         }
       }
@@ -80,19 +88,28 @@
 
     auto val = opmath_t(0);
     for (int i = 0; i < kILP; i++) {
-      val += vals[i];
+      if constexpr (norm_type == NormType::LInf) {
+        val = max_propagate_nan(val, vals[i]);
+      } else {
+        val += vals[i];
+      }
     }
-    auto final = at::native::cuda_utils::BlockReduceSum(val, s_vals);
+    auto final_val = norm_type == NormType::L1 || norm_type == NormType::L2
+        ? at::native::cuda_utils::BlockReduceSum(val, s_vals)
+        : at::native::cuda_utils::BlockReduceMax(val, s_vals);
 
     if (threadIdx.x == 0) {
       output_per_tensor
           [(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor +
-           chunk_idx] = final;
+           chunk_idx] = final_val;
     }
   }
 };
 
-template <typename T, int NormType, typename opmath_t = at::opmath_type<T>>
+template <
+    typename T,
+    NormType norm_type,
+    typename opmath_t = at::opmath_type<T>>
 __global__ void lpnorm_cleanup(
     const opmath_t* output_per_tensor,
     T* ret_per_tensor,
@@ -103,11 +120,20 @@
       output_per_tensor + blockIdx.x * max_chunks_per_tensor;
   opmath_t val = 0;
   for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
-    val += output_this_tensor[i];
+    if constexpr (norm_type == NormType::LInf) {
+      val = max_propagate_nan(val, output_this_tensor[i]);
+    } else {
+      val += output_this_tensor[i];
+    }
   }
-  opmath_t final = at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals);
+  opmath_t final_val = norm_type == NormType::L1 || norm_type == NormType::L2
+      ? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
+      : at::native::cuda_utils::BlockReduceMax(val, vals);
   if (threadIdx.x == 0) {
-    ret_per_tensor[blockIdx.x] = NormType == 1 ? final : ::sqrt(final);
+    ret_per_tensor[blockIdx.x] =
+        norm_type == NormType::L1 || norm_type == NormType::LInf
+        ? final_val
+        : ::sqrt(final_val);
   }
 }
 
@@ -135,7 +161,8 @@
             at::isComplexType(scalar_type);
       });
   if (!can_use_fast_route(tensors) || has_int_or_complex ||
-      !(p == static_cast<double>(1) || p == static_cast<double>(2))) {
+      !(p == static_cast<double>(1) || p == static_cast<double>(2) ||
+        p == std::numeric_limits<double>::infinity())) {
     return foreach_tensor_norm_slow(tensors, ord);
   }
 
@@ -166,14 +193,14 @@
           using opmath_t = typename at::opmath_type<scalar_t>;
           multi_tensor_apply<1>(
               tensor_lists,
-              LpNormFunctor<scalar_t, 1>(),
+              LpNormFunctor<scalar_t, NormType::L1>(),
               output_per_tensor.mutable_data_ptr<opmath_t>(),
               max_chunks_per_tensor);
           C10_CUDA_KERNEL_LAUNCH_CHECK();
           const at::cuda::OptionalCUDAGuard device_guard(
               device_of(output_per_tensor));
           auto stream = at::cuda::getCurrentCUDAStream();
-          lpnorm_cleanup<scalar_t, 1><<<ntensors, 512, 0, stream>>>(
+          lpnorm_cleanup<scalar_t, NormType::L1><<<ntensors, 512, 0, stream>>>(
               output_per_tensor.const_data_ptr<opmath_t>(),
               ret_per_tensor.mutable_data_ptr<scalar_t>(),
               max_chunks_per_tensor);
@@ -189,19 +216,43 @@
           using opmath_t = typename at::opmath_type<scalar_t>;
           multi_tensor_apply<1>(
               tensor_lists,
-              LpNormFunctor<scalar_t, 2>(),
+              LpNormFunctor<scalar_t, NormType::L2>(),
               output_per_tensor.mutable_data_ptr<opmath_t>(),
               max_chunks_per_tensor);
           C10_CUDA_KERNEL_LAUNCH_CHECK();
           const at::cuda::OptionalCUDAGuard device_guard(
               device_of(output_per_tensor));
           auto stream = at::cuda::getCurrentCUDAStream();
-          lpnorm_cleanup<scalar_t, 2><<<ntensors, 512, 0, stream>>>(
+          lpnorm_cleanup<scalar_t, NormType::L2><<<ntensors, 512, 0, stream>>>(
               output_per_tensor.const_data_ptr<opmath_t>(),
               ret_per_tensor.mutable_data_ptr<scalar_t>(),
               max_chunks_per_tensor);
           C10_CUDA_KERNEL_LAUNCH_CHECK();
         });
+  } else if (p == std::numeric_limits<double>::infinity()) {
+    AT_DISPATCH_FLOATING_TYPES_AND2(
+        kHalf,
+        kBFloat16,
+        tensor_lists[0][0].scalar_type(),
+        "foreach_tensor_norm_cuda",
+        [&]() {
+          using opmath_t = typename at::opmath_type<scalar_t>;
+          multi_tensor_apply<1>(
+              tensor_lists,
+              LpNormFunctor<scalar_t, NormType::LInf>(),
+              output_per_tensor.mutable_data_ptr<opmath_t>(),
+              max_chunks_per_tensor);
+          C10_CUDA_KERNEL_LAUNCH_CHECK();
+          const at::cuda::OptionalCUDAGuard device_guard(
+              device_of(output_per_tensor));
+          auto stream = at::cuda::getCurrentCUDAStream();
+          lpnorm_cleanup<scalar_t, NormType::LInf>
+              <<<ntensors, 512, 0, stream>>>(
+                  output_per_tensor.const_data_ptr<opmath_t>(),
+                  ret_per_tensor.mutable_data_ptr<scalar_t>(),
+                  max_chunks_per_tensor);
+          C10_CUDA_KERNEL_LAUNCH_CHECK();
+        });
   } else {
     TORCH_CHECK(
         false,
diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh
index fa75c71..a21588a 100644
--- a/aten/src/ATen/native/cuda/block_reduce.cuh
+++ b/aten/src/ATen/native/cuda/block_reduce.cuh
@@ -29,6 +29,19 @@
   return val;
 }
 
+// Picks the maximum `val` accross all threads in a warp.
+//
+// Assumptions:
+//   - The size of each block should be a multiple of `C10_WARP_SIZE`
+template <typename T>
+__inline__ __device__ T WarpReduceMax(T val) {
+#pragma unroll
+  for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
+    val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset));
+  }
+  return val;
+}
+
 struct Block1D {
     static __forceinline__ __device__ int Tid() { return threadIdx.x; }
 
@@ -72,6 +85,31 @@
   return val;
 }
 
+// Picks out the maximum `val` across all threads in a block.
+//
+// Warning: the return value is only valid for thread 0.
+// Assumptions:
+//   - The size of each block should be a multiple of `C10_WARP_SIZE`
+//   - `shared` should be a pointer to shared memory with size of, at least,
+//     `sizeof(T) * number_of_warps`
+template <typename T, typename B = Block1D>
+__inline__ __device__ T BlockReduceMax(T val, T* shared) {
+  const int tid = B::Tid();
+  const int lid = tid % C10_WARP_SIZE;
+  const int wid = tid / C10_WARP_SIZE;
+  val = WarpReduceMax(val);
+  __syncthreads(); // prevent races when BlockReduces are called in a row.
+  if (lid == 0) {
+    shared[wid] = val;
+  }
+  __syncthreads();
+  val = (tid < B::Warps()) ? shared[lid] : T(0);
+  if (wid == 0) {
+    val = WarpReduceMax(val);
+  }
+  return val;
+}
+
 template <typename T, class ReduceOp>
 __inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
 #pragma unroll
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 8bba46d..3607db5 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9146,10 +9146,10 @@
         assert "num_input_tensors" not in kwargs
         _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
         _foreach_inputs_kwargs["requires_grad"] = requires_grad
-        for ord in (0, 1, 2, -1, -2):
+        for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')):
             input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
             disable_fastpath = True
-            if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
+            if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
                 disable_fastpath = False
             yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
 
@@ -9159,13 +9159,32 @@
         _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
         _foreach_inputs_kwargs["requires_grad"] = requires_grad
 
-        for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2)):
+        for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2, float('inf'), float('-inf'))):
             input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
             disable_fastpath = True
-            if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
+            if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
                 disable_fastpath = False
             yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
 
+        # Also test nan propagation with a single tensor, but skip autograd testing
+        if not requires_grad:
+            nan_inputs = [
+                [float('nan')],
+                [float('nan'), 1.0],
+                [1.0, float('nan')],
+                [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0],
+                [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0],
+                [3.0, float('nan'), float('nan'), -1.5, 6.0],
+            ]
+            for input in nan_inputs:
+                x = torch.tensor(input, device=device)
+                disable_fastpath = True
+                if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
+                    disable_fastpath = False
+                yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath)
+
+
+
 
 class foreach_lerp_sample_func(foreach_inputs_sample_func):
     def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs):