[3/N] Non-Tensor: Support string parameter for aten operations (#125831)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125831
Approved by: https://github.com/jansel, https://github.com/jgong5
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index c7a1282..baa6024 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -804,7 +804,7 @@
@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skip_if_halide # aoti
- def test_eager_aoti_support_out(self):
+ def test_aoti_eager_support_out(self):
ns = "aten"
op_name = "clamp"
dispatch_key = "CPU"
@@ -857,7 +857,45 @@
@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skip_if_halide # aoti
- def test_eager_aoti_cache_hit(self):
+ def test_aoti_eager_support_str(self):
+ ns = "aten"
+ op_name = "div"
+ dispatch_key = "CPU"
+ device = "cpu"
+ if self.device.lower() == "cuda":
+ dispatch_key = "CUDA"
+ device = "cuda"
+
+ a = torch.randn(128, dtype=torch.float, device=device)
+ b = torch.randn(128, dtype=torch.float, device=device)
+ rounding_mode_list = ["trunc", "floor"]
+ with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
+ # Get ref result from eager
+ ref_value_list = []
+ for rounding_mode in rounding_mode_list:
+ ref_value = getattr(torch.ops.aten, op_name)(
+ a, b, rounding_mode=rounding_mode
+ )
+ ref_value_list.append(ref_value)
+
+ register_ops_with_aoti_compile(
+ ns, [op_name], dispatch_key, torch_compile_op_lib_impl
+ )
+
+ # Invoke the pre-compiled kernel and get result.
+ res_value_list = []
+ for rounding_mode in rounding_mode_list:
+ res_value = getattr(torch.ops.aten, op_name)(
+ a, b, rounding_mode=rounding_mode
+ )
+ res_value_list.append(res_value)
+
+ for ref_value, res_value in zip(ref_value_list, res_value_list):
+ self.assertEqual(ref_value, res_value)
+
+ @skipCUDAIf(not SM80OrLater, "Requires sm80")
+ @skip_if_halide # aoti
+ def test_aoti_eager_cache_hit(self):
ns = "aten"
op_name = "abs"
dispatch_key = "CPU"
@@ -899,7 +937,7 @@
@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skip_if_halide # aoti
- def test_eager_aoti_with_persistent_cache(self):
+ def test_aoti_eager_with_persistent_cache(self):
def fn(a):
return torch.abs(a)
@@ -944,7 +982,7 @@
@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skip_if_halide # aoti
- def test_eager_aoti_with_scalar(self):
+ def test_aoti_eager_with_scalar(self):
namespace_name = "aten"
op_name = "add"
op_overload_name = "Tensor"
@@ -1015,7 +1053,7 @@
@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skip_if_halide # aoti
- def test_eager_aoti_override_registration(self):
+ def test_aoti_eager_override_registration(self):
namespace_name = "aten"
dispatch_key = "CPU"
device = torch.device("cpu")
diff --git a/torch/_inductor/aoti_eager.py b/torch/_inductor/aoti_eager.py
index 5bd2027..83d4e98 100644
--- a/torch/_inductor/aoti_eager.py
+++ b/torch/_inductor/aoti_eager.py
@@ -1,7 +1,7 @@
import json
import os
from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest import mock
import torch
@@ -47,12 +47,14 @@
return []
for metadata in item["meta_info"]:
- assert not metadata[
- "is_dynamic"
- ], "Only support static shape for now"
- if metadata["device_type"] == "cpu":
+ if metadata.get("is_dynamic"):
+ raise NotImplementedError("Only support static shape for now")
+ if "device_type" in metadata and metadata["device_type"] == "cpu":
metadata["device_index"] = -1
- metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1])
+ if "dtype" in metadata:
+ metadata["dtype"] = getattr(
+ torch, metadata["dtype"].split(".")[-1]
+ )
return json_data
@@ -63,8 +65,7 @@
def supported_scalar_types() -> Tuple[type, ...]:
type_to_torch_dtype = supported_builtin_dtype_torch_dtype()
- supported_scalar_types = tuple(type_to_torch_dtype.keys())
- return supported_scalar_types
+ return tuple(type_to_torch_dtype.keys())
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]:
@@ -99,9 +100,7 @@
return metadata
-def extract_scalar_metadata(
- device_type: str, input: Union[int, float, bool]
-) -> Dict[str, Any]:
+def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
assert isinstance(input, supported_scalar_types())
metadata: Dict[str, Any] = {}
metadata["is_dynamic"] = False
@@ -114,6 +113,13 @@
return metadata
+def extract_string_metadata(input: str) -> Dict[str, Any]:
+ assert isinstance(input, str)
+ metadata: Dict[str, Any] = {}
+ metadata["string_value"] = input
+ return metadata
+
+
def aoti_compile_with_persistent_cache(
ns: str,
op_func_name_with_overload: str,
@@ -132,11 +138,9 @@
Compile the given function with persistent cache for AOTI eager mode.
"""
assert not dynamic, "Only support static shape for now"
- type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool}
- supported_scalar_types = tuple(type_to_torch_dtype.keys())
flattened_inputs = list(args) + list(kwargs.values())
if not all(
- isinstance(input, (supported_scalar_types, torch.Tensor, list))
+ isinstance(input, (supported_scalar_types(), torch.Tensor, list, str))
for input in flattened_inputs
):
raise NotImplementedError(
@@ -185,8 +189,12 @@
elif isinstance(input, list):
assert all(isinstance(item, torch.Tensor) for item in input)
metadata = extract_tensor_list_metadata(dynamic, input)
- else:
+ elif isinstance(input, supported_scalar_types()):
metadata = extract_scalar_metadata(device_type, input)
+ elif isinstance(input, str):
+ metadata = extract_string_metadata(input)
+ else:
+ raise NotImplementedError(f"Unsupported input type: {type(input)}")
metadata["arg_order"] = idx
kernel_metadata_items.append(metadata)
diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp
index 0809c7d..41a6ea8 100644
--- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp
+++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp
@@ -147,6 +147,9 @@
}
} else if (stack[idx].isTensor()) {
inputs_metadata.push_back(ParameterMetadata(stack[idx].toTensor(), idx));
+ } else if (stack[idx].isString()) {
+ inputs_metadata.push_back(
+ ParameterMetadata(stack[idx].toStringRef(), idx));
} else {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
@@ -309,6 +312,7 @@
uint64_t arg_idx = metadata["arg_order"].cast<uint64_t>();
bool is_scalar = metadata.contains("scalar_value");
bool is_tensor_list = metadata.contains("tensor_list");
+ bool is_string = metadata.contains("string_value");
if (is_tensor_list) {
// Tensor List
@@ -332,6 +336,12 @@
auto scalar_value = metadata["scalar_value"].cast<double>();
parameter_metadata_list.push_back(
ParameterMetadata(c10::Scalar(scalar_value), arg_idx));
+ } else if (is_string) {
+ // String
+ auto metadata = item_metadata.cast<py::dict>();
+ auto str_value = metadata["string_value"].cast<std::string>();
+ parameter_metadata_list.push_back(
+ ParameterMetadata(str_value, arg_idx));
} else {
// Tensor
auto metadata = item_metadata.cast<py::dict>();
diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp
index 95cc29b..801ea59 100644
--- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp
+++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp
@@ -151,6 +151,13 @@
value_ = scalar;
}
+ParameterMetadata::ParameterMetadata(
+ const std::string& str,
+ uint64_t input_order)
+ : tag_(STRING), order_(input_order) {
+ value_ = str;
+}
+
bool ParameterMetadata::operator==(const ParameterMetadata& other) const {
// Same type
if (tag_ != other.tag_) {
@@ -174,6 +181,9 @@
std::get<c10::Scalar>(other.value_).isFloatingPoint() ||
std::get<c10::Scalar>(other.value_).isIntegral(true /*includeBool*/));
return equal_to(std::get<c10::Scalar>(other.value_));
+ case STRING:
+ return std::get<std::string>(value_) ==
+ std::get<std::string>(other.value_);
default:
return false;
}
diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h
index d07814d..c5a858f 100644
--- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h
+++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h
@@ -89,13 +89,17 @@
TENSOR_LIST,
TENSOR_LIST_OPTIONAL,
SCALAR,
+ STRING,
INVALID,
};
// ParameterMetadataValue is to represent the value of the input parameters of a
// aten operation.
-using ParameterMetadataValue =
- std::variant<TensorMetadata, std::vector<TensorMetadata>, c10::Scalar>;
+using ParameterMetadataValue = std::variant<
+ TensorMetadata,
+ std::vector<TensorMetadata>,
+ c10::Scalar,
+ std::string>;
// ParameterMetadata is to represent the metadata of the input parameters of a
// aten operation. It includes the tag of the parameter, the value of the
@@ -122,6 +126,7 @@
const std::vector<TensorMetadata>& tensor_metadata_list,
uint64_t input_order);
ParameterMetadata(const c10::Scalar& scalar, uint64_t input_order);
+ ParameterMetadata(const std::string& string_value, uint64_t input_order);
bool operator==(const ParameterMetadata& other) const;