core: refactor load-balancing config handling (#5397)

The LoadBalancingConfig message, which looks like
```json
{
  "policy_name" : {
    "config_key1" : "config_value1",
    "config_key2" : "config_value2"
   }
}
```
appears multiple times. It gets super tedious and confusing to handle, because both the whole config and the value (in the above example is `{ "config_key1" : "config_value1" }`) are just `Map<String, Object>`, and each user needs to do the following validation:
 1. The whole config must have exactly one key
 2. The value must be a map

Here I define `LbConfig` that holds the policy name and the config value, and a method in `ServiceConfigUtil` that converts the parsed JSON format into `LbConfig`.

There is also multiple cases where you need to handle a list of configs (top-level balancing policy, child and fallback policies in xds, grpclb child policies). I also made another helper method in `ServiceConfigUtil` to convert them into `List<LbConfig>`.

Found and fixed a bug in the xds code, where the top-level balancer should pass the config value (excluding the policy name), not the whole config to the child balancers. Search for "supported_1_option" in the diff to see it in the tests.
diff --git a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java
index 7db6267..aac60bc 100644
--- a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java
+++ b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java
@@ -32,12 +32,12 @@
 import io.grpc.LoadBalancerProvider;
 import io.grpc.LoadBalancerRegistry;
 import io.grpc.Status;
+import io.grpc.internal.ServiceConfigUtil.LbConfig;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Map.Entry;
 import java.util.logging.Logger;
 import javax.annotation.Nullable;
 
@@ -215,6 +215,8 @@
       }
 
       if (haveBalancerAddress) {
+        // This is a special case where the existence of balancer address in the resolved address
+        // selects "grpclb" policy regardless of the service config.
         LoadBalancerProvider grpclbProvider = registry.getProvider("grpclb");
         if (grpclbProvider == null) {
           if (backendAddrs.isEmpty()) {
@@ -238,21 +240,16 @@
       }
       roundRobinDueToGrpclbDepMissing = false;
 
-      List<Map<String, Object>> lbConfigs = null;
+      List<LbConfig> lbConfigs = null;
       if (config != null) {
-        lbConfigs = ServiceConfigUtil.getLoadBalancingConfigsFromServiceConfig(config);
+        List<Map<String, Object>> rawLbConfigs =
+            ServiceConfigUtil.getLoadBalancingConfigsFromServiceConfig(config);
+        lbConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList(rawLbConfigs);
       }
       if (lbConfigs != null && !lbConfigs.isEmpty()) {
         LinkedHashSet<String> policiesTried = new LinkedHashSet<>();
-        for (Map<String, Object> lbConfig : lbConfigs) {
-          if (lbConfig.size() != 1) {
-            throw new PolicyException(
-                "There are " + lbConfig.size()
-                + " load-balancing configs in a list item. Exactly one is expected. Config="
-                + lbConfig);
-          }
-          Entry<String, Object> entry = lbConfig.entrySet().iterator().next();
-          String policy = entry.getKey();
+        for (LbConfig lbConfig : lbConfigs) {
+          String policy = lbConfig.getPolicyName();
           LoadBalancerProvider provider = registry.getProvider(policy);
           if (provider != null) {
             if (!policiesTried.isEmpty()) {
@@ -260,7 +257,7 @@
                   ChannelLogLevel.DEBUG,
                   "{0} specified by Service Config are not available", policiesTried);
             }
-            return new PolicySelection(provider, servers, (Map) entry.getValue());
+            return new PolicySelection(provider, servers, lbConfig.getRawConfigValue());
           }
           policiesTried.add(policy);
         }
@@ -297,13 +294,12 @@
     final List<EquivalentAddressGroup> serverList;
     @Nullable final Map<String, Object> config;
 
-    @SuppressWarnings("unchecked")
     PolicySelection(
         LoadBalancerProvider provider, List<EquivalentAddressGroup> serverList,
-        @Nullable Map<?, ?> config) {
+        @Nullable Map<String, Object> config) {
       this.provider = checkNotNull(provider, "provider");
       this.serverList = Collections.unmodifiableList(checkNotNull(serverList, "serverList"));
-      this.config = (Map<String, Object>) config;
+      this.config = config;
     }
   }
 
diff --git a/core/src/main/java/io/grpc/internal/ServiceConfigUtil.java b/core/src/main/java/io/grpc/internal/ServiceConfigUtil.java
index 3fb37d2..61bd2e9 100644
--- a/core/src/main/java/io/grpc/internal/ServiceConfigUtil.java
+++ b/core/src/main/java/io/grpc/internal/ServiceConfigUtil.java
@@ -21,6 +21,8 @@
 import static com.google.common.math.LongMath.checkedAdd;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.MoreObjects;
+import com.google.common.base.Objects;
 import io.grpc.internal.RetriableStream.Throttle;
 import java.text.ParseException;
 import java.util.ArrayList;
@@ -345,35 +347,75 @@
   }
 
   /**
-   * Extracts the loadbalancing policy name from loadbalancer config.
+   * Unwrap a LoadBalancingConfig JSON object into a {@link LbConfig}.  The input is a JSON object
+   * (map) with exactly one entry, where the key is the policy name and the value is a config object
+   * for that policy.
    */
-  public static String getBalancerPolicyNameFromLoadBalancingConfig(Map<String, Object> lbConfig) {
-    return lbConfig.entrySet().iterator().next().getKey();
+  @SuppressWarnings("unchecked")
+  public static LbConfig unwrapLoadBalancingConfig(Object lbConfig) {
+    Map<String, Object> map;
+    try {
+      map = (Map<String, Object>) lbConfig;
+    } catch (ClassCastException e) {
+      ClassCastException ex = new ClassCastException("Invalid type. Config=" + lbConfig);
+      ex.initCause(e);
+      throw ex;
+    }
+    if (map.size() != 1) {
+      throw new RuntimeException(
+          "There are " + map.size() + " fields in a LoadBalancingConfig object. Exactly one"
+          + " is expected. Config=" + lbConfig);
+    }
+    Map.Entry<String, Object> entry = map.entrySet().iterator().next();
+    Map<String, Object> configValue;
+    try {
+      configValue = (Map<String, Object>) entry.getValue();
+    } catch (ClassCastException e) {
+      ClassCastException ex =
+          new ClassCastException("Invalid value type.  value=" + entry.getValue());
+      ex.initCause(e);
+      throw ex;
+    }
+    return new LbConfig(entry.getKey(), configValue);
+  }
+
+  /**
+   * Given a JSON list of LoadBalancingConfigs, and convert it into a list of LbConfig.
+   */
+  @SuppressWarnings("unchecked")
+  public static List<LbConfig> unwrapLoadBalancingConfigList(Object listObject) {
+    List<?> list;
+    try {
+      list = (List<?>) listObject;
+    } catch (ClassCastException e) {
+      ClassCastException ex = new ClassCastException("List expected, but is " + listObject);
+      ex.initCause(e);
+      throw ex;
+    }
+    ArrayList<LbConfig> result = new ArrayList<>();
+    for (Object rawChildPolicy : list) {
+      result.add(unwrapLoadBalancingConfig(rawChildPolicy));
+    }
+    return Collections.unmodifiableList(result);
   }
 
   /**
    * Extracts the loadbalancer name from xds loadbalancer config.
    */
-  @SuppressWarnings("unchecked")
-  public static String getBalancerNameFromXdsConfig(
-      Map<String, Object> xdsConfig) {
-    Object entry = xdsConfig.entrySet().iterator().next().getValue();
-    return getString((Map<String, Object>) entry, XDS_CONFIG_BALANCER_NAME_KEY);
+  public static String getBalancerNameFromXdsConfig(LbConfig xdsConfig) {
+    Map<String, Object> map = xdsConfig.getRawConfigValue();
+    return getString(map, XDS_CONFIG_BALANCER_NAME_KEY);
   }
 
   /**
    * Extracts list of child policies from xds loadbalancer config.
    */
-  @SuppressWarnings("unchecked")
   @Nullable
-  public static List<Map<String, Object>> getChildPolicyFromXdsConfig(
-      Map<String, Object> xdsConfig) {
-    Object rawEntry = xdsConfig.entrySet().iterator().next().getValue();
-    if (rawEntry instanceof Map) {
-      Map<String, Object> entry = (Map<String, Object>) rawEntry;
-      if (entry.containsKey(XDS_CONFIG_CHILD_POLICY_KEY)) {
-        return (List<Map<String, Object>>) (List<?>) getList(entry, XDS_CONFIG_CHILD_POLICY_KEY);
-      }
+  public static List<LbConfig> getChildPolicyFromXdsConfig(LbConfig xdsConfig) {
+    Map<String, Object> map = xdsConfig.getRawConfigValue();
+    Object rawChildPolicies = map.get(XDS_CONFIG_CHILD_POLICY_KEY);
+    if (rawChildPolicies != null) {
+      return unwrapLoadBalancingConfigList(rawChildPolicies);
     }
     return null;
   }
@@ -381,16 +423,12 @@
   /**
    * Extracts list of fallback policies from xds loadbalancer config.
    */
-  @SuppressWarnings("unchecked")
   @Nullable
-  public static List<Map<String, Object>> getFallbackPolicyFromXdsConfig(
-      Map<String, Object> lbConfig) {
-    Object rawEntry = lbConfig.entrySet().iterator().next().getValue();
-    if (rawEntry instanceof Map) {
-      Map<String, Object> entry = (Map<String, Object>) rawEntry;
-      if (entry.containsKey(XDS_CONFIG_FALLBACK_POLICY_KEY)) {
-        return (List<Map<String, Object>>) (List<?>) getList(entry, XDS_CONFIG_FALLBACK_POLICY_KEY);
-      }
+  public static List<LbConfig> getFallbackPolicyFromXdsConfig(LbConfig xdsConfig) {
+    Map<String, Object> map = xdsConfig.getRawConfigValue();
+    Object rawFallbackPolicies = map.get(XDS_CONFIG_FALLBACK_POLICY_KEY);
+    if (rawFallbackPolicies != null) {
+      return unwrapLoadBalancingConfigList(rawFallbackPolicies);
     }
     return null;
   }
@@ -642,4 +680,49 @@
     // we did over/under flow, if the sign is negative we should return MAX otherwise MIN
     return Long.MAX_VALUE + ((naiveSum >>> (Long.SIZE - 1)) ^ 1);
   }
+
+  /**
+   * A LoadBalancingConfig that includes the policy name (the key) and its raw config value (parsed
+   * JSON).
+   */
+  public static final class LbConfig {
+    private final String policyName;
+    private final Map<String, Object> rawConfigValue;
+
+    public LbConfig(String policyName, Map<String, Object> rawConfigValue) {
+      this.policyName = checkNotNull(policyName, "policyName");
+      this.rawConfigValue = checkNotNull(rawConfigValue, "rawConfigValue");
+    }
+
+    public String getPolicyName() {
+      return policyName;
+    }
+
+    public Map<String, Object> getRawConfigValue() {
+      return rawConfigValue;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (o instanceof LbConfig) {
+        LbConfig other = (LbConfig) o;
+        return policyName.equals(other.policyName)
+            && rawConfigValue.equals(other.rawConfigValue);
+      }
+      return false;
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(policyName, rawConfigValue);
+    }
+
+    @Override
+    public String toString() {
+      return MoreObjects.toStringHelper(this)
+          .add("policyName", policyName)
+          .add("rawConfigValue", rawConfigValue)
+          .toString();
+    }
+  }
 }
diff --git a/core/src/test/java/io/grpc/internal/ServiceConfigUtilTest.java b/core/src/test/java/io/grpc/internal/ServiceConfigUtilTest.java
index 413c868..c8757f8 100644
--- a/core/src/test/java/io/grpc/internal/ServiceConfigUtilTest.java
+++ b/core/src/test/java/io/grpc/internal/ServiceConfigUtilTest.java
@@ -18,9 +18,10 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
 
+import io.grpc.internal.ServiceConfigUtil.LbConfig;
 import java.util.List;
-import java.util.Map;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -30,17 +31,6 @@
  */
 @RunWith(JUnit4.class)
 public class ServiceConfigUtilTest {
-  @SuppressWarnings("unchecked")
-  @Test
-  public void getBalancerPolicyNameFromLoadBalancingConfig() throws Exception {
-    String lbConfig = "{\"lbPolicy1\" : { \"key\" : \"val\" }}";
-    assertEquals(
-        "lbPolicy1",
-        ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig(
-            (Map<String, Object>) JsonParser.parse(lbConfig)));
-  }
-
-  @SuppressWarnings("unchecked")
   @Test
   public void getBalancerNameFromXdsConfig() throws Exception {
     String lbConfig = "{\"xds_experimental\" : { "
@@ -51,10 +41,9 @@
     assertEquals(
         "dns:///balancer.example.com:8080",
         ServiceConfigUtil.getBalancerNameFromXdsConfig(
-            (Map<String, Object>) JsonParser.parse(lbConfig)));
+            ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig))));
   }
 
-  @SuppressWarnings("unchecked")
   @Test
   public void getChildPolicyFromXdsConfig() throws Exception {
     String lbConfig = "{\"xds_experimental\" : { "
@@ -62,18 +51,17 @@
         + "\"childPolicy\" : [{\"round_robin\" : {}}, {\"lbPolicy2\" : {\"key\" : \"val\"}}],"
         + "\"fallbackPolicy\" : [{\"lbPolicy3\" : {\"key\" : \"val\"}}, {\"lbPolicy4\" : {}}]"
         + "}}";
-    Map<String, Object> expectedChildPolicy1 = (Map<String, Object>) JsonParser.parse(
-        "{\"round_robin\" : {}}");
-    Map<String, Object> expectedChildPolicy2 = (Map<String, Object>) JsonParser.parse(
-        "{\"lbPolicy2\" : {\"key\" : \"val\"}}");
+    LbConfig expectedChildPolicy1 = ServiceConfigUtil.unwrapLoadBalancingConfig(
+        JsonParser.parse("{\"round_robin\" : {}}"));
+    LbConfig expectedChildPolicy2 = ServiceConfigUtil.unwrapLoadBalancingConfig(
+        JsonParser.parse("{\"lbPolicy2\" : {\"key\" : \"val\"}}"));
 
-    List<Map<String, Object>> childPolicies = ServiceConfigUtil.getChildPolicyFromXdsConfig(
-        (Map<String, Object>) JsonParser.parse(lbConfig));
+    List<LbConfig> childPolicies = ServiceConfigUtil.getChildPolicyFromXdsConfig(
+        ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig)));
 
     assertThat(childPolicies).containsExactly(expectedChildPolicy1, expectedChildPolicy2);
   }
 
