Add `ShardedVariable` class.

PiperOrigin-RevId: 272745815
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 185495b..60c6ae8 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -1,6 +1,5 @@
 load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
 load("//tensorflow/core/platform:default/distribute.bzl", "distribute_py_test")
 
 package(
@@ -132,6 +131,7 @@
         ":distribute_lib",
         ":mirrored_strategy",
         ":one_device_strategy",
+        ":sharded_variable",
         "//tensorflow/python/distribute/experimental",
     ],
 )
@@ -779,6 +779,32 @@
 )
 
 py_library(
+    name = "sharded_variable",
+    srcs = ["sharded_variable.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:variables",
+        "//tensorflow/python/training/saving:saveable_object_util",
+        "//tensorflow/python/training/tracking:base",
+    ],
+)
+
+tf_py_test(
+    name = "sharded_variable_test",
+    size = "small",
+    srcs = ["sharded_variable_test.py"],
+    additional_deps = [
+        ":sharded_variable",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:variables",
+        "//tensorflow/python/compat:v2_compat",
+        "//tensorflow/python/training/tracking:util",
+    ],
+)
+
+py_library(
     name = "strategy_test_lib",
     srcs = ["strategy_test_lib.py"],
     srcs_version = "PY2AND3",
diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py
new file mode 100644
index 0000000..9886e42
--- /dev/null
+++ b/tensorflow/python/distribute/sharded_variable.py
@@ -0,0 +1,139 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ShardedVariable class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.training.saving import saveable_object_util
+from tensorflow.python.training.tracking import base as trackable
+
+
+class ShardedVariable(trackable.Trackable):
+  """A container for `Variables` that should be treated as shards.
+
+  Variables that are too large to fit on a single device (e.g., large
+  embeddings)
+  may need to be sharded over multiple devices. This class maintains a list of
+  smaller variables that can be independently stored on separate devices (eg,
+  multiple parameter servers), and saves and restores those variables as if they
+  were a single larger variable.
+
+  Objects of this class can be saved with a given number of shards and then
+  restored from a checkpoint into a different number of shards.
+
+  Sharding is only supported along the first dimension.
+  """
+
+  def __init__(self, variables, name='ShardedVariable'):
+    """Treats `variables` as shards of a larger Variable.
+
+
+    Example:
+
+    ```
+    variables = [
+      tf.Variable(..., shape=(10, 100), dtype=tf.float32),
+      tf.Variable(..., shape=(15, 100), dtype=tf.float32),
+      tf.Variable(..., shape=(5, 100), dtype=tf.float32)
+    ]
+    sharded_variable = ShardedVariable(variables)
+    assert sharded_variable.shape.as_list() == [30, 100]
+    ```
+
+    Args:
+      variables: A list of `ResourceVariable`s that comprise this sharded
+        variable. Variables should not be shared between different
+        `ShardedVariable` objects.
+      name: String. Name of this container. Defaults to "ShardedVariable".
+    """
+    super(ShardedVariable, self).__init__()
+    self._variables = variables
+    self._name = name
+
+    first_var = variables[0]
+
+    if any(not isinstance(v, variables_lib.Variable) for v in variables):
+      raise ValueError(
+          'Expected a list of `Variable`s, found: {}'.format(variables))
+
+    dtypes = {v.dtype for v in variables}
+    if len(dtypes) > 1:
+      raise ValueError(
+          'All `Variable`s must have the same dtype, found: {}'.format(
+              [v.dtype for v in variables]))
+    self._dtype = first_var.dtype
+
+    # All variables must have the same shape for axes > 0.
+    higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
+    if len(higher_dim_shapes) > 1:
+      raise ValueError(
+          'All `Variables`s must have the same shapes except for the first '
+          'axis, found {}'.format([v.shape for v in variables]))
+    first_dim = sum(int(v.shape[0]) for v in variables)
+    self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:])
+
+    save_slice_info = [v._get_save_slice_info() for v in variables]  # pylint: disable=protected-access
+    if any(slice_info is not None for slice_info in save_slice_info):
+      raise ValueError('`SaveSliceInfo` should not be set for `Variable`s. '
+                       '`ShardedVariable` will infer `SaveSliceInfo` according '
+                       'to the order of the `Variable`s in the list passed to '
+                       'the constructor. Found {}'.format(save_slice_info))
+
+  @property
+  def variables(self):
+    """The list of `Variable`s that make up the shards of this object."""
+    return self._variables
+
+  @property
+  def name(self):
+    """The name of this object. Used for checkpointing."""
+    return self._name
+
+  @property
+  def dtype(self):
+    """The dtype of all `Variable`s in this object."""
+    return self._dtype
+
+  @property
+  def shape(self):
+    """The overall shape, combining all shards along axis `0`."""
+    return self._shape
+
+  def _gather_saveables_for_checkpoint(self):
+    """Return a `Saveable` for each shard. See `Trackable`."""
+
+    def _saveable_factory(name=self.name):
+      """Creates `SaveableObject`s for this `ShardedVariable`."""
+      saveables = []
+      dims = len(self._variables[0].shape)
+      var_offset = [0 for _ in range(dims)]
+      for v in self._variables:
+        save_slice_info = variables_lib.Variable.SaveSliceInfo(
+            full_name=self.name,
+            full_shape=self.shape.as_list(),
+            var_offset=copy.copy(var_offset),
+            var_shape=v.shape.as_list())
+        saveables.append(
+            saveable_object_util.ResourceVariableSaveable(
+                v, save_slice_info.spec, name))  # pylint: disable=protected-access
+        var_offset[0] += int(v.shape[0])
+      return saveables
+
+    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py
new file mode 100644
index 0000000..7110a9f
--- /dev/null
+++ b/tensorflow/python/distribute/sharded_variable_test.py
@@ -0,0 +1,146 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ShardedVariable."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.distribute import sharded_variable
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import test
+from tensorflow.python.training.tracking import util
+
+
+class ShardedVariableTest(test.TestCase):
+
+  def test_sharded_variable_simple(self):
+    v0 = variables_lib.Variable([0])
+    v1 = variables_lib.Variable([1])
+    s = sharded_variable.ShardedVariable([v0, v1], name='s')
+    self.assertEqual(s.variables[0], v0)
+    self.assertEqual(s.variables[1], v1)
+    self.assertEqual(s.shape.as_list(), [2])
+    self.assertEqual(s.dtype, v0.dtype)
+    self.assertEqual(s.name, 's')
+
+  def test_save_restore(self):
+    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
+    variables = [
+        variables_lib.Variable([0]),
+        variables_lib.Variable([1]),
+        variables_lib.Variable([2]),
+        variables_lib.Variable([3])
+    ]
+    s = sharded_variable.ShardedVariable(variables, name='s')
+
+    cp = util.Checkpoint(s=s)
+    self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
+    cp.write(fname)
+
+    self.evaluate(cp.s.variables[0].assign([4]))
+    self.assertEqual(self.evaluate(cp.s.variables[0]), [4])
+
+    cp.restore(fname)
+    # Tests that the original weights are restored.
+    self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
+
+  def test_save_restore_different_partitions(self):
+    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
+    variables = [
+        variables_lib.Variable([0]),
+        variables_lib.Variable([1]),
+        variables_lib.Variable([2]),
+        variables_lib.Variable([3])
+    ]
+    s = sharded_variable.ShardedVariable(variables, name='s')
+
+    cp = util.Checkpoint(s=s)
+    cp.write(fname)
+
+    variables2 = [variables_lib.Variable([0, 0, 0, 0])]
+    s2 = sharded_variable.ShardedVariable(variables2, name='s')
+
+    # Restore from 4 partitions into 1.
+    cp2 = util.Checkpoint(s=s2)
+    cp2.restore(fname)
+    self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3])
+
+    self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20]))
+    cp2.write(fname)
+
+    # Restore 1 partition into 4.
+    cp.restore(fname)
+    self.assertEqual(self.evaluate(cp.s.variables[0]), [5])
+    self.assertEqual(self.evaluate(cp.s.variables[1]), [10])
+    self.assertEqual(self.evaluate(cp.s.variables[2]), [15])
+    self.assertEqual(self.evaluate(cp.s.variables[3]), [20])
+
+  def test_save_restore_4_to_2_partitions(self):
+    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
+    variables = [
+        variables_lib.Variable([0]),
+        variables_lib.Variable([1]),
+        variables_lib.Variable([2]),
+        variables_lib.Variable([3])
+    ]
+    s = sharded_variable.ShardedVariable(variables, name='s')
+    cp = util.Checkpoint(s=s)
+    cp.write(fname)
+
+    variables2 = [
+        variables_lib.Variable([0, 0]),
+        variables_lib.Variable([0, 0])
+    ]
+    s2 = sharded_variable.ShardedVariable(variables2, name='s')
+    cp2 = util.Checkpoint(s=s2)
+    cp2.restore(fname)
+    # Assert that weights from the 4 partitions were loaded here.
+    self.assertLen(cp2.s.variables, 2)
+    self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1])
+    self.assertAllEqual(self.evaluate(cp2.s.variables[1]), [2, 3])
+
+  def test_validation_errors(self):
+    with self.assertRaisesRegexp(ValueError, 'Expected a list of '):
+      sharded_variable.ShardedVariable(
+          [variables_lib.Variable([0]), 'not-a-variable'])
+
+    with self.assertRaisesRegexp(ValueError, 'must have the same dtype'):
+      sharded_variable.ShardedVariable([
+          variables_lib.Variable([0], dtype='int64'),
+          variables_lib.Variable([1], dtype='int32')
+      ])
+
+    with self.assertRaisesRegexp(ValueError, 'the same shapes except'):
+      sharded_variable.ShardedVariable([
+          variables_lib.Variable(array_ops.ones((5, 10))),
+          variables_lib.Variable(array_ops.ones((5, 20)))
+      ])
+
+    with self.assertRaisesRegexp(ValueError, '`SaveSliceInfo` should not'):
+      v = variables_lib.Variable([0])
+      v._set_save_slice_info(
+          variables_lib.Variable.SaveSliceInfo(
+              full_name='s', full_shape=[2], var_offset=[0], var_shape=[1]))
+      sharded_variable.ShardedVariable([v])
+
+
+if __name__ == '__main__':
+  v2_compat.enable_v2_behavior()
+  test.main()
diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py
index 099fcf0..f4c5ee7 100644
--- a/tensorflow/python/training/saving/saveable_object_util.py
+++ b/tensorflow/python/training/saving/saveable_object_util.py
@@ -28,6 +28,7 @@
 from tensorflow.python.ops import variables
 from tensorflow.python.training.saving import saveable_object
 from tensorflow.python.training.tracking import base as trackable
+from tensorflow.python.util import nest
 from tensorflow.python.util import object_identity
 
 
@@ -147,6 +148,9 @@
     slice_name = None
     # pylint: disable=protected-access
     for variable in op:
+      if isinstance(variable, saveable_object.SaveableObject):
+        yield variable
+        continue
       if not isinstance(variable, variables.Variable):
         raise ValueError("Slices must all be Variables: %s" % variable)
       if not variable._save_slice_info:
@@ -210,7 +214,7 @@
   """Create a dictionary of names to operation lists.
 
   Args:
-    op_list: A list, tuple, or set of Variables or SaveableObjects.
+    op_list: A (nested) list, tuple, or set of Variables or SaveableObjects.
     convert_variable_to_tensor: Whether or not to convert single Variables
       with no slice info into Tensors.
 
@@ -226,6 +230,8 @@
   if not isinstance(op_list, (list, tuple, set)):
     raise TypeError("Variables to save should be passed in a dict or a "
                     "list: %s" % op_list)
+  # List casting is necessary to support sets.
+  op_list = nest.flatten(list(op_list))
   # When ResourceVariables are converted to Tensors, read ops are added to the
   # graph. Sorting the op_list ensures that the resulting graph is always
   # constructed in a deterministic way: