Support non-constant start/end for linspace.
PiperOrigin-RevId: 276777257
Change-Id: Iba2506aa522add21f2cb975fd1710fc84dc81e7e
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index c8143c0..e306dc9 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1011,6 +1011,7 @@
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 200851e..d81a06a 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
@@ -28,7 +29,7 @@
from tensorflow.python.platform import googletest
-class TernaryOpsTest(xla_test.XLATestCase):
+class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _testTernary(self, op, a, b, c, expected):
with self.session() as session:
@@ -39,20 +40,24 @@
output = op(pa, pb, pc)
result = session.run(output, {pa: a, pb: b, pc: c})
self.assertAllClose(result, expected, rtol=1e-3)
+ return result
- def testLinspace(self):
- self._testTernary(
+ @parameterized.parameters(
+ {'start': 1, 'end': 2, 'num': 1},
+ {'start': 1, 'end': 4, 'num': 3},
+ {'start': 0, 'end': 41, 'num': 42})
+ def testLinspace(self, start, end, num):
+ expected = np.linspace(start, end, num, dtype=np.float32)
+ result = self._testTernary(
math_ops.linspace,
- np.float32(1),
- np.float32(2),
- np.int32(1),
- expected=np.array([1], dtype=np.float32))
- self._testTernary(
- math_ops.linspace,
- np.float32(1),
- np.float32(4),
- np.int32(3),
- expected=np.array([1, 2.5, 4], dtype=np.float32))
+ np.float32(start),
+ np.float32(end),
+ np.int32(num),
+ expected)
+ # According to linspace spec, start has to be the first element and end has
+ # to be last element.
+ self.assertEqual(result[-1], expected[-1])
+ self.assertEqual(result[0], expected[0])
def testRange(self):
self._testTernary(
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index 1cbb142..2eef554 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -18,6 +18,7 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -153,67 +154,27 @@
errors::InvalidArgument("num must be a scalar, not shape ",
num_in_shape.DebugString()));
- DataType type = ctx->input_type(0);
-
int64 num;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num));
OP_REQUIRES(ctx, num > 0,
errors::InvalidArgument("Requires num > 0: ", num));
- Tensor out_constant(type, TensorShape({num}));
-
- xla::Literal start_literal;
- OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal));
- xla::Literal stop_literal;
- OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal));
-
- switch (type) {
- case DT_FLOAT: {
- float start = start_literal.GetFirstElement<float>();
- float stop = stop_literal.GetFirstElement<float>();
- auto flat = out_constant.flat<float>();
- if (num == 1) {
- flat(0) = start;
- } else {
- const float step = (stop - start) / (num - 1);
- for (int64 i = 0; i < num - 1; ++i) {
- flat(i) = start + step * i;
- }
- // The last value in the sequence must be equal to stop.
- flat(num - 1) = stop;
- }
- break;
- }
- case DT_DOUBLE: {
- double start = start_literal.GetFirstElement<double>();
- double stop = stop_literal.GetFirstElement<double>();
- auto flat = out_constant.flat<double>();
- if (num == 1) {
- flat(0) = start;
- } else {
- const double step = (stop - start) / (num - 1);
- for (int64 i = 0; i < num - 1; ++i) {
- flat(i) = start + step * i;
- }
- // The last value in the sequence must be equal to stop.
- flat(num - 1) = stop;
- }
- break;
- }
-
- default:
- ctx->SetStatus(errors::InvalidArgument("Invalid argument type ",
- DataTypeString(type)));
- return;
+ xla::XlaOp start = ctx->Input("start");
+ xla::XlaOp stop = ctx->Input("stop");
+ xla::XlaOp iota = xla::Iota(ctx->builder(), ctx->output_xla_type(0), num) /
+ xla::ScalarLike(start, (num > 1 ? num - 1 : num));
+ xla::XlaOp result = iota * stop - (iota * start - start);
+ if (num > 1) {
+ // According to linspace spec, start has to be the first element and end
+ // has to be last element.
+ xla::XlaOp mask = xla::Iota(ctx->builder(), xla::S64, num);
+ xla::XlaOp eq = xla::Eq(mask, xla::ScalarLike(mask, num - 1));
+ result = xla::Select(eq, stop, result);
}
- ctx->SetConstantOutput(0, out_constant);
+ ctx->SetOutput(0, result);
}
};
-REGISTER_XLA_OP(Name("LinSpace")
- .CompileTimeConstantInput("start")
- .CompileTimeConstantInput("stop")
- .CompileTimeConstantInput("num"),
- LinSpaceOp);
+REGISTER_XLA_OP(Name("LinSpace").CompileTimeConstantInput("num"), LinSpaceOp);
} // namespace
} // namespace tensorflow