enable polymorphic output shape for collective all gather
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 49bdbbf..9637294 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -102,16 +102,7 @@
auto output_shape = c->input(0).shape();
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_.group.group_size);
- if (col_params_.instance.shape.num_elements() == 0) {
- col_params_.instance.shape = output_shape;
- } else {
- OP_REQUIRES_ASYNC(
- c, col_params_.instance.shape == output_shape,
- errors::Internal("Inconsistent output shapes, got ",
- output_shape.DebugString(), ", but expected is ",
- col_params_.instance.shape.DebugString(), "."),
- done);
- }
+ col_params_.instance.shape = output_shape;
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index 0e9b549..a29419d 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -33,6 +33,7 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@@ -346,6 +347,32 @@
'Shape mismatch'):
sess.run([c0, c1], options=run_options)
+ @test_util.run_deprecated_v1
+ def testCollectiveGatherPolymophicShape(self):
+ t0 = [0, 1, 2, 3, 4, 5, 6, 7]
+ t1 = [10, 11, 12, 13, 14, 15, 16, 17]
+ t01 = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17]
+ t01_ = [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17]
+ group_size = 2
+ group_key = 1
+ instance_key = 123
+ with self.session(
+ config=config_pb2.ConfigProto(device_count={'CPU': group_size})) as sess:
+ with ops.device('/CPU:0'):
+ in0 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
+ c0 = collective_ops.all_gather(in0, group_size, group_key, instance_key)
+ with ops.device('/CPU:1'):
+ in1 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
+ c1 = collective_ops.all_gather(in1, group_size, group_key, instance_key)
+
+ results = sess.run([c0, c1], feed_dict={in0: t0, in1: t1})
+ self.assertAllClose(results[0], t01, rtol=1e-5, atol=1e-5)
+ self.assertAllClose(results[1], t01, rtol=1e-5, atol=1e-5)
+
+ results_ = sess.run([c0, c1], feed_dict={in0: t0[1:], in1: t1[1:]})
+ self.assertAllClose(results_[0], t01_, rtol=1e-5, atol=1e-5)
+ self.assertAllClose(results_[1], t01_, rtol=1e-5, atol=1e-5)
+
@test_util.run_v2_only
def testCollectiveGroupSizeMismatch(self):
cpus = config.list_physical_devices('CPU')