Remove unnecessary down_casts.

PiperOrigin-RevId: 363471136
Change-Id: I30e3a082d93aa7c5b61fb62634439cd33237242d
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
index 79c16b4..077e381 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
+++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
@@ -1050,8 +1050,7 @@
 }
 
 int64 PjRtStreamExecutorBuffer::OnDeviceSizeInBytes() const {
-  return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
-      ->client()
+  return client_->client()
       ->backend()
       .transfer_manager()
       ->GetByteSizeRequirement(on_device_shape_);
@@ -1094,9 +1093,7 @@
     // the final set of usage events.
     events = device_buffer->LockUseAndTransferUsageEvents();
   }
-  LocalDeviceState* local_device_state =
-      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
-          ->local_device_state();
+  LocalDeviceState* local_device_state = device_->local_device_state();
   if (wait_for_operations_to_complete) {
     // Block the host until all usage events have completed. Usage events
     // dominate definition events, so this also waits for the buffer to be
@@ -1243,9 +1240,7 @@
     on_ready(InvalidArgument("ToLiteral called on empty tuple"));
     return;
   }
-  LocalDeviceState* local_device =
-      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
-          ->local_device_state();
+  LocalDeviceState* local_device = device_->local_device_state();
   se::Stream* stream = local_device->GetDeviceToHostStream();
   ScopedHold device_buffer(this, ScopedHold::kUsage);
   {
@@ -1268,12 +1263,8 @@
     on_ready(event_or.status());
     return;
   }
-  tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
-      ->client()
-      ->backend()
-      .transfer_manager()
-      ->TransferLiteralFromDevice(stream, shaped_buffer, literal,
-                                  std::move(on_ready));
+  client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
+      stream, shaped_buffer, literal, std::move(on_ready));
 
   auto usage_event = std::make_shared<BufferSequencingEvent>();
   local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
@@ -1361,9 +1352,8 @@
       // StallStreamOnError only makes sure the destination device is ok, so
       // make sure that the src buffer remains valid until after any transfers
       // have completed.
-      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
-          ->local_device_state()
-          ->ThenRelease(transfer_stream, std::move(src_device_buffer));
+      device_->local_device_state()->ThenRelease(transfer_stream,
+                                                 std::move(src_device_buffer));
     }
     return copy_event_or.status();
   }
@@ -1399,11 +1389,8 @@
       tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
           ->GetLocalDeviceState());
   LocalDeviceState* transfer_local_device =
-      tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
-              ->EnqueueD2DTransfersOnSrcStream()
-          ? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
-                ->local_device_state()
-          : dst_local_device;
+      client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
+                                                : dst_local_device;
   CHECK_EQ(dst_local_device->allocation_model(),
            transfer_local_device->allocation_model());
 
@@ -1441,9 +1428,7 @@
   // alternative is to ensure, before freeing the buffer, that the compute
   // stream is synchronized past the transfer, but it seems better to hold onto
   // the buffer too long than to stall the compute stream.
-  RecordUsage(std::move(src_device_buffer),
-              tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
-                  ->local_device_state(),
+  RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
               transfer_local_device, event, transfer_stream,
               /*prefer_to_retain_reference=*/true);
 
@@ -1452,8 +1437,7 @@
 
 Status PjRtStreamExecutorBuffer::CopyToRemoteDevice(
     absl::string_view serialized_descriptor) {
-  return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
-      ->CopyToRemoteDevice(this, serialized_descriptor);
+  return client_->CopyToRemoteDevice(this, serialized_descriptor);
 }
 
 Status PjRtStreamExecutorBuffer::BlockHostUntilReady() {
@@ -1468,9 +1452,7 @@
     }
     device_buffer = device_buffer_;
   }
-  LocalDeviceState* local_device_state =
-      tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
-          ->local_device_state();
+  LocalDeviceState* local_device_state = device_->local_device_state();
   std::unique_ptr<se::Stream> stream;
   for (auto& event : device_buffer->definition_events()) {
     if (!event->IsComplete()) {
@@ -1530,18 +1512,14 @@
 
 // Makes a tuple from the arguments to an execution.
 StatusOr<TupleHandle> MakeTupleHelper(
-    PjRtClient* client, LocalDeviceState* local_device,
+    PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
     bool strict_shape_checking, const Shape& tupled_parameter_shape,
     absl::Span<PjRtBuffer* const> py_buffers,
     absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
     int device_ordinal) {
-  se::DeviceMemoryAllocator* allocator =
-      tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
+  se::DeviceMemoryAllocator* allocator = client->allocator();
   TransferManager* transfer_manager =
-      tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
-          ->client()
-          ->backend()
-          .transfer_manager();
+      client->client()->backend().transfer_manager();
 
   if (tupled_parameter_shape.tuple_shapes_size() != py_buffers.size()) {
     return InvalidArgument("Executable expected %lld parameters but got %lld",
@@ -1634,10 +1612,7 @@
           std::move(addressable_device_logical_ids)),
       addressable_devices_(std::move(addressable_devices)) {
   TransferManager* transfer_manager =
-      tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
-          ->client()
-          ->backend()
-          .transfer_manager();
+      client_->client()->backend().transfer_manager();
   executables_.reserve(executables.size());
   for (auto& executable : executables) {
     const auto& computation_layout =
@@ -1713,10 +1688,7 @@
   std::vector<ExecutionInput> execution_inputs;
   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
   TransferManager* transfer_manager =
-      tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
-          ->client()
-          ->backend()
-          .transfer_manager();
+      client_->client()->backend().transfer_manager();
   // Lift tuple_handle outside the conditional so that the event it returns is
   // not destroyed until after the loop below that waits on events.
   absl::optional<TupleHandle> tuple_handle;
@@ -1748,10 +1720,8 @@
           execution_input.MutableBuffers()->begin();
       ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
           execution_input.MutableBuffers()->end();
-      device_buffers[i].AddToInput(
-          &input_iterator, iterator_end, &execution_input,
-          tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
-              ->allocator());
+      device_buffers[i].AddToInput(&input_iterator, iterator_end,
+                                   &execution_input, client_->allocator());
       CHECK(input_iterator == iterator_end);
     }
   }