[functorch] Fix a bunch of linalg coverage issuez (pytorch/functorch#765)
diff --git a/functorch/codegen/gen_vmap_plumbing.py b/functorch/codegen/gen_vmap_plumbing.py
index 28162b7..8b8e439 100644
--- a/functorch/codegen/gen_vmap_plumbing.py
+++ b/functorch/codegen/gen_vmap_plumbing.py
@@ -164,6 +164,27 @@
}}"""
+def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
+ schema = native_function.func
+ sig = DispatcherSignature.from_schema(schema)
+ cur_level_var = 'cur_level'
+
+ unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
+ bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var)
+
+ return f"""\
+template <typename batch_rule_t, batch_rule_t batch_rule>
+{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
+ c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+ auto maybe_layer = maybeCurrentDynamicLayer();
+ TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
+ int64_t {cur_level_var} = maybe_layer->layerId();
+{textwrap.indent(bdims_all_none_case, " ")}
+{textwrap.indent(unwraps, " ")}
+ batch_rule({', '.join(unwrapped_arg_list)});
+}}"""
+
+
def gen_vmap_plumbing(native_function: NativeFunction) -> str:
schema = native_function.func
sig = DispatcherSignature.from_schema(schema)
@@ -171,7 +192,7 @@
# Only support cases where all returns are Tensors or vector<Tensor>
if len(returns) == 0:
- return None
+ return gen_vmap_plumbing_no_returns(native_function)
if not all(ret.type.is_tensor_like() for ret in returns):
return None
if not accepts_at_least_one_tensor_input(schema):
@@ -256,6 +277,7 @@
'logaddexp',
'logaddexp2',
'lcm',
+ '_linalg_check_errors',
'maximum',
'minimum',
'mul.Tensor',
diff --git a/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp b/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp
index 7720b82..79ec97e 100644
--- a/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp
+++ b/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp
@@ -156,6 +156,12 @@
return at::add(self * beta, at::mm(mat1, mat2), alpha);
}
+void _linalg_check_errors_batch_rule(const Tensor& info, optional<int64_t> info_bdim, c10::string_view api_name, bool is_matrix) {
+ auto info_ = moveBatchDimToFront(info, info_bdim);
+ // Not a matrix means this is a batch of matrices
+ at::_linalg_check_errors(info_, api_name, false);
+}
+
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(bmm, bmm_batch_rule);
m.impl("addmv", addmv_decomp);
@@ -167,6 +173,8 @@
VMAP_SUPPORT(mm, mm_batch_rule);
m.impl("linear", linear_decomp);
+ VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule);
+
VARIADIC_BDIMS_BOXED(cholesky_solve);
VARIADIC_BDIMS_BOXED(linalg_cholesky_ex);
VARIADIC_BDIMS_BOXED(linalg_eig);
diff --git a/functorch/functorch/csrc/VmapGeneratedPlumbing.h b/functorch/functorch/csrc/VmapGeneratedPlumbing.h
index 289cfc7..db8259a 100644
--- a/functorch/functorch/csrc/VmapGeneratedPlumbing.h
+++ b/functorch/functorch/csrc/VmapGeneratedPlumbing.h
@@ -3642,6 +3642,20 @@
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
}
template <typename batch_rule_t, batch_rule_t batch_rule>
+void _linalg_check_errors_generated_plumbing(const at::Tensor & info, c10::string_view api_name, bool is_matrix) {
+ c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
+ auto maybe_layer = maybeCurrentDynamicLayer();
+ TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
+ int64_t cur_level = maybe_layer->layerId();
+ if (!isBatchedAtLevel(info, cur_level)) {
+ return at::_ops::_linalg_check_errors::call(info, api_name, is_matrix);
+ }
+ Tensor info_value;
+ optional<int64_t> info_bdim;
+ std::tie(info_value, info_bdim) = unwrapTensorAtLevel(info, cur_level);
+ batch_rule(info_value, info_bdim, api_name, is_matrix);
+}
+template <typename batch_rule_t, batch_rule_t batch_rule>
at::Tensor cholesky_generated_plumbing(const at::Tensor & self, bool upper) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto maybe_layer = maybeCurrentDynamicLayer();
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index 7382097..7773de3 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -569,13 +569,18 @@
skip('normal', 'number_mean'), # randomness
xfail('nn.functional.dropout'), # randomness
xfail('as_strided'), # as_strided is too wild for us to support, wontfix
+ xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
+ xfail('masked_scatter'), # dynamic
+ xfail('nn.functional.fractional_max_pool2d'), # random
+ xfail('nn.functional.fractional_max_pool3d'), # random
+ xfail('take'), # dynamic
# All of the following are bugs and need to be fixed
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
xfail('__getitem__', ''),
xfail('_masked.prod'), # calls aten::item
xfail('block_diag'),
- xfail('eig'),
+ xfail('eig'), # calls aten::item
xfail('fft.ihfft'),
xfail('fft.ihfft'),
xfail('fft.ihfft2'),
@@ -585,35 +590,25 @@
xfail('fft.rfft2'),
xfail('fft.rfftn'),
xfail('index_copy'),
- xfail('index_put', ''),
- xfail('linalg.cholesky'),
- xfail('linalg.det', ''),
+ xfail('linalg.det', ''), # calls .item()
xfail('linalg.eig'), # Uses aten::allclose
- xfail('linalg.eigh'),
- xfail('linalg.householder_product'),
- xfail('linalg.inv'),
- xfail('linalg.lu_factor', ''),
+ xfail('linalg.eigh'), # needs diag_scatter
+ xfail('linalg.householder_product'), # needs select_scatter
xfail('linalg.matrix_norm'),
- xfail('linalg.matrix_power'),
xfail('linalg.norm'),
xfail('linalg.norm', 'subgradients_at_zero'),
- xfail('linalg.slogdet'),
- xfail('linalg.tensorinv'),
- xfail('logdet'),
- xfail('lu_solve'),
- xfail('lu_unpack'),
- xfail('masked_scatter'),
- xfail('matrix_exp'),
- xfail('nanquantile'),
- xfail('nn.functional.fractional_max_pool2d'),
- xfail('nn.functional.fractional_max_pool3d'),
- xfail('nn.functional.gaussian_nll_loss'),
- xfail('prod'),
+ xfail('linalg.slogdet'), # calls .item()
+ xfail('logdet'), # calls .item()
+ xfail('lu_solve'), # requires .contiguous() call somewhere rather than fallback
+ xfail('lu_unpack'), # would benefit from narrow_scatter
+ xfail('matrix_exp'), # would benefit from narrow_scatter
+ xfail('nanquantile'), # checks q via a .item() call
+ xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
+ xfail('prod'), # calls nonzero
xfail('put'),
- xfail('quantile'),
+ xfail('quantile'), # checks q via a .item() call
xfail('stft'),
- xfail('symeig'),
- xfail('take'),
+ xfail('symeig'), # would benefit from diag_scatter
xfail('view_as_complex'),
# required rank 4 tensor to use channels_last format
@@ -751,9 +746,6 @@
xfail('nn.functional.soft_margin_loss', ''),
xfail('linalg.norm', 'subgradients_at_zero'),
xfail('nn.functional.binary_cross_entropy_with_logits', ''),
- xfail('linalg.inv'),
- xfail('linalg.tensorinv'),
- xfail('linalg.matrix_power'),
xfail('linalg.norm'),
xfail('linalg.householder_product'),
xfail('tensor_split'),
@@ -761,7 +753,6 @@
xfail('var_mean'),
xfail('as_strided'),
xfail('fill_'),
- xfail('linalg.cholesky'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('std_mean'),
xfail('block_diag'),
@@ -772,7 +763,6 @@
xfail('view_as_complex'),
xfail('prod'),
- xfail('linalg.lu_factor', ''),
skip('nn.functional.dropout2d', ''),
skip('nn.functional.feature_alpha_dropout', 'without_train'),
skip('pca_lowrank', ''),
@@ -823,6 +813,8 @@
xfail('cumprod'),
xfail('lu_solve'),
xfail('linalg.lstsq', 'grad_oriented'),
+ xfail('linalg.cholesky'),
+ xfail('linalg.qr'),
xfail('cross'),
xfail('qr'),
xfail('linalg.pinv'),
@@ -922,16 +914,13 @@
xfail('linalg.householder_product'),
xfail('linalg.lstsq', ''),
xfail('linalg.lstsq', 'grad_oriented'),
- xfail('linalg.inv'),
xfail('linalg.matrix_norm'),
- xfail('linalg.matrix_power'),
xfail('linalg.norm'),
xfail('linalg.pinv'),
xfail('linalg.qr'),
xfail('linalg.pinv', 'hermitian'),
xfail('linalg.slogdet'),
xfail('linalg.solve'),
- xfail('linalg.tensorinv'),
xfail('logdet'),
xfail('lu'),
xfail('lu_solve'),