[fx] Don't use __module__ to test if a function is bound from C++
The new `test_public_bindings.py` test means `__module__` will be set
correctly in future, even for functions bound from C++. Instead, just
test directly that the function is of the `BuiltinFunctionType` which
only passes for functions exported with the CPython API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75896
Approved by: https://github.com/albanD
diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py
index 131dea5..5c2d484 100644
--- a/torch/fx/operator_schemas.py
+++ b/torch/fx/operator_schemas.py
@@ -1,6 +1,7 @@
import torch
import inspect
import numbers
+import types
import typing
import enum
import warnings
@@ -256,7 +257,7 @@
if kwargs is None:
kwargs = {}
new_args_and_kwargs = None
- if target in boolean_dispatched or target.__module__ in ['torch.nn.functional', 'torch.functional']:
+ if not isinstance(target, types.BuiltinFunctionType):
target_for_analysis = target
if target in boolean_dispatched:
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index e88838b..b03daed 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -10545,7 +10545,6 @@
skips=(
# tests do not work with passing lambda for op
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
- DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
# At this time ROCm uses magma instead of rocSolver, and the test passes
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', active_if=(not TEST_WITH_ROCM)),
@@ -11394,7 +11393,6 @@
check_inplace_batched_forward_grad=False,
sample_inputs_func=sample_inputs_as_strided,
skips=(
- DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# AssertionError: False is not true : Tensors failed to compare as equal!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
# AssertionError: False is not true : Scalars failed to compare as equal!
@@ -15003,7 +15001,6 @@
backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
supports_out=False,
skips=(
- DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# JIT has issue when op is passed as lambda
# AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
@@ -15042,7 +15039,6 @@
skips=(
# Cannot resize variables that require grad
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
- DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
),
sample_inputs_func=sample_inputs_resize_ops),
@@ -15117,8 +15113,6 @@
supports_fwgrad_bwgrad=True,
supports_gradgrad=True,
skips=(
- # lambda impl
- DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
),
sample_inputs_func=sample_inputs_zero_),