Remove unnecessary down_casts.
PiperOrigin-RevId: 363471136
Change-Id: I30e3a082d93aa7c5b61fb62634439cd33237242d
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
index 79c16b4..077e381 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
+++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
@@ -1050,8 +1050,7 @@
}
int64 PjRtStreamExecutorBuffer::OnDeviceSizeInBytes() const {
- return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
- ->client()
+ return client_->client()
->backend()
.transfer_manager()
->GetByteSizeRequirement(on_device_shape_);
@@ -1094,9 +1093,7 @@
// the final set of usage events.
events = device_buffer->LockUseAndTransferUsageEvents();
}
- LocalDeviceState* local_device_state =
- tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
- ->local_device_state();
+ LocalDeviceState* local_device_state = device_->local_device_state();
if (wait_for_operations_to_complete) {
// Block the host until all usage events have completed. Usage events
// dominate definition events, so this also waits for the buffer to be
@@ -1243,9 +1240,7 @@
on_ready(InvalidArgument("ToLiteral called on empty tuple"));
return;
}
- LocalDeviceState* local_device =
- tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
- ->local_device_state();
+ LocalDeviceState* local_device = device_->local_device_state();
se::Stream* stream = local_device->GetDeviceToHostStream();
ScopedHold device_buffer(this, ScopedHold::kUsage);
{
@@ -1268,12 +1263,8 @@
on_ready(event_or.status());
return;
}
- tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
- ->client()
- ->backend()
- .transfer_manager()
- ->TransferLiteralFromDevice(stream, shaped_buffer, literal,
- std::move(on_ready));
+ client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
+ stream, shaped_buffer, literal, std::move(on_ready));
auto usage_event = std::make_shared<BufferSequencingEvent>();
local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
@@ -1361,9 +1352,8 @@
// StallStreamOnError only makes sure the destination device is ok, so
// make sure that the src buffer remains valid until after any transfers
// have completed.
- tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
- ->local_device_state()
- ->ThenRelease(transfer_stream, std::move(src_device_buffer));
+ device_->local_device_state()->ThenRelease(transfer_stream,
+ std::move(src_device_buffer));
}
return copy_event_or.status();
}
@@ -1399,11 +1389,8 @@
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
->GetLocalDeviceState());
LocalDeviceState* transfer_local_device =
- tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
- ->EnqueueD2DTransfersOnSrcStream()
- ? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
- ->local_device_state()
- : dst_local_device;
+ client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
+ : dst_local_device;
CHECK_EQ(dst_local_device->allocation_model(),
transfer_local_device->allocation_model());
@@ -1441,9 +1428,7 @@
// alternative is to ensure, before freeing the buffer, that the compute
// stream is synchronized past the transfer, but it seems better to hold onto
// the buffer too long than to stall the compute stream.
- RecordUsage(std::move(src_device_buffer),
- tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
- ->local_device_state(),
+ RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
transfer_local_device, event, transfer_stream,
/*prefer_to_retain_reference=*/true);
@@ -1452,8 +1437,7 @@
Status PjRtStreamExecutorBuffer::CopyToRemoteDevice(
absl::string_view serialized_descriptor) {
- return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
- ->CopyToRemoteDevice(this, serialized_descriptor);
+ return client_->CopyToRemoteDevice(this, serialized_descriptor);
}
Status PjRtStreamExecutorBuffer::BlockHostUntilReady() {
@@ -1468,9 +1452,7 @@
}
device_buffer = device_buffer_;
}
- LocalDeviceState* local_device_state =
- tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
- ->local_device_state();
+ LocalDeviceState* local_device_state = device_->local_device_state();
std::unique_ptr<se::Stream> stream;
for (auto& event : device_buffer->definition_events()) {
if (!event->IsComplete()) {
@@ -1530,18 +1512,14 @@
// Makes a tuple from the arguments to an execution.
StatusOr<TupleHandle> MakeTupleHelper(
- PjRtClient* client, LocalDeviceState* local_device,
+ PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
bool strict_shape_checking, const Shape& tupled_parameter_shape,
absl::Span<PjRtBuffer* const> py_buffers,
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
int device_ordinal) {
- se::DeviceMemoryAllocator* allocator =
- tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
+ se::DeviceMemoryAllocator* allocator = client->allocator();
TransferManager* transfer_manager =
- tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
- ->client()
- ->backend()
- .transfer_manager();
+ client->client()->backend().transfer_manager();
if (tupled_parameter_shape.tuple_shapes_size() != py_buffers.size()) {
return InvalidArgument("Executable expected %lld parameters but got %lld",
@@ -1634,10 +1612,7 @@
std::move(addressable_device_logical_ids)),
addressable_devices_(std::move(addressable_devices)) {
TransferManager* transfer_manager =
- tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
- ->client()
- ->backend()
- .transfer_manager();
+ client_->client()->backend().transfer_manager();
executables_.reserve(executables.size());
for (auto& executable : executables) {
const auto& computation_layout =
@@ -1713,10 +1688,7 @@
std::vector<ExecutionInput> execution_inputs;
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
TransferManager* transfer_manager =
- tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
- ->client()
- ->backend()
- .transfer_manager();
+ client_->client()->backend().transfer_manager();
// Lift tuple_handle outside the conditional so that the event it returns is
// not destroyed until after the loop below that waits on events.
absl::optional<TupleHandle> tuple_handle;
@@ -1748,10 +1720,8 @@
execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
execution_input.MutableBuffers()->end();
- device_buffers[i].AddToInput(
- &input_iterator, iterator_end, &execution_input,
- tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
- ->allocator());
+ device_buffers[i].AddToInput(&input_iterator, iterator_end,
+ &execution_input, client_->allocator());
CHECK(input_iterator == iterator_end);
}
}