[3/N] Non-Tensor: Support string parameter for aten operations (#125831)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125831
Approved by: https://github.com/jansel, https://github.com/jgong5
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index c7a1282..baa6024 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -804,7 +804,7 @@
 
     @skipCUDAIf(not SM80OrLater, "Requires sm80")
     @skip_if_halide  # aoti
-    def test_eager_aoti_support_out(self):
+    def test_aoti_eager_support_out(self):
         ns = "aten"
         op_name = "clamp"
         dispatch_key = "CPU"
@@ -857,7 +857,45 @@
 
     @skipCUDAIf(not SM80OrLater, "Requires sm80")
     @skip_if_halide  # aoti
-    def test_eager_aoti_cache_hit(self):
+    def test_aoti_eager_support_str(self):
+        ns = "aten"
+        op_name = "div"
+        dispatch_key = "CPU"
+        device = "cpu"
+        if self.device.lower() == "cuda":
+            dispatch_key = "CUDA"
+            device = "cuda"
+
+        a = torch.randn(128, dtype=torch.float, device=device)
+        b = torch.randn(128, dtype=torch.float, device=device)
+        rounding_mode_list = ["trunc", "floor"]
+        with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
+            # Get ref result from eager
+            ref_value_list = []
+            for rounding_mode in rounding_mode_list:
+                ref_value = getattr(torch.ops.aten, op_name)(
+                    a, b, rounding_mode=rounding_mode
+                )
+                ref_value_list.append(ref_value)
+
+            register_ops_with_aoti_compile(
+                ns, [op_name], dispatch_key, torch_compile_op_lib_impl
+            )
+
+            # Invoke the pre-compiled kernel and get result.
+            res_value_list = []
+            for rounding_mode in rounding_mode_list:
+                res_value = getattr(torch.ops.aten, op_name)(
+                    a, b, rounding_mode=rounding_mode
+                )
+                res_value_list.append(res_value)
+
+            for ref_value, res_value in zip(ref_value_list, res_value_list):
+                self.assertEqual(ref_value, res_value)
+
+    @skipCUDAIf(not SM80OrLater, "Requires sm80")
+    @skip_if_halide  # aoti
+    def test_aoti_eager_cache_hit(self):
         ns = "aten"
         op_name = "abs"
         dispatch_key = "CPU"
@@ -899,7 +937,7 @@
 
     @skipCUDAIf(not SM80OrLater, "Requires sm80")
     @skip_if_halide  # aoti
-    def test_eager_aoti_with_persistent_cache(self):
+    def test_aoti_eager_with_persistent_cache(self):
         def fn(a):
             return torch.abs(a)
 
@@ -944,7 +982,7 @@
 
     @skipCUDAIf(not SM80OrLater, "Requires sm80")
     @skip_if_halide  # aoti
-    def test_eager_aoti_with_scalar(self):
+    def test_aoti_eager_with_scalar(self):
         namespace_name = "aten"
         op_name = "add"
         op_overload_name = "Tensor"
@@ -1015,7 +1053,7 @@
 
     @skipCUDAIf(not SM80OrLater, "Requires sm80")
     @skip_if_halide  # aoti
-    def test_eager_aoti_override_registration(self):
+    def test_aoti_eager_override_registration(self):
         namespace_name = "aten"
         dispatch_key = "CPU"
         device = torch.device("cpu")
diff --git a/torch/_inductor/aoti_eager.py b/torch/_inductor/aoti_eager.py
index 5bd2027..83d4e98 100644
--- a/torch/_inductor/aoti_eager.py
+++ b/torch/_inductor/aoti_eager.py
@@ -1,7 +1,7 @@
 import json
 import os
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple
 from unittest import mock
 
 import torch
@@ -47,12 +47,14 @@
                     return []
 
                 for metadata in item["meta_info"]:
-                    assert not metadata[
-                        "is_dynamic"
-                    ], "Only support static shape for now"
-                    if metadata["device_type"] == "cpu":
+                    if metadata.get("is_dynamic"):
+                        raise NotImplementedError("Only support static shape for now")
+                    if "device_type" in metadata and metadata["device_type"] == "cpu":
                         metadata["device_index"] = -1
-                    metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1])
+                    if "dtype" in metadata:
+                        metadata["dtype"] = getattr(
+                            torch, metadata["dtype"].split(".")[-1]
+                        )
 
             return json_data
 
@@ -63,8 +65,7 @@
 
 def supported_scalar_types() -> Tuple[type, ...]:
     type_to_torch_dtype = supported_builtin_dtype_torch_dtype()
-    supported_scalar_types = tuple(type_to_torch_dtype.keys())
-    return supported_scalar_types
+    return tuple(type_to_torch_dtype.keys())
 
 
 def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]:
@@ -99,9 +100,7 @@
     return metadata
 
 
-def extract_scalar_metadata(
-    device_type: str, input: Union[int, float, bool]
-) -> Dict[str, Any]:
+def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
     assert isinstance(input, supported_scalar_types())
     metadata: Dict[str, Any] = {}
     metadata["is_dynamic"] = False
