Remove Context dependency from Tensor class (#14269)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14269
Removes reference to Context proper and instead adds a bool argument for async copy (the same as `copy_`)
For CopyFrom - I haven't tweaked all callsites yet. Instead I rely on a terrible hack that pointer to context is implicitly converted to bool when passed, haha :) It's not a good code and I propose to fix it in a follow up diff (maybe using clangr tooling).
Reviewed By: ezyang
Differential Revision: D13117981
fbshipit-source-id: 7cb1dc2ba6a4c50ac26614f45ab8318ea96e3138
diff --git a/aten/src/ATen/core/TensorImpl.h b/aten/src/ATen/core/TensorImpl.h
index 59edcaf..a4ebc36 100644
--- a/aten/src/ATen/core/TensorImpl.h
+++ b/aten/src/ATen/core/TensorImpl.h
@@ -917,9 +917,9 @@
* a tensor on CPU and then CopyFrom a CUDA tensor, that will to a
* CUDA-to-CPU transfer).
*
- * If the function is invoked without `context` the copy would be synchronous
+ * 'async' parameter triggers async copy for CUDA tensors
*/
- void CopyFrom(const TensorImpl& src, at::BaseContext* context = nullptr) {
+ void CopyFrom(const TensorImpl& src, bool async = false) {
AT_ASSERT(!is_variable());
AT_ASSERTM(
src.is_contiguous(),
@@ -978,7 +978,7 @@
src.device(),
new_data,
device(),
- context != nullptr);
+ async);
}
}
}
@@ -991,8 +991,10 @@
* elements, in which case this tensors' capacity is grown at a factor of
* growthPct. This ensures that Extend runs on an amortized O(1) time
* complexity.
+ *
+ * This op is auto-asynchronous if the underlying device (CUDA) supports it.
*/
- void Extend(int64_t num, float growthPct, at::BaseContext* context) {
+ void Extend(int64_t num, float growthPct) {
AT_ASSERT(sizes_.size() >= 1u);
AT_ASSERTM(num >= 0, "`num` must be non-negative for Extend");
AT_ASSERTM(
@@ -1022,8 +1024,6 @@
auto oldDims = sizes_;
Resize(newCapacity);
auto* newData = raw_mutable_data(data_type_);
- AT_ASSERTM(
- context != nullptr, "Context must be provided to Extend the tensor");
if (data_type_.copy()) {
AT_ASSERTM(
device_type() == ::at::DeviceType::CPU,
diff --git a/caffe2/core/tensor.cc b/caffe2/core/tensor.cc
index 8030ffe..79e751c 100644
--- a/caffe2/core/tensor.cc
+++ b/caffe2/core/tensor.cc
@@ -159,7 +159,7 @@
Tensor* t,
at::TensorOptions options,
const Tensor& src,
- BaseContext* context) {
+ bool async) {
auto device_type = options.device().type();
CAFFE_ENFORCE(t != nullptr, "Target tensor ptr is null.");
if (!*t || device_type != t->GetDeviceType()) {
@@ -172,7 +172,7 @@
t->dtype(),
" to: ",
src.dtype());
- t->CopyFrom(src, context);
+ t->CopyFrom(src, async);
}
namespace {
diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h
index 50ee348..4015422 100644
--- a/caffe2/core/tensor.h
+++ b/caffe2/core/tensor.h
@@ -97,23 +97,22 @@
return impl_.get()->GetDevice();
}
- void CopyFrom(const Tensor& src, BaseContext* context = nullptr) const {
- impl_.get()->CopyFrom(*src.impl_.get(), context);
+ void CopyFrom(const Tensor& src, bool async = false) const {
+ impl_.get()->CopyFrom(*src.impl_.get(), async);
}
/**
* @brief Extend the outer-most dimension of this tensor
* to dimension of `num`.
*/
- void ExtendTo(int64_t num, float growthPct, BaseContext* context) const {
+ void ExtendTo(int64_t num, float growthPct) const {
CAFFE_ENFORCE_GE_WITH_CALLER(impl_->dim(), 1);
CAFFE_ENFORCE_GE_WITH_CALLER(growthPct, 0);
- CAFFE_ENFORCE(context != nullptr, "Context must be provided.");
- Extend(num - impl_->size(0), growthPct, context);
+ Extend(num - impl_->size(0), growthPct);
}
- void Extend(int64_t num, float growthPct, BaseContext* context) const {
- impl_.get()->Extend(num, growthPct, context);
+ void Extend(int64_t num, float growthPct) const {
+ impl_.get()->Extend(num, growthPct);
}
/**
@@ -451,7 +450,7 @@
Tensor* t,
at::TensorOptions options,
const Tensor& src,
- BaseContext* context = nullptr);
+ bool async = false);
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(12, Tensor)
diff --git a/caffe2/experiments/operators/tt_pad_op.h b/caffe2/experiments/operators/tt_pad_op.h
index 57e0d4e..e25159d 100644
--- a/caffe2/experiments/operators/tt_pad_op.h
+++ b/caffe2/experiments/operators/tt_pad_op.h
@@ -52,7 +52,7 @@
int64_t padded_dim0 = (X_dim0 / scale_ + 1) * scale_;
auto dim0_diff = padded_dim0 - X_dim0;
// set growthPct to the upper bound percentage: (100 * scale_ / X_dim0)
- X_pad->Extend(dim0_diff, 100 * scale_ / X_dim0, &context_);
+ X_pad->Extend(dim0_diff, 100 * scale_ / X_dim0);
auto* X_pad_data = X_pad->template mutable_data<T>();
int64_t X_size = X_dim0 * X_dim1;
diff --git a/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm b/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm
index 4e8cb62..aa6a547 100644
--- a/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm
+++ b/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm
@@ -2302,8 +2302,8 @@
int csz = im_i_boxes.rows();
int cur_start_idx = out_rois->dim(0);
- out_rois->Extend(csz, 50, &context_);
- out_rois_probs->Extend(csz, 50, &context_);
+ out_rois->Extend(csz, 50);
+ out_rois_probs->Extend(csz, 50);
// write rois
Eigen::Map<ERArrXXf> cur_rois(
diff --git a/caffe2/operators/box_with_nms_limit_op.cc b/caffe2/operators/box_with_nms_limit_op.cc
index 2b83e19..18646b4 100644
--- a/caffe2/operators/box_with_nms_limit_op.cc
+++ b/caffe2/operators/box_with_nms_limit_op.cc
@@ -167,9 +167,9 @@
// Write results
int cur_start_idx = out_scores->size(0);
- out_scores->Extend(total_keep_count, 50, &context_);
- out_boxes->Extend(total_keep_count, 50, &context_);
- out_classes->Extend(total_keep_count, 50, &context_);
+ out_scores->Extend(total_keep_count, 50);
+ out_boxes->Extend(total_keep_count, 50);
+ out_classes->Extend(total_keep_count, 50);
int cur_out_idx = 0;
for (int j = 1; j < num_classes; j++) {
@@ -202,7 +202,7 @@
}
if (out_keeps) {
- out_keeps->Extend(total_keep_count, 50, &context_);
+ out_keeps->Extend(total_keep_count, 50);
Eigen::Map<EArrXi> out_keeps_arr(
out_keeps->template mutable_data<int>() + cur_start_idx,
diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc
index 3a074de..4a8efc4 100644
--- a/caffe2/operators/dataset_ops.cc
+++ b/caffe2/operators/dataset_ops.cc
@@ -776,7 +776,7 @@
CAFFE_ENFORCE(a.sizes()[i] == b.sizes()[i]);
}
auto oldSize = c->numel();
- c->Extend(b.sizes()[0], kDatasetGrowthPct, &context_);
+ c->Extend(b.sizes()[0], kDatasetGrowthPct);
auto* dst = (char*)c->raw_mutable_data() + oldSize * b.dtype().itemsize();
context_.CopyItemsSameDevice(b.dtype(), b.numel(), b.raw_data(), dst);
return true;
@@ -826,7 +826,7 @@
continue;
}
auto oldSize = c->numel();
- c->Extend(b.sizes()[0], kDatasetGrowthPct, &context_);
+ c->Extend(b.sizes()[0], kDatasetGrowthPct);
auto* dst = (char*)c->raw_mutable_data() + oldSize * b.dtype().itemsize();
context_.CopyItemsSameDevice(b.dtype(), b.numel(), b.raw_data(), dst);
}
diff --git a/caffe2/operators/expand_squeeze_dims_op.h b/caffe2/operators/expand_squeeze_dims_op.h
index f9c8798..89ff9c0 100644
--- a/caffe2/operators/expand_squeeze_dims_op.h
+++ b/caffe2/operators/expand_squeeze_dims_op.h
@@ -26,7 +26,7 @@
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
- output->CopyFrom(input, &context_);
+ output->CopyFrom(input, true /*async*/);
if (dims_.empty()) {
return true;
}
@@ -70,7 +70,7 @@
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
- output->CopyFrom(input, &context_);
+ output->CopyFrom(input, true /*async*/);
CAFFE_ENFORCE_GT(
input.dim(),
diff --git a/caffe2/operators/generate_proposals_op.cc b/caffe2/operators/generate_proposals_op.cc
index 0646f27..ade6bfe 100644
--- a/caffe2/operators/generate_proposals_op.cc
+++ b/caffe2/operators/generate_proposals_op.cc
@@ -284,8 +284,8 @@
for (int i = 0; i < num_images; i++) {
roi_counts += im_boxes[i].rows();
}
- out_rois->Extend(roi_counts, 50, &context_);
- out_rois_probs->Extend(roi_counts, 50, &context_);
+ out_rois->Extend(roi_counts, 50);
+ out_rois_probs->Extend(roi_counts, 50);
float* out_rois_ptr = out_rois->template mutable_data<float>();
float* out_rois_probs_ptr = out_rois_probs->template mutable_data<float>();
for (int i = 0; i < num_images; i++) {
diff --git a/caffe2/operators/last_n_window_collector.cc b/caffe2/operators/last_n_window_collector.cc
index b98d028..2b0695f 100644
--- a/caffe2/operators/last_n_window_collector.cc
+++ b/caffe2/operators/last_n_window_collector.cc
@@ -71,7 +71,7 @@
if (num_entries == 0) {
if (!output_initialized) {
// Get both shape and meta
- output->CopyFrom(input, &context_);
+ output->CopyFrom(input, true /*async*/);
}
return true;
}
@@ -83,7 +83,7 @@
// output_num is >= output_batch_size
if (output_num > output_batch_size) {
- output->ExtendTo(output_num, 50, &context_);
+ output->ExtendTo(output_num, 50);
}
auto* output_data =
diff --git a/caffe2/operators/mean_op.h b/caffe2/operators/mean_op.h
index 413a0f3..0a5d072 100644
--- a/caffe2/operators/mean_op.h
+++ b/caffe2/operators/mean_op.h
@@ -23,7 +23,7 @@
auto* output = Output(0);
output->ResizeLike(input0);
- output->CopyFrom(input0, &context_);
+ output->CopyFrom(input0, true /*async*/);
if (InputSize() == 1) {
return true;
@@ -102,7 +102,7 @@
for (int i = 1; i < num_inputs; i++) {
auto* cur_dX = Output(i);
cur_dX->ResizeLike(dY);
- cur_dX->CopyFrom(*dX0, &context_);
+ cur_dX->CopyFrom(*dX0, true /*async*/);
}
return true;
diff --git a/caffe2/operators/onnx_while_op.h b/caffe2/operators/onnx_while_op.h
index 4614b57..eeb45bb 100644
--- a/caffe2/operators/onnx_while_op.h
+++ b/caffe2/operators/onnx_while_op.h
@@ -171,7 +171,7 @@
scan_outputs_sizes[i],
"Size of scan output changed across iterations");
dims.insert(dims.begin(), itr);
- scan_output_target->Extend(1, 100, &context_);
+ scan_output_target->Extend(1, 100);
int64_t timestep_size = 1;
for (const int64_t t : scan_outputs_sizes[i]) {
diff --git a/caffe2/operators/reservoir_sampling.cc b/caffe2/operators/reservoir_sampling.cc
index 287a77d..285dbba 100644
--- a/caffe2/operators/reservoir_sampling.cc
+++ b/caffe2/operators/reservoir_sampling.cc
@@ -103,9 +103,9 @@
auto output_num =
std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
// output_num is >= output_batch_size
- output->ExtendTo(output_num, 50, &context_);
+ output->ExtendTo(output_num, 50);
if (pos_to_object) {
- pos_to_object->ExtendTo(output_num, 50, &context_);
+ pos_to_object->ExtendTo(output_num, 50);
}
auto* output_data =
diff --git a/caffe2/operators/rmac_regions_op.cc b/caffe2/operators/rmac_regions_op.cc
index 458afde..ab04906 100644
--- a/caffe2/operators/rmac_regions_op.cc
+++ b/caffe2/operators/rmac_regions_op.cc
@@ -58,7 +58,7 @@
(l + Hd - 1 > 0) ? ((H - region_size) / (1.0 * (l + Hd - 1))) : 0;
int cur_rows = output->dim32(0);
- output->Extend((l + Wd) * (l + Hd), 50, &context_);
+ output->Extend((l + Wd) * (l + Hd), 50);
auto* outputData = output->template mutable_data<float>() + cur_rows * 5;
for (int i = 0; i < l + Wd; ++i) {
@@ -87,7 +87,7 @@
// Replicate regions for all items in batch
int num_rois = output->dim32(0);
- output->Extend((batch_size - 1) * num_rois, 50, &context_);
+ output->Extend((batch_size - 1) * num_rois, 50);
auto* outputData = output->template mutable_data<float>();
for (int b = 1; b < batch_size; ++b) {
// Copy all rois
diff --git a/caffe2/operators/sequence_ops.h b/caffe2/operators/sequence_ops.h
index b521aa9..0b41da8 100644
--- a/caffe2/operators/sequence_ops.h
+++ b/caffe2/operators/sequence_ops.h
@@ -120,9 +120,9 @@
bool RunOnDevice() override {
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
- Output(0)->CopyFrom(Input(0), &context_);
+ Output(0)->CopyFrom(Input(0), true /*async*/);
if (OutputSize() == 2) {
- Output(1)->CopyFrom(Input(1), &context_);
+ Output(1)->CopyFrom(Input(1), true /*async*/);
}
return true;
}
@@ -160,9 +160,9 @@
bool RunOnDevice() override {
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
- Output(0)->CopyFrom(Input(0), &context_);
+ Output(0)->CopyFrom(Input(0), true /*async*/);
if (OutputSize() == 2) {
- Output(1)->CopyFrom(Input(1), &context_);
+ Output(1)->CopyFrom(Input(1), true /*async*/);
}
return true;
}
diff --git a/caffe2/operators/slice_op.cu b/caffe2/operators/slice_op.cu
index 8ddb204..7d888f2 100644
--- a/caffe2/operators/slice_op.cu
+++ b/caffe2/operators/slice_op.cu
@@ -123,9 +123,9 @@
}
if (dim == -1) {
if (!backward) {
- output->CopyFrom(data, context);
+ output->CopyFrom(data, true /*async*/);
} else {
- gdata->CopyFrom(*go, context);
+ gdata->CopyFrom(*go, true /*async*/);
}
return true;
}
diff --git a/caffe2/operators/slice_op.h b/caffe2/operators/slice_op.h
index 2e07beb..eb9193f 100644
--- a/caffe2/operators/slice_op.h
+++ b/caffe2/operators/slice_op.h
@@ -85,9 +85,9 @@
}
if (dim == -1) {
if (!backward) {
- output->CopyFrom(data, context);
+ output->CopyFrom(data, true /*async*/);
} else {
- gdata->CopyFrom(*go, context);
+ gdata->CopyFrom(*go, true /*async*/);
}
return true;
}
diff --git a/caffe2/operators/stop_gradient.h b/caffe2/operators/stop_gradient.h
index e05cd11..68bbad6 100644
--- a/caffe2/operators/stop_gradient.h
+++ b/caffe2/operators/stop_gradient.h
@@ -14,7 +14,7 @@
const auto& in = Input(0);
auto* out = Output(0);
if (out != &in) {
- out->CopyFrom(in, &context_);
+ out->CopyFrom(in, true /*async*/);
}
return true;
}
diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu
index 868e849..0d9bb32 100644
--- a/caffe2/operators/utility_ops.cu
+++ b/caffe2/operators/utility_ops.cu
@@ -130,7 +130,7 @@
// This op should act as an identity matrix if we don't find any NaNs/infs.
// Copy over the data if we are not doing this in-place.
if (&X != Y) {
- Y->CopyFrom(X, &context_);
+ Y->CopyFrom(X, true /*async*/);
}
return true;
}
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 09f39d8..0e03bfb 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -196,7 +196,7 @@
// allow the output to be copied from the input
if (&input != output) {
output->ResizeLike(input);
- output->CopyFrom(input, &context_);
+ output->CopyFrom(input, true /*async*/);
}
return true;
}
@@ -257,7 +257,7 @@
auto& input0 = Input(0);
auto* output = Output(0);
if (InputSize() == 1) {
- output->CopyFrom(input0, &context_);
+ output->CopyFrom(input0, true /*async*/);
return true;
}
output->ResizeLike(input0);
diff --git a/caffe2/queue/queue_ops.h b/caffe2/queue/queue_ops.h
index d5681e5..1050049 100644
--- a/caffe2/queue/queue_ops.h
+++ b/caffe2/queue/queue_ops.h
@@ -160,7 +160,7 @@
size,
" total columns");
- out->Extend(in.sizes()[0], kTensorGrowthPct, &context_);
+ out->Extend(in.sizes()[0], kTensorGrowthPct);
auto* dst =
(char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize();
context_.template CopyItems<Context, Context>(
diff --git a/caffe2/video/video_input_op.h b/caffe2/video/video_input_op.h
index b2dffc2..a7855f6 100644
--- a/caffe2/video/video_input_op.h
+++ b/caffe2/video/video_input_op.h
@@ -808,14 +808,17 @@
// prefetch function as well.
if (!std::is_same<Context, CPUContext>::value) {
if (get_rgb_) {
- prefetched_clip_rgb_on_device_.CopyFrom(prefetched_clip_rgb_, &context_);
+ prefetched_clip_rgb_on_device_.CopyFrom(
+ prefetched_clip_rgb_, true /*async*/);
}
if (get_optical_flow_) {
- prefetched_clip_of_on_device_.CopyFrom(prefetched_clip_of_, &context_);
+ prefetched_clip_of_on_device_.CopyFrom(
+ prefetched_clip_of_, true /*async*/);
}
- prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
+ prefetched_label_on_device_.CopyFrom(prefetched_label_, true /*async*/);
if (get_video_id_) {
- prefetched_video_id_on_device_.CopyFrom(prefetched_video_id_, &context_);
+ prefetched_video_id_on_device_.CopyFrom(
+ prefetched_video_id_, true /*async*/);
}
}
return true;
@@ -828,34 +831,34 @@
auto* clip_rgb_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
- clip_rgb_output->CopyFrom(prefetched_clip_rgb_, &context_);
+ clip_rgb_output->CopyFrom(prefetched_clip_rgb_, true /*async*/);
} else {
- clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, &context_);
+ clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, true /*async*/);
}
}
if (get_optical_flow_) {
auto* clip_of_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
- clip_of_output->CopyFrom(prefetched_clip_of_, &context_);
+ clip_of_output->CopyFrom(prefetched_clip_of_, true /*async*/);
} else {
- clip_of_output->CopyFrom(prefetched_clip_of_on_device_, &context_);
+ clip_of_output->CopyFrom(prefetched_clip_of_on_device_, true /*async*/);
}
}
auto* label_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
- label_output->CopyFrom(prefetched_label_, &context_);
+ label_output->CopyFrom(prefetched_label_, true /*async*/);
} else {
- label_output->CopyFrom(prefetched_label_on_device_, &context_);
+ label_output->CopyFrom(prefetched_label_on_device_, true /*async*/);
}
if (get_video_id_) {
auto* video_id_output =
OperatorBase::Output<Tensor>(index, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
- video_id_output->CopyFrom(prefetched_video_id_, &context_);
+ video_id_output->CopyFrom(prefetched_video_id_, true /*async*/);
} else {
- video_id_output->CopyFrom(prefetched_video_id_on_device_, &context_);
+ video_id_output->CopyFrom(prefetched_video_id_on_device_, true /*async*/);
}
}
return true;