[TF/XLA Bridge] [NFC] Extract blocks of code dealing with allocate_xla_tensors_ mode into separate functions
This somewhat simplifies the PopulateOutputs function.
PiperOrigin-RevId: 273981105
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 11c5a98..76af21b 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -318,6 +318,58 @@
output_allocator);
}
+static Status SetBufferForTensorUnderAllocateXlaTensors(
+ const xla::HloInputOutputAliasConfig& input_output_alias, int output_num,
+ OpKernelContext* ctx, int i, tensorflow::TensorShape shape,
+ xla::ScopedShapedBuffer* output,
+ std::shared_ptr<se::Event> definition_event, se::Stream* stream,
+ bool use_multiple_streams) {
+ if (MustAliasOutput(input_output_alias, output_num)) {
+ return errors::Unimplemented(
+ "Aliasing is not yet supported for allocate_xla_tensors_.");
+ }
+ Tensor* output_tensor;
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
+ if (xla_tensor) {
+ xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num}));
+ if (use_multiple_streams) {
+ xla_tensor->ResetDefinitionEvent(definition_event, stream);
+ }
+ } else {
+ // xla_tensor wasn't valid, which must mean this is a zero-element
+ // tensor.
+ CHECK_EQ(output_tensor->TotalBytes(), 0);
+ }
+ return Status::OK();
+}
+
+static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
+ const xla::HloInputOutputAliasConfig& input_output_alias, int output_num,
+ OpKernelContext* ctx, int i, const XlaCompiler::ResourceUpdate& write,
+ xla::ScopedShapedBuffer* output,
+ std::shared_ptr<se::Event> definition_event,
+ absl::Span<const VariableInfo> variable_infos, se::Stream* stream,
+ bool use_multiple_streams) {
+ if (MustAliasOutput(input_output_alias, output_num)) {
+ return errors::Unimplemented(
+ "Aliasing is not yet supported for allocate_xla_tensors_.");
+ }
+ Tensor output_tensor;
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_temp(write.type, write.shape, &output_tensor));
+ if (write.shape.num_elements() > 0) {
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
+ CHECK(xla_tensor);
+ xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num}));
+ if (use_multiple_streams) {
+ xla_tensor->ResetDefinitionEvent(definition_event, stream);
+ }
+ }
+ *variable_infos[i].var()->tensor() = output_tensor;
+ return Status::OK();
+}
+
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
@@ -421,23 +473,9 @@
<< "Expected output buffer to be aliased, but it is not nil.";
}
if (allocate_xla_tensors_) {
- if (MustAliasOutput(input_output_alias, output_num)) {
- return errors::Unimplemented(
- "Aliasing is not yet supported for allocate_xla_tensors_.");
- }
- Tensor* output_tensor;
- TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
- XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
- if (xla_tensor) {
- xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
- if (use_multiple_streams_) {
- xla_tensor->ResetDefinitionEvent(definition_event, stream);
- }
- } else {
- // xla_tensor wasn't valid, which must mean this is a zero-element
- // tensor.
- CHECK_EQ(output_tensor->TotalBytes(), 0);
- }
+ TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
+ input_output_alias, output_num, ctx, i, shape, &output,
+ definition_event, stream, use_multiple_streams_));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
@@ -491,22 +529,9 @@
}
if (allocate_xla_tensors_) {
- if (MustAliasOutput(input_output_alias, output_num)) {
- return errors::Unimplemented(
- "Aliasing is not yet supported for allocate_xla_tensors_.");
- }
- Tensor output_tensor;
- TF_RETURN_IF_ERROR(
- ctx->allocate_temp(write.type, write.shape, &output_tensor));
- if (write.shape.num_elements() > 0) {
- XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
- CHECK(xla_tensor);
- xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
- if (use_multiple_streams_) {
- xla_tensor->ResetDefinitionEvent(definition_event, stream);
- }
- }
- *variable_infos[i].var()->tensor() = output_tensor;
+ TF_RETURN_IF_ERROR(SetBufferForResourceVarTensorUnderAllocateXlaTensors(
+ input_output_alias, output_num, ctx, i, write, &output,
+ definition_event, variable_infos, stream, use_multiple_streams_));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});