-  @SuppressWarnings("unchecked")
   @Test
   public void getChildPolicyFromXdsConfig_null() throws Exception {
     String lbConfig = "{\"xds_experimental\" : { "
@@ -81,13 +69,12 @@
         + "\"fallbackPolicy\" : [{\"lbPolicy3\" : {\"key\" : \"val\"}}, {\"lbPolicy4\" : {}}]"
         + "}}";
 
-    List<Map<String, Object>> childPolicies = ServiceConfigUtil.getChildPolicyFromXdsConfig(
-        (Map<String, Object>) JsonParser.parse(lbConfig));
+    List<LbConfig> childPolicies = ServiceConfigUtil.getChildPolicyFromXdsConfig(
+        ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig)));
 
     assertThat(childPolicies).isNull();
   }
 
-  @SuppressWarnings("unchecked")
   @Test
   public void getFallbackPolicyFromXdsConfig() throws Exception {
     String lbConfig = "{\"xds_experimental\" : { "
@@ -95,18 +82,17 @@
         + "\"childPolicy\" : [{\"round_robin\" : {}}, {\"lbPolicy2\" : {\"key\" : \"val\"}}],"
         + "\"fallbackPolicy\" : [{\"lbPolicy3\" : {\"key\" : \"val\"}}, {\"lbPolicy4\" : {}}]"
         + "}}";
-    Map<String, Object> expectedFallbackPolicy1 = (Map<String, Object>) JsonParser.parse(
-        "{\"lbPolicy3\" : {\"key\" : \"val\"}}");
-    Map<String, Object> expectedFallbackPolicy2 = (Map<String, Object>) JsonParser.parse(
-        "{\"lbPolicy4\" : {}}");
+    LbConfig expectedFallbackPolicy1 = ServiceConfigUtil.unwrapLoadBalancingConfig(
+        JsonParser.parse("{\"lbPolicy3\" : {\"key\" : \"val\"}}"));
+    LbConfig expectedFallbackPolicy2 = ServiceConfigUtil.unwrapLoadBalancingConfig(
+        JsonParser.parse("{\"lbPolicy4\" : {}}"));
 
-    List<Map<String, Object>> childPolicies = ServiceConfigUtil.getFallbackPolicyFromXdsConfig(
-        (Map<String, Object>) JsonParser.parse(lbConfig));
+    List<LbConfig> childPolicies = ServiceConfigUtil.getFallbackPolicyFromXdsConfig(
+        ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig)));
 
     assertThat(childPolicies).containsExactly(expectedFallbackPolicy1, expectedFallbackPolicy2);
   }
 
