Support non-constant start/end for linspace.

PiperOrigin-RevId: 276777257
Change-Id: Iba2506aa522add21f2cb975fd1710fc84dc81e7e
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index c8143c0..e306dc9 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1011,6 +1011,7 @@
         "//tensorflow/python:framework",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
+        "@absl_py//absl/testing:parameterized",
     ],
 )
 
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 200851e..d81a06a 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -18,6 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.compiler.tests import xla_test
@@ -28,7 +29,7 @@
 from tensorflow.python.platform import googletest
 
 
-class TernaryOpsTest(xla_test.XLATestCase):
+class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
 
   def _testTernary(self, op, a, b, c, expected):
     with self.session() as session:
@@ -39,20 +40,24 @@
         output = op(pa, pb, pc)
       result = session.run(output, {pa: a, pb: b, pc: c})
       self.assertAllClose(result, expected, rtol=1e-3)
+      return result
 
-  def testLinspace(self):
-    self._testTernary(
+  @parameterized.parameters(
+      {'start': 1, 'end': 2, 'num': 1},
+      {'start': 1, 'end': 4, 'num': 3},
+      {'start': 0, 'end': 41, 'num': 42})
+  def testLinspace(self, start, end, num):
+    expected = np.linspace(start, end, num, dtype=np.float32)
+    result = self._testTernary(
         math_ops.linspace,
-        np.float32(1),
-        np.float32(2),
-        np.int32(1),
-        expected=np.array([1], dtype=np.float32))
-    self._testTernary(
-        math_ops.linspace,
-        np.float32(1),
-        np.float32(4),
-        np.int32(3),
-        expected=np.array([1, 2.5, 4], dtype=np.float32))
+        np.float32(start),
+        np.float32(end),
+        np.int32(num),
+        expected)
+    # According to linspace spec, start has to be the first element and end has
+    # to be last element.
+    self.assertEqual(result[-1], expected[-1])
+    self.assertEqual(result[0], expected[0])
 
   def testRange(self):
     self._testTernary(
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index 1cbb142..2eef554 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
@@ -153,67 +154,27 @@
                 errors::InvalidArgument("num must be a scalar, not shape ",
                                         num_in_shape.DebugString()));
 
-    DataType type = ctx->input_type(0);
-
     int64 num;
     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num));
     OP_REQUIRES(ctx, num > 0,
                 errors::InvalidArgument("Requires num > 0: ", num));
-    Tensor out_constant(type, TensorShape({num}));
-
-    xla::Literal start_literal;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal));
-    xla::Literal stop_literal;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal));
-
-    switch (type) {
-      case DT_FLOAT: {
-        float start = start_literal.GetFirstElement<float>();
-        float stop = stop_literal.GetFirstElement<float>();
-        auto flat = out_constant.flat<float>();
-        if (num == 1) {
-          flat(0) = start;
-        } else {
-          const float step = (stop - start) / (num - 1);
-          for (int64 i = 0; i < num - 1; ++i) {
-            flat(i) = start + step * i;
-          }
-          // The last value in the sequence must be equal to stop.
-          flat(num - 1) = stop;
-        }
-        break;
-      }
-      case DT_DOUBLE: {
-        double start = start_literal.GetFirstElement<double>();
-        double stop = stop_literal.GetFirstElement<double>();
-        auto flat = out_constant.flat<double>();
-        if (num == 1) {
-          flat(0) = start;
-        } else {
-          const double step = (stop - start) / (num - 1);
-          for (int64 i = 0; i < num - 1; ++i) {
-            flat(i) = start + step * i;
-          }
-          // The last value in the sequence must be equal to stop.
-          flat(num - 1) = stop;
-        }
-        break;
-      }
-
-      default:
-        ctx->SetStatus(errors::InvalidArgument("Invalid argument type ",
-                                               DataTypeString(type)));
-        return;
+    xla::XlaOp start = ctx->Input("start");
+    xla::XlaOp stop = ctx->Input("stop");
+    xla::XlaOp iota = xla::Iota(ctx->builder(), ctx->output_xla_type(0), num) /
+                      xla::ScalarLike(start, (num > 1 ? num - 1 : num));
+    xla::XlaOp result = iota * stop - (iota * start - start);
+    if (num > 1) {
+      // According to linspace spec, start has to be the first element and end
+      // has to be last element.
+      xla::XlaOp mask = xla::Iota(ctx->builder(), xla::S64, num);
+      xla::XlaOp eq = xla::Eq(mask, xla::ScalarLike(mask, num - 1));
+      result = xla::Select(eq, stop, result);
     }
-    ctx->SetConstantOutput(0, out_constant);
+    ctx->SetOutput(0, result);
   }
 };
 
-REGISTER_XLA_OP(Name("LinSpace")
-                    .CompileTimeConstantInput("start")
-                    .CompileTimeConstantInput("stop")
-                    .CompileTimeConstantInput("num"),
-                LinSpaceOp);
+REGISTER_XLA_OP(Name("LinSpace").CompileTimeConstantInput("num"), LinSpaceOp);
 
 }  // namespace
 }  // namespace tensorflow