| #include <ATen/ATen.h> |
| #include <ATen/Config.h> |
| #include <ATen/NativeFunctions.h> |
| #include <ATen/detail/CUDAHooksInterface.h> |
| #include <ATen/native/SpectralOpsUtils.h> |
| #include <ATen/native/TensorIterator.h> |
| |
| #include <algorithm> |
| #include <vector> |
| #include <cmath> |
| |
| namespace at { namespace native { |
| |
| namespace { |
| |
| // Promote inputs to FFT functions |
| // * Integers are promoted to the default floating type |
| // * If require_complex=True, all types are promoted to complex |
| // * Raises an error for half-precision dtypes to allow future support |
| ScalarType promote_type_fft(ScalarType type, bool require_complex) { |
| if (at::isComplexType(type)) { |
| return type; |
| } |
| // Promote integral to default float type |
| if (!at::isFloatingType(type)) { |
| type = c10::typeMetaToScalarType(c10::get_default_dtype()); |
| } |
| |
| TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type); |
| |
| if (!require_complex) { |
| return type; |
| } |
| |
| // Promote to complex |
| switch (type) { |
| case kFloat: return kComplexFloat; |
| case kDouble: return kComplexDouble; |
| default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype"); |
| } |
| } |
| |
| // Promote a tensor's dtype according to promote_type_fft |
| Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) { |
| auto cur_type = t.scalar_type(); |
| auto new_type = promote_type_fft(cur_type, require_complex); |
| return (cur_type == new_type) ? t : t.to(new_type); |
| } |
| |
| // Convert NumPy compatible normalization mode string to enum values |
| // NOTE: NumPy's normalization modes have direction-specific meanings. For example, |
| // "forward" translates to `by_n` for a forward transform and `none` for backward. |
| fft_norm_mode norm_from_string(c10::optional<std::string> norm, bool forward) { |
| if (!norm || *norm == "backward") { |
| return forward ? fft_norm_mode::none : fft_norm_mode::by_n; |
| } |
| |
| if (*norm == "forward") { |
| return forward ? fft_norm_mode::by_n : fft_norm_mode::none; |
| } |
| |
| if (*norm == "ortho") { |
| return fft_norm_mode::by_root_n; |
| } |
| |
| TORCH_CHECK(false, "Invalid normalization mode: \"", *norm, "\"") |
| } |
| |
| // Fixes the shape of x such that x.size(dims[i]) == sizes[i], |
| // either by zero-padding, or by slicing x starting from 0. |
| Tensor resize_fft_input(Tensor x, IntArrayRef dims, IntArrayRef sizes) { |
| TORCH_INTERNAL_ASSERT(dims.size() == sizes.size()); |
| bool must_copy = false; |
| auto x_sizes = x.sizes(); |
| DimVector pad_amount(x_sizes.size() * 2); |
| for (int64_t i = 0; i < dims.size(); ++i) { |
| if (sizes[i] == -1) { |
| continue; |
| } |
| |
| if (x_sizes[dims[i]] < sizes[i]) { |
| must_copy = true; |
| auto pad_idx = pad_amount.size() - 2 * dims[i] - 1; |
| pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]; |
| } |
| |
| if (x_sizes[dims[i]] > sizes[i]) { |
| x = x.slice(dims[i], 0, sizes[i]); |
| } |
| } |
| |
| // Only call pad if necessary since pad copies the entire tensor |
| return must_copy ? at::constant_pad_nd(x, pad_amount) : x; |
| } |
| |
| // Complex to real FFT |
| Tensor fft_c2r(c10::string_view function_name, |
| Tensor out, Tensor input, c10::optional<int64_t> n_opt, |
| int64_t unwrapped_dim, c10::optional<std::string> norm_str, |
| bool forward) { |
| TORCH_CHECK(!out.defined() || out.is_floating_point(), function_name, |
| " expects a floating point output tensor, but got ", out.scalar_type()); |
| input = promote_tensor_fft(input, /*require_complex=*/true); |
| const auto input_dim = input.dim(); |
| const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); |
| const auto n = n_opt.value_or(2*(input.sizes()[dim] - 1)); |
| TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified"); |
| if (n_opt) { |
| input = resize_fft_input(input, dim, n/2 + 1); |
| } |
| const auto norm = norm_from_string(norm_str, forward); |
| if (forward) { |
| // FIXME: _fft does not support complex_output=false with inverse=false |
| input = at::conj(input); |
| } |
| if (out.defined()) { |
| return at::_fft_c2r_out(out, input, dim, static_cast<int64_t>(norm), n); |
| } else { |
| return at::_fft_c2r(input, dim, static_cast<int64_t>(norm), n); |
| } |
| } |
| |
| // Real to complex FFT |
| Tensor fft_r2c(c10::string_view function_name, |
| Tensor out, Tensor input, c10::optional<int64_t> n_opt, |
| int64_t unwrapped_dim, c10::optional<std::string> norm_str, |
| bool forward, bool onesided) { |
| TORCH_CHECK(!input.is_complex(), function_name, |
| " expects a real input tensor, but got ", input.scalar_type()); |
| TORCH_CHECK(!out.defined() || out.is_complex(), function_name, |
| " expects a complex output tensor, but got ", out.scalar_type()); |
| input = promote_tensor_fft(input); |
| const auto input_dim = input.dim(); |
| const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); |
| const auto n = n_opt.value_or(input.sizes()[dim]); |
| TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified"); |
| if (n_opt) { |
| input = resize_fft_input(input, dim, n); |
| } |
| |
| const auto norm = norm_from_string(norm_str, forward); |
| |
| Tensor ret; |
| if (out.defined() && forward) { |
| ret = at::_fft_r2c_out(out, input, dim, static_cast<int64_t>(norm), onesided); |
| } else { |
| ret = at::_fft_r2c(input, dim, static_cast<int64_t>(norm), onesided); |
| } |
| |
| if (!forward) { |
| // FIXME: _fft_r2c doesn't support native r2c IFFT |
| return out.defined() ? at::conj_out(out, ret) : at::conj(ret); |
| } else { |
| return ret; |
| } |
| } |
| |
| // Complex to complex FFT |
| Tensor fft_c2c(c10::string_view function_name, |
| Tensor out, Tensor input, c10::optional<int64_t> n_opt, |
| int64_t unwrapped_dim, c10::optional<std::string> norm_str, |
| bool forward) { |
| TORCH_CHECK(input.is_complex(), function_name, |
| " expects a complex input tensor, but got ", input.scalar_type()); |
| const auto input_dim = input.dim(); |
| const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); |
| const auto n = n_opt.value_or(input.sizes()[dim]); |
| TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified"); |
| if (n_opt) { |
| input = resize_fft_input(input, dim, n); |
| } |
| const auto norm = norm_from_string(norm_str, forward); |
| if (out.defined()) { |
| TORCH_CHECK(out.is_complex(), function_name, |
| " expects a complex output tensor, but got ", out.scalar_type()); |
| return at::_fft_c2c_out(out, input, dim, static_cast<int64_t>(norm), forward); |
| } else { |
| return at::_fft_c2c(input, dim, static_cast<int64_t>(norm), forward); |
| } |
| } |
| |
| // Dimensions to transform, and the signal shape in those dimensions |
| struct ShapeAndDims { |
| DimVector shape, dim; |
| }; |
| |
| // Pre-process n-dimensional fft's `s` and `dim` arguments. |
| // Wraps dimensions and applies defaulting behavior. |
| // Also checks transform dims are unique and transform shape is non-empty. |
| ShapeAndDims canonicalize_fft_shape_and_dim_args( |
| Tensor input, c10::optional<IntArrayRef> shape, c10::optional<IntArrayRef> dim) { |
| const int64_t input_dim = input.dim(); |
| const IntArrayRef input_sizes = input.sizes(); |
| ShapeAndDims ret; |
| |
| if (dim) { |
| ret.dim.resize(dim->size()); |
| std::copy(dim->begin(), dim->end(), ret.dim.begin()); |
| maybe_wrap_dims(ret.dim, input_dim); |
| |
| // Check dims are unique |
| DimVector copy = ret.dim; |
| std::sort(copy.begin(), copy.end()); |
| auto duplicate = std::adjacent_find(copy.begin(), copy.end()); |
| TORCH_CHECK(duplicate == copy.end(), "FFT dims must be unique"); |
| } |
| |
| if (shape) { |
| // Has shape, may have dim |
| TORCH_CHECK(!dim || dim->size() == shape->size(), |
| "When given, dim and shape arguments must have the same length"); |
| TORCH_CHECK(shape->size() <= input_dim, |
| "Got shape with ", shape->size(), " values but input tensor " |
| "only has ", input_dim, " dimensions."); |
| const int64_t transform_ndim = shape->size(); |
| // If shape is given, dims defaults to the last shape.size() dimensions |
| if (!dim) { |
| ret.dim.resize(transform_ndim); |
| std::iota(ret.dim.begin(), ret.dim.end(), input_dim - transform_ndim); |
| } |
| |
| // Translate shape of -1 to the default length |
| ret.shape.resize(transform_ndim); |
| for (int64_t i = 0; i < transform_ndim; ++i) { |
| const auto n = (*shape)[i]; |
| ret.shape[i] = n == -1 ? input_sizes[ret.dim[i]] : n; |
| } |
| } else if (!dim) { |
| // No shape, no dim |
| ret.dim.resize(input_dim); |
| std::iota(ret.dim.begin(), ret.dim.end(), int64_t{0}); |
| ret.shape.resize(input_dim); |
| std::copy(input_sizes.begin(), input_sizes.end(), ret.shape.begin()); |
| } else { |
| // No shape, has dim |
| ret.shape.resize(ret.dim.size()); |
| for (int64_t i = 0; i < ret.dim.size(); ++i) { |
| ret.shape[i] = input_sizes[ret.dim[i]]; |
| } |
| } |
| |
| for (int64_t i = 0; i < ret.shape.size(); ++i) { |
| TORCH_CHECK(ret.shape[i] > 0, |
| "Invalid number of data points (", ret.shape[i], ") specified"); |
| } |
| |
| return ret; |
| } |
| |
| // Complex to complex n-dimensional fft |
| Tensor fftn_c2c( |
| c10::string_view function_name, |
| Tensor out, const Tensor& input, IntArrayRef shape, |
| IntArrayRef dim, c10::optional<std::string> norm_str, bool forward) { |
| TORCH_CHECK(input.is_complex(), function_name, " expects a complex input tensor, but got", input.scalar_type()); |
| Tensor x = resize_fft_input(input, dim, shape); |
| const auto norm = norm_from_string(norm_str, forward); |
| if (out.defined()) { |
| TORCH_CHECK(out.is_complex(), function_name, " expects a complex output tensor, but got ", out.scalar_type()); |
| return at::_fft_c2c_out(out, x, dim, static_cast<int64_t>(norm), forward); |
| } else { |
| return at::_fft_c2c(x, dim, static_cast<int64_t>(norm), forward); |
| } |
| } |
| |
| } // namespace (anonymous) |
| |
| // torch.fft.fft, analogous to NumPy's numpy.fft.fft |
| Tensor fft_fft(const Tensor& self, c10::optional<int64_t> n, int64_t dim, |
| c10::optional<std::string> norm) { |
| return self.is_complex() ? |
| fft_c2c("fft", {}, self, n, dim, norm, /*forward=*/true) : |
| fft_r2c("fft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/false); |
| } |
| |
| Tensor& fft_fft_out(const Tensor& self, c10::optional<int64_t> n, |
| int64_t dim, c10::optional<std::string> norm, Tensor& out) { |
| if (self.is_complex()) { |
| fft_c2c("fft", out, self, n, dim, norm, /*forward=*/true); |
| } else { |
| fft_r2c("fft", out, self, n, dim, norm, /*forward=*/true, /*onesided=*/false); |
| } |
| return out; |
| } |
| |
| Tensor fft_ifft(const Tensor& self, c10::optional<int64_t> n, int64_t dim, |
| c10::optional<std::string> norm) { |
| return self.is_complex() ? |
| fft_c2c("ifft", {}, self, n, dim, norm, /*forward=*/false) : |
| fft_r2c("ifft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/false); |
| } |
| |
| Tensor& fft_ifft_out(const Tensor& self, c10::optional<int64_t> n, |
| int64_t dim, c10::optional<std::string> norm, Tensor& out) { |
| if (self.is_complex()) { |
| fft_c2c("ifft", out, self, n, dim, norm, /*forward=*/false); |
| } else { |
| fft_r2c("ifft", out, self, n, dim, norm, /*forward=*/false, /*onesided=*/false); |
| } |
| return out; |
| } |
| |
| Tensor fft_rfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim, |
| c10::optional<std::string> norm) { |
| return fft_r2c("rfft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/true); |
| } |
| |
| Tensor& fft_rfft_out(const Tensor& self, c10::optional<int64_t> n, |
| int64_t dim, c10::optional<std::string> norm, Tensor& out) { |
| fft_r2c("rfft", out, self, n, dim, norm, /*forward=*/true, /*onesided=*/true); |
| return out; |
| } |
| |
| Tensor fft_irfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim, |
| c10::optional<std::string> norm) { |
| return fft_c2r("irfft", {}, self, n, dim, norm, /*forward=*/false); |
| } |
| |
| Tensor& fft_irfft_out(const Tensor& self, c10::optional<int64_t> n, |
| int64_t dim, c10::optional<std::string> norm, Tensor& out) { |
| fft_c2r("irfft", out, self, n, dim, norm, /*forward=*/false); |
| return out; |
| } |
| |
| Tensor fft_hfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim, |
| c10::optional<std::string> norm) { |
| return fft_c2r("hfft", {}, self, n, dim, norm, /*forward=*/true); |
| } |
| |
| Tensor& fft_hfft_out(const Tensor& self, c10::optional<int64_t> n, |
| int64_t dim, c10::optional<std::string> norm, Tensor& out) { |
| fft_c2r("hfft", out, self, n, dim, norm, /*forward=*/true); |
| return out; |
| } |
| |
| Tensor fft_ihfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim, |
| c10::optional<std::string> norm) { |
| return fft_r2c("ihfft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/true); |
| } |
| |
| Tensor& fft_ihfft_out(const Tensor& self, c10::optional<int64_t> n, |
| int64_t dim, c10::optional<std::string> norm, Tensor& out) { |
| fft_r2c("ihfft", out, self, n, dim, norm, /*forward=*/false, /*onesided=*/true); |
| return out; |
| } |
| |
| Tensor fft_fftn(const Tensor& self, c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| c10::optional<std::string> norm) { |
| auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); |
| // TODO: For real input, perform rfftn then mirror with conjugate symmetry |
| Tensor input = promote_tensor_fft(self, /*require_complex=*/true); |
| return fftn_c2c("fftn", {}, input, desc.shape, desc.dim, norm, /*forward=*/true); |
| } |
| |
| Tensor& fft_fftn_out(const Tensor& self, |
| c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| c10::optional<std::string> norm, Tensor& out) { |
| auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); |
| // TODO: For real input, perform rfftn then mirror with conjugate symmetry |
| Tensor input = promote_tensor_fft(self, /*require_complex=*/true); |
| fftn_c2c("fftn", out, input, desc.shape, desc.dim, norm, /*forward=*/true); |
| return out; |
| } |
| |
| Tensor fft_ifftn(const Tensor& self, c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| c10::optional<std::string> norm) { |
| auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); |
| Tensor input = promote_tensor_fft(self, /*require_complex=*/true); |
| return fftn_c2c("ifftn", {}, input, desc.shape, desc.dim, norm, /*forward=*/false); |
| } |
| |
| Tensor& fft_ifftn_out(const Tensor& self, |
| c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| c10::optional<std::string> norm, Tensor& out) { |
| auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); |
| Tensor input = promote_tensor_fft(self, /*require_complex=*/true); |
| fftn_c2c("ifftn", out, input, desc.shape, desc.dim, norm, /*forward=*/false); |
| return out; |
| } |
| |
| static Tensor fft_rfftn_impl(Tensor out, const Tensor& self, |
| c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| const c10::optional<std::string>& norm_str) { |
| TORCH_CHECK(!self.is_complex(), "rfftn expects a real-valued input tensor, but got ", self.scalar_type()); |
| auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); |
| TORCH_CHECK(desc.shape.size() > 0, "rfftn must transform at least one axis"); |
| Tensor input = promote_tensor_fft(self, /*require_complex=*/false); |
| Tensor x = resize_fft_input(input, desc.dim, desc.shape); |
| const auto norm = norm_from_string(norm_str, /*forward=*/true); |
| if (out.defined()) { |
| TORCH_CHECK(out.is_complex(), "rfftn expects a complex-valued output tensor, but got ", out.scalar_type()); |
| return at::_fft_r2c_out(out, x, desc.dim, static_cast<int64_t>(norm), /*onesided=*/true); |
| } else { |
| return at::_fft_r2c(x, desc.dim, static_cast<int64_t>(norm), /*onesided=*/true); |
| } |
| } |
| |
| Tensor fft_rfftn(const Tensor& self, c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| c10::optional<std::string> norm_str) { |
| return fft_rfftn_impl({}, self, s, dim, norm_str); |
| } |
| |
| Tensor& fft_rfftn_out(const Tensor& self, |
| c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| c10::optional<std::string> norm_str, Tensor& out) { |
| fft_rfftn_impl(out, self, s, dim, norm_str); |
| return out; |
| } |
| |
| static Tensor fft_irfftn_impl(Tensor out, const Tensor& self, |
| c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| const c10::optional<std::string>& norm_str) { |
| auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); |
| TORCH_CHECK(desc.shape.size() > 0, "irfftn must transform at least one axis"); |
| |
| const auto last_dim_size = [&] { |
| // Fixup default shape handling in the last dimension, |
| if (!s.has_value() || (s->back() == -1)) { |
| const auto last_dim = desc.dim.back(); |
| return 2 * (self.sizes()[last_dim] - 1); |
| } |
| return desc.shape.back(); |
| }(); |
| desc.shape.back() = last_dim_size / 2 + 1; |
| |
| Tensor input = promote_tensor_fft(self, /*require_complex=*/true); |
| Tensor x = resize_fft_input(input, desc.dim, desc.shape); |
| const auto norm = norm_from_string(norm_str, /*forward=*/false); |
| if (out.defined()) { |
| TORCH_CHECK(out.is_floating_point(), "irfftn expects a floating point output tensor, but got ", out.scalar_type()); |
| return at::_fft_c2r_out(out, x, desc.dim, static_cast<int64_t>(norm), last_dim_size); |
| } else { |
| return at::_fft_c2r(x, desc.dim, static_cast<int64_t>(norm), last_dim_size); |
| } |
| } |
| |
| Tensor fft_irfftn(const Tensor& self, |
| c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| c10::optional<std::string> norm_str) { |
| return fft_irfftn_impl({}, self, s, dim, norm_str); |
| } |
| |
| Tensor& fft_irfftn_out(const Tensor& self, |
| c10::optional<IntArrayRef> s, |
| c10::optional<IntArrayRef> dim, |
| c10::optional<std::string> norm_str, Tensor& out) { |
| fft_irfftn_impl(out, self, s, dim, norm_str); |
| return out; |
| } |
| |
| Tensor fft_fft2(const Tensor& self, c10::optional<IntArrayRef> s, |
| IntArrayRef dim, c10::optional<std::string> norm) { |
| return native::fft_fftn(self, s, dim, std::move(norm)); |
| } |
| |
| Tensor& fft_fft2_out(const Tensor& self, c10::optional<IntArrayRef> s, |
| IntArrayRef dim, c10::optional<std::string> norm, Tensor& out) { |
| return native::fft_fftn_out(self, s, dim, std::move(norm), out); |
| } |
| |
| Tensor fft_ifft2(const Tensor& self, c10::optional<IntArrayRef> s, |
| IntArrayRef dim, c10::optional<std::string> norm) { |
| return native::fft_ifftn(self, s, dim, std::move(norm)); |
| } |
| |
| Tensor& fft_ifft2_out(const Tensor& self, c10::optional<IntArrayRef> s, |
| IntArrayRef dim, c10::optional<std::string> norm, Tensor& out) { |
| return native::fft_ifftn_out(self, s, dim, std::move(norm), out); |
| } |
| |
| Tensor fft_rfft2(const Tensor& self, c10::optional<IntArrayRef> s, |
| IntArrayRef dim, c10::optional<std::string> norm) { |
| return native::fft_rfftn(self, s, dim, std::move(norm)); |
| } |
| |
| Tensor& fft_rfft2_out(const Tensor& self, c10::optional<IntArrayRef> s, |
| IntArrayRef dim, c10::optional<std::string> norm, Tensor& out) { |
| return native::fft_rfftn_out(self, s, dim, std::move(norm), out); |
| } |
| |
| Tensor fft_irfft2(const Tensor& self, c10::optional<IntArrayRef> s, |
| IntArrayRef dim, c10::optional<std::string> norm) { |
| return native::fft_irfftn(self, s, dim, std::move(norm)); |
| } |
| |
| Tensor& fft_irfft2_out(const Tensor& self, c10::optional<IntArrayRef> s, |
| IntArrayRef dim, c10::optional<std::string> norm, Tensor& out) { |
| return native::fft_irfftn_out(self, s, dim, std::move(norm), out); |
| } |
| |
| Tensor& fft_fftfreq_out(int64_t n, double d, Tensor& out) { |
| ScalarType dtype = out.scalar_type(); |
| TORCH_CHECK(at::isFloatingType(dtype) || at::isComplexType(dtype), |
| "fftfreq requires a floating point or complex dtype"); |
| // TODO: arange doesn't have complex support |
| at::arange_out(out, n); |
| auto right_slice = out.slice(0, (n + 1) / 2, 0); |
| at::arange_out(right_slice, -(n/2), 0, 1); |
| return out.mul_(1.0 / (n * d)); // Slightly faster than div_(n*d) |
| } |
| |
| Tensor fft_fftfreq(int64_t n, double d, |
| c10::optional<ScalarType> dtype, |
| c10::optional<Layout> layout, |
| c10::optional<Device> device, |
| c10::optional<bool> pin_memory) { |
| // See [Note: hacky wrapper removal for TensorOptions] |
| TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); |
| |
| auto out = at::empty({n}, options); |
| return native::fft_fftfreq_out(n, d, out); |
| } |
| |
| Tensor& fft_rfftfreq_out(int64_t n, double d, Tensor& out) { |
| ScalarType dtype = out.scalar_type(); |
| TORCH_CHECK(at::isFloatingType(dtype) || at::isComplexType(dtype), |
| "rfftfreq requires a floating point or complex dtype"); |
| // TODO: arange doesn't have complex support |
| native::arange_out(n/2 + 1, out); |
| return out.mul_(1.0 / (n * d)); // Slightly faster than div_(n*d) |
| } |
| |
| Tensor fft_rfftfreq(int64_t n, double d, |
| c10::optional<ScalarType> dtype, |
| c10::optional<Layout> layout, |
| c10::optional<Device> device, |
| c10::optional<bool> pin_memory) { |
| // See [Note: hacky wrapper removal for TensorOptions] |
| TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); |
| |
| auto out = at::empty({n/2 + 1}, options); |
| return native::fft_rfftfreq_out(n, d, out); |
| } |
| |
| // If an array dim is specified, wraps them according to self.dim(). |
| // Otherwise returns a vector of all dims. |
| DimVector default_alldims(const Tensor& self, c10::optional<IntArrayRef> dim_opt) { |
| DimVector dim; |
| if (dim_opt) { |
| IntArrayRef dim_unwrapped = *dim_opt; |
| dim.resize(dim_unwrapped.size()); |
| for (int64_t i = 0; i < dim.size(); ++i) { |
| dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim()); |
| } |
| } else { |
| dim.resize(self.dim()); |
| std::iota(dim.begin(), dim.end(), 0); |
| } |
| return dim; |
| } |
| |
| Tensor fft_fftshift(const Tensor& x, c10::optional<IntArrayRef> dim_opt) { |
| auto dim = default_alldims(x, dim_opt); |
| |
| IntArrayRef x_sizes = x.sizes(); |
| DimVector shift(dim.size()); |
| for (int64_t i = 0; i < dim.size(); ++i) { |
| shift[i] = x_sizes[dim[i]] / 2; |
| } |
| |
| return at::roll(x, shift, dim); |
| } |
| |
| Tensor fft_ifftshift(const Tensor& x, c10::optional<IntArrayRef> dim_opt) { |
| auto dim = default_alldims(x, dim_opt); |
| |
| IntArrayRef x_sizes = x.sizes(); |
| DimVector shift(dim.size()); |
| for (int64_t i = 0; i < dim.size(); ++i) { |
| shift[i] = (x_sizes[dim[i]] + 1) / 2; |
| } |
| |
| return at::roll(x, shift, dim); |
| } |
| |
| |
| // We call the following methods via CUDA hooks because they are really only |
| // valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details. |
| int64_t _cufft_get_plan_cache_max_size(int64_t device_index) { |
| return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize(device_index); |
| } |
| |
| void _cufft_set_plan_cache_max_size(int64_t device_index, int64_t max_size) { |
| detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(device_index, max_size); |
| } |
| |
| int64_t _cufft_get_plan_cache_size(int64_t device_index) { |
| return detail::getCUDAHooks().cuFFTGetPlanCacheSize(device_index); |
| } |
| |
| void _cufft_clear_plan_cache(int64_t device_index) { |
| detail::getCUDAHooks().cuFFTClearPlanCache(device_index); |
| } |
| |
| template <typename Stream, typename T> |
| static Stream& write_opt(Stream& SS, const optional<T>& value) { |
| if (value) { |
| SS << *value; |
| } else { |
| SS << "None"; |
| } |
| return SS; |
| } |
| |
| /* Short-time Fourier Transform, for signal analysis. |
| * |
| * This is modeled after librosa but with support for complex time-domain |
| * signals and complex windows. |
| * |
| * NOTE: librosa's center and pad_mode arguments are currently only implemented |
| * in python because it uses torch.nn.functional.pad which is python-only. |
| */ |
| Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt, |
| const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt, |
| const bool normalized, const optional<bool> onesidedOpt, |
| const optional<bool> return_complexOpt) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt); |
| const Tensor& window = *window_maybe_owned; |
| |
| #define REPR(SS) \ |
| SS << "stft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \ |
| << ", hop_length=" << hop_length << ", win_length=" << win_length \ |
| << ", window="; \ |
| if (window.defined()) { \ |
| SS << window.toString() << "{" << window.sizes() << "}"; \ |
| } else { \ |
| SS << "None"; \ |
| } \ |
| SS << ", normalized=" << normalized << ", onesided="; \ |
| write_opt(SS, onesidedOpt) << ", return_complex="; \ |
| write_opt(SS, return_complexOpt) << ") " |
| |
| TORCH_CHECK(!window.defined() || window.device() == self.device(), |
| "stft input and window must be on the same device but got self on ", |
| self.device(), " and window on ", window.device()) |
| |
| // default_init hop_length and win_length |
| auto hop_length = hop_lengthOpt.value_or(n_fft >> 2); |
| auto win_length = win_lengthOpt.value_or(n_fft); |
| const bool return_complex = return_complexOpt.value_or( |
| self.is_complex() || (window.defined() && window.is_complex())); |
| if (!return_complex) { |
| if (!return_complexOpt.has_value()) { |
| TORCH_WARN_ONCE( |
| "stft will soon require the return_complex parameter be given for real inputs, " |
| "and will further require that return_complex=True in a future PyTorch release." |
| ); |
| } |
| |
| |
| // TORCH_WARN_ONCE( |
| // "stft with return_complex=False is deprecated. In a future pytorch " |
| // "release, stft will return complex tensors for all inputs, and " |
| // "return_complex=False will raise an error.\n" |
| // "Note: you can still call torch.view_as_real on the complex output to " |
| // "recover the old return format."); |
| } |
| |
| if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected a tensor of floating point or complex values"; |
| AT_ERROR(ss.str()); |
| } |
| if (self.dim() > 2 || self.dim() < 1) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected a 1D or 2D tensor"; |
| AT_ERROR(ss.str()); |
| } |
| Tensor input = self; |
| if (self.dim() == 1) { |
| input = input.unsqueeze(0); |
| } |
| int64_t batch = input.size(0); |
| int64_t len = input.size(1); |
| if (n_fft <= 0 || n_fft > len) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected 0 < n_fft < " << len |
| << ", but got n_fft=" << win_length; |
| AT_ERROR(ss.str()); |
| } |
| if (hop_length <= 0) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected hop_length > 0, but got hop_length=" << hop_length; |
| AT_ERROR(ss.str()); |
| } |
| if (win_length <= 0 || win_length > n_fft) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected 0 < win_length <= n_fft, but got win_length=" |
| << win_length; |
| AT_ERROR(ss.str()); |
| } |
| if (window.defined() && (window.dim() != 1 || window.size(0) != win_length)) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected a 1D window tensor of size equal to win_length=" |
| << win_length << ", but got window with size " << window.sizes(); |
| AT_ERROR(ss.str()); |
| } |
| #undef REPR |
| auto window_ = window; |
| if (win_length < n_fft) { |
| // pad center |
| auto left = (n_fft - win_length) / 2; |
| if (window.defined()) { |
| window_ = at::zeros({n_fft}, window.options()); |
| window_.narrow(0, left, win_length).copy_(window); |
| } else { |
| window_ = at::zeros({n_fft}, self.options()); |
| window_.narrow(0, left, win_length).fill_(1); |
| } |
| } |
| int64_t n_frames = 1 + (len - n_fft) / hop_length; |
| // time2col |
| input = input.as_strided( |
| {batch, n_frames, n_fft}, |
| {input.stride(0), hop_length * input.stride(1), input.stride(1)} |
| ); |
| if (window_.defined()) { |
| input = input.mul(window_); |
| } |
| |
| // FFT and transpose to get (batch x fft_size x num_frames) |
| const bool complex_fft = input.is_complex(); |
| const auto onesided = onesidedOpt.value_or(!complex_fft); |
| |
| const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none; |
| Tensor out; |
| if (complex_fft) { |
| TORCH_CHECK(!onesided, "Cannot have onesided output if window or input is complex"); |
| out = at::_fft_c2c(input, input.dim() - 1, static_cast<int64_t>(norm), /*forward=*/true); |
| } else { |
| out = at::_fft_r2c(input, input.dim() - 1, static_cast<int64_t>(norm), onesided); |
| } |
| out.transpose_(1, 2); |
| |
| if (self.dim() == 1) { |
| out.squeeze_(0); |
| } |
| |
| if (return_complex) { |
| return out; |
| } else { |
| return at::view_as_real(out); |
| } |
| } |
| |
| // Create complex tensor from the old style of real tensor with size=(..., 2) |
| // This is to support istft in the transition to requiring complex input. |
| // NOTE: This may return a view of the input tensor, or might clone if necessary |
| static Tensor as_complex(const Tensor& self) { |
| const bool can_view_as_complex = [&]{ |
| auto strides = self.strides(); |
| for (int64_t i = 0; i + 1 < strides.size(); ++i) { |
| if (strides[i] % 2 != 0) { |
| return false; |
| } |
| } |
| return strides.back() == 1 && self.storage_offset() % 2 == 0; |
| }(); |
| return at::view_as_complex(can_view_as_complex ? self : self.clone(MemoryFormat::Contiguous)); |
| } |
| |
| /* Inverse Short-time Fourier Transform |
| * |
| * This is modeled after librosa but with support for complex time-domain |
| * signals and complex windows. |
| */ |
| Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt, |
| const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt, |
| const bool center, const bool normalized, const c10::optional<bool> onesidedOpt, |
| const optional<int64_t> lengthOpt, const bool return_complex) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt); |
| const Tensor& window = *window_maybe_owned; |
| |
| #define REPR(SS) \ |
| SS << "istft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \ |
| << ", hop_length=" << hop_length << ", win_length=" << win_length \ |
| << ", window="; \ |
| if (window.defined()) { \ |
| SS << window.toString() << "{" << window.sizes() << "}"; \ |
| } else { \ |
| SS << "None"; \ |
| } \ |
| SS << ", center=" << center << ", normalized=" << normalized << ", onesided="; \ |
| write_opt(SS, onesidedOpt) << ", length="; \ |
| write_opt(SS, lengthOpt) << ", return_complex=" << return_complex << ") " |
| |
| TORCH_CHECK(!window.defined() || window.device() == self.device(), |
| "istft input and window must be on the same device but got self on ", |
| self.device(), " and window on ", window.device()) |
| |
| // default_init hop_length and win_length |
| const auto hop_length = hop_lengthOpt.value_or(n_fft >> 2); |
| const auto win_length = win_lengthOpt.value_or(n_fft); |
| |
| if (!self.is_complex()) { |
| TORCH_WARN_ONCE( |
| "istft will require a complex-valued input tensor in a future PyTorch release. " |
| "Matching the output from stft with return_complex=True. "); |
| } |
| Tensor input = self.is_complex() ? at::view_as_real(self) : self; |
| const auto input_dim = input.dim(); |
| const auto n_frames = input.size(-2); |
| const auto fft_size = input.size(-3); |
| |
| const auto expected_output_signal_len = n_fft + hop_length * (n_frames - 1); |
| |
| const auto options = at::device(input.device()).dtype(input.dtype()); |
| if (input.numel() == 0) { |
| std::ostringstream ss; |
| REPR(ss) << ": input tensor cannot be empty."; |
| AT_ERROR(ss.str()); |
| } |
| if (input_dim != 3 && input_dim != 4) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected a tensor with 3 or 4 dimensions, but got " << input_dim; |
| AT_ERROR(ss.str()); |
| } |
| if (input.size(-1) != 2) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected the last dimension to be 2 (corresponding to real and imaginary parts), but got " << self.size(-1); |
| AT_ERROR(ss.str()); |
| } |
| |
| const bool onesided = onesidedOpt.value_or(fft_size != n_fft); |
| if (onesided) { |
| if (n_fft / 2 + 1 != fft_size) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onsided=True, but got " << fft_size; |
| AT_ERROR(ss.str()); |
| } |
| } else { |
| if (n_fft != fft_size) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onsided=False, but got " << fft_size; |
| AT_ERROR(ss.str()); |
| } |
| } |
| |
| if (!(0 < hop_length && hop_length <= win_length)) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected 0 < hop_length <= win_length"; |
| AT_ERROR(ss.str()); |
| } |
| |
| if (!(0 < win_length && win_length <= n_fft)) { |
| std::ostringstream ss; |
| REPR(ss) << ": expected 0 < win_length <= n_fft"; |
| AT_ERROR(ss.str()); |
| } |
| if (window.defined()) { |
| if (window.dim() != 1 || window.size(0) != win_length) { |
| std::ostringstream ss; |
| REPR(ss) << ": Invalid window shape. window has to be 1D and length of `win_length`"; |
| AT_ERROR(ss.str()); |
| } |
| } |
| |
| Tensor window_tmp = window.defined() ? window : at::ones({win_length,}, options); |
| if (win_length != n_fft) { |
| // center window by padding zeros on right and left side |
| int64_t left = (n_fft - win_length) / 2; |
| window_tmp = at::constant_pad_nd(window_tmp, {left, n_fft - win_length - left}, 0); |
| TORCH_INTERNAL_ASSERT(window_tmp.size(0) == n_fft); |
| } |
| |
| if (input_dim == 3) { |
| input = input.unsqueeze(0); |
| } |
| |
| input = as_complex(input.transpose(1, 2)); // size: (channel, n_frames, fft_size, 2) |
| |
| const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n; |
| if (return_complex) { |
| TORCH_CHECK(!onesided, "Cannot have onesided output if window or input is complex"); |
| input = at::_fft_c2c(input, input.dim() - 1, static_cast<int64_t>(norm), /*forward=*/false); // size: (channel, n_frames, n_fft) |
| } else { |
| TORCH_CHECK(!window.defined() || !window.is_complex(), |
| "Complex windows are incompatible with return_complex=False"); |
| if (!onesided) { |
| input = input.slice(-1, 0, n_fft / 2 + 1); |
| } |
| input = at::_fft_c2r(input, input.dim() - 1, static_cast<int64_t>(norm), n_fft); // size: (channel, n_frames, n_fft) |
| } |
| TORCH_INTERNAL_ASSERT(input.size(2) == n_fft); |
| |
| Tensor y_tmp = input * window_tmp.view({1, 1, n_fft}); // size: (channel, n_frames, n_fft) |
| y_tmp = y_tmp.transpose(1, 2); // size: (channel, n_fft, frame) |
| |
| Tensor y = at::col2im(y_tmp, |
| /*output_size*/ {1, (n_frames - 1) * hop_length + n_fft}, |
| /*kernel_size*/ {1, n_fft}, |
| /*dilation*/ {1, 1}, |
| /*padding*/ {0, 0}, |
| /*stride*/ {1, hop_length} |
| ).squeeze(2); |
| window_tmp = window_tmp.pow(2).view({n_fft, 1}).repeat({1, n_frames}).unsqueeze(0); // size: (1, n_fft, n_frames) |
| Tensor window_envelop = at::col2im(window_tmp, |
| /*output_size*/ {1, (n_frames - 1) * hop_length + n_fft}, |
| /*kernel_size*/ {1, n_fft}, |
| /*dilation*/ {1, 1}, |
| /*padding*/ {0, 0}, |
| /*stride*/ {1, hop_length} |
| ).squeeze(2); // size: (1, 1, expected_output_signal_len) |
| |
| TORCH_INTERNAL_ASSERT(expected_output_signal_len == y.size(2)); |
| TORCH_INTERNAL_ASSERT(expected_output_signal_len == window_envelop.size(2)); |
| |
| // We need to trim the front padding away if centered |
| const auto start = center ? n_fft / 2 : 0; |
| const auto end = lengthOpt.has_value()? start + lengthOpt.value() : - n_fft / 2; |
| |
| y = y.slice(2, start, end, 1); |
| window_envelop = window_envelop.slice(2, start, end, 1); |
| const auto window_envelop_lowest = window_envelop.abs().min().item().toDouble(); |
| if (window_envelop_lowest < 1e-11) { |
| std::ostringstream ss; |
| REPR(ss) << "window overlap add min: " << window_envelop_lowest; |
| AT_ERROR(ss.str()); |
| } |
| |
| y = (y / window_envelop).squeeze(1); // size: (channel, expected_output_signal_len) |
| if (input_dim == 3) { |
| y = y.squeeze(0); |
| } |
| return y; |
| |
| #undef REPR |
| } |
| |
| Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt, |
| const optional<int64_t> win_lengthOpt, const Tensor& window, |
| const bool normalized, const optional<bool> onesidedOpt) { |
| return at::native::stft( |
| self, n_fft, hop_lengthOpt, win_lengthOpt, window, normalized, onesidedOpt, |
| /*return_complex=*/c10::nullopt); |
| } |
| |
| Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt, |
| const optional<int64_t> win_lengthOpt, const Tensor& window, |
| const bool center, const bool normalized, const optional<bool> onesidedOpt, |
| const optional<int64_t> lengthOpt) { |
| return at::native::istft( |
| self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized, |
| onesidedOpt, lengthOpt, /*return_complex=*/false); |
| } |
| |
| void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { |
| const auto input_sizes = input.sizes(); |
| const auto input_strides = input.strides(); |
| TORCH_CHECK(dim_.size() > 0); |
| DimVector dim(dim_.begin(), dim_.end()); |
| at::maybe_wrap_dims(dim, input_strides.size()); |
| |
| if (input.numel() == 0 || input_sizes[dim.back()] <= 2) { |
| return; // No elements need writing |
| } |
| |
| // Small dimensions may be treated as batch dims since they don't get mirrored |
| dim.erase( |
| std::remove_if(dim.begin(), dim.end(), [&](int64_t dim) { |
| return (input_sizes[dim] <= 2); |
| }), |
| dim.end()); |
| |
| // Use TensorIterator to coalesce batch dimensions |
| // NOTE: Can't use TensorIterator loops because we need negative strides |
| auto iter = TensorIteratorConfig() |
| .add_output(input) |
| .add_input(input) |
| .resize_outputs(false) |
| .declare_static_shape(input_sizes, dim) |
| .build(); |
| |
| const auto iter_strides = iter.strides(0); |
| const auto iter_sizes = iter.shape(); |
| const auto ndim = iter_strides.size() + dim.size(); |
| DimVector in_strides(ndim), signal_half_sizes(ndim); |
| // Take coalesced batch dimensions from TensorIterator |
| std::copy(iter_strides.begin(), iter_strides.end(), in_strides.begin()); |
| std::copy(iter_sizes.begin(), iter_sizes.end(), signal_half_sizes.begin()); |
| |
| // Take transformed dimensions directly from the input |
| const auto element_size = iter.element_size(0); |
| for (int64_t i = 0; i < dim.size(); ++i) { |
| // Convert to byte strides to match TensorIterator |
| in_strides[iter_strides.size() + i] = input_strides[dim[i]] * element_size; |
| signal_half_sizes[iter_strides.size() + i] = input_sizes[dim[i]]; |
| } |
| |
| // For the last dimension, use negative strides to perform the mirroring |
| signal_half_sizes.back() = (input_sizes[dim.back()] - 1) / 2; |
| auto out_strides = in_strides; |
| out_strides.back() *= -1; |
| |
| auto* data_ptr = static_cast<char*>(input.data_ptr()); |
| const auto* in_data = data_ptr + input_strides[dim.back()] * element_size; |
| auto* out_data = data_ptr + ( |
| input_strides[dim.back()] * (input_sizes[dim.back()] - 1) * element_size); |
| |
| // Reorder dimensions by stride to maximize data locality |
| DimVector dim_permute(ndim); |
| std::iota(dim_permute.begin(), dim_permute.end(), 0); |
| std::sort(dim_permute.begin(), dim_permute.end(), |
| [&](auto dim1, auto dim2) { |
| return in_strides[dim1] < in_strides[dim2]; |
| }); |
| |
| DimVector temp(ndim); |
| auto apply_permutation = [&] (DimVector & vec) { |
| // Do permuted index copy into a temporary, then copy back |
| for (int64_t i = 0; i < ndim; ++i) { |
| temp[i] = vec[dim_permute[i]]; |
| } |
| vec = temp; |
| }; |
| apply_permutation(in_strides); |
| apply_permutation(out_strides); |
| apply_permutation(signal_half_sizes); |
| |
| // Find dims.slice(dims.size() - 1) in the new permuted order. |
| // These are the dimensions that need explicit Hermitian mirroring |
| DimVector mirror_dims; |
| mirror_dims.reserve(dim.size() - 1); |
| for (int64_t i = 0; i < ndim; ++i) { |
| if (dim_permute[i] >= iter_strides.size() && // Not a batch dimension |
| dim_permute[i] != ndim - 1) { // Not the last dim, which is mirrored separately with negative strides |
| mirror_dims.push_back(i); |
| } |
| } |
| TORCH_INTERNAL_ASSERT(mirror_dims.size() == dim.size() - 1); |
| |
| // Dispatch to CPU or CUDA kernel to do the actual conjugate mirroring |
| fft_fill_with_conjugate_symmetry_stub( |
| input.device().type(), input.scalar_type(), |
| mirror_dims, signal_half_sizes, in_strides, in_data, out_strides, out_data); |
| } |
| |
| DEFINE_DISPATCH(fft_fill_with_conjugate_symmetry_stub); |
| |
| }} // at::native |