Make JIT Aliasing Test Less Brittle (#65493)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65493
Added a last resolve to use whatever ATen operator that has Tensor outputs in the graph as the operator node to check alias annotation.
Test Plan: python test/test_ops.py -k test_variant_consistency_jit
Reviewed By: mrshenli
Differential Revision: D31321221
Pulled By: alanwaketan
fbshipit-source-id: f4a5cbfd36bd0867d8c1bf9de9a65365ee7c35d6
diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp
index 2777b99..28b70b0 100644
--- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp
+++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp
@@ -148,7 +148,7 @@
const Graph& g,
const std::string& unqualifiedOpName) {
const auto opName = Symbol::fromQualString("aten::" + unqualifiedOpName);
- for (const auto node : g.nodes()) {
+ for (const auto* node : g.nodes()) {
if (node->kind() == opName) {
return node;
}
@@ -156,10 +156,29 @@
// Check for alias-ed operator names
const auto aliasOp = torch::jit::getOperatorAliasMap().find(opName);
- AT_ASSERT(aliasOp != torch::jit::getOperatorAliasMap().end());
- for (const auto node : g.nodes()) {
- if (node->kind() == aliasOp->second) {
- return node;
+ if (aliasOp != torch::jit::getOperatorAliasMap().end()) {
+ for (const auto* node : g.nodes()) {
+ if (node->kind() == aliasOp->second) {
+ return node;
+ }
+ }
+ }
+
+ // Ideally, there will be only one ATen operator that has tensor outputs in
+ // the graph. Let's use that as the last resolve to make checkAliasAnnotation
+ // more robust.
+ for (const auto* node : g.nodes()) {
+ if (!node->maybeOperator()) {
+ continue;
+ }
+ if (!node->getOperator().isC10Op()) {
+ continue;
+ }
+
+ for (const auto* output : node->outputs()) {
+ if (output->type()->kind() == TypeKind::TensorType) {
+ return node;
+ }
}
}
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index fee181c..3bb3cfe 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -6490,6 +6490,9 @@
supports_forward_ad=True,
skips=(
# JIT does not support variadic tensors.
+ # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_broadcast_tensors),
@@ -6499,6 +6502,9 @@
supports_forward_ad=True,
skips=(
# JIT does not support variadic tensors.
+ # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_block_diag),
@@ -6625,6 +6631,9 @@
torch.bfloat16, torch.half),
supports_forward_ad=True,
skips=(
+ # RuntimeError: inputSet && outputSet
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
)),
OpInfo('resolve_conj',
@@ -6713,8 +6722,16 @@
sample_inputs_func=sample_inputs_cov,
supports_out=False,
supports_forward_ad=True,
- # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507)
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),)),
+ skips=(
+ # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507)
+ # RuntimeError:
+ # undefined value tensor:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950
+ # ~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ )),
OpInfo('cross',
dtypes=all_types_and_complex(),
dtypesIfCPU=all_types_and_complex_and(torch.bfloat16),
@@ -6790,10 +6807,6 @@
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=partial(sample_inputs_binary_pwise, rounding_mode="trunc"),
supports_forward_ad=True,
- skips=(
- # Reference: https://github.com/pytorch/pytorch/issues/59174
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
- ),
assert_autodiffed=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True)),
BinaryUfuncInfo('div',
@@ -6802,10 +6815,6 @@
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=partial(sample_inputs_binary_pwise, rounding_mode="floor"),
supports_forward_ad=True,
- skips=(
- # Reference: https://github.com/pytorch/pytorch/issues/59174
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
- ),
assert_autodiffed=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True)),
BinaryUfuncInfo('true_divide',
@@ -6846,9 +6855,6 @@
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
sample_inputs_func=sample_inputs_expand_as,
- skips=(
- # Because expand_as does not have a function variant.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),),
supports_out=False),
OpInfo('diag',
dtypes=all_types_and_complex_and(torch.bool),
@@ -7160,8 +7166,9 @@
skips=(
# following tests give a runtime error with undefined value tensor
# see discussion : https://github.com/pytorch/pytorch/issues/56660
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
- dtypes=(torch.float32, torch.complex64)),
+ # RuntimeError:
+ # Arguments for call are not valid.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950
),
supports_inplace_autograd=False,
sample_inputs_func=sample_inputs_gradient),
@@ -7309,10 +7316,7 @@
supports_out=True,
sample_inputs_func=sample_inputs_linalg_lstsq,
supports_autograd=False,
- decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
- skips=(
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
- )),
+ decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
OpInfo('linalg.matrix_power',
aliases=('matrix_power',),
aten_name='linalg_matrix_power',
@@ -7490,6 +7494,12 @@
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# we skip jit tests because `lu` is a torch function
+ # RuntimeError:
+ # 'Tensor (inferred)' object has no attribute or method 'lu'.:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return i0.lu(True, True)
+ # ~~~~~ <--- HERE
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
)),
OpInfo('lu_solve',
@@ -7622,6 +7632,9 @@
sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'),
skips=[
# JIT does not support variadic tensors.
+ # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# meshgrid is defined in torch.functional to take a
# variadic list of tensors. Variadic parameters are not
@@ -7742,13 +7755,7 @@
OpInfo('nn.functional.normalize',
dtypesIfCPU=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
- sample_inputs_func=sample_inputs_normalize,
- skips=(
- # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
- # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":159,
- # please report a bug to PyTorch.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),
- )),
+ sample_inputs_func=sample_inputs_normalize),
OpInfo('aminmax',
ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)),
dtypes=all_types_and(torch.bool),
@@ -7771,6 +7778,14 @@
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
skips=(
+ # RuntimeError:
+ # adaptive_avg_pool2d(Tensor input, int[2] output_size) -> (Tensor):
+ # Expected a value of type 'List[int]' for argument 'output_size' but instead found type 'Tuple[NoneType, int]'.
+ # :
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.nn.functional.adaptive_avg_pool2d(i0, (None, 7))
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False,
@@ -7840,10 +7855,9 @@
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'),
skips=(
- # op name not found in JIT graph
- # There are multiple aten ops, namely reflection_pad_{1,2,3}d
- # so we can't use aten_name argument in opinfo
- # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
+ # Doesn't have a corresponding aten operator.
+ # RuntimeError: falseINTERNAL ASSERT FAILED at
+ # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
@@ -7854,10 +7868,9 @@
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'),
skips=(
- # op name not found in JIT graph
- # There are multiple aten ops, namely replication_pad_{1,2,3}d
- # so we can't use aten_name argument in opinfo
- # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
+ # Doesn't have a corresponding aten operator.
+ # RuntimeError: falseINTERNAL ASSERT FAILED at
+ # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
@@ -7870,7 +7883,8 @@
check_batched_grad=False,
skips=(
# Doesn't have a corresponding aten operator.
- # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
+ # RuntimeError: falseINTERNAL ASSERT FAILED at
+ # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
supports_out=False),
@@ -7890,7 +7904,9 @@
dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_nn_unfold,
skips=(
- # JIT alias info internal asserts here
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
@@ -7902,7 +7918,9 @@
dtypesIfCUDA=floating_types_and(torch.half, torch.uint8),
sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'),
skips=(
- # JIT alias info internal asserts here
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
@@ -7913,7 +7931,9 @@
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=partial(sample_inputs_interpolate, 'linear'),
skips=(
- # JIT alias info internal asserts here
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
@@ -7925,7 +7945,9 @@
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'),
skips=(
- # JIT alias info internal asserts here
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
@@ -7937,7 +7959,9 @@
sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
skips=(
- # JIT alias info internal asserts here
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
@@ -7949,7 +7973,9 @@
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'),
skips=(
- # JIT alias info internal asserts here
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
@@ -7961,7 +7987,9 @@
sample_inputs_func=partial(sample_inputs_interpolate, 'area'),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
skips=(
- # JIT alias info internal asserts here
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+ # please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
@@ -8094,6 +8122,7 @@
supports_inplace_autograd=False,
skips=(
# test does not work with passing lambda for op
+ # AssertionError: False is not true : Tensors failed to compare as equal!
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# test fails are we permute the arguments function variant
# but not for inplace or method.
@@ -8121,6 +8150,7 @@
supports_inplace_autograd=False,
skips=(
# test does not work with passing lambda for op
+ # AssertionError: False is not true : Tensors failed to compare as equal!
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# test fails are we permute the arguments function variant
# but not for inplace or method.
@@ -8414,7 +8444,15 @@
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
sample_inputs_func=sample_inputs_rbinops,
supports_out=False,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),),
+ skips=(
+ # RuntimeError:
+ # object has no attribute __radd__:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.__radd__(i0, 3.14j)
+ # ~~~~~~~~~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),
+ ),
assert_autodiffed=True,
supports_forward_ad=True,
autodiff_nonfusible_nodes=['aten::add'],),
@@ -8423,7 +8461,15 @@
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
sample_inputs_func=sample_inputs_rbinops,
supports_out=False,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),),
+ skips=(
+ # RuntimeError:
+ # object has no attribute __rdiv__:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.__rdiv__(i0, 3.14j)
+ # ~~~~~~~~~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),
+ ),
supports_forward_ad=True,
assert_autodiffed=True,
autodiff_nonfusible_nodes=['aten::mul', 'aten::reciprocal'],),
@@ -8432,7 +8478,15 @@
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
sample_inputs_func=sample_inputs_rbinops,
supports_out=False,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),),
+ skips=(
+ # RuntimeError:
+ # object has no attribute __rmul__:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.__rmul__(i0, 3.14j)
+ # ~~~~~~~~~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),
+ ),
assert_autodiffed=True,
supports_forward_ad=True,
autodiff_nonfusible_nodes=['aten::mul'],),
@@ -8441,7 +8495,6 @@
dtypes=integral_types_and(torch.bool),
sample_inputs_func=sample_inputs_rbinops,
supports_out=False,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_jit',),),
supports_autograd=False,
supports_forward_ad=True,),
OpInfo('__ror__',
@@ -8449,7 +8502,6 @@
dtypes=integral_types_and(torch.bool),
sample_inputs_func=sample_inputs_rbinops,
supports_out=False,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_jit',),),
supports_autograd=False,
supports_forward_ad=True,),
OpInfo('__rxor__',
@@ -8457,7 +8509,6 @@
dtypes=integral_types_and(torch.bool),
sample_inputs_func=sample_inputs_rbinops,
supports_out=False,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_jit',),),
supports_autograd=False,
supports_forward_ad=True,),
OpInfo('__rmatmul__',
@@ -8477,6 +8528,12 @@
toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
'TestMathBits', 'test_conj_view')],
skips=(
+ # RuntimeError:
+ # object has no attribute __rmatmul__:
+ # File "<string>", line 3
+ # def the_method(i0, i1):
+ # return torch.__rmatmul__(i0, i1)
+ # ~~~~~~~~~~~~~~ <--- HERE
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),
)),
OpInfo('__rmod__',
@@ -8486,7 +8543,15 @@
dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half, torch.bool),
sample_inputs_func=sample_inputs_rbinops,
supports_out=False,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),),
+ skips=(
+ # RuntimeError:
+ # object has no attribute __rmod__:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.__rmod__(i0, 3.14)
+ # ~~~~~~~~~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),
+ ),
# Support autograd after torch.remainder(Tensor, Tensor) supports
# autograd of the second argument.
# https://github.com/pytorch/pytorch/pull/58476/files#r637167630
@@ -8503,7 +8568,14 @@
supports_out=False,
supports_forward_ad=True,
skips=(
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),),
+ # RuntimeError:
+ # object has no attribute __rpow__:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.__rpow__(i0, 3.14j)
+ # ~~~~~~~~~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),
+ ),
assert_autodiffed=True,
autodiff_nonfusible_nodes=['aten::pow'],),
OpInfo('__rsub__',
@@ -8511,7 +8583,15 @@
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
sample_inputs_func=sample_inputs_rbinops,
supports_out=False,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),),
+ skips=(
+ # RuntimeError:
+ # object has no attribute __rsub__:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.__rsub__(i0, 3.14j)
+ # ~~~~~~~~~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',),
+ ),
assert_autodiffed=True,
autodiff_nonfusible_nodes=['aten::rsub'],),
OpInfo('rsub',
@@ -8522,8 +8602,10 @@
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/53797
# JIT doesn't understand complex literals
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
- dtypes=[torch.cfloat, torch.cdouble]),
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":52,
+ # please report a bug to PyTorch.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.cfloat, torch.cdouble]), # noqa: B950
),
sample_inputs_func=partial(sample_inputs_rsub, variant='tensor'),),
OpInfo('rsub',
@@ -8535,8 +8617,11 @@
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/53797
# JIT doesn't understand complex literals
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
- dtypes=all_types_and_complex_and(torch.bfloat16, torch.half)),),
+ # RuntimeError: false
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":52,
+ # please report a bug to PyTorch.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.cfloat, torch.cdouble]), # noqa: B950
+ ),
assert_autodiffed=True,),
OpInfo('select',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
@@ -8847,6 +8932,7 @@
skips=(
# test does not work with passing lambda for op
# there's a test `test_einsum` in `test_jit.py` to handle this case
+ # AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
)),
OpInfo('svd',
@@ -8895,21 +8981,9 @@
supports_forward_ad=True,
sample_inputs_func=sample_inputs_polygamma,
skips=(
- # Probably related to the way the function is
- # scripted for JIT tests (or maybe not).
- # RuntimeError:
- # Arguments for call are not valid.
- # The following variants are available:
- # aten::polygamma(int n, Tensor self) -> (Tensor):
- # Expected a value of type 'Tensor' for argument 'self' but instead found type 'int'.
- # aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> (Tensor(a!)):
- # Expected a value of type 'Tensor' for argument 'self' but instead found type 'int'.
- # The original call is:
- # File "<string>", line 3
- # def the_method(i0):
- # return torch.polygamma(i0, 1)
- # ~~~~~~~~~~~~~~~ <--- HERE
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),),
+ # AssertionError: JIT Test does not execute any logic
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ ),
sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0})),
# A separate OpInfo entry for special.polygamma is needed to reorder the arguments
# for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939
@@ -8923,21 +8997,9 @@
supports_forward_ad=True,
sample_inputs_func=sample_inputs_polygamma,
skips=(
- # Probably related to the way the function is
- # scripted for JIT tests (or maybe not).
- # RuntimeError:
- # Arguments for call are not valid.
- # The following variants are available:
- # aten::polygamma(int n, Tensor self) -> (Tensor):
- # Expected a value of type 'Tensor' for argument 'self' but instead found type 'int'.
- # aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> (Tensor(a!)):
- # Expected a value of type 'Tensor' for argument 'self' but instead found type 'int'.
- # The original call is:
- # File "<string>", line 3
- # def the_method(i0):
- # return torch.polygamma(i0, 1)
- # ~~~~~~~~~~~~~~~ <--- HERE
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),),
+ # AssertionError: JIT Test does not execute any logic
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ ),
sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0})),
UnaryUfuncInfo('polygamma',
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
@@ -9059,9 +9121,6 @@
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
- skips=(
- # Because view_as does not have a function variant.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),),
sample_inputs_func=sample_inputs_view_as_reshape_as,
),
OpInfo('pinverse',
@@ -9111,7 +9170,10 @@
supports_inplace_autograd=False,
supports_scripting=False,
op=torch.Tensor.__getitem__,
- skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),),
+ skips=(
+ # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 104448
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
+ ),
assert_jit_shape_analysis=False, # TODO: support index.Tensor()
sample_inputs_func=sample_inputs_getitem,),
OpInfo('index_put',
@@ -9122,6 +9184,13 @@
test_neg_view=False,
sample_inputs_func=sample_inputs_index_put,
skips=(
+ # RuntimeError: The following operation failed in the TorchScript interpreter.
+ # Traceback of TorchScript (most recent call last):
+ # File "<string>", line 3, in forward
+ # def the_method(i0, i1: List[torch.Tensor], i2):
+ # return torch.index_put(i0, i1, i2, accumulate=False)
+ # ~~~~~~~~~~~~~~~ <--- HERE
+ # RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
)),
OpInfo('sort',
@@ -9181,7 +9250,14 @@
skips=(
# JIT tests don't work with Tensor keyword arguments
# https://github.com/pytorch/pytorch/issues/58507
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),),),
+ # RuntimeError:
+ # undefined value tensor:
+ # File "<string>", line 3
+ # def the_method(i0):
+ # return torch.histogram(i0, 1, weight=tensor(-0.5735, dtype=torch.float32), density=False)
+ # ~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ )),
OpInfo('cat',
ref=lambda input_seq, dim=0, **kwargs: np.concatenate(input_seq, axis=dim, **kwargs),
aliases=('concat',),
@@ -9222,9 +9298,6 @@
supports_forward_ad=True,
check_batched_gradgrad=False,
skips=(
- # torch.unfold does not exist so we get a RuntimeError.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
- dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16)),
# Skip operator schema test because this is a functional and not an operator
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
),
@@ -9255,11 +9328,6 @@
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
- skips=(
- # torch.repeat does not exist so we get a RuntimeError.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
- dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16)),
- ),
sample_inputs_func=sample_repeat_tile),
OpInfo('squeeze',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
@@ -9277,6 +9345,7 @@
supports_out=False,
skips=(
# JIT has issue when op is passed as lambda
+ # AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
sample_inputs_func=sample_inputs_fill_),
@@ -9347,6 +9416,7 @@
supports_forward_ad=True,
skips=(
# JIT has issue when op is passed as lambda
+ # AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
sample_inputs_func=sample_inputs_zero_),
@@ -9373,6 +9443,7 @@
safe_casts_outputs=True,
skips=(
# Lambda doesn't work in JIT test
+ # AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"),
),
sample_inputs_func=sample_inputs_zeta),
@@ -9430,9 +9501,6 @@
supports_forward_ad=True,
sample_inputs_func=sample_inputs_tensordot,
skips=(
- # Currently failing due to an INTERNAL_ASSERT_FAILED error.
- # Reference: https://github.com/pytorch/pytorch/issues/56314
- DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=[torch.float32]),
# Skip operator schema test because this is a functional and not an operator.
# Reference: https://github.com/pytorch/pytorch/issues/54574
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
@@ -9451,6 +9519,7 @@
# TODO: FIXME: complex inputs requiring grad error in forward
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
# JIT has issue when op is passed as lambda
+ # NotImplementedError: Cannot access storage of SparseTensorImpl
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
)
),
@@ -9623,6 +9692,8 @@
supports_out=False,
skips=(
# test does not work with passing lambda for op
+ # AssertionError: False is not true :
+ # Failure in testing nodes' autodifferentiation.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)),
@@ -9651,11 +9722,7 @@
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
# RuntimeError:
# Arguments for call are not valid.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.complex64,)),
- # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
- # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":157,
- # please report a bug to PyTorch.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.complex64, torch.float32,)), # noqa: B950
)
),
OpInfo('norm',
@@ -9669,11 +9736,7 @@
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
# RuntimeError:
# Arguments for call are not valid.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.complex64,)),
- # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
- # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":157,
- # please report a bug to PyTorch.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.complex64, torch.float32,)), # noqa: B950
)
),
OpInfo('norm',
@@ -9682,13 +9745,9 @@
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
backward_dtypesIfCPU=floating_and_complex_types_and(torch.float16, torch.bfloat16),
skips=(
- # following 3 tests failed intermittenly
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
- device_type='cpu', dtypes=(torch.complex64,)),
- DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad',
- device_type='cpu', dtypes=(torch.complex128,)),
- DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad',
- device_type='cpu', dtypes=(torch.complex128,)),
+ # following 2 tests failed intermittenly
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad', device_type='cpu', dtypes=(torch.complex128,)), # noqa: B950
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad', device_type='cpu', dtypes=(torch.complex128,)), # noqa: B950
)
),
OpInfo('t',
@@ -9741,24 +9800,11 @@
dtypesIfCPU=floating_types(),
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
supports_out=False,
- skips=(
- DecorateInfo(unittest.skip("Skipped!"),
- "TestJit",
- "test_variant_consistency_jit",
- dtypes=(torch.float32,),
- ),
- ),
),
OpInfo(
"linalg.tensorinv",
ref=np.linalg.tensorinv,
dtypes=floating_and_complex_types(),
- skips=(
- # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
- # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":159,
- # please report a bug to PyTorch.
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
- ),
sample_inputs_func=sample_inputs_tensorinv,
supports_forward_ad=True,
),
@@ -9771,11 +9817,10 @@
backward_dtypesIfCPU=floating_types(),
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
skips=(
- DecorateInfo(unittest.skip("Skipped!"),
- "TestJit",
- "test_variant_consistency_jit",
- dtypes=(torch.float32,),
- ),
+ # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+ # please report a bug to PyTorch.
+ DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
),
),
OpInfo(
@@ -9787,13 +9832,6 @@
sample_inputs_func=sample_inputs_grid_sample,
supports_gradgrad=False,
gradcheck_nondet_tol=1e-15,
- skips=(
- DecorateInfo(unittest.skip("Skipped!"),
- "TestJit",
- "test_variant_consistency_jit",
- dtypes=(torch.float32,),
- ),
- ),
),
ReductionOpInfo(
'all',
@@ -9939,8 +9977,8 @@
sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True),
ref=reference_reduction_numpy(np.nanmean),
skips=(
- # RuntimeError: deepEquals(input.iValue, deepCopiedInput)INTERNAL ASSERT FAILED at
- # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch.
+ # AssertionError: False is not true :
+ # Failure in testing nodes' autodifferentiation.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# FIXME: prod reduces all dimensions when dim=[]
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
@@ -10106,10 +10144,13 @@
supports_out=False,
sample_inputs_func=sample_inputs_nll_loss,
skips=(
- DecorateInfo(unittest.skip("Skipped!"),
- "TestJit",
- "test_variant_consistency_jit",
- dtypes=(torch.float32,),),
+ # RuntimeError:
+ # undefined value tensor:
+ # File "<string>", line 3
+ # def the_method(i0, i1):
+ # return torch.nn.functional.nll_loss(i0, i1, weight=tensor([8.4784, 1.7658, 4.3228], dtype=torch.float32))
+ # ~~~~~~ <--- HERE
+ DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
),
),
OpInfo(