-  @SuppressWarnings("unchecked")
   @Test
   public void getFallbackPolicyFromXdsConfig_null() throws Exception {
     String lbConfig = "{\"xds_experimental\" : { "
@@ -114,9 +100,127 @@
         + "\"childPolicy\" : [{\"round_robin\" : {}}, {\"lbPolicy2\" : {\"key\" : \"val\"}}]"
         + "}}";
 
-    List<Map<String, Object>> fallbackPolicies = ServiceConfigUtil.getFallbackPolicyFromXdsConfig(
-        (Map<String, Object>) JsonParser.parse(lbConfig));
+    List<LbConfig> fallbackPolicies = ServiceConfigUtil.getFallbackPolicyFromXdsConfig(
+        ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig)));
 
     assertThat(fallbackPolicies).isNull();
   }
+
+  @Test
+  public void unwrapLoadBalancingConfig() throws Exception {
+    String lbConfig = "{\"xds_experimental\" : { "
+        + "\"balancerName\" : \"dns:///balancer.example.com:8080\","
+        + "\"childPolicy\" : [{\"round_robin\" : {}}, {\"lbPolicy2\" : {\"key\" : \"val\"}}]"
+        + "}}";
+
+    LbConfig config = ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig));
+    assertThat(config.getPolicyName()).isEqualTo("xds_experimental");
+    assertThat(config.getRawConfigValue()).isEqualTo(JsonParser.parse(
+            "{\"balancerName\" : \"dns:///balancer.example.com:8080\","
+            + "\"childPolicy\" : [{\"round_robin\" : {}}, {\"lbPolicy2\" : {\"key\" : \"val\"}}]"
+            + "}"));
+  }
+
+  @Test
+  public void unwrapLoadBalancingConfig_failOnTooManyFields() throws Exception {
+    // A LoadBalancingConfig should not have more than one field.
+    String lbConfig = "{\"xds_experimental\" : { "
+        + "\"balancerName\" : \"dns:///balancer.example.com:8080\","
+        + "\"childPolicy\" : [{\"round_robin\" : {}}, {\"lbPolicy2\" : {\"key\" : \"val\"}}]"
+        + "},"
+        + "\"grpclb\" : {} }";
+    try {
+      ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig));
+      fail("Should throw");
+    } catch (Exception e) {
+      assertThat(e).hasMessageThat().contains("There are 2 fields");
+    }
+  }
+
+  @Test
+  public void unwrapLoadBalancingConfig_failOnEmptyObject() throws Exception {
+    // A LoadBalancingConfig should not exactly one field.
+    String lbConfig = "{}";
+    try {
+      ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig));
+      fail("Should throw");
+    } catch (Exception e) {
+      assertThat(e).hasMessageThat().contains("There are 0 fields");
+    }
+  }
+
+  @Test
+  public void unwrapLoadBalancingConfig_failOnList() throws Exception {
+    // A LoadBalancingConfig must be a JSON dictionary (map)
+    String lbConfig = "[ { \"xds\" : {} } ]";
+    try {
+      ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig));
+      fail("Should throw");
+    } catch (Exception e) {
+      assertThat(e).hasMessageThat().contains("Invalid type");
+    }
+  }
+
+  @Test
+  public void unwrapLoadBalancingConfig_failOnString() throws Exception {
+    // A LoadBalancingConfig must be a JSON dictionary (map)
+    String lbConfig = "\"xds\"";
+    try {
+      ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig));
+      fail("Should throw");
+    } catch (Exception e) {
+      assertThat(e).hasMessageThat().contains("Invalid type");
+    }
+  }
+
+  @Test
+  public void unwrapLoadBalancingConfig_failWhenConfigIsString() throws Exception {
+    // The value of the config should be a JSON dictionary (map)
+    String lbConfig = "{ \"xds\" : \"I thought I was a config.\" }";
+    try {
+      ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfig));
+      fail("Should throw");
+    } catch (Exception e) {
+      assertThat(e).hasMessageThat().contains("Invalid value type");
+    }
+  }
+
+  @Test
+  public void unwrapLoadBalancingConfigList() throws Exception {
+    String lbConfig = "[ "
+        + "{\"xds_experimental\" : {\"balancerName\" : \"dns:///balancer.example.com:8080\"} },"
+        + "{\"grpclb\" : {} } ]";
+    List<LbConfig> configs =
+        ServiceConfigUtil.unwrapLoadBalancingConfigList(JsonParser.parse(lbConfig));
+    assertThat(configs).containsExactly(
+        ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(
+                "{\"xds_experimental\" : "
+                + "{\"balancerName\" : \"dns:///balancer.example.com:8080\"} }")),
+        ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(
+                "{\"grpclb\" : {} }"))).inOrder();
+  }
+
+  @Test
+  public void unwrapLoadBalancingConfigList_failOnObject() throws Exception {
+    String notAList = "{}";
+    try {
+      ServiceConfigUtil.unwrapLoadBalancingConfigList(JsonParser.parse(notAList));
+      fail("Should throw");
+    } catch (Exception e) {
+      assertThat(e).hasMessageThat().contains("List expected");
+    }
+  }
+
+  @Test
+  public void unwrapLoadBalancingConfigList_failOnMalformedConfig() throws Exception {
+    String lbConfig = "[ "
+        + "{\"xds_experimental\" : \"I thought I was a config\" },"
+        + "{\"grpclb\" : {} } ]";
+    try {
+      ServiceConfigUtil.unwrapLoadBalancingConfigList(JsonParser.parse(lbConfig));
+      fail("Should throw");
+    } catch (Exception e) {
+      assertThat(e).hasMessageThat().contains("Invalid value type");
+    }
+  }
 }
