[custom_op] Support string default values in schema (#129179)
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129179
Approved by: https://github.com/albanD
ghstack dependencies: #129177, #129178
diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py
index 56db888..3decba2 100644
--- a/test/test_custom_ops.py
+++ b/test/test_custom_ops.py
@@ -2411,13 +2411,14 @@
b: float = 3.14,
c: bool = True,
d: int = 3,
+ e: str = "foo",
) -> Tensor:
- defaults.extend([a, b, c, d])
+ defaults.extend([a, b, c, d, e])
return x.clone()
x = torch.randn(3)
f(x)
- self.assertEqual(defaults, [None, 3.14, True, 3])
+ self.assertEqual(defaults, [None, 3.14, True, 3, "foo"])
def test_mutated_error(self):
with self.assertRaisesRegex(
diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py
index c4f7b8e..36b3998 100644
--- a/torch/_library/infer_schema.py
+++ b/torch/_library/infer_schema.py
@@ -74,15 +74,17 @@
if param.default is inspect.Parameter.empty:
params.append(f"{schema_type} {name}")
else:
- if param.default is not None and not isinstance(
- param.default, (int, float, bool)
- ):
+ default_repr = None
+ if param.default is None or isinstance(param.default, (int, float, bool)):
+ default_repr = str(param.default)
+ elif isinstance(param.default, str):
+ default_repr = f'"{param.default}"'
+ else:
error_fn(
- f"Parameter {name} has an unsupported default value (we only support "
- f"int, float, bool, None). Please file an issue on GitHub so we can "
- f"prioritize this."
+ f"Parameter {name} has an unsupported default value type {type(param.default)}. "
+ f"Please file an issue on GitHub so we can prioritize this."
)
- params.append(f"{schema_type} {name}={param.default}")
+ params.append(f"{schema_type} {name}={default_repr}")
mutates_args_not_seen = set(mutates_args) - seen_args
if len(mutates_args_not_seen) > 0:
error_fn(