Fix `get_replicated_var_handle` when used with a device assignment.
PiperOrigin-RevId: 292946075
Change-Id: I6a927e0dcf8697e6921c3c06545a98aa2a5dc9b3
diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py
index b607d18..3db5a31 100644
--- a/tensorflow/python/tpu/tpu.py
+++ b/tensorflow/python/tpu/tpu.py
@@ -312,19 +312,24 @@
return handle
if device_assignment is not None:
+ # Find a variable copy for each replica in the device assignment.
+ # Note that the order of devices for replicas for the variable and the
+ # device assignment might not match.
job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
-
- tpu_devices = set()
+ devices_to_vars = {v.device: v for v in vars_}
+ replicated_vars = []
for replica_id in range(device_assignment.num_replicas):
for logical_core in range(device_assignment.num_cores_per_replica):
- tpu_devices.add(
- device_util.canonicalize(
- device_assignment.tpu_device(
- replica=replica_id,
- logical_core=logical_core,
- job=job_name)))
-
- replicated_vars = [v for v in vars_ if v.device in tpu_devices]
+ device = device_util.canonicalize(
+ device_assignment.tpu_device(
+ replica=replica_id, logical_core=logical_core, job=job_name))
+ if device in devices_to_vars:
+ replicated_vars.append(devices_to_vars[device])
+ break
+ else:
+ raise ValueError(
+ "Failed to find a variable on any device in replica {} for "
+ "current device assignment".format(replica_id))
else:
replicated_vars = vars_