Add dtype to torch.*_window; Add dtype.is_floating_point (#6158)
diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp
index a2dd6fa..241682b 100644
--- a/torch/csrc/Dtype.cpp
+++ b/torch/csrc/Dtype.cpp
@@ -35,10 +35,20 @@
}
}
+PyObject *THPDtype_is_floating_point(THPDtype *self)
+{
+ if (at::isFloatingType(self->scalar_type)) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
typedef PyObject *(*getter)(PyObject *, void *);
static struct PyGetSetDef THPDtype_properties[] = {
{"is_cuda", (getter)THPDtype_is_cuda, nullptr, nullptr, nullptr},
+ {"is_floating_point", (getter)THPDtype_is_floating_point, nullptr, nullptr, nullptr},
{nullptr}
};
diff --git a/torch/functional.py b/torch/functional.py
index ca4687d..f55d44c 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -157,7 +157,7 @@
return P, L, U
-def hann_window(window_length, periodic=True):
+def hann_window(window_length, periodic=True, dtype=torch.float32):
r"""Hann window function.
This method computes the Hann window function:
@@ -184,16 +184,20 @@
window_length (int): the size of returned window
periodic (bool, optional): If True, returns a window to be used as periodic
function. If False, return a symmetric window.
+ dtype (torch.dtype, optional): the desired type of returned window.
+ Default: `torch.float32`
Returns:
Tensor: A 1-D tensor of size :math:`(\text{window_length})` containing the window
"""
+ if not dtype.is_floating_point:
+ raise ValueError("dtype must be a floating point type, but got dtype={}".format(dtype))
if window_length <= 0:
raise ValueError('window_length must be positive')
- return hamming_window(window_length, periodic=periodic, alpha=0.5, beta=0.5)
+ return hamming_window(window_length, periodic=periodic, alpha=0.5, beta=0.5, dtype=dtype)
-def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46):
+def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, dtype=torch.float32):
r"""Hamming window function.
This method computes the Hamming window function:
@@ -222,23 +226,28 @@
window_length (int): the size of returned window
periodic (bool, optional): If True, returns a window to be used as periodic
function. If False, return a symmetric window.
+ dtype (torch.dtype, optional): the desired type of returned window.
+ Default: `torch.float32`
Returns:
Tensor: A 1-D tensor of size :math:`(window\_length)` containing the window
"""
+ if not dtype.is_floating_point:
+ raise ValueError("dtype must be a floating point type, but got dtype={}".format(dtype))
if window_length <= 0:
raise ValueError('window_length must be positive')
if window_length == 1:
- return torch.ones(window_length)
+ return torch.ones(window_length, dtype=dtype)
window_length += int(periodic)
- window = torch.arange(window_length).mul_(math.pi * 2 / (window_length - 1)).cos_().mul_(-beta).add_(alpha)
+ window = torch.arange(window_length, dtype=dtype)
+ window = window.mul_(math.pi * 2 / (window_length - 1)).cos_().mul_(-beta).add_(alpha)
if periodic:
return window[:-1]
else:
return window
-def bartlett_window(window_length, periodic=True):
+def bartlett_window(window_length, periodic=True, dtype=torch.float32):
r"""Bartlett window function.
This method computes the Bartlett window function:
@@ -267,16 +276,20 @@
window_length (int): the size of returned window
periodic (bool, optional): If True, returns a window to be used as periodic
function. If False, return a symmetric window.
+ dtype (torch.dtype, optional): the desired type of returned window.
+ Default: `torch.float32`
Returns:
Tensor: A 1-D tensor of size :math:`(window\_length)` containing the window
"""
+ if not dtype.is_floating_point:
+ raise ValueError("dtype must be a floating point type, but got dtype={}".format(dtype))
if window_length <= 0:
raise ValueError('window_length must be positive')
if window_length == 1:
- return torch.ones(window_length)
+ return torch.ones(window_length, dtype=dtype)
window_length += int(periodic)
- window = torch.arange(window_length).mul_(2.0 / (window_length - 1))
+ window = torch.arange(window_length, dtype=dtype).mul_(2.0 / (window_length - 1))
first_half_size = ((window_length - 1) >> 1) + 1
window.narrow(0, first_half_size, window_length - first_half_size).mul_(-1).add_(2)
if periodic: