Add OpInfo based meta tensor tests [RELAND]

PR #75994 was taking too long to ship so I extracted out the CrossRef gadget and
had it run on a simple OpInfo invocation only.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77008

Approved by: https://github.com/ngimel
diff --git a/test/test_meta.py b/test/test_meta.py
new file mode 100644
index 0000000..971aa08
--- /dev/null
+++ b/test/test_meta.py
@@ -0,0 +1,886 @@
+# Owner(s): ["module: primTorch"]
+
+import torch
+from torch.utils._pytree import tree_map, tree_flatten
+from torch.testing._internal.common_utils import (
+    TestCase,
+    skipIfCrossRef,
+    suppress_warnings,
+    TEST_WITH_ASAN,
+    run_tests,
+)
+from torch.overrides import push_torch_function_mode
+from torch.testing._internal.common_device_type import (
+    onlyNativeDeviceTypes,
+    ops,
+    instantiate_device_type_tests,
+)
+from torch.testing._internal.common_methods_invocations import op_db
+
+import functools
+import re
+from functools import partial
+import unittest
+import warnings
+
+RE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ")
+
+# These just need an implementation of meta tensors, once you
+# implement them remove from this set.  When doing comprehensive
+# testing, we will verify that these raise errors when meta is run under
+# OpInfo
+meta_exclude_set = {
+    torch.Tensor.__lshift__,  # MISSING aten::__lshift__.Scalar
+    torch.Tensor.__lshift__,  # MISSING aten::__lshift__.Tensor
+    torch.Tensor.__reversed__,  # MISSING aten::flip
+    torch.Tensor.__rmatmul__,  # MISSING aten::dot
+    torch.Tensor.__rshift__,  # MISSING aten::__rshift__.Scalar
+    torch.Tensor.__rshift__,  # MISSING aten::__rshift__.Tensor
+    torch.Tensor.abs,  # MISSING aten::abs.out
+    torch.Tensor.abs_,  # MISSING aten::abs.out
+    torch.Tensor.absolute,  # MISSING aten::abs.out
+    torch.Tensor.absolute_,  # MISSING aten::abs.out
+    torch.Tensor.addbmm,  # MISSING aten::addbmm
+    torch.Tensor.addcmul,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.angle,  # MISSING aten::angle
+    torch.Tensor.argsort,  # MISSING aten::sort
+    torch.Tensor.bincount,  # MISSING aten::bincount
+    torch.Tensor.cholesky,  # MISSING aten::cholesky
+    torch.Tensor.cholesky_inverse,  # MISSING aten::cholesky_inverse
+    torch.Tensor.cholesky_solve,  # MISSING aten::_cholesky_solve_helper
+    torch.Tensor.clamp,  # MISSING aten::clamp.Tensor
+    torch.Tensor.clamp_,  # MISSING aten::clamp.Tensor_out
+    torch.Tensor.clip,  # MISSING aten::clamp.Tensor
+    torch.Tensor.clip_,  # MISSING aten::clamp.Tensor_out
+    torch.Tensor.conj_physical,  # MISSING aten::conj_physical.out
+    torch.Tensor.corrcoef,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.count_nonzero,  # MISSING aten::count_nonzero.dim_IntList
+    torch.Tensor.cov,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.cummax,  # MISSING aten::_cummax_helper
+    torch.Tensor.cummin,  # MISSING aten::_cummin_helper
+    torch.Tensor.cumprod_,  # MISSING aten::logical_and.out
+    torch.Tensor.dequantize,  # MISSING aten::dequantize.self
+    torch.Tensor.det,  # MISSING aten::_det_lu_based_helper
+    torch.Tensor.diag,  # MISSING aten::diag.out
+    torch.Tensor.diagflat,  # MISSING aten::diag.out
+    torch.Tensor.dot,  # MISSING aten::dot
+    torch.Tensor.eig,  # MISSING aten::abs.out
+    torch.Tensor.equal,  # MISSING aten::equal
+    torch.Tensor.flip,  # MISSING aten::flip
+    torch.Tensor.fliplr,  # MISSING aten::flip
+    torch.Tensor.flipud,  # MISSING aten::flip
+    torch.Tensor.floor_divide,  # MISSING aten::floor_divide
+    torch.Tensor.frexp,  # MISSING aten::frexp.Tensor_out
+    torch.Tensor.geqrf,  # MISSING aten::geqrf
+    torch.Tensor.histc,  # MISSING aten::histc
+    torch.Tensor.histogram,  # MISSING aten::histogram.bin_ct
+    torch.Tensor.index_select,  # MISSING aten::index_select
+    torch.Tensor.inverse,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.is_set_to,  # MISSING aten::is_set_to
+    torch.Tensor.isclose,  # MISSING aten::abs.out
+    torch.Tensor.isnan,  # MISSING aten::isnan
+    torch.Tensor.istft,  # MISSING aten::view_as_complex
+    torch.Tensor.kthvalue,  # MISSING aten::kthvalue.values
+    torch.Tensor.logcumsumexp,  # MISSING aten::_logcumsumexp
+    torch.Tensor.logdet,  # MISSING aten::abs.out
+    torch.Tensor.logical_and,  # MISSING aten::logical_and.out
+    torch.Tensor.logical_and_,  # MISSING aten::logical_and.out
+    torch.Tensor.logical_not,  # MISSING aten::logical_not.out
+    torch.Tensor.logical_or,  # MISSING aten::logical_or.out
+    torch.Tensor.logical_or_,  # MISSING aten::logical_or.out
+    torch.Tensor.logical_xor,  # MISSING aten::logical_xor.out
+    torch.Tensor.logical_xor_,  # MISSING aten::logical_xor.out
+    torch.Tensor.logit,  # MISSING aten::logit
+    torch.Tensor.logsumexp,  # MISSING aten::abs.out
+    torch.Tensor.lstsq,  # MISSING aten::lstsq
+    torch.Tensor.masked_select,  # MISSING aten::masked_select
+    torch.Tensor.matmul,  # MISSING aten::dot
+    torch.Tensor.matrix_exp,  # MISSING aten::linalg_matrix_exp
+    torch.Tensor.matrix_power,  # MISSING aten::eye.m_out
+    torch.Tensor.max,  # MISSING aten::max
+    torch.Tensor.median,  # MISSING aten::median
+    torch.Tensor.median,  # MISSING aten::median.dim_values
+    torch.Tensor.min,  # MISSING aten::min
+    torch.Tensor.mode,  # MISSING aten::mode
+    torch.Tensor.msort,  # MISSING aten::sort
+    torch.Tensor.multinomial,  # MISSING aten::multinomial
+    torch.Tensor.mvlgamma,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.mvlgamma_,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.nan_to_num,  # MISSING aten::nan_to_num.out
+    torch.Tensor.nan_to_num_,  # MISSING aten::nan_to_num.out
+    torch.Tensor.nanmean,  # MISSING aten::logical_not.out
+    torch.Tensor.nanmedian,  # MISSING aten::nanmedian
+    torch.Tensor.nanmedian,  # MISSING aten::nanmedian.dim_values
+    torch.Tensor.nanquantile,  # MISSING aten::sort
+    torch.Tensor.nansum,  # MISSING aten::nansum
+    torch.Tensor.narrow,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.nonzero,  # MISSING aten::nonzero
+    torch.Tensor.orgqr,  # MISSING aten::linalg_householder_product
+    torch.Tensor.ormqr,  # MISSING aten::ormqr
+    torch.Tensor.pinverse,  # MISSING aten::where.self
+    torch.Tensor.prod,  # MISSING aten::prod
+    torch.Tensor.qr,  # MISSING aten::_linalg_qr_helper
+    torch.Tensor.quantile,  # MISSING aten::sort
+    torch.Tensor.relu,  # MISSING aten::relu
+    torch.Tensor.renorm_,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.repeat_interleave,  # MISSING aten::repeat_interleave.Tensor
+    torch.Tensor.roll,  # MISSING aten::roll
+    torch.Tensor.rot90,  # MISSING aten::flip
+    torch.Tensor.slogdet,  # MISSING aten::linalg_slogdet
+    torch.Tensor.solve,  # MISSING aten::_solve_helper
+    torch.Tensor.sort,  # MISSING aten::sort
+    torch.Tensor.std,  # MISSING aten::std.correction
+    torch.Tensor.stft,  # MISSING aten::_fft_r2c
+    torch.Tensor.symeig,  # MISSING aten::_symeig_helper
+    torch.Tensor.take,  # MISSING aten::take
+    torch.Tensor.to_mkldnn,  # MISSING aten::to_mkldnn
+    torch.Tensor.to_sparse,  # MISSING aten::to_sparse
+    torch.Tensor.to_sparse_csr,  # MISSING aten::to_sparse_csr
+    torch.Tensor.topk,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.trace,  # MISSING aten::trace
+    torch.Tensor.unique,  # MISSING aten::_unique2
+    torch.Tensor.unique_consecutive,  # MISSING aten::unique_consecutive
+    torch.Tensor.unsqueeze,  # MISSING aten::_local_scalar_dense
+    torch.Tensor.var,  # MISSING aten::var.correction
+    torch.Tensor.vdot,  # MISSING aten::vdot
+    torch.Tensor.where,  # MISSING aten::where.self
+    torch._add_relu,  # MISSING aten::_add_relu.Tensor
+    torch._aminmax,  # MISSING aten::_aminmax
+    torch._assert_async,  # MISSING aten::_assert_async
+    torch._choose_qparams_per_tensor,  # MISSING aten::min
+    torch._compute_linear_combination,  # MISSING aten::_compute_linear_combination
+    torch._det_lu_based_helper,  # MISSING aten::_det_lu_based_helper
+    torch._dirichlet_grad,  # MISSING aten::_dirichlet_grad
+    torch._fake_quantize_learnable_per_channel_affine,  # MISSING aten::_fake_quantize_learnable_per_channel_affine
+    torch._fake_quantize_learnable_per_tensor_affine,  # MISSING aten::_fake_quantize_learnable_per_tensor_affine
+    torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams,  # MISSING aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams  # noqa: E501
+    torch._foreach_abs,  # MISSING aten::_foreach_abs
+    torch._foreach_abs_,  # MISSING aten::_foreach_abs_
+    torch._foreach_acos,  # MISSING aten::_foreach_acos
+    torch._foreach_acos_,  # MISSING aten::_foreach_acos_
+    torch._foreach_add,  # MISSING aten::_foreach_add.Scalar
+    torch._foreach_add_,  # MISSING aten::_foreach_add_.Scalar
+    torch._foreach_addcdiv,  # MISSING aten::_foreach_addcdiv.Scalar
+    torch._foreach_addcdiv_,  # MISSING aten::_foreach_addcdiv_.Scalar
+    torch._foreach_addcmul,  # MISSING aten::_foreach_addcmul.Scalar
+    torch._foreach_addcmul_,  # MISSING aten::_foreach_addcmul_.Scalar
+    torch._foreach_asin,  # MISSING aten::_foreach_asin
+    torch._foreach_asin_,  # MISSING aten::_foreach_asin_
+    torch._foreach_atan,  # MISSING aten::_foreach_atan
+    torch._foreach_atan_,  # MISSING aten::_foreach_atan_
+    torch._foreach_ceil,  # MISSING aten::_foreach_ceil
+    torch._foreach_ceil_,  # MISSING aten::_foreach_ceil_
+    torch._foreach_cos,  # MISSING aten::_foreach_cos
+    torch._foreach_cos_,  # MISSING aten::_foreach_cos_
+    torch._foreach_cosh,  # MISSING aten::_foreach_cosh
+    torch._foreach_cosh_,  # MISSING aten::_foreach_cosh_
+    torch._foreach_div,  # MISSING aten::_foreach_div.Scalar
+    torch._foreach_div_,  # MISSING aten::_foreach_div_.ScalarList
+    torch._foreach_erf,  # MISSING aten::_foreach_erf
+    torch._foreach_erf_,  # MISSING aten::_foreach_erf_
+    torch._foreach_erfc,  # MISSING aten::_foreach_erfc
+    torch._foreach_erfc_,  # MISSING aten::_foreach_erfc_
+    torch._foreach_exp,  # MISSING aten::_foreach_exp
+    torch._foreach_exp_,  # MISSING aten::_foreach_exp_
+    torch._foreach_expm1,  # MISSING aten::_foreach_expm1
+    torch._foreach_expm1_,  # MISSING aten::_foreach_expm1_
+    torch._foreach_floor,  # MISSING aten::_foreach_floor
+    torch._foreach_floor_,  # MISSING aten::_foreach_floor_
+    torch._foreach_frac,  # MISSING aten::_foreach_frac
+    torch._foreach_frac_,  # MISSING aten::_foreach_frac_
+    torch._foreach_log,  # MISSING aten::_foreach_log
+    torch._foreach_log10,  # MISSING aten::_foreach_log10
+    torch._foreach_log10_,  # MISSING aten::_foreach_log10_
+    torch._foreach_log1p,  # MISSING aten::_foreach_log1p
+    torch._foreach_log1p_,  # MISSING aten::_foreach_log1p_
+    torch._foreach_log2,  # MISSING aten::_foreach_log2
+    torch._foreach_log2_,  # MISSING aten::_foreach_log2_
+    torch._foreach_log_,  # MISSING aten::_foreach_log_
+    torch._foreach_maximum,  # MISSING aten::_foreach_maximum.List
+    torch._foreach_minimum,  # MISSING aten::_foreach_minimum.List
+    torch._foreach_mul,  # MISSING aten::_foreach_mul.Scalar
+    torch._foreach_mul_,  # MISSING aten::_foreach_mul_.ScalarList
+    torch._foreach_neg,  # MISSING aten::_foreach_neg
+    torch._foreach_neg_,  # MISSING aten::_foreach_neg_
+    torch._foreach_norm,  # MISSING aten::_foreach_norm.Scalar
+    torch._foreach_reciprocal,  # MISSING aten::_foreach_reciprocal
+    torch._foreach_reciprocal_,  # MISSING aten::_foreach_reciprocal_
+    torch._foreach_round,  # MISSING aten::_foreach_round
+    torch._foreach_round_,  # MISSING aten::_foreach_round_
+    torch._foreach_sigmoid,  # MISSING aten::_foreach_sigmoid
+    torch._foreach_sigmoid_,  # MISSING aten::_foreach_sigmoid_
+    torch._foreach_sin,  # MISSING aten::_foreach_sin
+    torch._foreach_sin_,  # MISSING aten::_foreach_sin_
+    torch._foreach_sinh,  # MISSING aten::_foreach_sinh
+    torch._foreach_sinh_,  # MISSING aten::_foreach_sinh_
+    torch._foreach_sqrt,  # MISSING aten::_foreach_sqrt
+    torch._foreach_sqrt_,  # MISSING aten::_foreach_sqrt_
+    torch._foreach_sub,  # MISSING aten::_foreach_sub.Scalar
+    torch._foreach_sub_,  # MISSING aten::_foreach_sub_.ScalarList
+    torch._foreach_tan,  # MISSING aten::_foreach_tan
+    torch._foreach_tan_,  # MISSING aten::_foreach_tan_
+    torch._foreach_tanh,  # MISSING aten::_foreach_tanh
+    torch._foreach_tanh_,  # MISSING aten::_foreach_tanh_
+    torch._foreach_trunc,  # MISSING aten::_foreach_trunc
+    torch._foreach_trunc_,  # MISSING aten::_foreach_trunc_
+    torch._foreach_zero_,  # MISSING aten::_foreach_zero_
+    torch._fused_moving_avg_obs_fq_helper,  # MISSING aten::_fused_moving_avg_obs_fq_helper
+    torch._make_per_tensor_quantized_tensor,  # MISSING aten::_make_per_tensor_quantized_tensor
+    torch._masked_softmax,  # MISSING aten::_masked_softmax
+    torch._sample_dirichlet,  # MISSING aten::_sample_dirichlet
+    torch._standard_gamma,  # MISSING aten::_standard_gamma
+    torch._unique,  # MISSING aten::_unique
+    torch._unique2,  # MISSING aten::_unique2
+    torch.abs,  # MISSING aten::abs.out
+    torch.absolute,  # MISSING aten::abs.out
+    torch.addbmm,  # MISSING aten::addbmm
+    torch.angle,  # MISSING aten::angle
+    torch.batch_norm,  # MISSING aten::native_batch_norm
+    torch.bernoulli,  # MISSING aten::bernoulli.out
+    torch.bincount,  # MISSING aten::bincount
+    torch.binomial,  # MISSING aten::binomial
+    torch.bucketize,  # MISSING aten::bucketize.Tensor
+    torch.cholesky,  # MISSING aten::cholesky
+    torch.cholesky_inverse,  # MISSING aten::cholesky_inverse
+    torch.cholesky_solve,  # MISSING aten::_cholesky_solve_helper
+    torch.clip,  # MISSING aten::clamp.Tensor
+    torch.combinations,  # MISSING aten::masked_select
+    torch.complex,  # MISSING aten::complex.out
+    torch.conj_physical,  # MISSING aten::conj_physical.out
+    torch.corrcoef,  # MISSING aten::_local_scalar_dense
+    torch.count_nonzero,  # MISSING aten::count_nonzero.dim_IntList
+    torch.cov,  # MISSING aten::_local_scalar_dense
+    torch.cummax,  # MISSING aten::_cummax_helper
+    torch.cummin,  # MISSING aten::_cummin_helper
+    torch.det,  # MISSING aten::_det_lu_based_helper
+    torch.diag,  # MISSING aten::diag.out
+    torch.diagflat,  # MISSING aten::diag.out
+    torch.dot,  # MISSING aten::dot
+    torch.eig,  # MISSING aten::abs.out
+    torch.embedding,  # MISSING aten::index_select
+    torch.equal,  # MISSING aten::equal
+    torch.eye,  # MISSING aten::eye.m_out
+    torch.fake_quantize_per_channel_affine,  # MISSING aten::fake_quantize_per_channel_affine_cachemask
+    torch.fake_quantize_per_tensor_affine,  # MISSING aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams
+    torch.fft.fft,  # MISSING aten::_fft_r2c
+    torch.fft.fft2,  # MISSING aten::_fft_c2c
+    torch.fft.fftn,  # MISSING aten::_fft_c2c
+    torch.fft.fftshift,  # MISSING aten::roll
+    torch.fft.hfft2,  # MISSING aten::_fft_c2c
+    torch.fft.hfftn,  # MISSING aten::_fft_c2c
+    torch.fft.ifft,  # MISSING aten::_fft_r2c
+    torch.fft.ifft2,  # MISSING aten::_fft_c2c
+    torch.fft.ifftn,  # MISSING aten::_fft_c2c
+    torch.fft.ifftshift,  # MISSING aten::roll
+    torch.fft.ihfft,  # MISSING aten::_fft_r2c
+    torch.fft.ihfft2,  # MISSING aten::_fft_r2c
+    torch.fft.ihfftn,  # MISSING aten::_fft_r2c
+    torch.fft.irfft,  # MISSING aten::_fft_c2r
+    torch.fft.irfft2,  # MISSING aten::_fft_c2r
+    torch.fft.irfftn,  # MISSING aten::_fft_c2r
+    torch.fft.rfft,  # MISSING aten::_fft_r2c
+    torch.fft.rfft2,  # MISSING aten::_fft_r2c
+    torch.fft.rfftn,  # MISSING aten::_fft_r2c
+    torch.flip,  # MISSING aten::flip
+    torch.fliplr,  # MISSING aten::flip
+    torch.flipud,  # MISSING aten::flip
+    torch.floor_divide,  # MISSING aten::floor_divide
+    torch.frexp,  # MISSING aten::frexp.Tensor_out
+    torch.functional.cdist,  # MISSING aten::_cdist_forward
+    torch.functional.einsum,  # MISSING aten::dot
+    torch.functional.istft,  # MISSING aten::view_as_complex
+    torch.functional.pca_lowrank,  # MISSING aten::_linalg_qr_helper
+    torch.functional.stft,  # MISSING aten::_fft_r2c
+    torch.functional.svd_lowrank,  # MISSING aten::_linalg_qr_helper
+    torch.functional.tensordot,  # MISSING aten::tensordot.out
+    torch.functional.unique,  # MISSING aten::_unique2
+    torch.functional.unique_consecutive,  # MISSING aten::unique_consecutive
+    torch.fused_moving_avg_obs_fake_quant,  # MISSING aten::_fused_moving_avg_obs_fq_helper
+    torch.geqrf,  # MISSING aten::geqrf
+    torch.group_norm,  # MISSING aten::native_batch_norm
+    torch.histc,  # MISSING aten::histc.out
+    torch.histogram,  # MISSING aten::histogram.bin_ct
+    torch.histogramdd,  # MISSING aten::_histogramdd_bin_edges
+    torch.index_select,  # MISSING aten::index_select
+    torch.inner,  # MISSING aten::tensordot.out
+    torch.inverse,  # MISSING aten::_local_scalar_dense
+    torch.isnan,  # MISSING aten::isnan
+    torch.kthvalue,  # MISSING aten::kthvalue.values
+    torch.layer_norm,  # MISSING aten::native_batch_norm
+    torch.linalg.cholesky,  # MISSING aten::linalg_cholesky_ex
+    torch.linalg.cholesky_ex,  # MISSING aten::linalg_cholesky_ex
+    torch.linalg.det,  # MISSING aten::_det_lu_based_helper
+    torch.linalg.eig,  # MISSING aten::linalg_eig
+    torch.linalg.eig,  # MISSING aten::linalg_eig.out
+    torch.linalg.eigh,  # MISSING aten::linalg_eigh
+    torch.linalg.eigvals,  # MISSING aten::linalg_eig
+    torch.linalg.eigvalsh,  # MISSING aten::linalg_eigh
+    torch.linalg.eigvalsh,  # MISSING aten::linalg_eigvalsh.out
+    torch.linalg.householder_product,  # MISSING aten::linalg_householder_product
+    torch.linalg.inv,  # MISSING aten::_local_scalar_dense
+    torch.linalg.lstsq,  # MISSING aten::linalg_lstsq.out
+    torch.linalg.lu_factor,  # MISSING aten::_local_scalar_dense
+    torch.linalg.matmul,  # MISSING aten::dot
+    torch.linalg.matrix_exp,  # MISSING aten::linalg_matrix_exp
+    torch.linalg.matrix_norm,  # MISSING aten::abs.out
+    torch.linalg.matrix_power,  # MISSING aten::_local_scalar_dense
+    torch.linalg.matrix_power,  # MISSING aten::eye.m_out
+    torch.linalg.norm,  # MISSING aten::linalg_vector_norm
+    torch.linalg.pinv,  # MISSING aten::where.self
+    torch.linalg.qr,  # MISSING aten::_linalg_qr_helper
+    torch.linalg.slogdet,  # MISSING aten::linalg_slogdet
+    torch.linalg.solve,  # MISSING aten::linalg_solve
+    torch.linalg.solve_triangular,  # MISSING aten::linalg_solve_triangular
+    torch.linalg.tensorinv,  # MISSING aten::_local_scalar_dense
+    torch.linalg.tensorsolve,  # MISSING aten::linalg_solve
+    torch.linalg.vector_norm,  # MISSING aten::linalg_vector_norm
+    torch.logcumsumexp,  # MISSING aten::_logcumsumexp
+    torch.logdet,  # MISSING aten::abs.out
+    torch.logical_and,  # MISSING aten::logical_and.out
+    torch.logical_not,  # MISSING aten::logical_not.out
+    torch.logical_or,  # MISSING aten::logical_or.out
+    torch.logical_xor,  # MISSING aten::logical_xor.out
+    torch.logit,  # MISSING aten::logit
+    torch.logsumexp,  # MISSING aten::abs.out
+    torch.lstsq,  # MISSING aten::lstsq
+    torch.masked_select,  # MISSING aten::masked_select
+    torch.matmul,  # MISSING aten::dot
+    torch.matrix_exp,  # MISSING aten::linalg_matrix_exp
+    torch.matrix_power,  # MISSING aten::eye.m_out
+    torch.matrix_rank,  # MISSING aten::linalg_eigvalsh.out
+    torch.median,  # MISSING aten::median
+    torch.median,  # MISSING aten::median.dim_values
+    torch.mode,  # MISSING aten::mode
+    torch.multinomial,  # MISSING aten::multinomial
+    torch.mvlgamma,  # MISSING aten::_local_scalar_dense
+    torch.nan_to_num,  # MISSING aten::nan_to_num.out
+    torch.nanmean,  # MISSING aten::logical_not.out
+    torch.nanmedian,  # MISSING aten::nanmedian
+    torch.nanmedian,  # MISSING aten::nanmedian.dim_values
+    torch.nansum,  # MISSING aten::nansum
+    torch.nn.functional.adaptive_avg_pool1d,  # MISSING aten::_adaptive_avg_pool2d
+    torch.nn.functional.adaptive_avg_pool2d,  # MISSING aten::_adaptive_avg_pool2d
+    torch.nn.functional.adaptive_avg_pool3d,  # MISSING aten::_adaptive_avg_pool3d
+    torch.nn.functional.batch_norm,  # MISSING aten::native_batch_norm
+    torch.nn.functional.binary_cross_entropy,  # MISSING aten::binary_cross_entropy
+    torch.nn.functional.channel_shuffle,  # MISSING aten::channel_shuffle
+    torch.nn.functional.cosine_embedding_loss,  # MISSING aten::clamp_min.out
+    torch.nn.functional.cross_entropy,  # MISSING aten::_local_scalar_dense
+    torch.nn.functional.cross_entropy,  # MISSING aten::nll_loss2d_forward
+    torch.nn.functional.ctc_loss,  # MISSING aten::_ctc_loss
+    torch.nn.functional.embedding,  # MISSING aten::index_select
+    torch.nn.functional.embedding_bag,  # MISSING aten::_embedding_bag
+    torch.nn.functional.fold,  # MISSING aten::col2im
+    torch.nn.functional.gaussian_nll_loss,  # MISSING aten::_local_scalar_dense
+    torch.nn.functional.grid_sample,  # MISSING aten::grid_sampler_2d
+    torch.nn.functional.group_norm,  # MISSING aten::native_batch_norm
+    torch.nn.functional.hardswish,  # MISSING aten::hardswish
+    torch.nn.functional.hardtanh,  # MISSING aten::hardtanh
+    torch.nn.functional.hinge_embedding_loss,  # MISSING aten::clamp_min.out
+    torch.nn.functional.huber_loss,  # MISSING aten::huber_loss
+    torch.nn.functional.instance_norm,  # MISSING aten::native_batch_norm
+    torch.nn.functional.kl_div,  # MISSING aten::where.self
+    torch.nn.functional.l1_loss,  # MISSING aten::abs.out
+    torch.nn.functional.layer_norm,  # MISSING aten::native_batch_norm
+    torch.nn.functional.logsigmoid,  # MISSING aten::log_sigmoid_forward
+    torch.nn.functional.lp_pool1d,  # MISSING aten::abs.out
+    torch.nn.functional.lp_pool2d,  # MISSING aten::abs.out
+    torch.nn.functional.max_pool3d,  # MISSING aten::max_pool3d_with_indices
+    torch.nn.functional.max_pool3d_with_indices,  # MISSING aten::max_pool3d_with_indices
+    torch.nn.functional.max_unpool1d,  # MISSING aten::max_unpool2d
+    torch.nn.functional.max_unpool2d,  # MISSING aten::max_unpool2d
+    torch.nn.functional.max_unpool3d,  # MISSING aten::max_unpool3d
+    torch.nn.functional.multi_head_attention_forward,  # MISSING aten::logical_or.out
+    torch.nn.functional.multi_margin_loss,  # MISSING aten::multi_margin_loss
+    torch.nn.functional.multilabel_margin_loss,  # MISSING aten::multilabel_margin_loss_forward
+    torch.nn.functional.multilabel_soft_margin_loss,  # MISSING aten::log_sigmoid_forward
+    torch.nn.functional.nll_loss,  # MISSING aten::nll_loss2d_forward
+    torch.nn.functional.one_hot,  # MISSING aten::min
+    torch.nn.functional.pdist,  # MISSING aten::_pdist_forward
+    torch.nn.functional.prelu,  # MISSING aten::prelu
+    torch.nn.functional.relu,  # MISSING aten::relu
+    torch.nn.functional.relu6,  # MISSING aten::hardtanh
+    torch.nn.functional.rrelu,  # MISSING aten::rrelu_with_noise
+    torch.nn.functional.softsign,  # MISSING aten::abs.out
+    torch.nn.functional.unfold,  # MISSING aten::im2col
+    torch.nonzero,  # MISSING aten::nonzero
+    torch.normal,  # MISSING aten::min
+    torch.orgqr,  # MISSING aten::linalg_householder_product
+    torch.ormqr,  # MISSING aten::ormqr
+    torch.pinverse,  # MISSING aten::where.self
+    torch.poisson,  # MISSING aten::poisson
+    torch.polar,  # MISSING aten::polar.out
+    torch.prod,  # MISSING aten::prod
+    torch.qr,  # MISSING aten::_linalg_qr_helper
+    torch.quantize_per_channel,  # MISSING aten::quantize_per_channel
+    torch.quantize_per_tensor,  # MISSING aten::quantize_per_tensor
+    torch.quantize_per_tensor_dynamic,  # MISSING aten::quantize_per_tensor_dynamic
+    torch.relu,  # MISSING aten::relu
+    torch.repeat_interleave,  # MISSING aten::repeat_interleave.Tensor
+    torch.rnn_relu,  # MISSING aten::relu
+    torch.rnn_relu_cell,  # MISSING aten::relu
+    torch.roll,  # MISSING aten::roll
+    torch.rot90,  # MISSING aten::flip
+    torch.rsub,  # MISSING aten::rsub.Tensor
+    torch.searchsorted,  # MISSING aten::searchsorted.Tensor
+    torch.slogdet,  # MISSING aten::linalg_slogdet
+    torch.solve,  # MISSING aten::_solve_helper
+    torch.special.logit,  # MISSING aten::logit
+    torch.special.logsumexp,  # MISSING aten::abs.out
+    torch.special.multigammaln,  # MISSING aten::_local_scalar_dense
+    torch.square,  # MISSING aten::square.out
+    torch.std,  # MISSING aten::std.correction
+    torch.std_mean,  # MISSING aten::std_mean.correction
+    torch.symeig,  # MISSING aten::_symeig_helper
+    torch.take,  # MISSING aten::take
+    torch.threshold,  # MISSING aten::_local_scalar_dense
+    torch.trace,  # MISSING aten::trace
+    torch.var,  # MISSING aten::var.correction
+    torch.var_mean,  # MISSING aten::var_mean.correction
+    torch.vdot,  # MISSING aten::vdot
+    torch.where,  # MISSING aten::where.self
+    torch.quantile,  # MISSING aten::isnan
+    torch.nanquantile,  # MISSING aten::isnan
+}
+
+# Only some overloads/configurations are covered with meta tensors,
+# so we can't use these to toggle expected failure.  Try to prioritize these
+overload_exclude_set = {
+    torch.clamp,  # MISSING aten::clamp.Tensor
+    torch.max,  # MISSING aten::max
+    torch.min,  # MISSING aten::min
+    torch.nn.functional.interpolate,  # MISSING aten::upsample_nearest3d.vec
+    torch.nn.functional.upsample_nearest,  # MISSING aten::upsample_nearest3d.vec
+    torch.nn.functional.pad,  # MISSING aten::reflection_pad2d
+    torch.remainder,  # MISSING aten::remainder.Scalar_Tensor
+    torch.linalg.matrix_rank,  # MISSING aten::linalg_eigh
+    torch.Tensor.isinf,  # MISSING aten::abs.out
+    torch.isinf,  # MISSING aten::abs.out
+    torch.Tensor.isfinite,  # MISSING aten::abs.out
+    torch.isfinite,  # MISSING aten::abs.out
+    torch.diff,  # MISSING aten::logical_xor.out
+}
+
+# These are fine in OpInfo tests, but triggered errors in full test suite
+# crossref testing, which means there is probably not enough coverage from
+# OpInfo.  Patch in https://github.com/pytorch/pytorch/pull/75994 and find
+# out where these fails come from.
+suspicious_exclude_set = {
+    torch.add,  # MISSING aten::_local_scalar_dense
+    torch.cat,  # MISSING aten::_local_scalar_dense
+    torch.cumprod,  # MISSING aten::logical_and.out
+    torch.cumsum,  # MISSING aten::_local_scalar_dense
+    torch.functional.norm,  # MISSING aten::isnan
+    torch.linalg.cond,  # MISSING aten::abs.out
+    torch.sgn,  # MISSING aten::abs.out
+
+    # RuntimeError: Expected 3D or 4D (batch mode) tensor with optional 0 dim
+    # batch size for input, but got:[1, 1, 0]
+    # in test_nn.py TestNNDeviceTypeCPU.test_max_pool1d_corner_cases_cpu_float64
+    torch.nn.functional.max_pool1d,
+
+    # Factory functions need tricky kwarg handling
+    torch.zeros_like,
+}
+
+# These also are known to not work, but they fail in a more special way
+# than the regular "Meta not implemented for aten op" way
+meta_exclude_set |= {
+    # Convolutions have a special error message
+    torch.nn.functional.conv1d,
+    torch.nn.functional.conv2d,
+    torch.nn.functional.conv3d,
+    torch.nn.functional.conv_transpose1d,
+    torch.nn.functional.conv_transpose2d,
+    torch.nn.functional.conv_transpose3d,
+    # complex stuff handle it specially
+    torch.view_as_complex,
+    torch.view_as_real,
+    # These operators happen very frequently, although they should
+    # work with meta we intentionally don't test them to speed
+    # up the test suite
+    torch.Tensor.__getitem__,
+    torch.Tensor.__rsub__,
+    torch.Tensor.__setitem__,
+    torch.Tensor.add,
+    torch.Tensor.add_,
+    torch.Tensor.clone,
+    torch.Tensor.detach,
+    torch.Tensor.div,
+    torch.Tensor.gt,
+    torch.Tensor.lt,
+    torch.Tensor.mul,
+    torch.Tensor.reshape,
+    torch.Tensor.sub,
+    torch.Tensor.sum,
+    torch.rand,
+    # These correctly report NotImplemented but they don't print
+    # correctly from resolve_name
+    torch.ops.quantized.linear_dynamic,
+    torch._VF.unique_dim,
+    torch._C._nn.binary_cross_entropy,
+    torch._C._nn.adaptive_avg_pool2d,
+    torch._C._nn._test_optional_filled_intlist,
+    torch._C._nn._test_optional_floatlist,
+    torch._C._nn._test_optional_intlist,
+    # Meta tensors don't support storage Python bindings at the
+    # moment, to be fixed
+    torch.Tensor.storage,
+    torch.Tensor.storage_type,
+    torch.Tensor.share_memory_,
+    # Weird stuff that hypothetically should work but it's weird
+    torch._make_dual,
+    torch._unpack_dual,  # fails because we don't preserve forward ad tangent in test code
+    # These functions cannot, even in principle, be implemented on meta
+    # tensors (because they involve accessing data somehow), so don't test
+    # them.
+    torch.Tensor.__bool__,
+    torch.Tensor.__float__,
+    torch.Tensor.__int__,
+    torch.Tensor.__complex__,
+    torch.Tensor.__index__,
+    torch.Tensor.__contains__,
+    torch.Tensor.cpu,
+    torch.isclose,
+    torch.Tensor.to,
+    torch.Tensor.tolist,
+    torch.Tensor.unbind,
+    torch.Tensor.item,
+    torch.Tensor.is_nonzero,
+    torch.Tensor.copy_,
+    torch.Tensor.numpy,
+    torch.Tensor.allclose,
+    torch.Tensor.argwhere,
+    torch.allclose,
+    torch.argwhere,
+    torch.Tensor.__array__,  # doesn't raise NotImplementedError
+    torch.Tensor.__dlpack_device__,  # doesn't raise NotImplementedError
+    torch.Tensor.__dlpack__,  # doesn't raise NotImplementedError
+    torch.to_dlpack,  # doesn't raise NotImplementedError
+    # Utility functions that get frequently invoked; don't test
+    torch.Tensor.__format__,
+    torch.Tensor.__repr__,
+    # These are getters/setters for properties on tensors; it's not
+    # really useful to test meta tensors on them
+    torch.Tensor.device.__get__,
+    torch.Tensor.dtype.__get__,
+    torch.Tensor.grad.__get__,
+    torch.Tensor.grad.__set__,
+    torch.Tensor.is_sparse.__get__,
+    torch.Tensor.layout.__get__,
+    torch.Tensor.shape.__get__,
+    torch.Tensor.requires_grad.__get__,
+    torch.Tensor.requires_grad.__set__,
+    torch.Tensor.data.__get__,
+    torch.Tensor.data.__set__,
+    torch.Tensor._base.__get__,
+    torch.Tensor.is_shared,
+    torch.Tensor.imag.__get__,
+    torch.Tensor.real.__get__,
+    torch.Tensor.__setstate__,
+    torch.Tensor.is_complex,
+    torch.Tensor.is_floating_point,
+    torch.Tensor.numel,
+    torch.Tensor.requires_grad_,
+    torch.Tensor.size,
+    # These perturb RNG and can cause tests to fail, so don't run
+    # them (TODO: this is not a complete list)
+    torch.randint,
+    torch.randn,
+    # Indirect use of conjugate fallback
+    torch.fft.hfft,
+    # These don't raise NotImplementedError, which suggests something
+    # is wrong with how they're registered with the dispatcher
+    torch.fbgemm_pack_gemm_matrix_fp16,
+    torch.fbgemm_pack_quantized_matrix,
+    torch.fbgemm_linear_fp16_weight,
+    torch._empty_per_channel_affine_quantized,
+    torch.fbgemm_linear_int8_weight,
+    torch._grid_sampler_2d_cpu_fallback,  # WAT
+    torch._nnpack_spatial_convolution,
+    torch.lstm,
+    torch.Tensor.conj_physical_,
+    torch.rnn_tanh,
+    torch.fbgemm_linear_quantize_weight,
+    torch._reshape_from_tensor,
+    torch.gru,
+    torch.Tensor.unflatten,
+    torch._saturate_weight_to_fp16,
+    torch.choose_qparams_optimized,
+    torch._validate_sparse_coo_tensor_args,
+    torch.sparse.mm,
+    torch.Tensor.new,
+    torch.Tensor.resize,  # WTF is this
+    torch._sobol_engine_initialize_state_,
+    torch._sobol_engine_draw,
+    torch._sobol_engine_scramble_,
+    torch._sobol_engine_ff_,
+    torch.tensor_split,
+    torch.Tensor.tensor_split,
+    torch._pack_padded_sequence,
+    torch._pad_packed_sequence,
+    torch.sparse_coo_tensor,
+    torch.linalg.ldl_factor,
+    torch._index_reduce,
+    # IndexError: select() cannot be applied to a 0-dim tensor.
+    # e.g. test_fn_fwgrad_bwgrad_index_add_cpu_complex128 (__main__.TestGradientsCPU)
+    torch.index_add,
+    torch.Tensor.index_add,
+    torch.Tensor.index_add_,
+    # Can't copy out of meta tensor
+    torch.linalg.eigvals,
+    torch.linalg.lu_factor,
+    torch.nn.functional.ctc_loss,
+    # Our conversion to meta is not accurate enough (doesn't
+    # preserve storage_offset, e.g.)
+    torch.Tensor.as_strided,
+    # This one segfaults when you call it
+    torch.Tensor.type,
+    # We don't clone autograd history, so this will generally not work
+    torch.autograd.grad,
+    torch.Tensor.backward,
+    torch.Tensor.__deepcopy__,
+    # Don't do factories
+    torch.ones,
+    torch.full,
+    torch.empty,
+    torch.randperm,
+    torch.logspace,
+    torch.zeros,
+    torch.arange,
+    torch.vander,
+    torch.as_tensor,
+    torch.tensor,
+    torch.randn_like,
+    torch.sparse_csr_tensor,
+    torch._sparse_coo_tensor_unsafe,
+    torch._sparse_csr_tensor_unsafe,
+    torch._validate_sparse_csr_tensor_args,
+}
+
+# This is a __torch_function__ mode that, when enabled, interposes every
+# Torch API call and runs the operator as normal, and then reruns it
+# with meta inputs, and then checks that everything about the output agrees.
+# Most of the logic deals with faithfully replicating the original tensor
+# as a meta tensor, which is nontrivial because there are a lot of subsystems
+# that may potentially be exercised.
+#
+# That being said, this class is a little overkill for what it is doing in
+# this test file (since I could have just inlined __torch_function__ on the
+# OpInfo call, and OpInfos generally have very regular inputs), but it will be
+# useful for more comprehensive testing e.g., as seen in
+# https://github.com/pytorch/pytorch/pull/75994
+class MetaCrossRefMode(torch.overrides.TorchFunctionMode):
+    test_case: TestCase
+    run_excludes_anyway: bool
+
+    def __init__(self, test_case, *, run_excludes_anyway):
+        self.test_case = test_case
+        self.run_excludes_anyway = run_excludes_anyway
+
+    def __torch_function__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+
+        hit = 0
+        miss = 0
+
+        # Doesn't actually return a storage
+        @functools.lru_cache(None)
+        def meta_storage(s):
+            return torch.empty(s.size(), dtype=s.dtype, device='meta')
+
+        def safe_is_leaf(t):
+            try:
+                return t.is_leaf
+            except RuntimeError:
+                # inference mode can trigger this
+                return False
+
+        @functools.lru_cache(None)
+        def meta_tensor(t):
+            with torch.inference_mode(t.is_inference()):
+                s = meta_storage(t.storage())
+                is_leaf = safe_is_leaf(t)
+                if is_leaf or not t._is_view():
+                    r = torch.empty(
+                        (0,), dtype=t.dtype, device='meta'
+                    )
+                    r.set_(s, t.storage_offset(), t.size(), t.stride())
+                    r.requires_grad = t.requires_grad
+                    if not is_leaf and t.requires_grad:
+                        with torch.enable_grad():
+                            r = r.clone()
+                else:
+                    base = torch.empty(
+                        (0,), dtype=t.dtype, device='meta'
+                    )
+                    base.set_(s, 0, s.size(), (1,))
+                    base.requires_grad = t.requires_grad
+                    with torch.enable_grad():
+                        if t._is_view() and not safe_is_leaf(t._base):
+                            base = base.clone()
+                        r = base.as_strided(t.size(), t.stride(), t.storage_offset())
+                torch._C._set_conj(r, t.is_conj())
+                torch._C._set_neg(r, t.is_neg())
+            return r
+
+        def to_meta(t):
+            nonlocal hit, miss
+            # TODO: zero tensors?  We appear to have eliminated them by
+            # excluding complex for now
+            if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
+                if any([
+                    t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized,
+                    t.is_nested, torch._is_functional_tensor(t),
+                    # these are supported in meta conversion but the fallbacks
+                    # don't work
+                    t.is_neg(), t.is_conj(),
+                    # conjugate fallback does not support meta tensors
+                    t.dtype in (torch.complex128, torch.complex64),
+                ]):
+                    # TODO: sparse should support meta
+                    # NB technically to('meta') does work but our logging
+                    # instrumentation will see the meta conversions and the
+                    # tests all break so we just exclude this.  In any case
+                    # the to conversion isn't really right anyhow.
+                    miss += 1
+                    return t
+                elif any([
+                    t.device.type in ("lazy", "meta"), t.is_complex(),
+                    # We need a way to test if a tensor is batched but there
+                    # is no official APi to do it
+                    # torch._C._is_batched(t),
+                ]):
+                    # TODO: this stuff should support storage
+                    # (well, maybe not batched)
+                    hit += 1
+                    return t.to("meta")
+                else:
+                    hit += 1
+                    r = meta_tensor(t)
+                    if type(t) is torch.nn.Parameter:
+                        r = torch.nn.Parameter(r, requires_grad=r.requires_grad)
+                    return r
+            elif torch.overrides.is_tensor_like(t):
+                # Blindly converting tensor subclasses to meta can cause
+                # unpredictable problems; e.g., FX tests will trace meta
+                # tensors into their trace / some subclasses don't correctly
+                # support meta.  Trying to YOLO this is more trouble than it's
+                # worth.
+                miss += 1
+                return t
+            else:
+                # non-Tensor types don't count as hit or miss
+                return t
+
+        do_meta = (
+            (self.run_excludes_anyway or func not in meta_exclude_set) and
+            not torch.jit.is_tracing() and
+            not isinstance(func, torch.ScriptMethod)
+        )
+
+        if do_meta:
+            try:
+                meta_args = tree_map(to_meta, args)
+                meta_kwargs = tree_map(to_meta, kwargs)
+            except Exception as e:
+                raise RuntimeError(
+                    f"failed to convert args to meta; "
+                    f"originally (*{args}, **{kwargs})") from e
+
+        rs = func(*args, **kwargs)
+
+        # TODO: also handle cases where func raise an exception
+
+        # For now, only attempt if we managed to convert all tensor types
+        # (if any of them failed, we're in a mixed device situation and
+        # this isn't well supported)
+        if do_meta and hit > 0 and miss == 0:
+            try:
+                # suppress warnings
+                with warnings.catch_warnings():
+                    warnings.simplefilter("ignore")
+                    meta_rs = func(*meta_args, **meta_kwargs)
+            except Exception as e:
+                suppress = False
+                """
+                # This code can be helpful for full crossref test to filter
+                # out "pedestrian" omissions
+                if isinstance(e, NotImplementedError):
+                    m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0])
+                    if m and m.group(1) not in ("aten::_efficientzerotensor", "aten::view_as_real"):
+                        suppress = True
+                """
+                if not suppress:
+                    raise RuntimeError(f"""\
+failed to run: {func}(
+    *{meta_args},
+    **{meta_kwargs}
+  )""") from e
+            else:
+                def test_assert(cond, msg):
+                    if not cond:
+                        raise RuntimeError(f"""\
+meta disagrees with real impl:
+{func}(
+    *{meta_args},
+    **{meta_kwargs}
+) = {meta_r}
+{msg}
+""")
+                flat_meta_rs, _ = tree_flatten(meta_rs)
+                flat_rs, _ = tree_flatten(rs)
+                self.test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
+                for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs):
+                    if isinstance(r, torch.Tensor):
+                        test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
+                        test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}")
+                        test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
+                        test_assert(meta_r.stride() == r.stride(), f"but real stride was {r.stride()}")
+                        test_assert(
+                            meta_r.storage_offset() == r.storage_offset(),
+                            f"but real storage_offset was {r.storage_offset()}")
+                        test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}")
+                        test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}")
+                        test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}")
+
+        return rs
+
+class TestMeta(TestCase):
+    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
+    @onlyNativeDeviceTypes
+    @skipIfCrossRef
+    @suppress_warnings
+    @ops(op_db)
+    def test_meta(self, device, dtype, op):
+        # run the OpInfo sample inputs, cross-referencing them with the
+        # meta implementation and check the results are the same.  All
+        # the heavy lifting happens in MetaCrossRefMode
+        func = op.get_op()
+
+        def do_test(run_excludes_anyway=False):
+            samples = op.sample_inputs(device, dtype, requires_grad=False)
+            for sample_input in samples:
+                args = [sample_input.input] + list(sample_input.args)
+                kwargs = sample_input.kwargs
+                with push_torch_function_mode(partial(MetaCrossRefMode, self, run_excludes_anyway=run_excludes_anyway)):
+                    expected = func(*args, **kwargs)
+                    if isinstance(expected, torch.Tensor) and op.supports_out:
+                        func(*args, **kwargs, out=expected)
+
+        if func in overload_exclude_set:
+            self.skipTest('permanently excluded')
+        elif func in meta_exclude_set and dtype not in (torch.complex128, torch.complex64):
+            try:
+                do_test(run_excludes_anyway=True)
+            except Exception:
+                pass
+            else:
+                self.fail('expected failure, but succeeded')
+        else:
+            do_test()
+
+instantiate_device_type_tests(TestMeta, globals())
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 08d680c..b788fd7 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9709,6 +9709,7 @@
                # Reference: https://github.com/pytorch/pytorch/issues/50747
                DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
                             dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', dtypes=(torch.bool,)),
            ),
            sample_inputs_func=sample_inputs_addr,
            gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
