[primTorch] Implement two-dimensional fft transforms (#80736)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80736
Approved by: https://github.com/mruberry
diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py
index 137fa7e..b0fc40e 100644
--- a/torch/_refs/fft.py
+++ b/torch/_refs/fft.py
@@ -29,6 +29,12 @@
     "irfftn",
     "hfftn",
     "ihfftn",
+    "fft2",
+    "ifft2",
+    "rfft2",
+    "irfft2",
+    "hfft2",
+    "ihfft2",
 ]
 
 NormType = Union[None, Literal["forward"], Literal["backward"], Literal["ortho"]]
@@ -474,3 +480,69 @@
     tmp = prims.conj_physical(tmp)
     out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
     return _apply_norm(out, norm, last_dim_size, forward=True)
+
+
+@register_decomposition(torch.ops.aten.fft_fft2)
+@out_wrapper()
+def fft2(
+    input: TensorLikeType,
+    s: Optional[ShapeType] = None,
+    dim: Optional[DimsType] = (-2, -1),
+    norm: NormType = None,
+) -> TensorLikeType:
+    return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_ifft2)
+@out_wrapper()
+def ifft2(
+    input: TensorLikeType,
+    s: Optional[ShapeType] = None,
+    dim: Optional[DimsType] = (-2, -1),
+    norm: NormType = None,
+) -> TensorLikeType:
+    return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_rfft2)
+@out_wrapper()
+def rfft2(
+    input: TensorLikeType,
+    s: Optional[ShapeType] = None,
+    dim: Optional[DimsType] = (-2, -1),
+    norm: NormType = None,
+) -> TensorLikeType:
+    return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_irfft2)
+@out_wrapper()
+def irfft2(
+    input: TensorLikeType,
+    s: Optional[ShapeType] = None,
+    dim: Optional[DimsType] = (-2, -1),
+    norm: NormType = None,
+) -> TensorLikeType:
+    return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_hfft2)
+@out_wrapper()
+def hfft2(
+    input: TensorLikeType,
+    s: Optional[ShapeType] = None,
+    dim: Optional[DimsType] = (-2, -1),
+    norm: NormType = None,
+) -> TensorLikeType:
+    return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_ihfft2)
+@out_wrapper()
+def ihfft2(
+    input: TensorLikeType,
+    s: Optional[ShapeType] = None,
+    dim: Optional[DimsType] = (-2, -1),
+    norm: NormType = None,
+) -> TensorLikeType:
+    return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index defc82a..c27d8b8 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -21518,6 +21518,36 @@
         torch_opinfo_name="fft.ihfftn",
         supports_nvfuser=False,
     ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fft2",
+        torch_opinfo_name="fft.fft2",
+        supports_nvfuser=False,
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifft2",
+        torch_opinfo_name="fft.ifft2",
+        supports_nvfuser=False,
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfft2",
+        torch_opinfo_name="fft.rfft2",
+        supports_nvfuser=False,
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfft2",
+        torch_opinfo_name="fft.irfft2",
+        supports_nvfuser=False,
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfft2",
+        torch_opinfo_name="fft.hfft2",
+        supports_nvfuser=False,
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfft2",
+        torch_opinfo_name="fft.ihfft2",
+        supports_nvfuser=False,
+    ),
 ]
 
 # Common operator groupings