Back out "Revert D34524207: [pytorch][PR] remove _s_where" (#73579)
Summary:
Original commit changeset: 87b1220d851c
Original Phabricator Diff: D34524207 (https://github.com/pytorch/pytorch/commit/4eb248256801103b08726bf5d85496641cebcdbb) (https://github.com/pytorch/pytorch/commit/4eb248256801103b08726bf5d85496641cebcdbb)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73579
Test Plan:
OSS tests
tested with canary https://www.internalfb.com/intern/ads/canary/441912928798660873
Reviewed By: ezyang
Differential Revision: D34688237
Pulled By: ngimel
fbshipit-source-id: 32f3a0046053ef52e95ab45a26bfc1de17e7e061
(cherry picked from commit d1c0acbe3e0ff884c429072923a468ee1d3d447d)
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 0114deb..f5cff20 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -324,20 +324,25 @@
}
Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
- TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(),
- "Expected condition, x and y to be on the same device, but condition is on ",
- condition.device(), " and x and y are on ", self.device(), " and ", other.device(),
- " respectively");
+ TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
if (condition.scalar_type() == ScalarType::Byte) {
TORCH_WARN_ONCE("where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
-} else {
+ } else {
TORCH_CHECK(condition.scalar_type() == ScalarType::Bool, "where expected condition to be a boolean tensor, but got a tensor with dtype ", condition.scalar_type());
-}
+ }
+ Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition;
+ Tensor ret = at::empty({0}, self.options());
+ auto iter = at::TensorIteratorConfig()
+ .check_all_same_dtype(false)
+ .add_output(ret)
+ .add_input(cond_bool)
+ .add_input(self)
+ .add_input(other)
+ .build();
+ where_kernel(iter.device_type(), iter);
+ return ret;
- c10::MaybeOwned<Tensor> b_condition, b_self, b_other;
- std::tie(b_condition, b_self, b_other) = expand_outplace(condition, self, other, "where");
- return at::_s_where(*b_condition, *b_self, *b_other);
}
Tensor where(const Tensor& condition, const Scalar& self, const Tensor& other) {
@@ -359,22 +364,6 @@
return condition.nonzero_numpy();
}
-Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
- TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
- Tensor ret = at::empty(self.sizes(), self.options());
- //
- Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition;
- auto iter = at::TensorIteratorConfig()
- .check_all_same_dtype(false)
- .add_output(ret)
- .add_input(cond_bool)
- .add_input(self)
- .add_input(other)
- .build();
- where_kernel(iter.device_type(), iter);
- return ret;
-}
-
std::tuple<Tensor, Tensor> mode(const Tensor& self, int64_t dim, bool keepdim) {
Tensor values = at::empty({0}, self.options());
Tensor indices = at::empty({0}, self.options().dtype(kLong));
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index baad10f..c408ea9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4772,12 +4772,11 @@
device_check: NoCheck
device_guard: False
-# we define both of these because 'where' does the broadcast and '_s_where' doesn't;
-# this allows us to implicitly calculate the broadcast derivative, while only dealing with the
-# _s_where derivative.
- func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
+ dispatch:
+ CPU, CUDA: where
- func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor
variants: function
@@ -4792,11 +4791,6 @@
device_check: NoCheck # TensorIterator
variants: function
-- func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
- variants: function
- dispatch:
- CPU, CUDA: _s_where
-
- func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
variants: function
diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
index b7dc0d5..0860bfc 100644
--- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py
+++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
@@ -110,6 +110,7 @@
("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)),
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
("aten::_scatter_reduce.two", datetime.date(9999, 1, 1)),
+ ("aten::_s_where", datetime.date(2022, 9, 30)),
]
ALLOW_LIST_COMPILED = [
diff --git a/test/test_torch.py b/test/test_torch.py
index 73ef010..24af923 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -116,19 +116,6 @@
class TestTorchDeviceType(TestCase):
exact_dtype = True
- # FIXME: Port this to ErrorInputs on where
- @onlyCUDA
- @dtypes(torch.float32)
- def test_where_invalid_device(self, device, dtype):
- for devices in [('cpu', device, device), (device, 'cpu', 'cpu'),
- (device, 'cpu', device), ('cpu', device, 'cpu')]:
- condition = make_tensor(16, device=devices[0], dtype=torch.float32)
- x = make_tensor(16, device=devices[1], dtype=torch.float32)
- y = make_tensor(16, device=devices[2], dtype=torch.float32)
- with self.assertRaisesRegex(RuntimeError,
- "Expected condition, x and y to be on the same device"):
- torch.where(condition, x, y)
-
# TODO: move all tensor creation to common ops
def _rand_shape(self, dim, min_size, max_size):
shape = []
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index fd0b773..54aac6e 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1643,7 +1643,7 @@
self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
result: at::view_as_complex(self_t)
-- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
+- name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
condition: non_differentiable
self: where(condition, grad, zeros_like(grad))
other: where(condition, zeros_like(grad), grad)
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 4b63414..7f4c44f 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -91,7 +91,7 @@
'triu', 'chunk', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum',
'_conj', 'sin', 'cos', 'mul', 'sinc', 'sinh', 'cosh', '__rmul__',
'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex', 'index_put',
- 'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
+ 'neg', 'complex', 'select', 'where', 'as_strided', 'slice', 'constant_pad_nd',
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'outer',
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 67e2aa4..f1dc434 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7096,6 +7096,16 @@
args=(make_bool_mask(mask_shape), make_arg(other_shape)),
broadcasts_input=broadcasts_input)
+def error_inputs_where(op_info, device, **kwargs):
+ shape = (S,)
+ err_msg = "Expected all tensors to be on the same device"
+ for devices in product(('cpu', device), repeat=3):
+ if len(set(devices)) == 2:
+ si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32),
+ args=(make_tensor(shape, dtype=torch.bool, device=devices[1]),
+ make_tensor(shape, device=devices[2], dtype=torch.float32)))
+ yield ErrorInput(si, error_type=RuntimeError, error_regex=err_msg)
+
def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -14539,9 +14549,12 @@
op=lambda self, condition, other: torch.where(condition, self, other),
ref=lambda self, condition, other: np.where(condition, self, other),
sample_inputs_func=sample_inputs_where,
+ error_inputs_func=error_inputs_where,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
+ decorators=(
+ DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),),
skips=(
# test does not work with passing lambda for op
# AssertionError: False is not true :