Add OpInfo tests for torch.{dot, vdot, bmm, mv} (#56409)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56409

Reviewed By: nikithamalgifb

Differential Revision: D27870769

Pulled By: anjali411

fbshipit-source-id: a1a0e89856529a4739c7612c5b1e3c5ed2569126
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 2a2cf06..99f6c8b 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -5385,10 +5385,8 @@
                 'expand', 'rot90', 'transpose',
                 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
                 'chunk', 'split', 'split_with_sizes', 'zero_',
-                '__radd__', 'sum', 'mul',
-                '__rmul__', 'dot', 'vdot', 'matmul',
-                'bmm', 'mv', 'ger', 'diagonal', 'fill_', 'sub',
-                'mean', 'inverse', 'linalg.tensorinv', 'matrix_exp',
+                '__radd__', 'mul', '__rmul__', 'matmul',
+                'diagonal', 'fill_', 'sub',
                 'narrow', 'swapaxes', 'swapdims', 'tensor_split',
                 'baddbmm'] + complex_list_filter + separate_complex_tests
 
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 9752960..80b3384 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -658,6 +658,36 @@
     else:
         return (input, )
 
+def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
+    return (
+        SampleInput(
+            make_tensor((S, M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
+            args=(
+                make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
+            )
+        ),
+    )
+
+def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs):
+    return (
+        SampleInput(
+            make_tensor((M, S, M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
+            args=(
+                make_tensor((M, M, S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
+            )
+        ),
+    )
+
+def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs):
+    return (
+        SampleInput(
+            make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
+            args=(
+                make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
+            )
+        ),
+    )
+
 def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs):
     test_cases = (((S,), (S, M), (M,), 1, 1, False),
                   ((S,), (S, M), (M,), 0.2, 0.6, False),
@@ -3047,8 +3077,7 @@
     OpInfo('addbmm',
            dtypes=floating_types(),
            dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
-           dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
-                                           *[torch.bfloat16] if CUDA11OrLater else []),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
            dtypesIfROCM=floating_types_and(torch.half),
            skips=(
                # addbmm does not correctly warn when resizing out= inputs
@@ -3061,6 +3090,50 @@
                SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.bfloat16, ),
                         device_type='cuda', active_if=not SM53OrLater)),
            sample_inputs_func=sample_inputs_addbmm),
+    OpInfo('dot',
+           dtypes=all_types_and_complex_and(torch.float16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16),
+           skips=(
+               # dot does not handle correctly out= dtypes
+               # https://github.com/pytorch/pytorch/issues/55561
+               SkipInfo('TestCommon', 'test_out'),
+           ),
+           assert_autodiffed=True,
+           sample_inputs_func=sample_inputs_dot_vdot),
+    OpInfo('vdot',
+           dtypes=all_types_and_complex_and(torch.float16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16),
+           skips=(
+               # vdot does not handle correctly out= dtypes
+               # https://github.com/pytorch/pytorch/issues/55561
+               SkipInfo('TestCommon', 'test_out'),
+           ),
+           sample_inputs_func=sample_inputs_dot_vdot),
+    OpInfo('bmm',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
+           assert_autodiffed=True,
+           skips=(
+               # bmm does not correctly warn when resizing out= inputs
+               SkipInfo('TestCommon', 'test_out'),
+               # cuda gradchecks are slow
+               # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
+               SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
+               SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.bfloat16, ),
+                        device_type='cuda', active_if=not SM53OrLater)),
+           sample_inputs_func=sample_inputs_bmm),
+    OpInfo('mv',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
+           skips=(
+               # bmm does not correctly warn when resizing out= inputs
+               SkipInfo('TestCommon', 'test_out'),
+               SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.float16,)),
+               # mv calls into addmv which doesn't fully support float16
+               # RuntimeError: "addmv_impl_cpu" not implemented for 'Half'
+               SkipInfo('TestOpInfo', 'test_supported_dtypes', dtypes=(torch.float16,)),),
+           assert_autodiffed=True,
+           sample_inputs_func=sample_inputs_mv),
     OpInfo('addr',
            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            # Reference: https://github.com/pytorch/pytorch/issues/50747
@@ -5270,10 +5343,6 @@
         ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
         ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), ident,
          {'beta': 0.2, 'alpha': 0.6}),
-        ('dot', (L,), ((L,),), '', (True,)),
-        ('vdot', (L,), ((L,),),),
-        ('bmm', (M, S, M), ((M, M, S),), '', (True,)),
-        ('mv', (S, M), ((M,),), '', (True,)),
         ('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"),
         ('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"),
         ('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"),