Replace `tensorflow::StatusOr::ConsumeValueOrDie` with `StatusOr::value()`.
This is to align `tensorflow::StatusOr` with `absl::StatusOr`.
See https://github.com/abseil/abseil-cpp/blob/master/absl/status/statusor.h#L501-L525 for documentation. In particular:
```
// Otherwise, if the value type supports an efficient move, it can be
// used as follows:
//
// T value = std::move(statusor).value();
//
// The `std::move` on statusor instead of on the whole expression enables
// warnings about possible uses of the statusor object after the move.
```
One benefit of having the `std::move(status_or).value()` compared to `status_or.ConsumeValueOrDie()` is that the compiler will know that `status_or` should not be accessed anymore.
PiperOrigin-RevId: 456079434
diff --git a/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc
index 532d869..17f3935 100644
--- a/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc
@@ -857,26 +857,23 @@
auto x = Mul(param0, param1);
Add(x, param2);
- auto computation = builder.Build().ConsumeValueOrDie();
+ auto computation = builder.Build().value();
// Transfer literals to device.
auto param0_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
ScopedShapedBuffer buffer0 =
- client->LiteralToShapedBuffer(param0_literal, device_ordinal)
- .ConsumeValueOrDie();
+ client->LiteralToShapedBuffer(param0_literal, device_ordinal).value();
auto param1_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
ScopedShapedBuffer buffer1 =
- client->LiteralToShapedBuffer(param1_literal, device_ordinal)
- .ConsumeValueOrDie();
+ client->LiteralToShapedBuffer(param1_literal, device_ordinal).value();
auto param2_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
ScopedShapedBuffer buffer2 =
- client->LiteralToShapedBuffer(param2_literal, device_ordinal)
- .ConsumeValueOrDie();
+ client->LiteralToShapedBuffer(param2_literal, device_ordinal).value();
// Build executable.
auto executables =
@@ -885,7 +882,7 @@
{&buffer0.on_host_shape(), &buffer1.on_host_shape(),
&buffer2.on_host_shape()},
ExecutableBuildOptions())
- .ConsumeValueOrDie();
+ .value();
auto executable = std::move(executables[0]);
se::Stream stream(executors[device_ordinal]);
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index e0f23b0..711ce91 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include <memory>
+#include <utility>
#include <vector>
#include "absl/types/span.h"
@@ -43,10 +44,9 @@
// transferred from the device successfully.
std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
- XlaComputation computation = builder->Build().ConsumeValueOrDie();
+ XlaComputation computation = builder->Build().value();
auto global_data =
- client_->Execute(computation, arguments, &execution_options_)
- .ConsumeValueOrDie();
+ client_->Execute(computation, arguments, &execution_options_).value();
TF_CHECK_OK(client_->Transfer(*global_data).status());
return global_data;
}
@@ -63,7 +63,7 @@
EXPECT_TRUE(result_status.ok());
// Try copying the elements back and comparing it
- auto handles = result_status.ConsumeValueOrDie();
+ auto handles = std::move(result_status).value();
Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
@@ -83,8 +83,8 @@
auto result_status2 = client_->DeconstructTuple(*global_data);
EXPECT_TRUE(result_status2.ok());
- auto handles1 = result_status1.ConsumeValueOrDie();
- auto handles2 = result_status2.ConsumeValueOrDie();
+ auto handles1 = std::move(result_status1).value();
+ auto handles2 = std::move(result_status2).value();
Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
@@ -114,7 +114,7 @@
// Verify the returned GlobalDataHandle arrays have repeated elements like the
// tuple does. That is, in the returned vector of handles, handle[0] should be
// the same as handle[3] and handle[1] should be the same as handle[2].
- auto handles = result_status.ConsumeValueOrDie();
+ auto handles = std::move(result_status).value();
Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
@@ -136,7 +136,7 @@
auto result_status = client_->DeconstructTuple(*global_data);
EXPECT_TRUE(result_status.ok());
- auto handles = result_status.ConsumeValueOrDie();
+ auto handles = std::move(result_status).value();
// Deallocate the tuple, then try copying the elements back. The elements
// should not have been deallocated because of reference counting.
@@ -172,14 +172,14 @@
XlaBuilder builder(TestName());
Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).value();
auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
Tuple(&builder, {p});
auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
auto result_status = client_->DeconstructTuple(*global_data);
EXPECT_TRUE(result_status.ok());
- auto handles = result_status.ConsumeValueOrDie();
+ auto handles = std::move(result_status).value();
EXPECT_NE(handles[0]->handle().handle(), param0_data->handle().handle());
}
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 052c414..ca34e6f 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -205,11 +205,11 @@
this->client_
->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
- .ConsumeValueOrDie();
+ .value();
auto rhs_handle = this->client_
->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
- .ConsumeValueOrDie();
+ .value();
if (std::is_same<Eigen::half, T>::value) {
this->error_spec_ = ErrorSpec{0.0001, 1e-3};
@@ -230,14 +230,14 @@
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
- .ConsumeValueOrDie();
+ .value();
auto rhs_handle =
client_
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
- .ConsumeValueOrDie();
+ .value();
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
@@ -325,7 +325,7 @@
*dot_lhs_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
- client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_lhs_lit).value();
std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
@@ -334,7 +334,7 @@
Literal dot_rhs_lit =
LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
- client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_rhs_lit).value();
std::unique_ptr<Array2D<NativeT>> addend_data;
Literal addend_lit;
@@ -345,7 +345,7 @@
addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
- addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
+ addend_handle = client_->TransferToServer(addend_lit).value();
}
XlaBuilder builder(TestName());
@@ -515,14 +515,14 @@
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
- .ConsumeValueOrDie();
+ .value();
auto rhs_handle =
client_
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
- .ConsumeValueOrDie();
+ .value();
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
@@ -547,13 +547,13 @@
client_
->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
- .ConsumeValueOrDie();
+ .value();
auto rhs_handle =
client_
->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
- .ConsumeValueOrDie();
+ .value();
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
@@ -625,14 +625,14 @@
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
{{4000.0f, 400.0f}, {40.0f, 4.0f}}}}))
- .ConsumeValueOrDie();
+ .value();
auto y_data =
this->client_
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
- .ConsumeValueOrDie();
+ .value();
if (std::is_same<Eigen::half, T>::value) {
this->error_spec_ = ErrorSpec{0.0001, 1e-3};
@@ -668,13 +668,13 @@
this->client_
->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
- .ConsumeValueOrDie();
+ .value();
auto y_data =
this->client_
->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
- .ConsumeValueOrDie();
+ .value();
this->template ComputeAndCompareR3<T>(
&builder,
@@ -703,12 +703,12 @@
this->client_
->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
- .ConsumeValueOrDie();
+ .value();
auto y_data = this->client_
->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 0.0f}, {0.0f, 1.0f}}))
- .ConsumeValueOrDie();
+ .value();
this->template ComputeAndCompareR2<T>(
&builder,
@@ -735,13 +735,13 @@
auto x_data = this->client_
->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 0.0f}, {0.0f, 1.0f}}))
- .ConsumeValueOrDie();
+ .value();
auto y_data =
this->client_
->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
- .ConsumeValueOrDie();
+ .value();
this->template ComputeAndCompareR2<T>(
&builder,
@@ -774,14 +774,14 @@
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{9.0f, 10.0f}, {11.0f, 12.0f}},
{{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
- .ConsumeValueOrDie();
+ .value();
auto y_data =
this->client_
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
{{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
- .ConsumeValueOrDie();
+ .value();
this->template ComputeAndCompareR4<T>(
&builder,
@@ -813,14 +813,14 @@
LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*lhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
- .ConsumeValueOrDie();
+ .value();
auto rhs_handle =
this->client_
->TransferToServer(
LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*rhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
- .ConsumeValueOrDie();
+ .value();
XlaBuilder builder(this->TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
@@ -1839,12 +1839,11 @@
auto lhs = Reshape(t1, {d0, d2 * d1});
auto rhs = ConstantR2FromArray2D(&builder, const_arr);
Dot(lhs, rhs);
- auto computation = builder.Build().ConsumeValueOrDie();
+ auto computation = builder.Build().value();
auto input_literal = LiteralUtil::CreateR3FromArray3D<float>(input_arr);
ScopedShapedBuffer buffer0 =
- client->LiteralToShapedBuffer(input_literal, device_ordinal)
- .ConsumeValueOrDie();
+ client->LiteralToShapedBuffer(input_literal, device_ordinal).value();
TF_ASSERT_OK_AND_ASSIGN(
auto executables, client->Compile(computation, {&buffer0.on_host_shape()},
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index fa111ca..94a7078 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -789,7 +789,7 @@
.transfer_manager()
->AllocateScopedShapedBuffer(start_indices_shape, &allocator,
/*device_ordinal=*/0)
- .ConsumeValueOrDie());
+ .value());
host_shapes[i] = &shaped_buffers[i].on_host_shape();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
stream.get(), start_index_literal, shaped_buffers[i]));
@@ -797,7 +797,7 @@
// Add DynamicSlice op to the computation.
DynamicSlice(input, start_indices, {1, 1, 1, 1});
- auto computation = builder.Build().ConsumeValueOrDie();
+ auto computation = builder.Build().value();
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index cd9e5e4..27623fe 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -357,7 +357,7 @@
::testing::AssertionResult HloTestBase::RunAndCompare(
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
- auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ auto fake_arguments = MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
@@ -371,8 +371,7 @@
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
- const auto& fake_arguments =
- MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ const auto fake_arguments = MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
@@ -384,8 +383,7 @@
::testing::AssertionResult HloTestBase::Run(std::unique_ptr<HloModule> module,
bool run_hlo_passes) {
- const auto fake_arguments =
- MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ const auto fake_arguments = MakeFakeArguments(module.get()).value();
const auto change = hlo_verifier_->Run(module.get());
if (!change.ok()) {
return ::testing::AssertionFailure() << change.status();
@@ -407,7 +405,7 @@
<< "Error while parsing HLO text format: "
<< module_or_status.status().ToString();
}
- return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
+ return RunAndCompare(std::move(module_or_status).value(), error,
reference_preprocessor);
}
@@ -481,7 +479,7 @@
}
}
- auto fake_arguments = MakeFakeArguments(module_0.get()).ConsumeValueOrDie();
+ auto fake_arguments = MakeFakeArguments(module_0.get()).value();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
@@ -508,8 +506,8 @@
<< "Error while parsing HLO text format: "
<< module_1_or_status.status().ToString();
}
- return RunAndCompareTwoModules(module_0_or_status.ConsumeValueOrDie(),
- module_1_or_status.ConsumeValueOrDie(), error);
+ return RunAndCompareTwoModules(std::move(module_0_or_status).value(),
+ std::move(module_1_or_status).value(), error);
}
::testing::AssertionResult HloTestBase::Run(
@@ -523,8 +521,7 @@
}
std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
- const auto& fake_arguments =
- MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ const auto fake_arguments = MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
@@ -573,8 +570,7 @@
}
std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
- const auto& fake_arguments =
- MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ const auto fake_arguments = MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
@@ -623,7 +619,7 @@
std::unique_ptr<HloModule> module =
std::move(module_or_status.ValueOrDie());
- fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ fake_arguments[i] = MakeFakeArguments(module.get()).value();
if (profiles != nullptr) {
// We have to enable HLO profiling since otherwise currently the
@@ -667,7 +663,7 @@
if (assert_determinism) {
if (!canonical_output.has_value()) {
- canonical_output = output.ConsumeValueOrDie();
+ canonical_output = std::move(output).value();
} else {
if (*canonical_output != output.ValueOrDie()) {
return ::testing::AssertionFailure()
@@ -690,7 +686,7 @@
return ::testing::AssertionFailure()
<< "failed reading hlo module from file";
}
- return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
+ return RunAndCompare(std::move(module_or_status).value(), error,
reference_preprocessor);
}
@@ -703,7 +699,7 @@
<< "Error while parsing HLO text format: "
<< module_or_status.status().ToString();
}
- return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
+ return RunAndCompareNoHloPasses(std::move(module_or_status).value(), error,
reference_preprocessor);
}
@@ -716,7 +712,7 @@
return ::testing::AssertionFailure()
<< "failed reading hlo module from file";
}
- return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
+ return RunAndCompareNoHloPasses(std::move(module_or_status).value(), error,
reference_preprocessor);
}
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index 169f2fa..cebb954 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -152,8 +152,7 @@
LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal_proto));
- Literal literal =
- Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
+ Literal literal = Literal::CreateFromProto(literal_proto).value();
if (result.find("expected") != std::string::npos) {
EXPECT_EQ("f32[] 2", literal.ToString());
} else if (result.find("actual") != std::string::npos) {
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index 827cd29..e79c302 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -76,7 +76,7 @@
auto x = ConstantR1<float>(&builder, {0.0f, 1.0f, 2.0f});
auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
Add(x, y);
- auto computation = builder.Build().ConsumeValueOrDie();
+ auto computation = builder.Build().value();
TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
for (int d = 0; d < local_client_->device_count(); ++d) {
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index 5b8e870..ec100ac 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -40,20 +40,19 @@
XlaBuilder builder(TestName());
auto two = ConstantR0<int32_t>(&builder, 2);
Add(two, two);
- XlaComputation computation = builder.Build().ConsumeValueOrDie();
+ XlaComputation computation = builder.Build().value();
// Serialize it out.
- std::unique_ptr<HloSnapshot> module =
- computation.Snapshot().ConsumeValueOrDie();
+ std::unique_ptr<HloSnapshot> module = computation.Snapshot().value();
// Replay it.
- XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+ XlaComputation replayed = client_->LoadSnapshot(*module).value();
// Check signature is the same.
std::unique_ptr<ProgramShape> original_shape =
- client_->GetComputationShape(computation).ConsumeValueOrDie();
+ client_->GetComputationShape(computation).value();
std::unique_ptr<ProgramShape> replayed_shape =
- client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ client_->GetComputationShape(replayed).value();
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
replayed_shape->ToProto()));
@@ -61,7 +60,7 @@
Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
- .ConsumeValueOrDie();
+ .value();
// Expect 4.
LiteralTestUtil::ExpectR0Equal<int32_t>(4, literal);
@@ -73,36 +72,33 @@
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(S32, {}), "y");
Add(x, y);
- XlaComputation computation = builder.Build().ConsumeValueOrDie();
+ XlaComputation computation = builder.Build().value();
// Serialize it out.
- std::unique_ptr<HloSnapshot> module =
- computation.Snapshot().ConsumeValueOrDie();
+ std::unique_ptr<HloSnapshot> module = computation.Snapshot().value();
// Replay it.
- XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+ XlaComputation replayed = client_->LoadSnapshot(*module).value();
// Check signature is the same.
std::unique_ptr<ProgramShape> original_shape =
- client_->GetComputationShape(computation).ConsumeValueOrDie();
+ client_->GetComputationShape(computation).value();
std::unique_ptr<ProgramShape> replayed_shape =
- client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ client_->GetComputationShape(replayed).value();
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
replayed_shape->ToProto()));
// Run it.
std::unique_ptr<GlobalData> x_data =
- client_->TransferToServer(LiteralUtil::CreateR0<int32_t>(2))
- .ConsumeValueOrDie();
+ client_->TransferToServer(LiteralUtil::CreateR0<int32_t>(2)).value();
std::unique_ptr<GlobalData> y_data =
- client_->TransferToServer(LiteralUtil::CreateR0<int32_t>(3))
- .ConsumeValueOrDie();
+ client_->TransferToServer(LiteralUtil::CreateR0<int32_t>(3)).value();
Literal literal =
client_
->ExecuteAndTransfer(replayed,
/*arguments=*/{x_data.get(), y_data.get()},
&execution_options_)
- .ConsumeValueOrDie();
+ .value();
// Expect 5.
LiteralTestUtil::ExpectR0Equal<int32_t>(5, literal);
@@ -114,26 +110,25 @@
auto input =
Parameter(&plus_two_builder, 0, ShapeUtil::MakeShape(S32, {}), "input");
Add(input, ConstantR0<int32_t>(&plus_two_builder, 2));
- XlaComputation plus_two = plus_two_builder.Build().ConsumeValueOrDie();
+ XlaComputation plus_two = plus_two_builder.Build().value();
XlaBuilder mapper_builder(TestName());
auto original = ConstantR1<int32_t>(&mapper_builder, {1, 2, 3});
Map(&mapper_builder, {original}, plus_two, {0});
- XlaComputation computation = mapper_builder.Build().ConsumeValueOrDie();
+ XlaComputation computation = mapper_builder.Build().value();
// Serialize it out.
- std::unique_ptr<HloSnapshot> module =
- computation.Snapshot().ConsumeValueOrDie();
+ std::unique_ptr<HloSnapshot> module = computation.Snapshot().value();
// Replay it.
- XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+ XlaComputation replayed = client_->LoadSnapshot(*module).value();
// Check signature is the same.
std::unique_ptr<ProgramShape> original_shape =
- client_->GetComputationShape(computation).ConsumeValueOrDie();
+ client_->GetComputationShape(computation).value();
std::unique_ptr<ProgramShape> replayed_shape =
- client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ client_->GetComputationShape(replayed).value();
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
replayed_shape->ToProto()));
@@ -141,7 +136,7 @@
Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
- .ConsumeValueOrDie();
+ .value();
// Expect result.
LiteralTestUtil::ExpectR1Equal<int32_t>({3, 4, 5}, literal);