commit | 998bb9d9438c53e9732e1f474264fe4efff3ca5b | [log] [tgz] |
---|---|---|
author | Qiao Zhang <zhangqiaorjc@google.com> | Tue Jun 01 16:08:25 2021 -0700 |
committer | TensorFlower Gardener <gardener@tensorflow.org> | Tue Jun 01 16:14:49 2021 -0700 |
tree | 4a34dcf819cd123c36a8cc8dd9d76d52bd0a75b4 | |
parent | 1acfde6e42c1007b028c82833c5922f14023d6fe [diff] |
[JAX] Bug fix in TfrtCpuBuffer::logical_on_device_shape. PiperOrigin-RevId: 376939192 Change-Id: I3f624c70cca924a66c7fdc5361cc6e863f26ccf0
diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc index 4e51141..b1714b8 100644 --- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc
@@ -889,6 +889,9 @@ } StatusOr<Shape> TfrtCpuBuffer::logical_on_device_shape() { + if (on_device_shape_.is_static()) { + return on_device_shape_; + } ScopedHold device_buffer(this, ScopedHold::kUsage); { absl::MutexLock lock(&mu_);