@@ -11019,6 +11020,8 @@
                    skips=(
                        # Skip since real and imag don't have out variants.
                        DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta',
+                                    dtypes=(torch.complex32,)),
                    )),
     OpInfo('gradient',
            dtypes=floating_and_complex_types_and(torch.int8, torch.int16,
@@ -12143,6 +12146,7 @@
            skips=(
                # AssertionError: Resizing an out= argument with no elements threw a resize warning!
                DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cpu'),
            )),
     OpInfo('as_strided',
            op=lambda x, size, stride, storage_offset=0:
@@ -12868,6 +12872,8 @@
                # Pre-existing condition; Needs to be fixed
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'),
                # RuntimeError: "max_pool1d_impl" not implemented for 'BFloat16'
+               DecorateInfo(unittest.skip("Works on some configs"), 'TestMeta',
+                            'test_meta', dtypes=(torch.bfloat16,)),
                DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
                             'test_nnc_correctness', dtypes=(torch.bfloat16,)),
                DecorateInfo(unittest.skip("Works on some conifgs"), 'TestCudaFuserOpInfo',
@@ -13899,6 +13905,8 @@
                    skips=(
                        # Skip since real and imag don't have out variants.
                        DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta',
+                                    dtypes=(torch.complex32,)),
                    )),
     OpInfo('roll',
            ref=np.roll,
@@ -14779,6 +14787,9 @@
                DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
+               # stride mismatch
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda',
+                            dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
            )),
     OpInfo('linalg.svd',
            op=torch.linalg.svd,
@@ -14797,6 +14808,9 @@
                DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
+               # stride mismatch
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda',
+                            dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
            )),
     OpInfo('linalg.svdvals',
            op=torch.linalg.svdvals,
@@ -14835,6 +14849,8 @@
                DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
+               # stride mismatch
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM),
            )),
     OpInfo('pca_lowrank',
            op=lambda *args, **kwargs: wrapper_set_seed(
@@ -14859,6 +14875,8 @@
                DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
                DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
+               # stride mismatch
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM),
            )),
     BinaryUfuncInfo('polar',
                     dtypes=floating_types(),