diff --git a/xds/src/main/java/io/grpc/xds/XdsLbState.java b/xds/src/main/java/io/grpc/xds/XdsLbState.java
index b6feb78..0098acd 100644
--- a/xds/src/main/java/io/grpc/xds/XdsLbState.java
+++ b/xds/src/main/java/io/grpc/xds/XdsLbState.java
@@ -26,10 +26,10 @@
 import io.grpc.LoadBalancer.Subchannel;
 import io.grpc.ManagedChannel;
 import io.grpc.Status;
+import io.grpc.internal.ServiceConfigUtil.LbConfig;
 import io.grpc.xds.XdsComms.AdsStreamCallback;
 import java.net.SocketAddress;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.atomic.AtomicReference;
 import javax.annotation.Nullable;
 
@@ -54,7 +54,7 @@
   final String balancerName;
 
   @Nullable
-  final Map<String, Object> childPolicy;
+  final LbConfig childPolicy;
 
   private final SubchannelStore subchannelStore;
   private final Helper helper;
@@ -66,7 +66,7 @@
 
   XdsLbState(
       String balancerName,
-      @Nullable Map<String, Object> childPolicy,
+      @Nullable LbConfig childPolicy,
       @Nullable XdsComms xdsComms,
       Helper helper,
       SubchannelStore subchannelStore,
diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java b/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java
index 46a0b13..9a8e648 100644
--- a/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java
@@ -19,7 +19,6 @@
 import static com.google.common.base.Preconditions.checkNotNull;
 import static io.grpc.ConnectivityState.IDLE;
 import static io.grpc.ConnectivityState.SHUTDOWN;
-import static io.grpc.internal.ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
@@ -33,6 +32,7 @@
 import io.grpc.Status;
 import io.grpc.SynchronizationContext.ScheduledHandle;
 import io.grpc.internal.ServiceConfigUtil;
+import io.grpc.internal.ServiceConfigUtil.LbConfig;
 import io.grpc.xds.XdsComms.AdsStreamCallback;
 import io.grpc.xds.XdsLbState.SubchannelStore;
 import java.util.List;
@@ -51,8 +51,8 @@
   static final Attributes.Key<AtomicReference<ConnectivityStateInfo>> STATE_INFO =
       Attributes.Key.create("io.grpc.xds.XdsLoadBalancer.stateInfo");
 
-  private static final ImmutableMap<String, Object> DEFAULT_FALLBACK_POLICY =
-      ImmutableMap.of("round_robin", (Object) ImmutableMap.<String, Object>of());
+  private static final LbConfig DEFAULT_FALLBACK_POLICY =
+      new LbConfig("round_robin", ImmutableMap.<String, Object>of());
 
   private final SubchannelStore subchannelStore;
   private final Helper helper;
@@ -77,7 +77,7 @@
   @Nullable
   private XdsLbState xdsLbState;
 
-  private Map<String, Object> fallbackPolicy;
+  private LbConfig fallbackPolicy;
 
   XdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry, SubchannelStore subchannelStore) {
     this.helper = checkNotNull(helper, "helper");
@@ -89,8 +89,9 @@
   @Override
   public void handleResolvedAddressGroups(
       List<EquivalentAddressGroup> servers, Attributes attributes) {
-    Map<String, Object> newLbConfig = checkNotNull(
+    Map<String, Object> newRawLbConfig = checkNotNull(
         attributes.get(ATTR_LOAD_BALANCING_CONFIG), "ATTR_LOAD_BALANCING_CONFIG not available");
+    LbConfig newLbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(newRawLbConfig);
     fallbackPolicy = selectFallbackPolicy(newLbConfig, lbRegistry);
     fallbackManager.updateFallbackServers(servers, attributes, fallbackPolicy);
     fallbackManager.maybeStartFallbackTimer();
@@ -98,9 +99,9 @@
     xdsLbState.handleResolvedAddressGroups(servers, attributes);
   }
 
-  private void handleNewConfig(Map<String, Object> newLbConfig) {
+  private void handleNewConfig(LbConfig newLbConfig) {
     String newBalancerName = ServiceConfigUtil.getBalancerNameFromXdsConfig(newLbConfig);
-    Map<String, Object> childPolicy = selectChildPolicy(newLbConfig, lbRegistry);
+    LbConfig childPolicy = selectChildPolicy(newLbConfig, lbRegistry);
     XdsComms xdsComms = null;
     if (xdsLbState != null) { // may release and re-use/shutdown xdsComms from current xdsLbState
       if (!newBalancerName.equals(xdsLbState.balancerName)) {
@@ -130,43 +131,37 @@
   }
 
   @Nullable
-  private static String getPolicyNameOrNull(@Nullable Map<String, Object> config) {
+  private static String getPolicyNameOrNull(@Nullable LbConfig config) {
     if (config == null) {
       return null;
     }
-    return getBalancerPolicyNameFromLoadBalancingConfig(config);
+    return config.getPolicyName();
   }
 
   @Nullable
   @VisibleForTesting
-  static Map<String, Object> selectChildPolicy(
-      Map<String, Object> lbConfig, LoadBalancerRegistry lbRegistry) {
-    List<Map<String, Object>> childConfigs =
-        ServiceConfigUtil.getChildPolicyFromXdsConfig(lbConfig);
+  static LbConfig selectChildPolicy(LbConfig lbConfig, LoadBalancerRegistry lbRegistry) {
+    List<LbConfig> childConfigs = ServiceConfigUtil.getChildPolicyFromXdsConfig(lbConfig);
     return selectSupportedLbPolicy(childConfigs, lbRegistry);
   }
 
   @VisibleForTesting
-  static Map<String, Object> selectFallbackPolicy(
-      Map<String, Object> lbConfig, LoadBalancerRegistry lbRegistry) {
-    List<Map<String, Object>> fallbackConfigs =
-        ServiceConfigUtil.getFallbackPolicyFromXdsConfig(lbConfig);
-    Map<String, Object> fallbackPolicy = selectSupportedLbPolicy(fallbackConfigs, lbRegistry);
+  static LbConfig selectFallbackPolicy(LbConfig lbConfig, LoadBalancerRegistry lbRegistry) {
+    List<LbConfig> fallbackConfigs = ServiceConfigUtil.getFallbackPolicyFromXdsConfig(lbConfig);
+    LbConfig fallbackPolicy = selectSupportedLbPolicy(fallbackConfigs, lbRegistry);
     return fallbackPolicy == null ? DEFAULT_FALLBACK_POLICY : fallbackPolicy;
   }
 
   @Nullable
-  private static Map<String, Object> selectSupportedLbPolicy(
-      List<Map<String, Object>> lbConfigs, LoadBalancerRegistry lbRegistry) {
+  private static LbConfig selectSupportedLbPolicy(
+      @Nullable List<LbConfig> lbConfigs, LoadBalancerRegistry lbRegistry) {
     if (lbConfigs == null) {
       return null;
     }
-    for (Object lbConfig : lbConfigs) {
-      @SuppressWarnings("unchecked")
-      Map<String, Object> candidate = (Map<String, Object>) lbConfig;
-      String lbPolicy = ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig(candidate);
+    for (LbConfig lbConfig : lbConfigs) {
+      String lbPolicy = lbConfig.getPolicyName();
       if (lbRegistry.getProvider(lbPolicy) != null) {
-        return candidate;
+        return lbConfig;
       }
     }
     return null;
@@ -239,7 +234,7 @@
     private final SubchannelStore subchannelStore;
     private final LoadBalancerRegistry lbRegistry;
 
-    private Map<String, Object> fallbackPolicy;
+    private LbConfig fallbackPolicy;
 
     // read-only for outer class
     private LoadBalancer fallbackBalancer;
@@ -281,9 +276,7 @@
 
       helper.getChannelLogger().log(
           ChannelLogLevel.INFO, "Using fallback policy");
-      String fallbackPolicyName = ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig(
-          fallbackPolicy);
-      fallbackBalancer = lbRegistry.getProvider(fallbackPolicyName)
+      fallbackBalancer = lbRegistry.getProvider(fallbackPolicy.getPolicyName())
           .newLoadBalancer(helper);
       fallbackBalancer.handleResolvedAddressGroups(fallbackServers, fallbackAttributes);
       // TODO: maybe update picker
@@ -291,20 +284,16 @@
 
     void updateFallbackServers(
         List<EquivalentAddressGroup> servers, Attributes attributes,
-        Map<String, Object> fallbackPolicy) {
+        LbConfig fallbackPolicy) {
       this.fallbackServers = servers;
       this.fallbackAttributes = Attributes.newBuilder()
           .setAll(attributes)
-          .set(ATTR_LOAD_BALANCING_CONFIG, fallbackPolicy)
+          .set(ATTR_LOAD_BALANCING_CONFIG, fallbackPolicy.getRawConfigValue())
           .build();
-      Map<String, Object> currentFallbackPolicy = this.fallbackPolicy;
+      LbConfig currentFallbackPolicy = this.fallbackPolicy;
       this.fallbackPolicy = fallbackPolicy;
       if (fallbackBalancer != null) {
-        String currentPolicyName =
-            ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig(currentFallbackPolicy);
-        String newPolicyName =
-            ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig(fallbackPolicy);
-        if (newPolicyName.equals(currentPolicyName)) {
+        if (fallbackPolicy.getPolicyName().equals(currentFallbackPolicy.getPolicyName())) {
           fallbackBalancer.handleResolvedAddressGroups(fallbackServers, fallbackAttributes);
         } else {
           fallbackBalancer.shutdown();
diff --git a/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java b/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java
index 5520b0d..f0456b5 100644
--- a/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java
+++ b/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java
@@ -32,12 +32,12 @@
 import io.grpc.LoadBalancerRegistry;
 import io.grpc.SynchronizationContext;
 import io.grpc.internal.FakeClock;
+import io.grpc.internal.ServiceConfigUtil.LbConfig;
 import io.grpc.xds.XdsLbState.SubchannelStoreImpl;
 import io.grpc.xds.XdsLoadBalancer.FallbackManager;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.TimeUnit;
 import org.junit.After;
 import org.junit.Before;
@@ -97,7 +97,7 @@
   private ChannelLogger channelLogger;
 
   private FallbackManager fallbackManager;
-  private Map<String, Object> fallbackPolicy;
+  private LbConfig fallbackPolicy;
 
   @Before
   public void setUp() {
@@ -106,8 +106,7 @@
     doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService();
     doReturn(channelLogger).when(helper).getChannelLogger();
     fallbackManager = new FallbackManager(helper, new SubchannelStoreImpl(), lbRegistry);
-    fallbackPolicy = new HashMap<>();
-    fallbackPolicy.put("test_policy", new HashMap<>());
+    fallbackPolicy = new LbConfig("test_policy", new HashMap<String, Object>());
     lbRegistry.register(fakeLbProvider);
   }
 
@@ -131,7 +130,7 @@
     verify(fakeLb).handleResolvedAddressGroups(
         same(eags),
         eq(Attributes.newBuilder()
-            .set(LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, fallbackPolicy)
+            .set(LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, fallbackPolicy.getRawConfigValue())
             .build()));
   }
 
diff --git a/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java
index cfb2494..756ea8d 100644
--- a/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java
@@ -52,13 +52,14 @@
 import io.grpc.inprocess.InProcessServerBuilder;
 import io.grpc.internal.FakeClock;
 import io.grpc.internal.JsonParser;
+import io.grpc.internal.ServiceConfigUtil;
+import io.grpc.internal.ServiceConfigUtil.LbConfig;
 import io.grpc.internal.testing.StreamRecorder;
 import io.grpc.stub.StreamObserver;
 import io.grpc.testing.GrpcCleanupRule;
 import io.grpc.xds.XdsLbState.SubchannelStore;
 import io.grpc.xds.XdsLbState.SubchannelStoreImpl;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
@@ -254,13 +255,13 @@
         + "{\"supported_2\" : {\"key\" : \"val\"}}],"
         + "\"fallbackPolicy\" : [{\"lbPolicy3\" : {\"key\" : \"val\"}}, {\"lbPolicy4\" : {}}]"
         + "}}";
-    @SuppressWarnings("unchecked")
-    Map<String, Object> expectedChildPolicy = (Map<String, Object>) JsonParser.parse(
-        "{\"supported_1\" : {\"key\" : \"val\"}}");
+    LbConfig expectedChildPolicy =
+        ServiceConfigUtil.unwrapLoadBalancingConfig(
+            JsonParser.parse("{\"supported_1\" : {\"key\" : \"val\"}}"));
 
-    @SuppressWarnings("unchecked")
-    Map<String, Object> childPolicy = XdsLoadBalancer
-        .selectChildPolicy((Map<String, Object>) JsonParser.parse(lbConfigRaw), lbRegistry);
+    LbConfig childPolicy = XdsLoadBalancer
+        .selectChildPolicy(
+            ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfigRaw)), lbRegistry);
 
     assertEquals(expectedChildPolicy, childPolicy);
   }
@@ -273,13 +274,11 @@
         + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}},"
         + "{\"supported_2\" : {\"key\" : \"val\"}}]"
         + "}}";
-    @SuppressWarnings("unchecked")
-    Map<String, Object> expectedFallbackPolicy = (Map<String, Object>) JsonParser.parse(
-        "{\"supported_1\" : {\"key\" : \"val\"}}");
+    LbConfig expectedFallbackPolicy = ServiceConfigUtil.unwrapLoadBalancingConfig(
+        JsonParser.parse("{\"supported_1\" : {\"key\" : \"val\"}}"));
 
-    @SuppressWarnings("unchecked")
-    Map<String, Object> fallbackPolicy = XdsLoadBalancer
-        .selectFallbackPolicy((Map<String, Object>) JsonParser.parse(lbConfigRaw), lbRegistry);
+    LbConfig fallbackPolicy = XdsLoadBalancer.selectFallbackPolicy(
+        ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfigRaw)), lbRegistry);
 
     assertEquals(expectedFallbackPolicy, fallbackPolicy);
   }
@@ -290,13 +289,11 @@
         + "\"balancerName\" : \"dns:///balancer.example.com:8080\","
         + "\"childPolicy\" : [{\"lbPolicy3\" : {\"key\" : \"val\"}}, {\"lbPolicy4\" : {}}]"
         + "}}";
-    @SuppressWarnings("unchecked")
-    Map<String, Object> expectedFallbackPolicy = (Map<String, Object>) JsonParser.parse(
-        "{\"round_robin\" : {}}");
+    LbConfig expectedFallbackPolicy = ServiceConfigUtil.unwrapLoadBalancingConfig(
+        JsonParser.parse("{\"round_robin\" : {}}"));
 
-    @SuppressWarnings("unchecked")
-    Map<String, Object> fallbackPolicy = XdsLoadBalancer
-        .selectFallbackPolicy((Map<String, Object>) JsonParser.parse(lbConfigRaw), lbRegistry);
+    LbConfig fallbackPolicy = XdsLoadBalancer.selectFallbackPolicy(
+        ServiceConfigUtil.unwrapLoadBalancingConfig(JsonParser.parse(lbConfigRaw)), lbRegistry);
 
     assertEquals(expectedFallbackPolicy, fallbackPolicy);
   }
@@ -508,7 +505,7 @@
     verify(fakeBalancer1).handleResolvedAddressGroups(
         Matchers.<List<EquivalentAddressGroup>>any(), captor.capture());
     assertThat(captor.getValue().get(ATTR_LOAD_BALANCING_CONFIG))
-        .containsExactly("supported_1", new HashMap<String, Object>());
+        .containsExactly("supported_1_option", "yes");
   }
 
   @Test
@@ -534,7 +531,7 @@
     verify(fakeBalancer1).handleResolvedAddressGroups(
         Matchers.<List<EquivalentAddressGroup>>any(), captor.capture());
     assertThat(captor.getValue().get(ATTR_LOAD_BALANCING_CONFIG))
-        .containsExactly("supported_1", new HashMap<String, Object>());
+        .containsExactly("supported_1_option", "yes");
 
     assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
     assertThat(fakeClock.getPendingTasks()).isEmpty();
@@ -582,13 +579,13 @@
     verify(fakeBalancer1).handleResolvedAddressGroups(
         Matchers.<List<EquivalentAddressGroup>>any(), captor.capture());
     assertThat(captor.getValue().get(ATTR_LOAD_BALANCING_CONFIG))
-        .containsExactly("supported_1", new HashMap<String, Object>());
+        .containsExactly("supported_1_option", "yes");
   }
 
   private static Attributes standardModeWithFallback1Attributes() throws Exception {
     String lbConfigRaw = "{\"xds_experimental\" : { "
         + "\"balancerName\" : \"dns:///balancer.example.com:8080\","
-        + "\"fallbackPolicy\" : [{\"supported_1\" : {}}]"
+        + "\"fallbackPolicy\" : [{\"supported_1\" : { \"supported_1_option\" : \"yes\"}}]"
         + "}}";
     @SuppressWarnings("unchecked")
     Map<String, Object> lbConfig = (Map<String, Object>) JsonParser.parse(lbConfigRaw);