[JAX] Bug fix in TfrtCpuBuffer::logical_on_device_shape.

PiperOrigin-RevId: 376939192
Change-Id: I3f624c70cca924a66c7fdc5361cc6e863f26ccf0
diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc
index 4e51141..b1714b8 100644
--- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc
+++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc
@@ -889,6 +889,9 @@
 }
 
 StatusOr<Shape> TfrtCpuBuffer::logical_on_device_shape() {
+  if (on_device_shape_.is_static()) {
+    return on_device_shape_;
+  }
   ScopedHold device_buffer(this, ScopedHold::kUsage);
   {
     absl::MutexLock lock(&mu_);