rls: Support multiple returned targets from RLS Server (#9374)

* rls: Support multiple returned targets from RLS Server
Pick the first target that is not in TRANSIENT_FAILURE state.  If none, use the first target.
Also initialize all targets returned from RLS so DataCache will contain a list of child policy wrappers.

Fixes #9236
diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java
index 4135b14..fbd2e46 100644
--- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java
+++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java
@@ -60,6 +60,7 @@
 import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
@@ -372,6 +373,15 @@
       return dataCacheEntry.getChildPolicyWrapper();
     }
 
+    @VisibleForTesting
+    @Nullable
+    ChildPolicyWrapper getChildPolicyWrapper(String target) {
+      if (!hasData()) {
+        return null;
+      }
+      return dataCacheEntry.getChildPolicyWrapper(target);
+    }
+
     @Nullable
     String getHeaderData() {
       if (!hasData()) {
@@ -509,16 +519,16 @@
     private final RouteLookupResponse response;
     private final long expireTime;
     private final long staleTime;
-    private final ChildPolicyWrapper childPolicyWrapper;
+    private final List<ChildPolicyWrapper> childPolicyWrappers;
 
     // GuardedBy CachingRlsLbClient.lock
     DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) {
       super(request);
       this.response = checkNotNull(response, "response");
-      // TODO(creamsoup) fallback to other targets if first one is not available
-      childPolicyWrapper =
+      checkState(!response.targets().isEmpty(), "No targets returned by RLS");
+      childPolicyWrappers =
           refCountedChildPolicyWrapperFactory
-              .createOrGet(response.targets().get(0));
+              .createOrGet(response.targets());
       long now = ticker.read();
       expireTime = now + maxAgeNanos;
       staleTime = now + staleAgeNanos;
@@ -563,9 +573,25 @@
       }
     }
 
+    @VisibleForTesting
+    ChildPolicyWrapper getChildPolicyWrapper(String target) {
+      for (ChildPolicyWrapper childPolicyWrapper : childPolicyWrappers) {
+        if (childPolicyWrapper.getTarget().equals(target)) {
+          return childPolicyWrapper;
+        }
+      }
+
+      throw new RuntimeException("Target not found:" + target);
+    }
+
     @Nullable
     ChildPolicyWrapper getChildPolicyWrapper() {
-      return childPolicyWrapper;
+      for (ChildPolicyWrapper childPolicyWrapper : childPolicyWrappers) {
+        if (childPolicyWrapper.getState() != ConnectivityState.TRANSIENT_FAILURE) {
+          return childPolicyWrapper;
+        }
+      }
+      return childPolicyWrappers.get(0);
     }
 
     String getHeaderData() {
@@ -591,7 +617,9 @@
     @Override
     void cleanup() {
       synchronized (lock) {
-        refCountedChildPolicyWrapperFactory.release(childPolicyWrapper);
+        for (ChildPolicyWrapper policyWrapper : childPolicyWrappers) {
+          refCountedChildPolicyWrapperFactory.release(policyWrapper);
+        }
       }
     }
 
@@ -602,7 +630,7 @@
           .add("response", response)
           .add("expireTime", expireTime)
           .add("staleTime", staleTime)
-          .add("childPolicyWrapper", childPolicyWrapper)
+          .add("childPolicyWrappers", childPolicyWrappers)
           .toString();
     }
   }
@@ -898,7 +926,8 @@
       boolean hasFallback = defaultTarget != null && !defaultTarget.isEmpty();
       if (response.hasData()) {
         ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper();
-        SubchannelPicker picker = childPolicyWrapper.getPicker();
+        SubchannelPicker picker =
+            (childPolicyWrapper != null) ? childPolicyWrapper.getPicker() : null;
         if (picker == null) {
           return PickResult.withNoResult();
         }
diff --git a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java
index f40cde9..0abef01 100644
--- a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java
+++ b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java
@@ -248,6 +248,15 @@
     }
 
     // GuardedBy CachingRlsLbClient.lock
+    List<ChildPolicyWrapper> createOrGet(List<String> targets) {
+      List<ChildPolicyWrapper> retVal = new ArrayList<>();
+      for (String target : targets) {
+        retVal.add(createOrGet(target));
+      }
+      return retVal;
+    }
+
+    // GuardedBy CachingRlsLbClient.lock
     void release(ChildPolicyWrapper childPolicyWrapper) {
       checkNotNull(childPolicyWrapper, "childPolicyWrapper");
       String target = childPolicyWrapper.getTarget();
@@ -312,6 +321,10 @@
       return helper;
     }
 
