[primTorch] Implement two-dimensional fft transforms (#80736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80736
Approved by: https://github.com/mruberry
diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py
index 137fa7e..b0fc40e 100644
--- a/torch/_refs/fft.py
+++ b/torch/_refs/fft.py
@@ -29,6 +29,12 @@
"irfftn",
"hfftn",
"ihfftn",
+ "fft2",
+ "ifft2",
+ "rfft2",
+ "irfft2",
+ "hfft2",
+ "ihfft2",
]
NormType = Union[None, Literal["forward"], Literal["backward"], Literal["ortho"]]
@@ -474,3 +480,69 @@
tmp = prims.conj_physical(tmp)
out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
return _apply_norm(out, norm, last_dim_size, forward=True)
+
+
+@register_decomposition(torch.ops.aten.fft_fft2)
+@out_wrapper()
+def fft2(
+ input: TensorLikeType,
+ s: Optional[ShapeType] = None,
+ dim: Optional[DimsType] = (-2, -1),
+ norm: NormType = None,
+) -> TensorLikeType:
+ return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_ifft2)
+@out_wrapper()
+def ifft2(
+ input: TensorLikeType,
+ s: Optional[ShapeType] = None,
+ dim: Optional[DimsType] = (-2, -1),
+ norm: NormType = None,
+) -> TensorLikeType:
+ return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_rfft2)
+@out_wrapper()
+def rfft2(
+ input: TensorLikeType,
+ s: Optional[ShapeType] = None,
+ dim: Optional[DimsType] = (-2, -1),
+ norm: NormType = None,
+) -> TensorLikeType:
+ return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_irfft2)
+@out_wrapper()
+def irfft2(
+ input: TensorLikeType,
+ s: Optional[ShapeType] = None,
+ dim: Optional[DimsType] = (-2, -1),
+ norm: NormType = None,
+) -> TensorLikeType:
+ return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_hfft2)
+@out_wrapper()
+def hfft2(
+ input: TensorLikeType,
+ s: Optional[ShapeType] = None,
+ dim: Optional[DimsType] = (-2, -1),
+ norm: NormType = None,
+) -> TensorLikeType:
+ return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
+
+
+@register_decomposition(torch.ops.aten.fft_ihfft2)
+@out_wrapper()
+def ihfft2(
+ input: TensorLikeType,
+ s: Optional[ShapeType] = None,
+ dim: Optional[DimsType] = (-2, -1),
+ norm: NormType = None,
+) -> TensorLikeType:
+ return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index defc82a..c27d8b8 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -21518,6 +21518,36 @@
torch_opinfo_name="fft.ihfftn",
supports_nvfuser=False,
),
+ SpectralFuncPythonRefInfo(
+ "_refs.fft.fft2",
+ torch_opinfo_name="fft.fft2",
+ supports_nvfuser=False,
+ ),
+ SpectralFuncPythonRefInfo(
+ "_refs.fft.ifft2",
+ torch_opinfo_name="fft.ifft2",
+ supports_nvfuser=False,
+ ),
+ SpectralFuncPythonRefInfo(
+ "_refs.fft.rfft2",
+ torch_opinfo_name="fft.rfft2",
+ supports_nvfuser=False,
+ ),
+ SpectralFuncPythonRefInfo(
+ "_refs.fft.irfft2",
+ torch_opinfo_name="fft.irfft2",
+ supports_nvfuser=False,
+ ),
+ SpectralFuncPythonRefInfo(
+ "_refs.fft.hfft2",
+ torch_opinfo_name="fft.hfft2",
+ supports_nvfuser=False,
+ ),
+ SpectralFuncPythonRefInfo(
+ "_refs.fft.ihfft2",
+ torch_opinfo_name="fft.ihfft2",
+ supports_nvfuser=False,
+ ),
]
# Common operator groupings