Delete redundant device/dtype in TensorIterator add_input/add_output (#39798)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39798
add_input's device/dtype are 100% redundant, as compute_types will
always (internal) assert that this dtype matches the expected dtype.
add_output's device/dtype is redundant UNLESS you have an undefined
tensor (in which case it seems to be an indication what the output type
should be). The one add_output case I killed can never be exercised, see:
```
import torch
x = torch.randn(3, 4)
mask = x.ge(0.5)
torch.masked_select(x.cuda(), mask.cuda(), out=torch.zeros((0), dtype=torch.int64, device='cuda'))
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D21981742
Pulled By: ezyang
fbshipit-source-id: a042d1b9fce0ad58b833856ffe32001787551e59
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 2f1749c..d11b8dd 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -246,7 +246,7 @@
static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor& result) {
auto iter = TensorIterator();
iter.check_all_same_dtype(false);
- iter.add_output(result, info.src.device(), info.src.scalar_type());
+ iter.add_output(result);
iter.add_input(info.src);
for (auto& index : info.indices) {
iter.add_input(index);
diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h
index 51b0e79..9be6f3d 100644
--- a/aten/src/ATen/native/TensorIterator.h
+++ b/aten/src/ATen/native/TensorIterator.h
@@ -312,10 +312,6 @@
operands_.emplace_back(input);
}
- void add_input(const Tensor& input, Device device, ScalarType dtype) {
- operands_.emplace_back(input, device, dtype);
- }
-
void set_check_mem_overlap(bool check_mem_overlap) {
config_check_mem_overlap_ = check_mem_overlap;
}
diff --git a/aten/src/ATen/native/UnfoldBackward.h b/aten/src/ATen/native/UnfoldBackward.h
index 0f78d41..e8b7a03 100644
--- a/aten/src/ATen/native/UnfoldBackward.h
+++ b/aten/src/ATen/native/UnfoldBackward.h
@@ -95,7 +95,7 @@
iter.check_all_same_dtype(false);
iter.dont_resize_outputs();
iter.add_output(grad_out_restrided);
- iter.add_input(grad_in_restrided, grad_in.device(), grad_in.scalar_type());
+ iter.add_input(grad_in_restrided);
iter.add_input(idx_dim_restrided);
iter.build();
@@ -166,7 +166,7 @@
iter.check_all_same_dtype(false);
iter.dont_resize_outputs();
iter.add_output(grad_out_restrided);
- iter.add_input(grad_in, grad_in.device(), grad_in.scalar_type());
+ iter.add_input(grad_in);
iter.add_input(idx_dim_restrided);
iter.add_input(idx_last_dim_restrided);
iter.build();
diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
index 22fa39d..93b8f3a 100644
--- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
+++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
@@ -66,7 +66,7 @@
iter.dont_resize_outputs();
iter.declare_static_shape(index.sizes(), /*squash_dim=*/dim);
iter.add_output(self);
- iter.add_input(src, src.device(), src.scalar_type());
+ iter.add_input(src);
iter.add_input(index);
iter.build();
diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu
index 2732c61..43d5d0a 100644
--- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu
+++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu
@@ -144,7 +144,7 @@
iter.check_all_same_dtype(false);
iter.dont_resize_outputs();
iter.add_output(self_restrided);
- iter.add_input(src_restrided, src.device(), src.scalar_type());
+ iter.add_input(src_restrided);
iter.add_input(index);
iter.build();