+    public ConnectivityState getState() {
+      return state;
+    }
+
     void refreshState() {
       helper.getSynchronizationContext().execute(
           new Runnable() {
diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java
index c443a3e..69aa27a 100644
--- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java
+++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java
@@ -21,6 +21,7 @@
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.common.truth.Truth.assertWithMessage;
 import static io.grpc.rls.CachingRlsLbClient.RLS_DATA_KEY;
+import static org.junit.Assert.assertSame;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.inOrder;
@@ -78,8 +79,10 @@
 import java.io.IOException;
 import java.lang.Thread.UncaughtExceptionHandler;
 import java.net.SocketAddress;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
@@ -423,6 +426,72 @@
     assertThat(resp2.getChildPolicyWrapper()).isEqualTo(resp.getChildPolicyWrapper());
   }
 
+  @Test
+  public void get_childPolicyWrapper_multiTarget() throws Exception {
+    setUpRlsLbClient();
+    RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of(
+        "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar"));
+    rlsServerImpl.setLookupTable(
+        ImmutableMap.of(
+            routeLookupRequest,
+            RouteLookupResponse.create(
+                ImmutableList.of("target1", "target2", "target3"),
+                "header")));
+
+    CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest);
+    assertThat(resp.isPending()).isTrue();
+    fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS);
+
+    resp = getInSyncContext(routeLookupRequest);
+    assertThat(resp.hasData()).isTrue();
+    List<ChildPolicyWrapper> policyWrappers = new ArrayList<>();
+
+    for (int i = 1; i <= 3; i++) {
+      String target = "target" + i;
+      policyWrappers.add(resp.getChildPolicyWrapper(target));
+    }
+
+    // Set to states: null, READY, null
+    setState(policyWrappers.get(1), ConnectivityState.READY);
+    ChildPolicyWrapper childPolicy = resp.getChildPolicyWrapper();
+    assertSame(policyWrappers.get(0), childPolicy);
+
+    // Set to states: null, CONNECTING, null
+    setState(policyWrappers.get(1), ConnectivityState.CONNECTING);
+    childPolicy = resp.getChildPolicyWrapper();
+    assertSame(policyWrappers.get(0), childPolicy);
+
+    // Set to states: null, CONNECTING, READY
+    setState(policyWrappers.get(2), ConnectivityState.READY);
+    childPolicy = resp.getChildPolicyWrapper();
+    assertSame(policyWrappers.get(0), childPolicy);
+
+    // Set to states: READY, CONNECTING, READY
+    setState(policyWrappers.get(0), ConnectivityState.READY);
+    childPolicy = resp.getChildPolicyWrapper();
+    assertSame(policyWrappers.get(0), childPolicy);
+
+    // Set to states: TRANSIENT_FAILURE, CONNECTING, READY
+    setState(policyWrappers.get(0), ConnectivityState.TRANSIENT_FAILURE);
+    childPolicy = resp.getChildPolicyWrapper();
+    assertSame(policyWrappers.get(1), childPolicy);
+
+    // Set to states: TRANSIENT_FAILURE, TRANSIENT_FAILURE, TRANSIENT_FAILURE
+    setState(policyWrappers.get(1), ConnectivityState.TRANSIENT_FAILURE);
+    setState(policyWrappers.get(2), ConnectivityState.TRANSIENT_FAILURE);
+    childPolicy = resp.getChildPolicyWrapper();
+    assertSame(policyWrappers.get(0), childPolicy);
+
+    // Set to states: TRANSIENT_FAILURE, TRANSIENT_FAILURE, READY
+    setState(policyWrappers.get(2), ConnectivityState.READY);
+    childPolicy = resp.getChildPolicyWrapper();
+    assertSame(policyWrappers.get(2), childPolicy);
+  }
+
+  private void setState(ChildPolicyWrapper policyWrapper, ConnectivityState newState) {
+    policyWrapper.getHelper().updateBalancingState(newState, policyWrapper.getPicker());
+  }
+
   private static RouteLookupConfig getRouteLookupConfig() {
     return RouteLookupConfig.builder()
         .grpcKeybuilders(ImmutableList.of(