Fix `get_replicated_var_handle` when used with a device assignment.

PiperOrigin-RevId: 292946075
Change-Id: I6a927e0dcf8697e6921c3c06545a98aa2a5dc9b3
diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py
index b607d18..3db5a31 100644
--- a/tensorflow/python/tpu/tpu.py
+++ b/tensorflow/python/tpu/tpu.py
@@ -312,19 +312,24 @@
       return handle
 
     if device_assignment is not None:
+      # Find a variable copy for each replica in the device assignment.
+      # Note that the order of devices for replicas for the variable and the
+      # device assignment might not match.
       job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
-
-      tpu_devices = set()
+      devices_to_vars = {v.device: v for v in vars_}
+      replicated_vars = []
       for replica_id in range(device_assignment.num_replicas):
         for logical_core in range(device_assignment.num_cores_per_replica):
-          tpu_devices.add(
-              device_util.canonicalize(
-                  device_assignment.tpu_device(
-                      replica=replica_id,
-                      logical_core=logical_core,
-                      job=job_name)))
-
-      replicated_vars = [v for v in vars_ if v.device in tpu_devices]
+          device = device_util.canonicalize(
+              device_assignment.tpu_device(
+                  replica=replica_id, logical_core=logical_core, job=job_name))
+          if device in devices_to_vars:
+            replicated_vars.append(devices_to_vars[device])
+            break
+        else:
+          raise ValueError(
+              "Failed to find a variable on any device in replica {} for "
+              "current device assignment".format(replica_id))
     else:
       replicated_vars = vars_