@@ -114,6 +113,13 @@
     return metadata
 
 
+def extract_string_metadata(input: str) -> Dict[str, Any]:
+    assert isinstance(input, str)
+    metadata: Dict[str, Any] = {}
+    metadata["string_value"] = input
+    return metadata
+
+
 def aoti_compile_with_persistent_cache(
     ns: str,
     op_func_name_with_overload: str,
@@ -132,11 +138,9 @@
     Compile the given function with persistent cache for AOTI eager mode.
     """
     assert not dynamic, "Only support static shape for now"
-    type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool}
-    supported_scalar_types = tuple(type_to_torch_dtype.keys())
     flattened_inputs = list(args) + list(kwargs.values())
     if not all(
-        isinstance(input, (supported_scalar_types, torch.Tensor, list))
+        isinstance(input, (supported_scalar_types(), torch.Tensor, list, str))
         for input in flattened_inputs
     ):
         raise NotImplementedError(
@@ -185,8 +189,12 @@
                 elif isinstance(input, list):
                     assert all(isinstance(item, torch.Tensor) for item in input)
                     metadata = extract_tensor_list_metadata(dynamic, input)
-                else:
+                elif isinstance(input, supported_scalar_types()):
                     metadata = extract_scalar_metadata(device_type, input)
+                elif isinstance(input, str):
+                    metadata = extract_string_metadata(input)
+                else:
+                    raise NotImplementedError(f"Unsupported input type: {type(input)}")
 
                 metadata["arg_order"] = idx
                 kernel_metadata_items.append(metadata)
diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp
index 0809c7d..41a6ea8 100644
--- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp
+++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp
@@ -147,6 +147,9 @@
       }
     } else if (stack[idx].isTensor()) {
       inputs_metadata.push_back(ParameterMetadata(stack[idx].toTensor(), idx));
+    } else if (stack[idx].isString()) {
+      inputs_metadata.push_back(
+          ParameterMetadata(stack[idx].toStringRef(), idx));
     } else {
       TORCH_CHECK_NOT_IMPLEMENTED(
           false,
@@ -309,6 +312,7 @@
       uint64_t arg_idx = metadata["arg_order"].cast<uint64_t>();
       bool is_scalar = metadata.contains("scalar_value");
       bool is_tensor_list = metadata.contains("tensor_list");
+      bool is_string = metadata.contains("string_value");
 
       if (is_tensor_list) {
         // Tensor List
@@ -332,6 +336,12 @@
         auto scalar_value = metadata["scalar_value"].cast<double>();
         parameter_metadata_list.push_back(
             ParameterMetadata(c10::Scalar(scalar_value), arg_idx));
+      } else if (is_string) {
+        // String
+        auto metadata = item_metadata.cast<py::dict>();
+        auto str_value = metadata["string_value"].cast<std::string>();
+        parameter_metadata_list.push_back(
+            ParameterMetadata(str_value, arg_idx));
       } else {
         // Tensor
         auto metadata = item_metadata.cast<py::dict>();
diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp
index 95cc29b..801ea59 100644
--- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp
+++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp
@@ -151,6 +151,13 @@
   value_ = scalar;
 }
 
+ParameterMetadata::ParameterMetadata(
+    const std::string& str,
+    uint64_t input_order)
+    : tag_(STRING), order_(input_order) {
+  value_ = str;
+}
+
 bool ParameterMetadata::operator==(const ParameterMetadata& other) const {
   // Same type
   if (tag_ != other.tag_) {
@@ -174,6 +181,9 @@
           std::get<c10::Scalar>(other.value_).isFloatingPoint() ||
           std::get<c10::Scalar>(other.value_).isIntegral(true /*includeBool*/));
       return equal_to(std::get<c10::Scalar>(other.value_));
+    case STRING:
+      return std::get<std::string>(value_) ==
+          std::get<std::string>(other.value_);
     default:
       return false;
   }
diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h
index d07814d..c5a858f 100644
--- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h
+++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h
@@ -89,13 +89,17 @@
   TENSOR_LIST,
   TENSOR_LIST_OPTIONAL,
   SCALAR,
+  STRING,
   INVALID,
 };
 
 // ParameterMetadataValue is to represent the value of the input parameters of a
 // aten operation.
-using ParameterMetadataValue =
-    std::variant<TensorMetadata, std::vector<TensorMetadata>, c10::Scalar>;
+using ParameterMetadataValue = std::variant<
+    TensorMetadata,
+    std::vector<TensorMetadata>,
+    c10::Scalar,
+    std::string>;
 
 // ParameterMetadata is to represent the metadata of the input parameters of a
 // aten operation. It includes the tag of the parameter, the value of the
@@ -122,6 +126,7 @@
       const std::vector<TensorMetadata>& tensor_metadata_list,
       uint64_t input_order);
   ParameterMetadata(const c10::Scalar& scalar, uint64_t input_order);
+  ParameterMetadata(const std::string& string_value, uint64_t input_order);
 
   bool operator==(const ParameterMetadata& other) const;