enable polymorphic output shape for collective all gather
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 49bdbbf..9637294 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -102,16 +102,7 @@
     auto output_shape = c->input(0).shape();
     output_shape.set_dim(
         0, output_shape.dim_size(0) * col_params_.group.group_size);
-    if (col_params_.instance.shape.num_elements() == 0) {
-      col_params_.instance.shape = output_shape;
-    } else {
-      OP_REQUIRES_ASYNC(
-          c, col_params_.instance.shape == output_shape,
-          errors::Internal("Inconsistent output shapes, got ",
-                           output_shape.DebugString(), ", but expected is ",
-                           col_params_.instance.shape.DebugString(), "."),
-          done);
-    }
+    col_params_.instance.shape = output_shape;
 
     // Allocate output on the first pass through this function.  This must be
     // done immediately, while we're still in the executor thread.  Otherwise
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index 0e9b549..a29419d 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -33,6 +33,7 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
+from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging as logging
 
@@ -346,6 +347,32 @@
                                    'Shape mismatch'):
         sess.run([c0, c1], options=run_options)
 
+  @test_util.run_deprecated_v1
+  def testCollectiveGatherPolymophicShape(self):
+      t0 = [0, 1, 2, 3, 4, 5, 6, 7]
+      t1 = [10, 11, 12, 13, 14, 15, 16, 17]
+      t01 = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17]
+      t01_ = [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17]
+      group_size = 2
+      group_key = 1
+      instance_key = 123
+      with self.session(
+              config=config_pb2.ConfigProto(device_count={'CPU': group_size})) as sess:
+          with ops.device('/CPU:0'):
+              in0 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
+              c0 = collective_ops.all_gather(in0, group_size, group_key, instance_key)
+          with ops.device('/CPU:1'):
+              in1 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
+              c1 = collective_ops.all_gather(in1, group_size, group_key, instance_key)
+
+          results = sess.run([c0, c1], feed_dict={in0: t0, in1: t1})
+          self.assertAllClose(results[0], t01, rtol=1e-5, atol=1e-5)
+          self.assertAllClose(results[1], t01, rtol=1e-5, atol=1e-5)
+
+          results_ = sess.run([c0, c1], feed_dict={in0: t0[1:], in1: t1[1:]})
+          self.assertAllClose(results_[0], t01_, rtol=1e-5, atol=1e-5)
+          self.assertAllClose(results_[1], t01_, rtol=1e-5, atol=1e-5)
+
   @test_util.run_v2_only
   def testCollectiveGroupSizeMismatch(self):
     cpus = config.list_physical_devices('CPU')