Use correct variable _device attribute in Keras optimizer_v2.
PiperOrigin-RevId: 302069884
Change-Id: I32ff43f146c6f60d462d2713908c3cf258ace3de
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
index 2a4d4cf..d9f090c 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -699,8 +699,10 @@
def _prepare(self, var_list):
keys = set()
for var in var_list:
- var_devices = (getattr(var, "devices", None) or # Distributed
- [var.device]) # Regular
+ if isinstance(var, ds_values.DistributedValues):
+ var_devices = var._devices # pylint: disable=protected-access
+ else:
+ var_devices = [var.device]
var_dtype = var.dtype.base_dtype
for var_device in var_devices:
keys.add((var_device, var_dtype))