[NNC] Fix HalfChecker when half present but unused (#48068)

Summary:
Fixes an internally reported issue in the tensorexpr fuser when using FP16 on Cuda. The HalfChecker analysis to determine if we need to define the Half type searches the IR for expressions that use Half. If one of the parameters is of type Half but it (or any other Half expr) are not used in the IR we'll return a false negative. Fix this by adding the parameter list to the HalfChecker.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48068

Reviewed By: ZolotukhinM

Differential Revision: D25009680

Pulled By: nickgg

fbshipit-source-id: 24fddef06821f130db3d3f45d6d041c7f34a6ab0
diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp
index df5359d..9fefa12 100644
--- a/test/cpp/tensorexpr/test_cuda.cpp
+++ b/test/cpp/tensorexpr/test_cuda.cpp
@@ -1034,6 +1034,63 @@
   cudaFree(reluDev);
 }
 
+void testCudaUnusedHalfArgument() {
+  KernelScope kernel_scope;
+  Placeholder a("a", kFloat, {4});
+  auto half = ToDtype<at::Half>();
+  Placeholder b("b", half, {4});
+  Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) {
+    return Max::make(a.load(i), ExprHandle(new FloatImm(0)), true);
+  });
+
+  LoopNest l({relu});
+  l.prepareForCodegen();
+  Stmt* s = l.root_stmt();
+  CudaCodeGen cg(s, {a, b, relu});
+
+  std::ostringstream oss;
+  oss << *cg.stmt();
+
+  // Check the types used by the Max are Float.
+  const std::string& verification_pattern =
+      R"IR(
+# CHECK: for (
+# CHECK:  float v = a[n];
+# CHECK:  relu[n] = Max(v, 0.f
+# CHECK: })IR";
+
+  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
+
+  // Sanity Cbeck;
+  std::vector<float> aData(4, 2.0f);
+  std::vector<at::Half> bData(4, 2.0f);
+  std::vector<float> reluData(4, 0.0f);
+  at::Half* aDev = nullptr;
+  at::Half* bDev = nullptr;
+  at::Half* reluDev = nullptr;
+  auto aSize = aData.size() * sizeof(aData[0]);
+  auto bSize = bData.size() * sizeof(bData[0]);
+  auto reluSize = reluData.size() * sizeof(reluData[0]);
+
+  cudaMalloc(&aDev, aSize);
+  cudaMalloc(&bDev, bSize);
+  cudaMalloc(&reluDev, reluSize);
+  cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice);
+  cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice);
+  cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice);
+  cudaDeviceSynchronize();
+
+  cg.call({aDev, bDev, reluDev});
+  cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost);
+  cudaDeviceSynchronize();
+
+  assertAllEqual(aData, reluData);
+
+  cudaFree(aDev);
+  cudaFree(bDev);
+  cudaFree(reluDev);
+}
+
 void testCudaPrioritizeDependents() {
   KernelScope kernel_scope;
   Placeholder a("a", kFloat, {10});
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index e765cab..8587b52 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -567,6 +567,7 @@
   _(CudaHalfCast)                          \
   _(CudaHalfSupport)                       \
   _(CudaHalfPropagation)                   \
+  _(CudaUnusedHalfArgument)                \
   _(CudaPrioritizeDependents)              \
   _(CudaMaskBlockDim)                      \
   _(CudaMaskThreadDim)                     \
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
index c674c05..3a6d094 100644
--- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -923,7 +923,7 @@
   // Check whether the statement uses the Half type, if so add the
   // half_support_literal.
   Stmt* stmt_v = stmt();
-  HalfChecker halfChecker;
+  HalfChecker halfChecker(buffer_args());
   stmt_v->accept(&halfChecker);
   if (halfChecker.hasHalf()) {
     os() << fuser::cuda::half_support_literal << std::endl;
diff --git a/torch/csrc/jit/tensorexpr/half_support.h b/torch/csrc/jit/tensorexpr/half_support.h
index c20d355..571f809 100644
--- a/torch/csrc/jit/tensorexpr/half_support.h
+++ b/torch/csrc/jit/tensorexpr/half_support.h
@@ -1,5 +1,6 @@
 #pragma once
 
+#include <torch/csrc/jit/tensorexpr/codegen.h>
 #include <torch/csrc/jit/tensorexpr/ir.h>
 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
 #include <torch/csrc/jit/tensorexpr/tensor.h>
@@ -11,6 +12,12 @@
 // Walk the Statment looking for Half size loads/stores.
 class HalfChecker : public IRVisitor {
  public:
+  HalfChecker(const std::vector<CodeGen::BufferArg>& args) {
+    for (const auto& BA : args) {
+      hasHalf_ |= BA.dtype().scalar_type() == ScalarType::Half;
+    }
+  }
+
   bool hasHalf() {
     return hasHalf_;
   }