Merge pull request #47075 from pedro-r-marques:rt-bce
PiperOrigin-RevId: 358064569
Change-Id: I278260acc2b6dd1fedcba263136234d04ec8e66f
diff --git a/tensorflow/opensource_only/ISSUES.md b/ISSUES.md
similarity index 100%
rename from tensorflow/opensource_only/ISSUES.md
rename to ISSUES.md
diff --git a/README.md b/README.md
index 8801f83..fb3eddc 100644
--- a/README.md
+++ b/README.md
@@ -155,6 +155,7 @@
* [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
+* [TensorFlow: Advanced Techniques from Coursera](https://www.coursera.org/specializations/tensorflow-advanced-techniques)
* [Intro to TensorFlow for A.I, M.L, and D.L from Coursera](https://www.coursera.org/learn/introduction-tensorflow)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
diff --git a/RELEASE.md b/RELEASE.md
index 8da2b27..9566460 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -81,6 +81,8 @@
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
+ * Replaced the model cache key for models computation algorithm with
+ one guaranteed to be stable across runs.
* 16 bits quantization
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
* Additional tests and fixes for ADD and SUB operators.
@@ -98,6 +100,9 @@
function for a given signaturedef.
* Add int8 support for `ReshapeV2`.
* Add experimental support for optimization with sparsity.
+ * Add nominal support for unsigned 32-bit integer tensor types. Note that
+ very few TFLite kernels support this type natively, so its use in mobile
+ ML authoring is generally discouraged.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 7927460..078a254 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1168,16 +1168,17 @@
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
- TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(matmul, m, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
+ TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(matmul, m, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(matmul, &retval[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteOp(matmul);
}
if (async) {
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
@@ -1249,16 +1250,15 @@
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
- TFE_OpAddInput(op, var_handle, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-
int num_retvals = 1;
TFE_TensorHandle* h = nullptr;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
+ TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+ TFE_OpAddInput(op, var_handle, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(op, &h, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK_EQ(1, num_retvals);
@@ -1267,11 +1267,9 @@
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
h = nullptr;
- TFE_OpAddInput(op, var_handle, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteOp(op);
}
tensorflow::testing::StopTiming();
- TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index a340b9d..be81fa8 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -309,9 +309,13 @@
}
}
- ops::PartitionedCall call(
- root.WithOpName("partitioned_call"), args, n->output_types(), func,
- ops::PartitionedCall::Attrs{}.ConfigProto(config_string));
+ // In theory we can use PartitionedCall if the XLA cluster does not have any
+ // stateful operations. However, for now we choose to be conservative since
+ // we don't have any evidence that choosing a stateless partitioned call helps
+ // for performance.
+ ops::StatefulPartitionedCall call(
+ root.WithOpName("stateful_partitioned_call"), args, n->output_types(),
+ func, ops::StatefulPartitionedCall::Attrs{}.ConfigProto(config_string));
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index 160ea83..869d869 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -194,7 +194,7 @@
auto xla_run =
NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key)));
auto tf_call =
- NodeWith(Op("PartitionedCall"),
+ NodeWith(Op("StatefulPartitionedCall"),
CtrlDeps(NodeWith(Op("Identity"),
Inputs(Out(0, predicated_compilation_key)))));
auto merge = NodeWith(Op("_XlaMerge"), Inputs(Out(tf_call), Out(xla_run)));
@@ -252,9 +252,10 @@
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
Node* sink_node = graph->sink_node();
- EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
- NodeWith(Op("PartitionedCall")),
- NodeWith(Op("NoOp")))));
+ EXPECT_THAT(sink_node,
+ NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
+ NodeWith(Op("StatefulPartitionedCall")),
+ NodeWith(Op("NoOp")))));
}
#ifdef GOOGLE_CUDA
@@ -298,15 +299,15 @@
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
- Node* partitioned_call_op = nullptr;
+ Node* stateful_partitioned_call_op = nullptr;
for (Node* n : graph->op_nodes()) {
- if (n->type_string() == "PartitionedCall") {
- ASSERT_EQ(partitioned_call_op, nullptr);
- partitioned_call_op = n;
+ if (n->type_string() == "StatefulPartitionedCall") {
+ ASSERT_EQ(stateful_partitioned_call_op, nullptr);
+ stateful_partitioned_call_op = n;
}
}
- ASSERT_NE(partitioned_call_op, nullptr);
+ ASSERT_NE(stateful_partitioned_call_op, nullptr);
auto xla_compile = NodeWith(Op("_XlaCompile"));
auto switch_on_compilation_pred =
NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile)));
@@ -315,7 +316,7 @@
// Check that we pipe int32 inputs through an IdentityN to avoid extra D2H
// copies.
EXPECT_THAT(
- partitioned_call_op,
+ stateful_partitioned_call_op,
NodeWith(Inputs(Out(NodeWith(Op("IdentityN"), CtrlDeps(ctrl_dep))))));
}
#endif
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 4a5c79c..9e209f3 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -43,7 +43,7 @@
for (Node* n : graph->nodes()) {
string name;
// Only consider nodes being compiled.
- if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue;
+ if (!TryGetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name)) continue;
// Early return for any node with a device that is not a CPU or GPU.
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) {
@@ -58,8 +58,8 @@
// Checks if a graph node is marked to be a guaranteed constant.
bool is_guaranteed_constant(const Node& n) {
bool guaranteed_constant = false;
- if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
- .ok()) {
+ if (!TryGetNodeAttr(n.attrs(), "_is_guaranteed_constant",
+ &guaranteed_constant)) {
return false;
}
return guaranteed_constant;
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 61ff6bc..a7d4e05 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -286,7 +286,9 @@
const ConfigProto* config = ctx->function_library()->config_proto();
// TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
bool use_mlir = config &&
- GetMlirBridgeRolloutPolicy(*graph, *config) ==
+ GetMlirBridgeRolloutPolicy(
+ *graph, *config, /*uses_uninitialized_resource_args=*/
+ AnyUninitializedResourceArg(args)) ==
MlirBridgeRolloutPolicy::kEnabledByUser &&
node_def.op() != "VarIsInitializedOp";
if (!use_mlir) {
diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc
index 602c2d2..054ac99 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator.cc
@@ -103,8 +103,15 @@
if (flr->config_proto()) {
config_proto = *flr->config_proto();
}
- MlirBridgeRolloutPolicy policy =
- GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto);
+ // There is no easy way to check if we have uninitialized resource args here
+ // so we assume there are uninitialized resource args. This means that we
+ // might run the compilability checker in cases where we don't need to (when
+ // MLIR bridge is run later). Note that this is just temporary until
+ // b/171732021 gets fixed.
+ // We should also revisit if this check provides any value, otherwise we
+ // should remove it.
+ MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
+ *fbody->graph, config_proto, /*uses_uninitialized_resource_args=*/true);
if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 20967e6..43147f4 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -585,6 +585,7 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
],
@@ -656,6 +657,7 @@
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
@@ -874,6 +876,7 @@
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@@ -969,6 +972,7 @@
":chlo_legalize_to_hlo_inc_gen",
":hlo",
":map_chlo_to_hlo_op",
+ "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
index e17c0d9..13d0f08 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
@@ -178,6 +178,15 @@
}];
}
+def HLOClient_BroadcastPolygammaOp : HLOClient_BroadcastBinaryElementwiseOp<
+ "broadcast_polygamma", [NoSideEffect, SameOperandsAndResultElementType]> {
+ let summary = "Polygamma function (with optional broadcasting)";
+
+ let description = [{
+ Returns `Polygamma(operand, operand)` element-wise.
+ }];
+}
+
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_power",
[NoSideEffect, SameOperandsAndResultElementType]> {
@@ -339,10 +348,9 @@
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
//===----------------------------------------------------------------------===//
-def HLOClient_ZetaOp : HLOClient_Op<"zeta",
- [NoSideEffect, SameOperandsAndResultType]> {
+def HLOClient_ZetaOp : HLOClient_Op<"zeta", [NoSideEffect,
+ SameOperandsAndResultType]> {
let summary = "Hurwitz zeta function";
-
let description = [{
Returns `Zeta(operand, operand)` element-wise.
@@ -351,15 +359,26 @@
$$
}];
- let arguments = (ins
- HLO_FpTensor:$x,
- HLO_FpTensor:$q
- );
-
- let results = (outs HLO_FpTensor);
+ let arguments = (ins HLO_FpTensor:$x, HLO_FpTensor:$q);
+ let results = (outs HLO_FpTensor:$result);
let assemblyFormat = [{
- $x `,` $q attr-dict `:` `(` type($x) `,` type($q) `)` `->` type(results)
+ $x `,` $q attr-dict `:` type($x) `,` type($q) `->` type(results)
+ }];
+}
+
+def HLOClient_PolygammaOp : HLOClient_Op<"polygamma", [NoSideEffect,
+ SameOperandsAndResultType]> {
+ let summary = "Polygamma function";
+ let description = [{
+ Returns `Polygamma(operand, operand)` element-wise.
+ }];
+
+ let arguments = (ins HLO_FpTensor:$n, HLO_FpTensor:$x);
+ let results = (outs HLO_FpTensor:$result);
+
+ let assemblyFormat = [{
+ $n `,` $x attr-dict `:` type($n) `,` type($x) `->` type(results)
}];
}
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h
index 2b1b07c8d..d9e637d 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h
@@ -72,6 +72,7 @@
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
+ POPULATE_BCAST(BroadcastPolygammaOp, PolygammaOp);
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
index bbe156d..d726c7f 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
@@ -24,6 +24,7 @@
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -173,7 +174,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
loc, result_types, args, b);
}
@@ -246,7 +247,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
loc, result_types, args, b);
}
@@ -310,14 +311,38 @@
// No conversion is needed for the same width floats
return args.front();
}
+ if (targetType.isInteger(/*width=*/1)) {
+ // When casting to bool, we need to compare whether the value is equal to
+ // zero.
+ if (sourceType.isSignlessInteger()) {
+ Value zero_intval = b->create<::mlir::ConstantIntOp>(
+ loc, 0, sourceType.cast<IntegerType>().getWidth());
+ if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
+ zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
+ }
+ return b->create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
+ zero_intval);
+ } else if (sourceType.isa<FloatType>()) {
+ Value zero = b->create<ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0));
+ if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
+ zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
+ }
+ return b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, args.front(),
+ zero);
+ }
+ }
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
IntegerType src = sourceType.cast<IntegerType>();
IntegerType res = targetType.cast<IntegerType>();
if (src.getWidth() > res.getWidth()) {
return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
- } else if (src.getWidth() < res.getWidth()) {
+ } else if (src.getWidth() == 1) {
+ // Special case boolean values, so they get casted to `1` instead of `-1`.
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
mlir::None);
+ } else if (src.getWidth() < res.getWidth()) {
+ return b->create<mlir::SignExtendIOp>(loc, result_types, args,
+ mlir::None);
}
// No conversion is needed for the same width integers
return args.front();
@@ -358,7 +383,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
loc, result_types, args, b);
}
@@ -367,7 +392,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
loc, result_types, args, b);
}
@@ -434,7 +459,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
loc, result_types, args, b);
}
@@ -463,7 +488,7 @@
Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
Value x = args.front();
Value neg_x = b->create<NegFOp>(loc, x);
- Value exp_neg_x = b->create<::mlir::ExpOp>(loc, neg_x);
+ Value exp_neg_x = b->create<::mlir::math::ExpOp>(loc, neg_x);
Value one_add_exp_neg_x = b->create<AddFOp>(loc, one, exp_neg_x);
return b->create<DivFOp>(loc, one, one_add_exp_neg_x);
}
@@ -473,7 +498,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Log1pOp>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
loc, result_types, args, b);
}
@@ -579,7 +604,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
loc, result_types, args, b);
}
@@ -593,8 +618,8 @@
// Floating point can use std::powf
auto result_type = result_types.front();
if (result_type.isa<::mlir::FloatType>())
- return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args,
- b);
+ return MapLhloOpToStdScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
+ args, b);
assert(result_type.isa<::mlir::IntegerType>() &&
"only float and integer `pow` is supported right now");
@@ -746,7 +771,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
loc, result_types, args, b);
}
@@ -766,7 +791,7 @@
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
- return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
+ return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
loc, result_types, args, b);
}
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
index aa65344..31a1ee3 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
@@ -318,6 +318,7 @@
BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
+BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index a6ca692..c6ad9e0 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -740,6 +740,12 @@
static LogicalResult Verify(BroadcastInDimOp op) {
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
+ if (!operandType) {
+ // The following verification checks all depend on knowing the rank of
+ // the operand. Bail out now if we don't know the rank of the operand.
+ return success();
+ }
+
auto operandRank = operandType.getRank();
if (!op.broadcast_dimensions()) {
if (operandRank == 0) {
@@ -783,13 +789,15 @@
dimIndex, resultRank));
}
- auto dimSize = operandType.getDimSize(i);
- auto resultDimSize = resultType.getDimSize(dimIndex);
- if (dimSize != 1 && dimSize != resultDimSize) {
- return op.emitOpError(
- llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
- "1 or size of result dimension {2} ({3})",
- i, dimSize, dimIndex, resultDimSize));
+ if (!operandType.isDynamicDim(i)) {
+ auto dimSize = operandType.getDimSize(i);
+ auto resultDimSize = resultType.getDimSize(dimIndex);
+ if (dimSize != 1 && dimSize != resultDimSize) {
+ return op.emitOpError(
+ llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
+ "1 or size of result dimension {2} ({3})",
+ i, dimSize, dimIndex, resultDimSize));
+ }
}
}
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
index 50866b7..d200be6 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
@@ -94,6 +94,7 @@
LmhloDialect
MLIRIR
MLIRPass
+ MLIRMath
)
add_mlir_library(MhloToStandard
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
index 808b0af..fabf8ef 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
@@ -22,6 +22,7 @@
#include <numeric>
#include <vector>
+#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
@@ -96,7 +97,8 @@
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value MaterializeErfcApproximationF64ForMagnituteGEOne(
- ConversionPatternRewriter &rewriter, Location loc, Value x) {
+ ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
+ Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
const double kMaxlog = 7.09782712893383996843E2;
@@ -179,7 +181,8 @@
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value MaterializeErfApproximationF64ForMagnituteLEOne(
- ConversionPatternRewriter &rewriter, Location loc, Value x) {
+ ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
+ Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
const std::vector<double> kErfTCoefficients{
@@ -204,7 +207,8 @@
// This implementation is based on Cephes.
Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
- Location loc, Value x) {
+ Location loc, ValueRange args) {
+ Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
@@ -230,7 +234,8 @@
}
Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
- Location loc, Value x) {
+ Location loc, ValueRange args) {
+ Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
@@ -261,7 +266,8 @@
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
- ConversionPatternRewriter &rewriter, Location loc, Value x) {
+ ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
+ Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const double kMaxlog = 88.72283905206835;
@@ -325,7 +331,8 @@
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value MaterializeErfApproximationF32ForMagnitudeLEOne(
- ConversionPatternRewriter &rewriter, Location loc, Value x) {
+ ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
+ Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kErfTCoefficients{
@@ -344,8 +351,9 @@
// This is the same approximation as used in Eigen.
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
- Location loc, Value operand) {
- assert(operand.getType().cast<ShapedType>().getElementType().isF32() &&
+ Location loc, ValueRange args) {
+ Value x = args.front();
+ assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
@@ -358,10 +366,9 @@
};
// Clamp argument between -4 and 4.
- Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
- Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
- Value x =
- rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
+ Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
+ Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
+ x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
// Materialize polynomial approximation for x in [-4, 4] as
@@ -375,7 +382,8 @@
}
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
- Location loc, Value x) {
+ Location loc, ValueRange args) {
+ Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
@@ -401,18 +409,30 @@
}
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
- Value arg, FloatType min_precision_ty,
+ ValueRange args, FloatType min_precision_ty,
Value callback(ConversionPatternRewriter &,
- Location, Value)) {
- auto original_ty = getElementTypeOrSelf(arg.getType()).cast<FloatType>();
+ Location, ValueRange)) {
+ auto original_ty =
+ getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
+
+ // Upcast arguments if necessary.
+ llvm::SmallVector<Value, 2> casted_args;
if (needs_upcast) {
- arg = rewriter.create<mhlo::ConvertOp>(loc, arg, min_precision_ty);
+ for (Value a : args) {
+ casted_args.push_back(
+ rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
+ }
+ args = casted_args;
}
- Value result = callback(rewriter, loc, arg);
+
+ Value result = callback(rewriter, loc, args);
+
+ // Cast back if necessary.
if (needs_upcast) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
}
+
return result;
}
@@ -434,9 +454,9 @@
return success();
}
- rewriter.replaceOp(
- op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
- &MaterializeErfApproximationF32));
+ rewriter.replaceOp(op, MaterializeWithUpcast(
+ rewriter, loc, operands, rewriter.getF32Type(),
+ &MaterializeErfApproximationF32));
return success();
}
};
@@ -459,9 +479,9 @@
return success();
}
- rewriter.replaceOp(
- op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
- &MaterializeErfcApproximationF32));
+ rewriter.replaceOp(op, MaterializeWithUpcast(
+ rewriter, loc, operands, rewriter.getF32Type(),
+ &MaterializeErfcApproximationF32));
return success();
}
};
@@ -491,12 +511,13 @@
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
- Value x) {
+ ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
+ Value x = args.front();
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x);
@@ -635,12 +656,13 @@
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
- Value x) {
+ ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
+ Value x = args.front();
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x);
@@ -739,36 +761,11 @@
digamma);
}
-struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
- using OpConversionPattern<LgammaOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- LgammaOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- LgammaOp::Adaptor transformed(operands);
- FloatType min_precision_ty = rewriter.getF32Type();
- rewriter.replaceOp(
- op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
- min_precision_ty, &MaterializeLgamma));
- return success();
- }
-};
-
-struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
- using OpConversionPattern<DigammaOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- DigammaOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- DigammaOp::Adaptor transformed(operands);
- FloatType min_precision_ty = rewriter.getF32Type();
- rewriter.replaceOp(
- op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
- min_precision_ty, &MaterializeDigamma));
- return success();
- }
-};
-
-Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
- Location loc, Value x, Value q) {
+Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
+ ValueRange args) {
+ assert(args.size() == 2);
+ Value x = args[0];
+ Value q = args[1];
static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18,
1.8152105401943546773e17,
@@ -897,34 +894,101 @@
return output;
}
+Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
+ ValueRange args) {
+ PolygammaOp::Adaptor transformed(args);
+ Value n = transformed.n();
+ Value x = transformed.x();
+
+ // Handle integer n > 0.
+ Value one = getConstantLike(rewriter, loc, 1.0, x);
+ Value two = getConstantLike(rewriter, loc, 2.0, x);
+ Value sign = rewriter.create<mhlo::SubOp>(
+ loc,
+ rewriter.create<mhlo::MulOp>(loc, two,
+ rewriter.create<mhlo::RemOp>(loc, n, two)),
+ one);
+ Value n_plus_one = rewriter.create<mhlo::AddOp>(loc, n, one);
+ Value exp_lgamma_np1 = rewriter.create<mhlo::ExpOp>(
+ loc, rewriter.create<chlo::LgammaOp>(loc, n_plus_one));
+ Value zeta = rewriter.create<chlo::ZetaOp>(loc, n_plus_one, x);
+ Value result = rewriter.create<mhlo::MulOp>(
+ loc, rewriter.create<mhlo::MulOp>(loc, sign, exp_lgamma_np1), zeta);
+
+ // Handle n = 0.
+ const StringAttr kEQ = rewriter.getStringAttr(
+ mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
+ Value zero = getConstantLike(rewriter, loc, 0.0, x);
+ Value n_eq_zero = rewriter.create<mhlo::CompareOp>(loc, n, zero, kEQ);
+ result = rewriter.create<mhlo::SelectOp>(
+ loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
+
+ // Check that n is a natural number.
+ const StringAttr kNE = rewriter.getStringAttr(
+ mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
+ Value non_int = rewriter.create<mhlo::CompareOp>(
+ loc, n, rewriter.create<mhlo::FloorOp>(loc, n), kNE);
+ const StringAttr kLT = rewriter.getStringAttr(
+ mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
+ Value negative = rewriter.create<mhlo::CompareOp>(loc, n, zero, kLT);
+ Value non_natural = rewriter.create<mhlo::OrOp>(loc, non_int, negative);
+ return rewriter.create<mhlo::SelectOp>(
+ loc, non_natural,
+ getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
+ x),
+ result);
+}
+
+struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
+ using OpConversionPattern<LgammaOp>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ LgammaOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ FloatType min_precision_ty = rewriter.getF32Type();
+ rewriter.replaceOp(
+ op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
+ min_precision_ty, &MaterializeLgamma));
+ return success();
+ }
+};
+
+struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
+ using OpConversionPattern<DigammaOp>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ DigammaOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ FloatType min_precision_ty = rewriter.getF32Type();
+ rewriter.replaceOp(
+ op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
+ min_precision_ty, &MaterializeDigamma));
+ return success();
+ }
+};
+
+struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
+ using OpConversionPattern<PolygammaOp>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ PolygammaOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ FloatType min_precision_ty = rewriter.getF32Type();
+ rewriter.replaceOp(
+ op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
+ &MaterializePolygamma));
+ return success();
+ }
+};
+
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
using OpConversionPattern<ZetaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ZetaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- ZetaOpAdaptor adaptor(operands);
Location loc = op.getLoc();
-
- // Zeta is only defined on tensors of float elements and statically
- // verified that both have the same type. So it suffices to look at one
- // here.
- auto elm_type = adaptor.x().getType().cast<ShapedType>().getElementType();
-
- bool needs_upcast = elm_type.isF16() || elm_type.isBF16();
-
- Value x = adaptor.x();
- Value q = adaptor.q();
-
- if (needs_upcast) {
- x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
- q = rewriter.create<mhlo::ConvertOp>(loc, q, rewriter.getF32Type());
- }
- Value result = MaterializeZetaComputation(rewriter, loc, x, q);
- if (needs_upcast) {
- result = rewriter.create<mhlo::ConvertOp>(loc, result, elm_type);
- }
- rewriter.replaceOp(op, {result});
-
+ FloatType min_precision_ty = rewriter.getF32Type();
+ rewriter.replaceOp(
+ op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
+ &MaterializeZeta));
return success();
}
};
@@ -1090,6 +1154,7 @@
ConvertErfOp,
ConvertErfcOp,
ConvertLgammaOp,
+ ConvertPolygammaOp,
ConvertZetaOp>(context);
// clang-format on
}
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc
index cae718e..f9ee38b 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc
@@ -55,7 +55,7 @@
if (broadcast_only_) {
chlo::PopulateChloBroadcastingPatterns(&getContext(),
&conversionPatterns);
- conversionTarget.addLegalOp<chlo::ZetaOp>();
+ conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>();
} else {
chlo::PopulateLegalizeChloToHloPatterns(&getContext(),
&conversionPatterns);
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index c5c9edd..6cbb68c 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -26,6 +26,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -804,9 +805,8 @@
loc, collapsed_type, args[0], collapsing_map);
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
loc, result_type, collapsed_op, expanding_map);
- rewriter.replaceOpWithNewOp<linalg::CopyOp>(
- reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
- /*outputPermutation =*/nullptr);
+ rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
+ args[1]);
} else {
auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
@@ -820,9 +820,8 @@
if (isLHLO) {
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
reshape_op.getLoc(), result_type, args[0], reassociation_map);
- rewriter.replaceOpWithNewOp<linalg::CopyOp>(
- reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
- /*outputPermutation =*/nullptr);
+ rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
+ args[1]);
} else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape_op, result_type, args[0], reassociation_map);
@@ -1456,14 +1455,15 @@
struct LhloLegalizeToLinalgPass
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
- registry.insert<AffineDialect, linalg::LinalgDialect>();
+ registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
- StandardOpsDialect, AffineDialect>();
+ math::MathDialect, StandardOpsDialect,
+ AffineDialect>();
auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
@@ -1477,15 +1477,15 @@
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect, scf::SCFDialect,
- complex::ComplexDialect>();
+ complex::ComplexDialect, math::MathDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
- StandardOpsDialect, tensor::TensorDialect,
- scf::SCFDialect>();
+ math::MathDialect, StandardOpsDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
auto func = getFunction();
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc
index 9c1dc9c..64be60f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc
@@ -18,6 +18,7 @@
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
@@ -74,10 +75,10 @@
};
class ApproximateTanhLowering
- : public ApproximateOnExtendedF32Lowering<TanhOp> {
+ : public ApproximateOnExtendedF32Lowering<math::TanhOp> {
public:
explicit ApproximateTanhLowering(MLIRContext *ctx)
- : ApproximateOnExtendedF32Lowering<TanhOp>(ctx) {}
+ : ApproximateOnExtendedF32Lowering<math::TanhOp>(ctx) {}
// Emits the fast tanh approximation that is also used by XLA.
Value emitApproximation(ValueRange args, Location loc,
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
index 3b52288..7c47b6f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
@@ -230,46 +230,54 @@
// pattern will handle the lowering.
if (!lhs_type || !rhs_type) return failure();
- // If lhs is scalar
+ Value shape_of_lhs = rewriter.create<shape::ShapeOfOp>(loc, lhs);
+ Value shape_of_rhs = rewriter.create<shape::ShapeOfOp>(loc, rhs);
+
+ // If lhs has exactly one element
auto if_op = rewriter.create<scf::IfOp>(
- loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
+ loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs),
+ true);
OpBuilder if_lhs_scalar_builder =
if_op.getThenBodyBuilder(rewriter.getListener());
- Value reshaped_lhs = if_lhs_scalar_builder.create<tensor::CastOp>(
+ Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
op.getAttrs());
- if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
+ Value extended_if_lhs_scalar_result =
+ extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
+ shape_of_lhs, shape_of_rhs);
+ if_lhs_scalar_builder.create<scf::YieldOp>(loc,
+ extended_if_lhs_scalar_result);
- // If lhs is NOT scalar
+ // If lhs does not have exactly one element
//
- // See if rhs is scalar
+ // See if rhs has exactly one element
OpBuilder else_lhs_scalar_builder =
if_op.getElseBodyBuilder(rewriter.getListener());
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
- loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
- true);
+ loc, result_type,
+ IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true);
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder =
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
- Value reshaped_rhs = if_rhs_scalar_builder.create<tensor::CastOp>(
- loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
+ Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
+ loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
op.getAttrs());
- if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
+ Value extended_if_rhs_scalar_result =
+ extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
+ shape_of_lhs, shape_of_rhs);
+ if_rhs_scalar_builder.create<scf::YieldOp>(loc,
+ extended_if_rhs_scalar_result);
- // If NEITHER shape is scalar
+ // If NEITHER shape has exactly one element
//
// See if shapes are equal.
OpBuilder else_no_scalars_builder =
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
- Value shape_of_lhs =
- else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
- Value shape_of_rhs =
- else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
loc, shape_of_lhs, shape_of_rhs);
@@ -284,7 +292,7 @@
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
- // If shapes are not scalar, nor equal
+ // If shapes do not have exactly one element, nor are equal
//
// See if values are of a rank that we support.
OpBuilder if_neq_shapes_builder =
@@ -297,16 +305,17 @@
}
private:
- // Returns the dynamic result of checking the given value is a scalar tensor.
- Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
+ // Returns the dynamic result of checking the given value is effectively a
+ // scalar shape (i.e. the number of elements is 1).
+ Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op,
+ Value shape_of_tensor) const {
auto loc = op.getLoc();
- Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
- Value rank_tensor = rewriter.create<shape::RankOp>(
- loc, rewriter.getIndexType(), shape_of_tensor);
+ Value num_elements =
+ rewriter.create<shape::NumElementsOp>(loc, shape_of_tensor);
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
- rank_tensor,
- rewriter.create<ConstantIndexOp>(loc, 0));
+ num_elements,
+ rewriter.create<ConstantIndexOp>(loc, 1));
}
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
@@ -326,6 +335,36 @@
greater_rank_is_n, true);
}
+ Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
+ Value shape_of_lhs, Value shape_of_rhs) const {
+ auto unknown_rank_extent_tensor_type = RankedTensorType::get(
+ {RankedTensorType::kDynamicSize}, builder.getIndexType());
+ Value broadcast_shape =
+ builder.create<shape::BroadcastOp>(loc, unknown_rank_extent_tensor_type,
+ shape_of_lhs, shape_of_rhs, nullptr);
+ return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
+ broadcast_shape);
+ }
+
+ Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value,
+ int targeted_rank) const {
+ auto loc = op.getLoc();
+ Value shape = builder.create<shape::ShapeOfOp>(loc, value);
+ SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
+ auto unknown_rank_extent_tensor_type = RankedTensorType::get(
+ {RankedTensorType::kDynamicSize}, builder.getIndexType());
+ auto known_rank_extent_tensor_type =
+ RankedTensorType::get({targeted_rank}, builder.getIndexType());
+ Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
+ loc, known_rank_extent_tensor_type,
+ mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
+ ranked_shape));
+ Value extended_value = builder.create<shape::BroadcastOp>(
+ loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
+ return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
+ extended_value);
+ }
+
// Create the if statement and code for a broadcasting op with a result of a
// given rank.
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
@@ -333,32 +372,16 @@
int targeted_rank) const {
auto loc = op.getLoc();
- // Handle shape broadcasting and inferrence.
- Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
- Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
- SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
- auto unknown_rank_extent_tensor_type = RankedTensorType::get(
- {RankedTensorType::kDynamicSize}, if_builder.getIndexType());
- auto known_rank_extent_tensor_type =
- RankedTensorType::get({targeted_rank}, if_builder.getIndexType());
+ // Handle shape broadcasting and inference.
+ Value extended_lhs_casted =
+ createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank);
+ Value extended_rhs_casted =
+ createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank);
+ auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
+ targeted_rank, RankedTensorType::kDynamicSize);
auto reshaped_type = RankedTensorType::get(
- llvm::SmallVector<int64_t, 6>(targeted_rank,
- RankedTensorType::kDynamicSize),
+ dynamic_dimensions,
lhs.getType().template dyn_cast<TensorType>().getElementType());
- Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
- loc, known_rank_extent_tensor_type,
- mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
- ranked_shape));
- Value extended_lhs = if_builder.create<shape::BroadcastOp>(
- loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
- nullptr);
- Value extended_lhs_casted = if_builder.create<tensor::CastOp>(
- loc, known_rank_extent_tensor_type, extended_lhs);
- Value extended_rhs = if_builder.create<shape::BroadcastOp>(
- loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
- nullptr);
- Value extended_rhs_casted = if_builder.create<tensor::CastOp>(
- loc, known_rank_extent_tensor_type, extended_rhs);
// 1. Reshape operands to the given rank (with the same number of elements)
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
@@ -372,10 +395,8 @@
.getType()
.template dyn_cast<TensorType>()
.getElementType();
- auto result_type = RankedTensorType::get(
- llvm::SmallVector<int64_t, 6>(targeted_rank,
- RankedTensorType::kDynamicSize),
- result_element_type);
+ auto result_type =
+ RankedTensorType::get(dynamic_dimensions, result_element_type);
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type},
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir
index 57c20c8..efe96cf 100644
--- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir
@@ -247,3 +247,14 @@
: (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
+
+// -----
+// CHECK-LABEL: @PolygammaWithoutBroadcast
+// CHECK-SAME: (%[[LHS:.*]]: tensor<4xf32>, %[[RHS:.*]]: tensor<4xf32>)
+func @PolygammaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>)
+ -> tensor<4xf32> {
+ // CHECK: chlo.polygamma %[[LHS]], %[[RHS]]
+ %0 = chlo.broadcast_polygamma %arg0, %arg1
+ : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir
index fd2aad2..44693f9 100644
--- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir
@@ -1110,179 +1110,950 @@
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f16>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f16>) -> tensor<f16> {
func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
- %0 = chlo.zeta %arg0, %arg1 : (tensor<f16>, tensor<f16>) -> tensor<f16>
-// CHECK: %[[VAL_2:.*]] = "mhlo.convert"(%[[VAL_0]]) : (tensor<f16>) -> tensor<f32>
-// CHECK: %[[VAL_3:.*]] = "mhlo.convert"(%[[VAL_1]]) : (tensor<f16>) -> tensor<f32>
-// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_5:.*]] = "mhlo.negate"(%[[VAL_2]]) : (tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_6:.*]] = mhlo.power %[[VAL_3]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_8:.*]] = mhlo.add %[[VAL_3]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_9:.*]] = mhlo.power %[[VAL_8]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_10:.*]] = mhlo.add %[[VAL_6]], %[[VAL_9]] : tensor<f32>
-// CHECK: %[[VAL_11:.*]] = mhlo.add %[[VAL_8]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_12:.*]] = mhlo.power %[[VAL_11]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_13:.*]] = mhlo.add %[[VAL_10]], %[[VAL_12]] : tensor<f32>
-// CHECK: %[[VAL_14:.*]] = mhlo.add %[[VAL_11]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_15:.*]] = mhlo.power %[[VAL_14]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_16:.*]] = mhlo.add %[[VAL_13]], %[[VAL_15]] : tensor<f32>
-// CHECK: %[[VAL_17:.*]] = mhlo.add %[[VAL_14]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_18:.*]] = mhlo.power %[[VAL_17]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_19:.*]] = mhlo.add %[[VAL_16]], %[[VAL_18]] : tensor<f32>
-// CHECK: %[[VAL_20:.*]] = mhlo.add %[[VAL_17]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_21:.*]] = mhlo.power %[[VAL_20]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_22:.*]] = mhlo.add %[[VAL_19]], %[[VAL_21]] : tensor<f32>
-// CHECK: %[[VAL_23:.*]] = mhlo.add %[[VAL_20]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_24:.*]] = mhlo.power %[[VAL_23]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_25:.*]] = mhlo.add %[[VAL_22]], %[[VAL_24]] : tensor<f32>
-// CHECK: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_27:.*]] = mhlo.power %[[VAL_26]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_28:.*]] = mhlo.add %[[VAL_25]], %[[VAL_27]] : tensor<f32>
-// CHECK: %[[VAL_29:.*]] = mhlo.add %[[VAL_26]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_30:.*]] = mhlo.power %[[VAL_29]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_31:.*]] = mhlo.add %[[VAL_28]], %[[VAL_30]] : tensor<f32>
-// CHECK: %[[VAL_32:.*]] = mhlo.add %[[VAL_29]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_33:.*]] = mhlo.power %[[VAL_32]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_34:.*]] = mhlo.add %[[VAL_31]], %[[VAL_33]] : tensor<f32>
-// CHECK: %[[VAL_35:.*]] = mhlo.add %[[VAL_32]], %[[VAL_7]] : tensor<f32>
-// CHECK: %[[VAL_36:.*]] = mhlo.power %[[VAL_35]], %[[VAL_5]] : tensor<f32>
-// CHECK: %[[VAL_37:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_38:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_37]] : tensor<f32>
-// CHECK: %[[VAL_39:.*]] = mhlo.multiply %[[VAL_36]], %[[VAL_35]] : tensor<f32>
-// CHECK: %[[VAL_40:.*]] = mhlo.divide %[[VAL_39]], %[[VAL_38]] : tensor<f32>
-// CHECK: %[[VAL_41:.*]] = mhlo.add %[[VAL_34]], %[[VAL_40]] : tensor<f32>
-// CHECK: %[[VAL_42:.*]] = mhlo.multiply %[[VAL_35]], %[[VAL_35]] : tensor<f32>
-// CHECK: %[[VAL_43:.*]] = mhlo.divide %[[VAL_7]], %[[VAL_42]] : tensor<f32>
-// CHECK: %[[VAL_44:.*]] = mhlo.constant dense<2.200000e+01> : tensor<f32>
-// CHECK: %[[VAL_45:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_44]] : tensor<f32>
-// CHECK: %[[VAL_46:.*]] = mhlo.constant dense<2.100000e+01> : tensor<f32>
-// CHECK: %[[VAL_47:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_46]] : tensor<f32>
-// CHECK: %[[VAL_48:.*]] = mhlo.multiply %[[VAL_45]], %[[VAL_47]] : tensor<f32>
-// CHECK: %[[VAL_49:.*]] = mhlo.constant dense<-1.39544646E-19> : tensor<f32>
-// CHECK: %[[VAL_50:.*]] = mhlo.add %[[VAL_4]], %[[VAL_49]] : tensor<f32>
-// CHECK: %[[VAL_51:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_50]] : tensor<f32>
-// CHECK: %[[VAL_52:.*]] = mhlo.multiply %[[VAL_48]], %[[VAL_51]] : tensor<f32>
-// CHECK: %[[VAL_53:.*]] = mhlo.constant dense<2.000000e+01> : tensor<f32>
-// CHECK: %[[VAL_54:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_53]] : tensor<f32>
-// CHECK: %[[VAL_55:.*]] = mhlo.constant dense<1.900000e+01> : tensor<f32>
-// CHECK: %[[VAL_56:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_55]] : tensor<f32>
-// CHECK: %[[VAL_57:.*]] = mhlo.multiply %[[VAL_54]], %[[VAL_56]] : tensor<f32>
-// CHECK: %[[VAL_58:.*]] = mhlo.constant dense<5.50900303E-18> : tensor<f32>
-// CHECK: %[[VAL_59:.*]] = mhlo.add %[[VAL_52]], %[[VAL_58]] : tensor<f32>
-// CHECK: %[[VAL_60:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_59]] : tensor<f32>
-// CHECK: %[[VAL_61:.*]] = mhlo.multiply %[[VAL_57]], %[[VAL_60]] : tensor<f32>
-// CHECK: %[[VAL_62:.*]] = mhlo.constant dense<1.800000e+01> : tensor<f32>
-// CHECK: %[[VAL_63:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_62]] : tensor<f32>
-// CHECK: %[[VAL_64:.*]] = mhlo.constant dense<1.700000e+01> : tensor<f32>
-// CHECK: %[[VAL_65:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_64]] : tensor<f32>
-// CHECK: %[[VAL_66:.*]] = mhlo.multiply %[[VAL_63]], %[[VAL_65]] : tensor<f32>
-// CHECK: %[[VAL_67:.*]] = mhlo.constant dense<-2.17486866E-16> : tensor<f32>
-// CHECK: %[[VAL_68:.*]] = mhlo.add %[[VAL_61]], %[[VAL_67]] : tensor<f32>
-// CHECK: %[[VAL_69:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_68]] : tensor<f32>
-// CHECK: %[[VAL_70:.*]] = mhlo.multiply %[[VAL_66]], %[[VAL_69]] : tensor<f32>
-// CHECK: %[[VAL_71:.*]] = mhlo.constant dense<1.600000e+01> : tensor<f32>
-// CHECK: %[[VAL_72:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_71]] : tensor<f32>
-// CHECK: %[[VAL_73:.*]] = mhlo.constant dense<1.500000e+01> : tensor<f32>
-// CHECK: %[[VAL_74:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_73]] : tensor<f32>
-// CHECK: %[[VAL_75:.*]] = mhlo.multiply %[[VAL_72]], %[[VAL_74]] : tensor<f32>
-// CHECK: %[[VAL_76:.*]] = mhlo.constant dense<8.58606213E-15> : tensor<f32>
-// CHECK: %[[VAL_77:.*]] = mhlo.add %[[VAL_70]], %[[VAL_76]] : tensor<f32>
-// CHECK: %[[VAL_78:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_77]] : tensor<f32>
-// CHECK: %[[VAL_79:.*]] = mhlo.multiply %[[VAL_75]], %[[VAL_78]] : tensor<f32>
-// CHECK: %[[VAL_80:.*]] = mhlo.constant dense<1.400000e+01> : tensor<f32>
-// CHECK: %[[VAL_81:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_80]] : tensor<f32>
-// CHECK: %[[VAL_82:.*]] = mhlo.constant dense<1.300000e+01> : tensor<f32>
-// CHECK: %[[VAL_83:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_82]] : tensor<f32>
-// CHECK: %[[VAL_84:.*]] = mhlo.multiply %[[VAL_81]], %[[VAL_83]] : tensor<f32>
-// CHECK: %[[VAL_85:.*]] = mhlo.constant dense<-3.3896803E-13> : tensor<f32>
-// CHECK: %[[VAL_86:.*]] = mhlo.add %[[VAL_79]], %[[VAL_85]] : tensor<f32>
-// CHECK: %[[VAL_87:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_86]] : tensor<f32>
-// CHECK: %[[VAL_88:.*]] = mhlo.multiply %[[VAL_84]], %[[VAL_87]] : tensor<f32>
-// CHECK: %[[VAL_89:.*]] = mhlo.constant dense<1.200000e+01> : tensor<f32>
-// CHECK: %[[VAL_90:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_89]] : tensor<f32>
-// CHECK: %[[VAL_91:.*]] = mhlo.constant dense<1.100000e+01> : tensor<f32>
-// CHECK: %[[VAL_92:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_91]] : tensor<f32>
-// CHECK: %[[VAL_93:.*]] = mhlo.multiply %[[VAL_90]], %[[VAL_92]] : tensor<f32>
-// CHECK: %[[VAL_94:.*]] = mhlo.constant dense<1.33825364E-11> : tensor<f32>
-// CHECK: %[[VAL_95:.*]] = mhlo.add %[[VAL_88]], %[[VAL_94]] : tensor<f32>
-// CHECK: %[[VAL_96:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_95]] : tensor<f32>
-// CHECK: %[[VAL_97:.*]] = mhlo.multiply %[[VAL_93]], %[[VAL_96]] : tensor<f32>
-// CHECK: %[[VAL_98:.*]] = mhlo.constant dense<1.000000e+01> : tensor<f32>
-// CHECK: %[[VAL_99:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_98]] : tensor<f32>
-// CHECK: %[[VAL_100:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_101:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_100]] : tensor<f32>
-// CHECK: %[[VAL_102:.*]] = mhlo.multiply %[[VAL_99]], %[[VAL_101]] : tensor<f32>
-// CHECK: %[[VAL_103:.*]] = mhlo.constant dense<-5.28419031E-10> : tensor<f32>
-// CHECK: %[[VAL_104:.*]] = mhlo.add %[[VAL_97]], %[[VAL_103]] : tensor<f32>
-// CHECK: %[[VAL_105:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_104]] : tensor<f32>
-// CHECK: %[[VAL_106:.*]] = mhlo.multiply %[[VAL_102]], %[[VAL_105]] : tensor<f32>
-// CHECK: %[[VAL_107:.*]] = mhlo.constant dense<8.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_108:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_107]] : tensor<f32>
-// CHECK: %[[VAL_109:.*]] = mhlo.constant dense<7.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_110:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_109]] : tensor<f32>
-// CHECK: %[[VAL_111:.*]] = mhlo.multiply %[[VAL_108]], %[[VAL_110]] : tensor<f32>
-// CHECK: %[[VAL_112:.*]] = mhlo.constant dense<2.08767563E-8> : tensor<f32>
-// CHECK: %[[VAL_113:.*]] = mhlo.add %[[VAL_106]], %[[VAL_112]] : tensor<f32>
-// CHECK: %[[VAL_114:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_113]] : tensor<f32>
-// CHECK: %[[VAL_115:.*]] = mhlo.multiply %[[VAL_111]], %[[VAL_114]] : tensor<f32>
-// CHECK: %[[VAL_116:.*]] = mhlo.constant dense<6.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_117:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_116]] : tensor<f32>
-// CHECK: %[[VAL_118:.*]] = mhlo.constant dense<5.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_119:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_118]] : tensor<f32>
-// CHECK: %[[VAL_120:.*]] = mhlo.multiply %[[VAL_117]], %[[VAL_119]] : tensor<f32>
-// CHECK: %[[VAL_121:.*]] = mhlo.constant dense<-8.26719599E-7> : tensor<f32>
-// CHECK: %[[VAL_122:.*]] = mhlo.add %[[VAL_115]], %[[VAL_121]] : tensor<f32>
-// CHECK: %[[VAL_123:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_122]] : tensor<f32>
-// CHECK: %[[VAL_124:.*]] = mhlo.multiply %[[VAL_120]], %[[VAL_123]] : tensor<f32>
-// CHECK: %[[VAL_125:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_126:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_125]] : tensor<f32>
-// CHECK: %[[VAL_127:.*]] = mhlo.constant dense<3.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_128:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_127]] : tensor<f32>
-// CHECK: %[[VAL_129:.*]] = mhlo.multiply %[[VAL_126]], %[[VAL_128]] : tensor<f32>
-// CHECK: %[[VAL_130:.*]] = mhlo.constant dense<3.30687835E-5> : tensor<f32>
-// CHECK: %[[VAL_131:.*]] = mhlo.add %[[VAL_124]], %[[VAL_130]] : tensor<f32>
-// CHECK: %[[VAL_132:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_131]] : tensor<f32>
-// CHECK: %[[VAL_133:.*]] = mhlo.multiply %[[VAL_129]], %[[VAL_132]] : tensor<f32>
-// CHECK: %[[VAL_134:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_135:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_134]] : tensor<f32>
-// CHECK: %[[VAL_136:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_137:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_136]] : tensor<f32>
-// CHECK: %[[VAL_138:.*]] = mhlo.multiply %[[VAL_135]], %[[VAL_137]] : tensor<f32>
-// CHECK: %[[VAL_139:.*]] = mhlo.constant dense<-0.00138888892> : tensor<f32>
-// CHECK: %[[VAL_140:.*]] = mhlo.add %[[VAL_133]], %[[VAL_139]] : tensor<f32>
-// CHECK: %[[VAL_141:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_140]] : tensor<f32>
-// CHECK: %[[VAL_142:.*]] = mhlo.multiply %[[VAL_138]], %[[VAL_141]] : tensor<f32>
-// CHECK: %[[VAL_143:.*]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
-// CHECK: %[[VAL_144:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_35]] : tensor<f32>
-// CHECK: %[[VAL_145:.*]] = mhlo.constant dense<0.0833333358> : tensor<f32>
-// CHECK: %[[VAL_146:.*]] = mhlo.add %[[VAL_145]], %[[VAL_142]] : tensor<f32>
-// CHECK: %[[VAL_147:.*]] = mhlo.multiply %[[VAL_144]], %[[VAL_146]] : tensor<f32>
-// CHECK: %[[VAL_148:.*]] = mhlo.add %[[VAL_143]], %[[VAL_147]] : tensor<f32>
-// CHECK: %[[VAL_149:.*]] = mhlo.multiply %[[VAL_36]], %[[VAL_148]] : tensor<f32>
-// CHECK: %[[VAL_150:.*]] = mhlo.add %[[VAL_41]], %[[VAL_149]] : tensor<f32>
-// CHECK: %[[VAL_151:.*]] = "mhlo.abs"(%[[VAL_36]]) : (tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_152:.*]] = "mhlo.abs"(%[[VAL_34]]) : (tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_153:.*]] = mhlo.constant dense<1.401300e-45> : tensor<f32>
-// CHECK: %[[VAL_154:.*]] = mhlo.multiply %[[VAL_152]], %[[VAL_153]] : tensor<f32>
-// CHECK: %[[VAL_155:.*]] = "mhlo.compare"(%[[VAL_151]], %[[VAL_154]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
-// CHECK: %[[VAL_156:.*]] = "mhlo.select"(%[[VAL_155]], %[[VAL_34]], %[[VAL_150]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_157:.*]] = mhlo.constant dense<0x7F800000> : tensor<f32>
-// CHECK: %[[VAL_158:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_37]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
-// CHECK: %[[VAL_159:.*]] = "mhlo.select"(%[[VAL_158]], %[[VAL_157]], %[[VAL_156]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_160:.*]] = mhlo.constant dense<0x7FC00000> : tensor<f32>
-// CHECK: %[[VAL_161:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_37]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
-// CHECK: %[[VAL_162:.*]] = "mhlo.select"(%[[VAL_161]], %[[VAL_160]], %[[VAL_159]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_163:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
-// CHECK: %[[VAL_164:.*]] = "mhlo.compare"(%[[VAL_3]], %[[VAL_163]]) {comparison_direction = "LE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
-// CHECK: %[[VAL_165:.*]] = "mhlo.floor"(%[[VAL_2]]) : (tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_166:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_165]]) {comparison_direction = "NE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
-// CHECK: %[[VAL_167:.*]] = mhlo.and %[[VAL_164]], %[[VAL_166]] : tensor<i1>
-// CHECK: %[[VAL_169:.*]] = "mhlo.floor"(%[[VAL_3]]) : (tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_170:.*]] = "mhlo.compare"(%[[VAL_3]], %[[VAL_169]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
-// CHECK: %[[VAL_171:.*]] = mhlo.and %[[VAL_164]], %[[VAL_170]] : tensor<i1>
-// CHECK: %[[VAL_172:.*]] = "mhlo.select"(%[[VAL_171]], %[[VAL_157]], %[[VAL_162]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_173:.*]] = "mhlo.select"(%[[VAL_167]], %[[VAL_160]], %[[VAL_172]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK: %[[VAL_174:.*]] = "mhlo.convert"(%[[VAL_173]]) : (tensor<f32>) -> tensor<f16>
+ // CHECK: %[[VAL_2:.*]] = "mhlo.convert"(%[[VAL_0]]) : (tensor<f16>) -> tensor<f32>
+ // CHECK: %[[VAL_3:.*]] = "mhlo.convert"(%[[VAL_1]]) : (tensor<f16>) -> tensor<f32>
+ // CHECK: %[[VAL_4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_5:.*]] = "mhlo.negate"(%[[VAL_2]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_6:.*]] = mhlo.power %[[VAL_3]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_8:.*]] = mhlo.add %[[VAL_3]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_9:.*]] = mhlo.power %[[VAL_8]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_10:.*]] = mhlo.add %[[VAL_6]], %[[VAL_9]] : tensor<f32>
+ // CHECK: %[[VAL_11:.*]] = mhlo.add %[[VAL_8]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_12:.*]] = mhlo.power %[[VAL_11]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_13:.*]] = mhlo.add %[[VAL_10]], %[[VAL_12]] : tensor<f32>
+ // CHECK: %[[VAL_14:.*]] = mhlo.add %[[VAL_11]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_15:.*]] = mhlo.power %[[VAL_14]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_16:.*]] = mhlo.add %[[VAL_13]], %[[VAL_15]] : tensor<f32>
+ // CHECK: %[[VAL_17:.*]] = mhlo.add %[[VAL_14]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_18:.*]] = mhlo.power %[[VAL_17]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_19:.*]] = mhlo.add %[[VAL_16]], %[[VAL_18]] : tensor<f32>
+ // CHECK: %[[VAL_20:.*]] = mhlo.add %[[VAL_17]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_21:.*]] = mhlo.power %[[VAL_20]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_22:.*]] = mhlo.add %[[VAL_19]], %[[VAL_21]] : tensor<f32>
+ // CHECK: %[[VAL_23:.*]] = mhlo.add %[[VAL_20]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_24:.*]] = mhlo.power %[[VAL_23]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_25:.*]] = mhlo.add %[[VAL_22]], %[[VAL_24]] : tensor<f32>
+ // CHECK: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_27:.*]] = mhlo.power %[[VAL_26]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_28:.*]] = mhlo.add %[[VAL_25]], %[[VAL_27]] : tensor<f32>
+ // CHECK: %[[VAL_29:.*]] = mhlo.add %[[VAL_26]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_30:.*]] = mhlo.power %[[VAL_29]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_31:.*]] = mhlo.add %[[VAL_28]], %[[VAL_30]] : tensor<f32>
+ // CHECK: %[[VAL_32:.*]] = mhlo.add %[[VAL_29]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_33:.*]] = mhlo.power %[[VAL_32]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_34:.*]] = mhlo.add %[[VAL_31]], %[[VAL_33]] : tensor<f32>
+ // CHECK: %[[VAL_35:.*]] = mhlo.add %[[VAL_32]], %[[VAL_7]] : tensor<f32>
+ // CHECK: %[[VAL_36:.*]] = mhlo.power %[[VAL_35]], %[[VAL_5]] : tensor<f32>
+ // CHECK: %[[VAL_37:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_38:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_37]] : tensor<f32>
+ // CHECK: %[[VAL_39:.*]] = mhlo.multiply %[[VAL_36]], %[[VAL_35]] : tensor<f32>
+ // CHECK: %[[VAL_40:.*]] = mhlo.divide %[[VAL_39]], %[[VAL_38]] : tensor<f32>
+ // CHECK: %[[VAL_41:.*]] = mhlo.add %[[VAL_34]], %[[VAL_40]] : tensor<f32>
+ // CHECK: %[[VAL_42:.*]] = mhlo.multiply %[[VAL_35]], %[[VAL_35]] : tensor<f32>
+ // CHECK: %[[VAL_43:.*]] = mhlo.divide %[[VAL_7]], %[[VAL_42]] : tensor<f32>
+ // CHECK: %[[VAL_44:.*]] = mhlo.constant dense<2.200000e+01> : tensor<f32>
+ // CHECK: %[[VAL_45:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_44]] : tensor<f32>
+ // CHECK: %[[VAL_46:.*]] = mhlo.constant dense<2.100000e+01> : tensor<f32>
+ // CHECK: %[[VAL_47:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_46]] : tensor<f32>
+ // CHECK: %[[VAL_48:.*]] = mhlo.multiply %[[VAL_45]], %[[VAL_47]] : tensor<f32>
+ // CHECK: %[[VAL_49:.*]] = mhlo.constant dense<-1.39544646E-19> : tensor<f32>
+ // CHECK: %[[VAL_50:.*]] = mhlo.add %[[VAL_4]], %[[VAL_49]] : tensor<f32>
+ // CHECK: %[[VAL_51:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_50]] : tensor<f32>
+ // CHECK: %[[VAL_52:.*]] = mhlo.multiply %[[VAL_48]], %[[VAL_51]] : tensor<f32>
+ // CHECK: %[[VAL_53:.*]] = mhlo.constant dense<2.000000e+01> : tensor<f32>
+ // CHECK: %[[VAL_54:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_53]] : tensor<f32>
+ // CHECK: %[[VAL_55:.*]] = mhlo.constant dense<1.900000e+01> : tensor<f32>
+ // CHECK: %[[VAL_56:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_55]] : tensor<f32>
+ // CHECK: %[[VAL_57:.*]] = mhlo.multiply %[[VAL_54]], %[[VAL_56]] : tensor<f32>
+ // CHECK: %[[VAL_58:.*]] = mhlo.constant dense<5.50900303E-18> : tensor<f32>
+ // CHECK: %[[VAL_59:.*]] = mhlo.add %[[VAL_52]], %[[VAL_58]] : tensor<f32>
+ // CHECK: %[[VAL_60:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_59]] : tensor<f32>
+ // CHECK: %[[VAL_61:.*]] = mhlo.multiply %[[VAL_57]], %[[VAL_60]] : tensor<f32>
+ // CHECK: %[[VAL_62:.*]] = mhlo.constant dense<1.800000e+01> : tensor<f32>
+ // CHECK: %[[VAL_63:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_62]] : tensor<f32>
+ // CHECK: %[[VAL_64:.*]] = mhlo.constant dense<1.700000e+01> : tensor<f32>
+ // CHECK: %[[VAL_65:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_64]] : tensor<f32>
+ // CHECK: %[[VAL_66:.*]] = mhlo.multiply %[[VAL_63]], %[[VAL_65]] : tensor<f32>
+ // CHECK: %[[VAL_67:.*]] = mhlo.constant dense<-2.17486866E-16> : tensor<f32>
+ // CHECK: %[[VAL_68:.*]] = mhlo.add %[[VAL_61]], %[[VAL_67]] : tensor<f32>
+ // CHECK: %[[VAL_69:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_68]] : tensor<f32>
+ // CHECK: %[[VAL_70:.*]] = mhlo.multiply %[[VAL_66]], %[[VAL_69]] : tensor<f32>
+ // CHECK: %[[VAL_71:.*]] = mhlo.constant dense<1.600000e+01> : tensor<f32>
+ // CHECK: %[[VAL_72:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_71]] : tensor<f32>
+ // CHECK: %[[VAL_73:.*]] = mhlo.constant dense<1.500000e+01> : tensor<f32>
+ // CHECK: %[[VAL_74:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_73]] : tensor<f32>
+ // CHECK: %[[VAL_75:.*]] = mhlo.multiply %[[VAL_72]], %[[VAL_74]] : tensor<f32>
+ // CHECK: %[[VAL_76:.*]] = mhlo.constant dense<8.58606213E-15> : tensor<f32>
+ // CHECK: %[[VAL_77:.*]] = mhlo.add %[[VAL_70]], %[[VAL_76]] : tensor<f32>
+ // CHECK: %[[VAL_78:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_77]] : tensor<f32>
+ // CHECK: %[[VAL_79:.*]] = mhlo.multiply %[[VAL_75]], %[[VAL_78]] : tensor<f32>
+ // CHECK: %[[VAL_80:.*]] = mhlo.constant dense<1.400000e+01> : tensor<f32>
+ // CHECK: %[[VAL_81:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_80]] : tensor<f32>
+ // CHECK: %[[VAL_82:.*]] = mhlo.constant dense<1.300000e+01> : tensor<f32>
+ // CHECK: %[[VAL_83:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_82]] : tensor<f32>
+ // CHECK: %[[VAL_84:.*]] = mhlo.multiply %[[VAL_81]], %[[VAL_83]] : tensor<f32>
+ // CHECK: %[[VAL_85:.*]] = mhlo.constant dense<-3.3896803E-13> : tensor<f32>
+ // CHECK: %[[VAL_86:.*]] = mhlo.add %[[VAL_79]], %[[VAL_85]] : tensor<f32>
+ // CHECK: %[[VAL_87:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_86]] : tensor<f32>
+ // CHECK: %[[VAL_88:.*]] = mhlo.multiply %[[VAL_84]], %[[VAL_87]] : tensor<f32>
+ // CHECK: %[[VAL_89:.*]] = mhlo.constant dense<1.200000e+01> : tensor<f32>
+ // CHECK: %[[VAL_90:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_89]] : tensor<f32>
+ // CHECK: %[[VAL_91:.*]] = mhlo.constant dense<1.100000e+01> : tensor<f32>
+ // CHECK: %[[VAL_92:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_91]] : tensor<f32>
+ // CHECK: %[[VAL_93:.*]] = mhlo.multiply %[[VAL_90]], %[[VAL_92]] : tensor<f32>
+ // CHECK: %[[VAL_94:.*]] = mhlo.constant dense<1.33825364E-11> : tensor<f32>
+ // CHECK: %[[VAL_95:.*]] = mhlo.add %[[VAL_88]], %[[VAL_94]] : tensor<f32>
+ // CHECK: %[[VAL_96:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_95]] : tensor<f32>
+ // CHECK: %[[VAL_97:.*]] = mhlo.multiply %[[VAL_93]], %[[VAL_96]] : tensor<f32>
+ // CHECK: %[[VAL_98:.*]] = mhlo.constant dense<1.000000e+01> : tensor<f32>
+ // CHECK: %[[VAL_99:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_98]] : tensor<f32>
+ // CHECK: %[[VAL_100:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_101:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_100]] : tensor<f32>
+ // CHECK: %[[VAL_102:.*]] = mhlo.multiply %[[VAL_99]], %[[VAL_101]] : tensor<f32>
+ // CHECK: %[[VAL_103:.*]] = mhlo.constant dense<-5.28419031E-10> : tensor<f32>
+ // CHECK: %[[VAL_104:.*]] = mhlo.add %[[VAL_97]], %[[VAL_103]] : tensor<f32>
+ // CHECK: %[[VAL_105:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_104]] : tensor<f32>
+ // CHECK: %[[VAL_106:.*]] = mhlo.multiply %[[VAL_102]], %[[VAL_105]] : tensor<f32>
+ // CHECK: %[[VAL_107:.*]] = mhlo.constant dense<8.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_108:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_107]] : tensor<f32>
+ // CHECK: %[[VAL_109:.*]] = mhlo.constant dense<7.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_110:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_109]] : tensor<f32>
+ // CHECK: %[[VAL_111:.*]] = mhlo.multiply %[[VAL_108]], %[[VAL_110]] : tensor<f32>
+ // CHECK: %[[VAL_112:.*]] = mhlo.constant dense<2.08767563E-8> : tensor<f32>
+ // CHECK: %[[VAL_113:.*]] = mhlo.add %[[VAL_106]], %[[VAL_112]] : tensor<f32>
+ // CHECK: %[[VAL_114:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_113]] : tensor<f32>
+ // CHECK: %[[VAL_115:.*]] = mhlo.multiply %[[VAL_111]], %[[VAL_114]] : tensor<f32>
+ // CHECK: %[[VAL_116:.*]] = mhlo.constant dense<6.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_117:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_116]] : tensor<f32>
+ // CHECK: %[[VAL_118:.*]] = mhlo.constant dense<5.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_119:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_118]] : tensor<f32>
+ // CHECK: %[[VAL_120:.*]] = mhlo.multiply %[[VAL_117]], %[[VAL_119]] : tensor<f32>
+ // CHECK: %[[VAL_121:.*]] = mhlo.constant dense<-8.26719599E-7> : tensor<f32>
+ // CHECK: %[[VAL_122:.*]] = mhlo.add %[[VAL_115]], %[[VAL_121]] : tensor<f32>
+ // CHECK: %[[VAL_123:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_122]] : tensor<f32>
+ // CHECK: %[[VAL_124:.*]] = mhlo.multiply %[[VAL_120]], %[[VAL_123]] : tensor<f32>
+ // CHECK: %[[VAL_125:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_126:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_125]] : tensor<f32>
+ // CHECK: %[[VAL_127:.*]] = mhlo.constant dense<3.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_128:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_127]] : tensor<f32>
+ // CHECK: %[[VAL_129:.*]] = mhlo.multiply %[[VAL_126]], %[[VAL_128]] : tensor<f32>
+ // CHECK: %[[VAL_130:.*]] = mhlo.constant dense<3.30687835E-5> : tensor<f32>
+ // CHECK: %[[VAL_131:.*]] = mhlo.add %[[VAL_124]], %[[VAL_130]] : tensor<f32>
+ // CHECK: %[[VAL_132:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_131]] : tensor<f32>
+ // CHECK: %[[VAL_133:.*]] = mhlo.multiply %[[VAL_129]], %[[VAL_132]] : tensor<f32>
+ // CHECK: %[[VAL_134:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_135:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_134]] : tensor<f32>
+ // CHECK: %[[VAL_136:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_137:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_136]] : tensor<f32>
+ // CHECK: %[[VAL_138:.*]] = mhlo.multiply %[[VAL_135]], %[[VAL_137]] : tensor<f32>
+ // CHECK: %[[VAL_139:.*]] = mhlo.constant dense<-0.00138888892> : tensor<f32>
+ // CHECK: %[[VAL_140:.*]] = mhlo.add %[[VAL_133]], %[[VAL_139]] : tensor<f32>
+ // CHECK: %[[VAL_141:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_140]] : tensor<f32>
+ // CHECK: %[[VAL_142:.*]] = mhlo.multiply %[[VAL_138]], %[[VAL_141]] : tensor<f32>
+ // CHECK: %[[VAL_143:.*]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
+ // CHECK: %[[VAL_144:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_35]] : tensor<f32>
+ // CHECK: %[[VAL_145:.*]] = mhlo.constant dense<0.0833333358> : tensor<f32>
+ // CHECK: %[[VAL_146:.*]] = mhlo.add %[[VAL_145]], %[[VAL_142]] : tensor<f32>
+ // CHECK: %[[VAL_147:.*]] = mhlo.multiply %[[VAL_144]], %[[VAL_146]] : tensor<f32>
+ // CHECK: %[[VAL_148:.*]] = mhlo.add %[[VAL_143]], %[[VAL_147]] : tensor<f32>
+ // CHECK: %[[VAL_149:.*]] = mhlo.multiply %[[VAL_36]], %[[VAL_148]] : tensor<f32>
+ // CHECK: %[[VAL_150:.*]] = mhlo.add %[[VAL_41]], %[[VAL_149]] : tensor<f32>
+ // CHECK: %[[VAL_151:.*]] = "mhlo.abs"(%[[VAL_36]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_152:.*]] = "mhlo.abs"(%[[VAL_34]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_153:.*]] = mhlo.constant dense<1.401300e-45> : tensor<f32>
+ // CHECK: %[[VAL_154:.*]] = mhlo.multiply %[[VAL_152]], %[[VAL_153]] : tensor<f32>
+ // CHECK: %[[VAL_155:.*]] = "mhlo.compare"(%[[VAL_151]], %[[VAL_154]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ // CHECK: %[[VAL_156:.*]] = "mhlo.select"(%[[VAL_155]], %[[VAL_34]], %[[VAL_150]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_157:.*]] = mhlo.constant dense<0x7F800000> : tensor<f32>
+ // CHECK: %[[VAL_158:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_37]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ // CHECK: %[[VAL_159:.*]] = "mhlo.select"(%[[VAL_158]], %[[VAL_157]], %[[VAL_156]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_160:.*]] = mhlo.constant dense<0x7FC00000> : tensor<f32>
+ // CHECK: %[[VAL_161:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_37]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ // CHECK: %[[VAL_162:.*]] = "mhlo.select"(%[[VAL_161]], %[[VAL_160]], %[[VAL_159]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_163:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK: %[[VAL_164:.*]] = "mhlo.compare"(%[[VAL_3]], %[[VAL_163]]) {comparison_direction = "LE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ // CHECK: %[[VAL_165:.*]] = "mhlo.floor"(%[[VAL_2]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_166:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_165]]) {comparison_direction = "NE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ // CHECK: %[[VAL_167:.*]] = mhlo.and %[[VAL_164]], %[[VAL_166]] : tensor<i1>
+ // CHECK: %[[VAL_169:.*]] = "mhlo.floor"(%[[VAL_3]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_170:.*]] = "mhlo.compare"(%[[VAL_3]], %[[VAL_169]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ // CHECK: %[[VAL_171:.*]] = mhlo.and %[[VAL_164]], %[[VAL_170]] : tensor<i1>
+ // CHECK: %[[VAL_172:.*]] = "mhlo.select"(%[[VAL_171]], %[[VAL_157]], %[[VAL_162]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_173:.*]] = "mhlo.select"(%[[VAL_167]], %[[VAL_160]], %[[VAL_172]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: %[[VAL_174:.*]] = "mhlo.convert"(%[[VAL_173]]) : (tensor<f32>) -> tensor<f16>
+ // CHECK: return %[[VAL_174]] : tensor<f16>
+ %0 = chlo.zeta %arg0, %arg1 : tensor<f16>, tensor<f16> -> tensor<f16>
return %0 : tensor<f16>
-// CHECK: return %[[VAL_174]] : tensor<f16>
+}
+
+// CHECK: @polygamma_f32
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>)
+func @polygamma_f32(%lhs : tensor<f32>, %rhs : tensor<f32>) -> tensor<f32> {
+ // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_1:.*]] = mhlo.constant dense<2.000000e+00>
+ // CHECK: %[[TMP_2:.*]] = mhlo.remainder %[[ARG0]], %[[TMP_1]]
+ // CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_1]], %[[TMP_2]]
+ // CHECK: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_3]], %[[TMP_0]]
+ // CHECK: %[[TMP_5:.*]] = mhlo.add %[[ARG0]], %[[TMP_0]]
+ // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<5.000000e-01>
+ // CHECK: %[[TMP_7:.*]] = "mhlo.compare"(%[[TMP_5]], %[[TMP_6]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_8:.*]] = "mhlo.negate"(%[[TMP_5]])
+ // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_10:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_9]]
+ // CHECK: %[[TMP_11:.*]] = "mhlo.select"(%[[TMP_7]], %[[TMP_8]], %[[TMP_10]])
+ // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_13:.*]] = mhlo.constant dense<676.520386>
+ // CHECK: %[[TMP_14:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_11]], %[[TMP_14]]
+ // CHECK: %[[TMP_16:.*]] = mhlo.divide %[[TMP_13]], %[[TMP_15]]
+ // CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_12]], %[[TMP_16]]
+ // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-1259.13916>
+ // CHECK: %[[TMP_19:.*]] = mhlo.constant dense<2.000000e+00>
+ // CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_11]], %[[TMP_19]]
+ // CHECK: %[[TMP_21:.*]] = mhlo.divide %[[TMP_18]], %[[TMP_20]]
+ // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_17]], %[[TMP_21]]
+ // CHECK: %[[TMP_23:.*]] = mhlo.constant dense<771.323425>
+ // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<3.000000e+00>
+ // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_11]], %[[TMP_24]]
+ // CHECK: %[[TMP_26:.*]] = mhlo.divide %[[TMP_23]], %[[TMP_25]]
+ // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_22]], %[[TMP_26]]
+ // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-176.615036>
+ // CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4.000000e+00>
+ // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_11]], %[[TMP_29]]
+ // CHECK: %[[TMP_31:.*]] = mhlo.divide %[[TMP_28]], %[[TMP_30]]
+ // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_27]], %[[TMP_31]]
+ // CHECK: %[[TMP_33:.*]] = mhlo.constant dense<12.5073433>
+ // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<5.000000e+00>
+ // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_11]], %[[TMP_34]]
+ // CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_33]], %[[TMP_35]]
+ // CHECK: %[[TMP_37:.*]] = mhlo.add %[[TMP_32]], %[[TMP_36]]
+ // CHECK: %[[TMP_38:.*]] = mhlo.constant dense<-0.138571098>
+ // CHECK: %[[TMP_39:.*]] = mhlo.constant dense<6.000000e+00>
+ // CHECK: %[[TMP_40:.*]] = mhlo.add %[[TMP_11]], %[[TMP_39]]
+ // CHECK: %[[TMP_41:.*]] = mhlo.divide %[[TMP_38]], %[[TMP_40]]
+ // CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_37]], %[[TMP_41]]
+ // CHECK: %[[TMP_43:.*]] = mhlo.constant dense<9.98436917E-6>
+ // CHECK: %[[TMP_44:.*]] = mhlo.constant dense<7.000000e+00>
+ // CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_11]], %[[TMP_44]]
+ // CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_43]], %[[TMP_45]]
+ // CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_42]], %[[TMP_46]]
+ // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<1.50563267E-7>
+ // CHECK: %[[TMP_49:.*]] = mhlo.constant dense<8.000000e+00>
+ // CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_11]], %[[TMP_49]]
+ // CHECK: %[[TMP_51:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
+ // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_47]], %[[TMP_51]]
+ // CHECK: %[[TMP_53:.*]] = mhlo.constant dense<7.500000e+00>
+ // CHECK: %[[TMP_54:.*]] = mhlo.add %[[TMP_53]], %[[TMP_11]]
+ // CHECK: %[[TMP_55:.*]] = mhlo.constant dense<2.01490307>
+ // CHECK: %[[TMP_56:.*]] = mhlo.divide %[[TMP_11]], %[[TMP_53]]
+ // CHECK: %[[TMP_57:.*]] = "mhlo.log_plus_one"(%[[TMP_56]])
+ // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_55]], %[[TMP_57]]
+ // CHECK: %[[TMP_59:.*]] = mhlo.divide %[[TMP_54]], %[[TMP_58]]
+ // CHECK: %[[TMP_60:.*]] = mhlo.add %[[TMP_11]], %[[TMP_6]]
+ // CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_60]], %[[TMP_59]]
+ // CHECK: %[[TMP_62:.*]] = mhlo.multiply %[[TMP_61]], %[[TMP_58]]
+ // CHECK: %[[TMP_63:.*]] = "mhlo.log"(%[[TMP_52]])
+ // CHECK: %[[TMP_64:.*]] = mhlo.constant dense<0.918938517>
+ // CHECK: %[[TMP_65:.*]] = mhlo.add %[[TMP_64]], %[[TMP_62]]
+ // CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_65]], %[[TMP_63]]
+ // CHECK: %[[TMP_67:.*]] = "mhlo.abs"(%[[TMP_5]])
+ // CHECK: %[[TMP_68:.*]] = "mhlo.floor"(%[[TMP_67]])
+ // CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_67]], %[[TMP_68]]
+ // CHECK: %[[TMP_70:.*]] = "mhlo.compare"(%[[TMP_6]], %[[TMP_69]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_71:.*]] = mhlo.subtract %[[TMP_9]], %[[TMP_69]]
+ // CHECK: %[[TMP_72:.*]] = "mhlo.select"(%[[TMP_70]], %[[TMP_71]], %[[TMP_69]])
+ // CHECK: %[[TMP_73:.*]] = mhlo.constant dense<3.14159274>
+ // CHECK: %[[TMP_74:.*]] = mhlo.multiply %[[TMP_73]], %[[TMP_72]]
+ // CHECK: %[[TMP_75:.*]] = "mhlo.sine"(%[[TMP_74]])
+ // CHECK: %[[TMP_76:.*]] = "mhlo.log"(%[[TMP_75]])
+ // CHECK: %[[TMP_77:.*]] = mhlo.constant dense<1.14472985>
+ // CHECK: %[[TMP_78:.*]] = mhlo.subtract %[[TMP_77]], %[[TMP_76]]
+ // CHECK: %[[TMP_79:.*]] = mhlo.subtract %[[TMP_78]], %[[TMP_66]]
+ // CHECK: %[[TMP_80:.*]] = "mhlo.is_finite"(%[[TMP_76]])
+ // CHECK: %[[TMP_81:.*]] = "mhlo.negate"(%[[TMP_76]])
+ // CHECK: %[[TMP_82:.*]] = "mhlo.select"(%[[TMP_80]], %[[TMP_79]], %[[TMP_81]])
+ // CHECK: %[[TMP_83:.*]] = "mhlo.select"(%[[TMP_7]], %[[TMP_82]], %[[TMP_66]])
+ // CHECK: %[[TMP_84:.*]] = "mhlo.abs"(%[[TMP_5]])
+ // CHECK: %[[TMP_85:.*]] = mhlo.constant dense<0x7F800000>
+ // CHECK: %[[TMP_86:.*]] = "mhlo.compare"(%[[TMP_84]], %[[TMP_85]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_87:.*]] = mhlo.constant dense<0x7F800000>
+ // CHECK: %[[TMP_88:.*]] = "mhlo.select"(%[[TMP_86]], %[[TMP_87]], %[[TMP_83]])
+ // CHECK: %[[TMP_89:.*]] = "mhlo.exponential"(%[[TMP_88]])
+ // CHECK: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00>
+ // CHECK: %[[TMP_91:.*]] = "mhlo.negate"(%[[TMP_5]])
+ // CHECK: %[[TMP_92:.*]] = mhlo.power %[[ARG1]], %[[TMP_91]]
+ // CHECK: %[[TMP_93:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_93]]
+ // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_91]]
+ // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_92]], %[[TMP_95]]
+ // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_93]]
+ // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_91]]
+ // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_96]], %[[TMP_98]]
+ // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_93]]
+ // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_91]]
+ // CHECK: %[[TMP_102:.*]] = mhlo.add %[[TMP_99]], %[[TMP_101]]
+ // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_93]]
+ // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_91]]
+ // CHECK: %[[TMP_105:.*]] = mhlo.add %[[TMP_102]], %[[TMP_104]]
+ // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_93]]
+ // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_91]]
+ // CHECK: %[[TMP_108:.*]] = mhlo.add %[[TMP_105]], %[[TMP_107]]
+ // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_93]]
+ // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_91]]
+ // CHECK: %[[TMP_111:.*]] = mhlo.add %[[TMP_108]], %[[TMP_110]]
+ // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_93]]
+ // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_91]]
+ // CHECK: %[[TMP_114:.*]] = mhlo.add %[[TMP_111]], %[[TMP_113]]
+ // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_93]]
+ // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_91]]
+ // CHECK: %[[TMP_117:.*]] = mhlo.add %[[TMP_114]], %[[TMP_116]]
+ // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_93]]
+ // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_91]]
+ // CHECK: %[[TMP_120:.*]] = mhlo.add %[[TMP_117]], %[[TMP_119]]
+ // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_93]]
+ // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_91]]
+ // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_124:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]]
+ // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]]
+ // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_125]], %[[TMP_124]]
+ // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]]
+ // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]]
+ // CHECK: %[[TMP_129:.*]] = mhlo.divide %[[TMP_93]], %[[TMP_128]]
+ // CHECK: %[[TMP_130:.*]] = mhlo.constant dense<2.200000e+01>
+ // CHECK: %[[TMP_131:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_130]]
+ // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<2.100000e+01>
+ // CHECK: %[[TMP_133:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_132]]
+ // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_131]], %[[TMP_133]]
+ // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<-1.39544646E-19>
+ // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_90]], %[[TMP_135]]
+ // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_136]]
+ // CHECK: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_134]], %[[TMP_137]]
+ // CHECK: %[[TMP_139:.*]] = mhlo.constant dense<2.000000e+01>
+ // CHECK: %[[TMP_140:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_139]]
+ // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<1.900000e+01>
+ // CHECK: %[[TMP_142:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_141]]
+ // CHECK: %[[TMP_143:.*]] = mhlo.multiply %[[TMP_140]], %[[TMP_142]]
+ // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<5.50900303E-18>
+ // CHECK: %[[TMP_145:.*]] = mhlo.add %[[TMP_138]], %[[TMP_144]]
+ // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_145]]
+ // CHECK: %[[TMP_147:.*]] = mhlo.multiply %[[TMP_143]], %[[TMP_146]]
+ // CHECK: %[[TMP_148:.*]] = mhlo.constant dense<1.800000e+01>
+ // CHECK: %[[TMP_149:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_148]]
+ // CHECK: %[[TMP_150:.*]] = mhlo.constant dense<1.700000e+01>
+ // CHECK: %[[TMP_151:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_150]]
+ // CHECK: %[[TMP_152:.*]] = mhlo.multiply %[[TMP_149]], %[[TMP_151]]
+ // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<-2.17486866E-16>
+ // CHECK: %[[TMP_154:.*]] = mhlo.add %[[TMP_147]], %[[TMP_153]]
+ // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_154]]
+ // CHECK: %[[TMP_156:.*]] = mhlo.multiply %[[TMP_152]], %[[TMP_155]]
+ // CHECK: %[[TMP_157:.*]] = mhlo.constant dense<1.600000e+01>
+ // CHECK: %[[TMP_158:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_157]]
+ // CHECK: %[[TMP_159:.*]] = mhlo.constant dense<1.500000e+01>
+ // CHECK: %[[TMP_160:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_159]]
+ // CHECK: %[[TMP_161:.*]] = mhlo.multiply %[[TMP_158]], %[[TMP_160]]
+ // CHECK: %[[TMP_162:.*]] = mhlo.constant dense<8.58606213E-15>
+ // CHECK: %[[TMP_163:.*]] = mhlo.add %[[TMP_156]], %[[TMP_162]]
+ // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_163]]
+ // CHECK: %[[TMP_165:.*]] = mhlo.multiply %[[TMP_161]], %[[TMP_164]]
+ // CHECK: %[[TMP_166:.*]] = mhlo.constant dense<1.400000e+01>
+ // CHECK: %[[TMP_167:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_166]]
+ // CHECK: %[[TMP_168:.*]] = mhlo.constant dense<1.300000e+01>
+ // CHECK: %[[TMP_169:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_168]]
+ // CHECK: %[[TMP_170:.*]] = mhlo.multiply %[[TMP_167]], %[[TMP_169]]
+ // CHECK: %[[TMP_171:.*]] = mhlo.constant dense<-3.3896803E-13>
+ // CHECK: %[[TMP_172:.*]] = mhlo.add %[[TMP_165]], %[[TMP_171]]
+ // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_172]]
+ // CHECK: %[[TMP_174:.*]] = mhlo.multiply %[[TMP_170]], %[[TMP_173]]
+ // CHECK: %[[TMP_175:.*]] = mhlo.constant dense<1.200000e+01>
+ // CHECK: %[[TMP_176:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_175]]
+ // CHECK: %[[TMP_177:.*]] = mhlo.constant dense<1.100000e+01>
+ // CHECK: %[[TMP_178:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_177]]
+ // CHECK: %[[TMP_179:.*]] = mhlo.multiply %[[TMP_176]], %[[TMP_178]]
+ // CHECK: %[[TMP_180:.*]] = mhlo.constant dense<1.33825364E-11>
+ // CHECK: %[[TMP_181:.*]] = mhlo.add %[[TMP_174]], %[[TMP_180]]
+ // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_181]]
+ // CHECK: %[[TMP_183:.*]] = mhlo.multiply %[[TMP_179]], %[[TMP_182]]
+ // CHECK: %[[TMP_184:.*]] = mhlo.constant dense<1.000000e+01>
+ // CHECK: %[[TMP_185:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_184]]
+ // CHECK: %[[TMP_186:.*]] = mhlo.constant dense<9.000000e+00>
+ // CHECK: %[[TMP_187:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_186]]
+ // CHECK: %[[TMP_188:.*]] = mhlo.multiply %[[TMP_185]], %[[TMP_187]]
+ // CHECK: %[[TMP_189:.*]] = mhlo.constant dense<-5.28419031E-10>
+ // CHECK: %[[TMP_190:.*]] = mhlo.add %[[TMP_183]], %[[TMP_189]]
+ // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_190]]
+ // CHECK: %[[TMP_192:.*]] = mhlo.multiply %[[TMP_188]], %[[TMP_191]]
+ // CHECK: %[[TMP_193:.*]] = mhlo.constant dense<8.000000e+00>
+ // CHECK: %[[TMP_194:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_193]]
+ // CHECK: %[[TMP_195:.*]] = mhlo.constant dense<7.000000e+00>
+ // CHECK: %[[TMP_196:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_195]]
+ // CHECK: %[[TMP_197:.*]] = mhlo.multiply %[[TMP_194]], %[[TMP_196]]
+ // CHECK: %[[TMP_198:.*]] = mhlo.constant dense<2.08767563E-8>
+ // CHECK: %[[TMP_199:.*]] = mhlo.add %[[TMP_192]], %[[TMP_198]]
+ // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_199]]
+ // CHECK: %[[TMP_201:.*]] = mhlo.multiply %[[TMP_197]], %[[TMP_200]]
+ // CHECK: %[[TMP_202:.*]] = mhlo.constant dense<6.000000e+00>
+ // CHECK: %[[TMP_203:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_202]]
+ // CHECK: %[[TMP_204:.*]] = mhlo.constant dense<5.000000e+00>
+ // CHECK: %[[TMP_205:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_204]]
+ // CHECK: %[[TMP_206:.*]] = mhlo.multiply %[[TMP_203]], %[[TMP_205]]
+ // CHECK: %[[TMP_207:.*]] = mhlo.constant dense<-8.26719599E-7>
+ // CHECK: %[[TMP_208:.*]] = mhlo.add %[[TMP_201]], %[[TMP_207]]
+ // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_208]]
+ // CHECK: %[[TMP_210:.*]] = mhlo.multiply %[[TMP_206]], %[[TMP_209]]
+ // CHECK: %[[TMP_211:.*]] = mhlo.constant dense<4.000000e+00>
+ // CHECK: %[[TMP_212:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_211]]
+ // CHECK: %[[TMP_213:.*]] = mhlo.constant dense<3.000000e+00>
+ // CHECK: %[[TMP_214:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_213]]
+ // CHECK: %[[TMP_215:.*]] = mhlo.multiply %[[TMP_212]], %[[TMP_214]]
+ // CHECK: %[[TMP_216:.*]] = mhlo.constant dense<3.30687835E-5>
+ // CHECK: %[[TMP_217:.*]] = mhlo.add %[[TMP_210]], %[[TMP_216]]
+ // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_217]]
+ // CHECK: %[[TMP_219:.*]] = mhlo.multiply %[[TMP_215]], %[[TMP_218]]
+ // CHECK: %[[TMP_220:.*]] = mhlo.constant dense<2.000000e+00>
+ // CHECK: %[[TMP_221:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_220]]
+ // CHECK: %[[TMP_222:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_223:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_222]]
+ // CHECK: %[[TMP_224:.*]] = mhlo.multiply %[[TMP_221]], %[[TMP_223]]
+ // CHECK: %[[TMP_225:.*]] = mhlo.constant dense<-0.00138888892>
+ // CHECK: %[[TMP_226:.*]] = mhlo.add %[[TMP_219]], %[[TMP_225]]
+ // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_226]]
+ // CHECK: %[[TMP_228:.*]] = mhlo.multiply %[[TMP_224]], %[[TMP_227]]
+ // CHECK: %[[TMP_229:.*]] = mhlo.constant dense<5.000000e-01>
+ // CHECK: %[[TMP_230:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]]
+ // CHECK: %[[TMP_231:.*]] = mhlo.constant dense<0.0833333358>
+ // CHECK: %[[TMP_232:.*]] = mhlo.add %[[TMP_231]], %[[TMP_228]]
+ // CHECK: %[[TMP_233:.*]] = mhlo.multiply %[[TMP_230]], %[[TMP_232]]
+ // CHECK: %[[TMP_234:.*]] = mhlo.add %[[TMP_229]], %[[TMP_233]]
+ // CHECK: %[[TMP_235:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_234]]
+ // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_127]], %[[TMP_235]]
+ // CHECK: %[[TMP_237:.*]] = "mhlo.abs"(%[[TMP_122]])
+ // CHECK: %[[TMP_238:.*]] = "mhlo.abs"(%[[TMP_120]])
+ // CHECK: %[[TMP_239:.*]] = mhlo.constant dense<1.401300e-45>
+ // CHECK: %[[TMP_240:.*]] = mhlo.multiply %[[TMP_238]], %[[TMP_239]]
+ // CHECK: %[[TMP_241:.*]] = "mhlo.compare"(%[[TMP_237]], %[[TMP_240]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_242:.*]] = "mhlo.select"(%[[TMP_241]], %[[TMP_120]], %[[TMP_236]])
+ // CHECK: %[[TMP_243:.*]] = mhlo.constant dense<0x7F800000>
+ // CHECK: %[[TMP_244:.*]] = "mhlo.compare"(%[[TMP_5]], %[[TMP_123]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_245:.*]] = "mhlo.select"(%[[TMP_244]], %[[TMP_243]], %[[TMP_242]])
+ // CHECK: %[[TMP_246:.*]] = mhlo.constant dense<0x7FC00000>
+ // CHECK: %[[TMP_247:.*]] = "mhlo.compare"(%[[TMP_5]], %[[TMP_123]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_248:.*]] = "mhlo.select"(%[[TMP_247]], %[[TMP_246]], %[[TMP_245]])
+ // CHECK: %[[TMP_249:.*]] = mhlo.constant dense<0.000000e+00>
+ // CHECK: %[[TMP_250:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_249]]) {comparison_direction = "LE"}
+ // CHECK: %[[TMP_251:.*]] = "mhlo.floor"(%[[TMP_5]])
+ // CHECK: %[[TMP_252:.*]] = "mhlo.compare"(%[[TMP_5]], %[[TMP_251]]) {comparison_direction = "NE"}
+ // CHECK: %[[TMP_253:.*]] = mhlo.and %[[TMP_250]], %[[TMP_252]]
+ // CHECK: %[[TMP_254:.*]] = "mhlo.floor"(%[[ARG1]])
+ // CHECK: %[[TMP_255:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_254]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_256:.*]] = mhlo.and %[[TMP_250]], %[[TMP_255]]
+ // CHECK: %[[TMP_257:.*]] = "mhlo.select"(%[[TMP_256]], %[[TMP_243]], %[[TMP_248]])
+ // CHECK: %[[TMP_258:.*]] = "mhlo.select"(%[[TMP_253]], %[[TMP_246]], %[[TMP_257]])
+ // CHECK: %[[TMP_259:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_89]]
+ // CHECK: %[[TMP_260:.*]] = mhlo.multiply %[[TMP_259]], %[[TMP_258]]
+ // CHECK: %[[TMP_261:.*]] = mhlo.constant dense<0.000000e+00>
+ // CHECK: %[[TMP_262:.*]] = "mhlo.compare"(%[[ARG0]], %[[TMP_261]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_263:.*]] = mhlo.constant dense<5.000000e-01>
+ // CHECK: %[[TMP_264:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_263]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_265:.*]] = "mhlo.negate"(%[[ARG1]])
+ // CHECK: %[[TMP_266:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_267:.*]] = mhlo.subtract %[[ARG1]], %[[TMP_266]]
+ // CHECK: %[[TMP_268:.*]] = "mhlo.select"(%[[TMP_264]], %[[TMP_265]], %[[TMP_267]])
+ // CHECK: %[[TMP_269:.*]] = mhlo.constant dense<0.000000e+00>
+ // CHECK: %[[TMP_270:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_271:.*]] = mhlo.constant dense<676.520386>
+ // CHECK: %[[TMP_272:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_273:.*]] = mhlo.add %[[TMP_268]], %[[TMP_272]]
+ // CHECK: %[[TMP_274:.*]] = mhlo.multiply %[[TMP_273]], %[[TMP_273]]
+ // CHECK: %[[TMP_275:.*]] = mhlo.divide %[[TMP_271]], %[[TMP_274]]
+ // CHECK: %[[TMP_276:.*]] = mhlo.subtract %[[TMP_269]], %[[TMP_275]]
+ // CHECK: %[[TMP_277:.*]] = mhlo.divide %[[TMP_271]], %[[TMP_273]]
+ // CHECK: %[[TMP_278:.*]] = mhlo.add %[[TMP_270]], %[[TMP_277]]
+ // CHECK: %[[TMP_279:.*]] = mhlo.constant dense<-1259.13916>
+ // CHECK: %[[TMP_280:.*]] = mhlo.constant dense<2.000000e+00>
+ // CHECK: %[[TMP_281:.*]] = mhlo.add %[[TMP_268]], %[[TMP_280]]
+ // CHECK: %[[TMP_282:.*]] = mhlo.multiply %[[TMP_281]], %[[TMP_281]]
+ // CHECK: %[[TMP_283:.*]] = mhlo.divide %[[TMP_279]], %[[TMP_282]]
+ // CHECK: %[[TMP_284:.*]] = mhlo.subtract %[[TMP_276]], %[[TMP_283]]
+ // CHECK: %[[TMP_285:.*]] = mhlo.divide %[[TMP_279]], %[[TMP_281]]
+ // CHECK: %[[TMP_286:.*]] = mhlo.add %[[TMP_278]], %[[TMP_285]]
+ // CHECK: %[[TMP_287:.*]] = mhlo.constant dense<771.323425>
+ // CHECK: %[[TMP_288:.*]] = mhlo.constant dense<3.000000e+00>
+ // CHECK: %[[TMP_289:.*]] = mhlo.add %[[TMP_268]], %[[TMP_288]]
+ // CHECK: %[[TMP_290:.*]] = mhlo.multiply %[[TMP_289]], %[[TMP_289]]
+ // CHECK: %[[TMP_291:.*]] = mhlo.divide %[[TMP_287]], %[[TMP_290]]
+ // CHECK: %[[TMP_292:.*]] = mhlo.subtract %[[TMP_284]], %[[TMP_291]]
+ // CHECK: %[[TMP_293:.*]] = mhlo.divide %[[TMP_287]], %[[TMP_289]]
+ // CHECK: %[[TMP_294:.*]] = mhlo.add %[[TMP_286]], %[[TMP_293]]
+ // CHECK: %[[TMP_295:.*]] = mhlo.constant dense<-176.615036>
+ // CHECK: %[[TMP_296:.*]] = mhlo.constant dense<4.000000e+00>
+ // CHECK: %[[TMP_297:.*]] = mhlo.add %[[TMP_268]], %[[TMP_296]]
+ // CHECK: %[[TMP_298:.*]] = mhlo.multiply %[[TMP_297]], %[[TMP_297]]
+ // CHECK: %[[TMP_299:.*]] = mhlo.divide %[[TMP_295]], %[[TMP_298]]
+ // CHECK: %[[TMP_300:.*]] = mhlo.subtract %[[TMP_292]], %[[TMP_299]]
+ // CHECK: %[[TMP_301:.*]] = mhlo.divide %[[TMP_295]], %[[TMP_297]]
+ // CHECK: %[[TMP_302:.*]] = mhlo.add %[[TMP_294]], %[[TMP_301]]
+ // CHECK: %[[TMP_303:.*]] = mhlo.constant dense<12.5073433>
+ // CHECK: %[[TMP_304:.*]] = mhlo.constant dense<5.000000e+00>
+ // CHECK: %[[TMP_305:.*]] = mhlo.add %[[TMP_268]], %[[TMP_304]]
+ // CHECK: %[[TMP_306:.*]] = mhlo.multiply %[[TMP_305]], %[[TMP_305]]
+ // CHECK: %[[TMP_307:.*]] = mhlo.divide %[[TMP_303]], %[[TMP_306]]
+ // CHECK: %[[TMP_308:.*]] = mhlo.subtract %[[TMP_300]], %[[TMP_307]]
+ // CHECK: %[[TMP_309:.*]] = mhlo.divide %[[TMP_303]], %[[TMP_305]]
+ // CHECK: %[[TMP_310:.*]] = mhlo.add %[[TMP_302]], %[[TMP_309]]
+ // CHECK: %[[TMP_311:.*]] = mhlo.constant dense<-0.138571098>
+ // CHECK: %[[TMP_312:.*]] = mhlo.constant dense<6.000000e+00>
+ // CHECK: %[[TMP_313:.*]] = mhlo.add %[[TMP_268]], %[[TMP_312]]
+ // CHECK: %[[TMP_314:.*]] = mhlo.multiply %[[TMP_313]], %[[TMP_313]]
+ // CHECK: %[[TMP_315:.*]] = mhlo.divide %[[TMP_311]], %[[TMP_314]]
+ // CHECK: %[[TMP_316:.*]] = mhlo.subtract %[[TMP_308]], %[[TMP_315]]
+ // CHECK: %[[TMP_317:.*]] = mhlo.divide %[[TMP_311]], %[[TMP_313]]
+ // CHECK: %[[TMP_318:.*]] = mhlo.add %[[TMP_310]], %[[TMP_317]]
+ // CHECK: %[[TMP_319:.*]] = mhlo.constant dense<9.98436917E-6>
+ // CHECK: %[[TMP_320:.*]] = mhlo.constant dense<7.000000e+00>
+ // CHECK: %[[TMP_321:.*]] = mhlo.add %[[TMP_268]], %[[TMP_320]]
+ // CHECK: %[[TMP_322:.*]] = mhlo.multiply %[[TMP_321]], %[[TMP_321]]
+ // CHECK: %[[TMP_323:.*]] = mhlo.divide %[[TMP_319]], %[[TMP_322]]
+ // CHECK: %[[TMP_324:.*]] = mhlo.subtract %[[TMP_316]], %[[TMP_323]]
+ // CHECK: %[[TMP_325:.*]] = mhlo.divide %[[TMP_319]], %[[TMP_321]]
+ // CHECK: %[[TMP_326:.*]] = mhlo.add %[[TMP_318]], %[[TMP_325]]
+ // CHECK: %[[TMP_327:.*]] = mhlo.constant dense<1.50563267E-7>
+ // CHECK: %[[TMP_328:.*]] = mhlo.constant dense<8.000000e+00>
+ // CHECK: %[[TMP_329:.*]] = mhlo.add %[[TMP_268]], %[[TMP_328]]
+ // CHECK: %[[TMP_330:.*]] = mhlo.multiply %[[TMP_329]], %[[TMP_329]]
+ // CHECK: %[[TMP_331:.*]] = mhlo.divide %[[TMP_327]], %[[TMP_330]]
+ // CHECK: %[[TMP_332:.*]] = mhlo.subtract %[[TMP_324]], %[[TMP_331]]
+ // CHECK: %[[TMP_333:.*]] = mhlo.divide %[[TMP_327]], %[[TMP_329]]
+ // CHECK: %[[TMP_334:.*]] = mhlo.add %[[TMP_326]], %[[TMP_333]]
+ // CHECK: %[[TMP_335:.*]] = mhlo.constant dense<7.500000e+00>
+ // CHECK: %[[TMP_336:.*]] = mhlo.add %[[TMP_335]], %[[TMP_268]]
+ // CHECK: %[[TMP_337:.*]] = mhlo.constant dense<2.01490307>
+ // CHECK: %[[TMP_338:.*]] = mhlo.divide %[[TMP_268]], %[[TMP_335]]
+ // CHECK: %[[TMP_339:.*]] = "mhlo.log_plus_one"(%[[TMP_338]])
+ // CHECK: %[[TMP_340:.*]] = mhlo.add %[[TMP_337]], %[[TMP_339]]
+ // CHECK: %[[TMP_341:.*]] = mhlo.divide %[[TMP_332]], %[[TMP_334]]
+ // CHECK: %[[TMP_342:.*]] = mhlo.constant dense<7.000000e+00>
+ // CHECK: %[[TMP_343:.*]] = mhlo.divide %[[TMP_342]], %[[TMP_336]]
+ // CHECK: %[[TMP_344:.*]] = mhlo.add %[[TMP_340]], %[[TMP_341]]
+ // CHECK: %[[TMP_345:.*]] = mhlo.subtract %[[TMP_344]], %[[TMP_343]]
+ // CHECK: %[[TMP_346:.*]] = mhlo.constant dense<5.000000e-01>
+ // CHECK: %[[TMP_347:.*]] = mhlo.add %[[ARG1]], %[[TMP_346]]
+ // CHECK: %[[TMP_348:.*]] = "mhlo.floor"(%[[TMP_347]])
+ // CHECK: %[[TMP_349:.*]] = "mhlo.abs"(%[[TMP_348]])
+ // CHECK: %[[TMP_350:.*]] = mhlo.add %[[ARG1]], %[[TMP_349]]
+ // CHECK: %[[TMP_351:.*]] = mhlo.constant dense<3.14159274>
+ // CHECK: %[[TMP_352:.*]] = mhlo.multiply %[[TMP_351]], %[[TMP_350]]
+ // CHECK: %[[TMP_353:.*]] = "mhlo.cosine"(%[[TMP_352]])
+ // CHECK: %[[TMP_354:.*]] = "mhlo.sine"(%[[TMP_352]])
+ // CHECK: %[[TMP_355:.*]] = mhlo.multiply %[[TMP_351]], %[[TMP_353]]
+ // CHECK: %[[TMP_356:.*]] = mhlo.divide %[[TMP_355]], %[[TMP_354]]
+ // CHECK: %[[TMP_357:.*]] = mhlo.subtract %[[TMP_345]], %[[TMP_356]]
+ // CHECK: %[[TMP_358:.*]] = "mhlo.select"(%[[TMP_264]], %[[TMP_357]], %[[TMP_345]])
+ // CHECK: %[[TMP_359:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_269]]) {comparison_direction = "LE"}
+ // CHECK: %[[TMP_360:.*]] = "mhlo.floor"(%[[ARG1]])
+ // CHECK: %[[TMP_361:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_360]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_362:.*]] = mhlo.and %[[TMP_359]], %[[TMP_361]]
+ // CHECK: %[[TMP_363:.*]] = mhlo.constant dense<0x7FC00000>
+ // CHECK: %[[TMP_364:.*]] = "mhlo.select"(%[[TMP_362]], %[[TMP_363]], %[[TMP_358]])
+ // CHECK: %[[TMP_365:.*]] = "mhlo.select"(%[[TMP_262]], %[[TMP_364]], %[[TMP_260]])
+ // CHECK: %[[TMP_366:.*]] = "mhlo.floor"(%[[ARG0]])
+ // CHECK: %[[TMP_367:.*]] = "mhlo.compare"(%[[ARG0]], %[[TMP_366]]) {comparison_direction = "NE"}
+ // CHECK: %[[TMP_368:.*]] = "mhlo.compare"(%[[ARG0]], %[[TMP_261]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_369:.*]] = mhlo.or %[[TMP_367]], %[[TMP_368]]
+ // CHECK: %[[TMP_370:.*]] = mhlo.constant dense<0x7FC00000>
+ // CHECK: %[[TMP_371:.*]] = "mhlo.select"(%[[TMP_369]], %[[TMP_370]], %[[TMP_365]])
+ // CHECK: return %[[TMP_371]]
+ %1 = chlo.polygamma %lhs, %rhs : tensor<f32>, tensor<f32> -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// CHECK: @polygamma_f64
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<f64>, %[[ARG1:.*]]: tensor<f64>)
+func @polygamma_f64(%lhs : tensor<f64>, %rhs : tensor<f64>) -> tensor<f64> {
+ // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_1:.*]] = mhlo.constant dense<2.000000e+00>
+ // CHECK: %[[TMP_2:.*]] = mhlo.remainder %[[ARG0]], %[[TMP_1]]
+ // CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_1]], %[[TMP_2]]
+ // CHECK: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_3]], %[[TMP_0]]
+ // CHECK: %[[TMP_5:.*]] = mhlo.add %[[ARG0]], %[[TMP_0]]
+ // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<5.000000e-01>
+ // CHECK: %[[TMP_7:.*]] = "mhlo.compare"(%[[TMP_5]], %[[TMP_6]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_8:.*]] = "mhlo.negate"(%[[TMP_5]])
+ // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_10:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_9]]
+ // CHECK: %[[TMP_11:.*]] = "mhlo.select"(%[[TMP_7]], %[[TMP_8]], %[[TMP_10]])
+ // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<0.99999999999980993>
+ // CHECK: %[[TMP_13:.*]] = mhlo.constant dense<676.5203681218851>
+ // CHECK: %[[TMP_14:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_11]], %[[TMP_14]]
+ // CHECK: %[[TMP_16:.*]] = mhlo.divide %[[TMP_13]], %[[TMP_15]]
+ // CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_12]], %[[TMP_16]]
+ // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-1259.1392167224028>
+ // CHECK: %[[TMP_19:.*]] = mhlo.constant dense<2.000000e+00>
+ // CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_11]], %[[TMP_19]]
+ // CHECK: %[[TMP_21:.*]] = mhlo.divide %[[TMP_18]], %[[TMP_20]]
+ // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_17]], %[[TMP_21]]
+ // CHECK: %[[TMP_23:.*]] = mhlo.constant dense<771.32342877765313>
+ // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<3.000000e+00>
+ // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_11]], %[[TMP_24]]
+ // CHECK: %[[TMP_26:.*]] = mhlo.divide %[[TMP_23]], %[[TMP_25]]
+ // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_22]], %[[TMP_26]]
+ // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-176.61502916214059>
+ // CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4.000000e+00>
+ // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_11]], %[[TMP_29]]
+ // CHECK: %[[TMP_31:.*]] = mhlo.divide %[[TMP_28]], %[[TMP_30]]
+ // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_27]], %[[TMP_31]]
+ // CHECK: %[[TMP_33:.*]] = mhlo.constant dense<12.507343278686905>
+ // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<5.000000e+00>
+ // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_11]], %[[TMP_34]]
+ // CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_33]], %[[TMP_35]]
+ // CHECK: %[[TMP_37:.*]] = mhlo.add %[[TMP_32]], %[[TMP_36]]
+ // CHECK: %[[TMP_38:.*]] = mhlo.constant dense<-0.13857109526572012>
+ // CHECK: %[[TMP_39:.*]] = mhlo.constant dense<6.000000e+00>
+ // CHECK: %[[TMP_40:.*]] = mhlo.add %[[TMP_11]], %[[TMP_39]]
+ // CHECK: %[[TMP_41:.*]] = mhlo.divide %[[TMP_38]], %[[TMP_40]]
+ // CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_37]], %[[TMP_41]]
+ // CHECK: %[[TMP_43:.*]] = mhlo.constant dense<9.9843695780195716E-6>
+ // CHECK: %[[TMP_44:.*]] = mhlo.constant dense<7.000000e+00>
+ // CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_11]], %[[TMP_44]]
+ // CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_43]], %[[TMP_45]]
+ // CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_42]], %[[TMP_46]]
+ // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<1.5056327351493116E-7>
+ // CHECK: %[[TMP_49:.*]] = mhlo.constant dense<8.000000e+00>
+ // CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_11]], %[[TMP_49]]
+ // CHECK: %[[TMP_51:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
+ // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_47]], %[[TMP_51]]
+ // CHECK: %[[TMP_53:.*]] = mhlo.constant dense<7.500000e+00>
+ // CHECK: %[[TMP_54:.*]] = mhlo.add %[[TMP_53]], %[[TMP_11]]
+ // CHECK: %[[TMP_55:.*]] = mhlo.constant dense<2.0149030205422647>
+ // CHECK: %[[TMP_56:.*]] = mhlo.divide %[[TMP_11]], %[[TMP_53]]
+ // CHECK: %[[TMP_57:.*]] = "mhlo.log_plus_one"(%[[TMP_56]])
+ // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_55]], %[[TMP_57]]
+ // CHECK: %[[TMP_59:.*]] = mhlo.divide %[[TMP_54]], %[[TMP_58]]
+ // CHECK: %[[TMP_60:.*]] = mhlo.add %[[TMP_11]], %[[TMP_6]]
+ // CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_60]], %[[TMP_59]]
+ // CHECK: %[[TMP_62:.*]] = mhlo.multiply %[[TMP_61]], %[[TMP_58]]
+ // CHECK: %[[TMP_63:.*]] = "mhlo.log"(%[[TMP_52]])
+ // CHECK: %[[TMP_64:.*]] = mhlo.constant dense<0.91893853320467266>
+ // CHECK: %[[TMP_65:.*]] = mhlo.add %[[TMP_64]], %[[TMP_62]]
+ // CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_65]], %[[TMP_63]]
+ // CHECK: %[[TMP_67:.*]] = "mhlo.abs"(%[[TMP_5]])
+ // CHECK: %[[TMP_68:.*]] = "mhlo.floor"(%[[TMP_67]])
+ // CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_67]], %[[TMP_68]]
+ // CHECK: %[[TMP_70:.*]] = "mhlo.compare"(%[[TMP_6]], %[[TMP_69]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_71:.*]] = mhlo.subtract %[[TMP_9]], %[[TMP_69]]
+ // CHECK: %[[TMP_72:.*]] = "mhlo.select"(%[[TMP_70]], %[[TMP_71]], %[[TMP_69]])
+ // CHECK: %[[TMP_73:.*]] = mhlo.constant dense<3.1415926535897931>
+ // CHECK: %[[TMP_74:.*]] = mhlo.multiply %[[TMP_73]], %[[TMP_72]]
+ // CHECK: %[[TMP_75:.*]] = "mhlo.sine"(%[[TMP_74]])
+ // CHECK: %[[TMP_76:.*]] = "mhlo.log"(%[[TMP_75]])
+ // CHECK: %[[TMP_77:.*]] = mhlo.constant dense<1.1447298858494002>
+ // CHECK: %[[TMP_78:.*]] = mhlo.subtract %[[TMP_77]], %[[TMP_76]]
+ // CHECK: %[[TMP_79:.*]] = mhlo.subtract %[[TMP_78]], %[[TMP_66]]
+ // CHECK: %[[TMP_80:.*]] = "mhlo.is_finite"(%[[TMP_76]])
+ // CHECK: %[[TMP_81:.*]] = "mhlo.negate"(%[[TMP_76]])
+ // CHECK: %[[TMP_82:.*]] = "mhlo.select"(%[[TMP_80]], %[[TMP_79]], %[[TMP_81]])
+ // CHECK: %[[TMP_83:.*]] = "mhlo.select"(%[[TMP_7]], %[[TMP_82]], %[[TMP_66]])
+ // CHECK: %[[TMP_84:.*]] = "mhlo.abs"(%[[TMP_5]])
+ // CHECK: %[[TMP_85:.*]] = mhlo.constant dense<0x7FF0000000000000>
+ // CHECK: %[[TMP_86:.*]] = "mhlo.compare"(%[[TMP_84]], %[[TMP_85]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_87:.*]] = mhlo.constant dense<0x7FF0000000000000>
+ // CHECK: %[[TMP_88:.*]] = "mhlo.select"(%[[TMP_86]], %[[TMP_87]], %[[TMP_83]])
+ // CHECK: %[[TMP_89:.*]] = "mhlo.exponential"(%[[TMP_88]])
+ // CHECK: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00>
+ // CHECK: %[[TMP_91:.*]] = "mhlo.negate"(%[[TMP_5]])
+ // CHECK: %[[TMP_92:.*]] = mhlo.power %[[ARG1]], %[[TMP_91]]
+ // CHECK: %[[TMP_93:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_93]]
+ // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_91]]
+ // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_92]], %[[TMP_95]]
+ // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_93]]
+ // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_91]]
+ // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_96]], %[[TMP_98]]
+ // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_93]]
+ // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_91]]
+ // CHECK: %[[TMP_102:.*]] = mhlo.add %[[TMP_99]], %[[TMP_101]]
+ // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_93]]
+ // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_91]]
+ // CHECK: %[[TMP_105:.*]] = mhlo.add %[[TMP_102]], %[[TMP_104]]
+ // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_93]]
+ // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_91]]
+ // CHECK: %[[TMP_108:.*]] = mhlo.add %[[TMP_105]], %[[TMP_107]]
+ // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_93]]
+ // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_91]]
+ // CHECK: %[[TMP_111:.*]] = mhlo.add %[[TMP_108]], %[[TMP_110]]
+ // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_93]]
+ // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_91]]
+ // CHECK: %[[TMP_114:.*]] = mhlo.add %[[TMP_111]], %[[TMP_113]]
+ // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_93]]
+ // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_91]]
+ // CHECK: %[[TMP_117:.*]] = mhlo.add %[[TMP_114]], %[[TMP_116]]
+ // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_93]]
+ // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_91]]
+ // CHECK: %[[TMP_120:.*]] = mhlo.add %[[TMP_117]], %[[TMP_119]]
+ // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_93]]
+ // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_91]]
+ // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_124:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]]
+ // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]]
+ // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_125]], %[[TMP_124]]
+ // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]]
+ // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]]
+ // CHECK: %[[TMP_129:.*]] = mhlo.divide %[[TMP_93]], %[[TMP_128]]
+ // CHECK: %[[TMP_130:.*]] = mhlo.constant dense<2.200000e+01>
+ // CHECK: %[[TMP_131:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_130]]
+ // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<2.100000e+01>
+ // CHECK: %[[TMP_133:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_132]]
+ // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_131]], %[[TMP_133]]
+ // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<-1.3954464685812522E-19>
+ // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_90]], %[[TMP_135]]
+ // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_136]]
+ // CHECK: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_134]], %[[TMP_137]]
+ // CHECK: %[[TMP_139:.*]] = mhlo.constant dense<2.000000e+01>
+ // CHECK: %[[TMP_140:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_139]]
+ // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<1.900000e+01>
+ // CHECK: %[[TMP_142:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_141]]
+ // CHECK: %[[TMP_143:.*]] = mhlo.multiply %[[TMP_140]], %[[TMP_142]]
+ // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<5.5090028283602295E-18>
+ // CHECK: %[[TMP_145:.*]] = mhlo.add %[[TMP_138]], %[[TMP_144]]
+ // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_145]]
+ // CHECK: %[[TMP_147:.*]] = mhlo.multiply %[[TMP_143]], %[[TMP_146]]
+ // CHECK: %[[TMP_148:.*]] = mhlo.constant dense<1.800000e+01>
+ // CHECK: %[[TMP_149:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_148]]
+ // CHECK: %[[TMP_150:.*]] = mhlo.constant dense<1.700000e+01>
+ // CHECK: %[[TMP_151:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_150]]
+ // CHECK: %[[TMP_152:.*]] = mhlo.multiply %[[TMP_149]], %[[TMP_151]]
+ // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<-2.1748686985580617E-16>
+ // CHECK: %[[TMP_154:.*]] = mhlo.add %[[TMP_147]], %[[TMP_153]]
+ // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_154]]
+ // CHECK: %[[TMP_156:.*]] = mhlo.multiply %[[TMP_152]], %[[TMP_155]]
+ // CHECK: %[[TMP_157:.*]] = mhlo.constant dense<1.600000e+01>
+ // CHECK: %[[TMP_158:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_157]]
+ // CHECK: %[[TMP_159:.*]] = mhlo.constant dense<1.500000e+01>
+ // CHECK: %[[TMP_160:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_159]]
+ // CHECK: %[[TMP_161:.*]] = mhlo.multiply %[[TMP_158]], %[[TMP_160]]
+ // CHECK: %[[TMP_162:.*]] = mhlo.constant dense<8.5860620562778452E-15>
+ // CHECK: %[[TMP_163:.*]] = mhlo.add %[[TMP_156]], %[[TMP_162]]
+ // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_163]]
+ // CHECK: %[[TMP_165:.*]] = mhlo.multiply %[[TMP_161]], %[[TMP_164]]
+ // CHECK: %[[TMP_166:.*]] = mhlo.constant dense<1.400000e+01>
+ // CHECK: %[[TMP_167:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_166]]
+ // CHECK: %[[TMP_168:.*]] = mhlo.constant dense<1.300000e+01>
+ // CHECK: %[[TMP_169:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_168]]
+ // CHECK: %[[TMP_170:.*]] = mhlo.multiply %[[TMP_167]], %[[TMP_169]]
+ // CHECK: %[[TMP_171:.*]] = mhlo.constant dense<-3.3896802963225832E-13>
+ // CHECK: %[[TMP_172:.*]] = mhlo.add %[[TMP_165]], %[[TMP_171]]
+ // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_172]]
+ // CHECK: %[[TMP_174:.*]] = mhlo.multiply %[[TMP_170]], %[[TMP_173]]
+ // CHECK: %[[TMP_175:.*]] = mhlo.constant dense<1.200000e+01>
+ // CHECK: %[[TMP_176:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_175]]
+ // CHECK: %[[TMP_177:.*]] = mhlo.constant dense<1.100000e+01>
+ // CHECK: %[[TMP_178:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_177]]
+ // CHECK: %[[TMP_179:.*]] = mhlo.multiply %[[TMP_176]], %[[TMP_178]]
+ // CHECK: %[[TMP_180:.*]] = mhlo.constant dense<1.3382536530684679E-11>
+ // CHECK: %[[TMP_181:.*]] = mhlo.add %[[TMP_174]], %[[TMP_180]]
+ // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_181]]
+ // CHECK: %[[TMP_183:.*]] = mhlo.multiply %[[TMP_179]], %[[TMP_182]]
+ // CHECK: %[[TMP_184:.*]] = mhlo.constant dense<1.000000e+01>
+ // CHECK: %[[TMP_185:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_184]]
+ // CHECK: %[[TMP_186:.*]] = mhlo.constant dense<9.000000e+00>
+ // CHECK: %[[TMP_187:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_186]]
+ // CHECK: %[[TMP_188:.*]] = mhlo.multiply %[[TMP_185]], %[[TMP_187]]
+ // CHECK: %[[TMP_189:.*]] = mhlo.constant dense<-5.2841901386874932E-10>
+ // CHECK: %[[TMP_190:.*]] = mhlo.add %[[TMP_183]], %[[TMP_189]]
+ // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_190]]
+ // CHECK: %[[TMP_192:.*]] = mhlo.multiply %[[TMP_188]], %[[TMP_191]]
+ // CHECK: %[[TMP_193:.*]] = mhlo.constant dense<8.000000e+00>
+ // CHECK: %[[TMP_194:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_193]]
+ // CHECK: %[[TMP_195:.*]] = mhlo.constant dense<7.000000e+00>
+ // CHECK: %[[TMP_196:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_195]]
+ // CHECK: %[[TMP_197:.*]] = mhlo.multiply %[[TMP_194]], %[[TMP_196]]
+ // CHECK: %[[TMP_198:.*]] = mhlo.constant dense<2.08767569878681E-8>
+ // CHECK: %[[TMP_199:.*]] = mhlo.add %[[TMP_192]], %[[TMP_198]]
+ // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_199]]
+ // CHECK: %[[TMP_201:.*]] = mhlo.multiply %[[TMP_197]], %[[TMP_200]]
+ // CHECK: %[[TMP_202:.*]] = mhlo.constant dense<6.000000e+00>
+ // CHECK: %[[TMP_203:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_202]]
+ // CHECK: %[[TMP_204:.*]] = mhlo.constant dense<5.000000e+00>
+ // CHECK: %[[TMP_205:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_204]]
+ // CHECK: %[[TMP_206:.*]] = mhlo.multiply %[[TMP_203]], %[[TMP_205]]
+ // CHECK: %[[TMP_207:.*]] = mhlo.constant dense<-8.2671957671957675E-7>
+ // CHECK: %[[TMP_208:.*]] = mhlo.add %[[TMP_201]], %[[TMP_207]]
+ // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_208]]
+ // CHECK: %[[TMP_210:.*]] = mhlo.multiply %[[TMP_206]], %[[TMP_209]]
+ // CHECK: %[[TMP_211:.*]] = mhlo.constant dense<4.000000e+00>
+ // CHECK: %[[TMP_212:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_211]]
+ // CHECK: %[[TMP_213:.*]] = mhlo.constant dense<3.000000e+00>
+ // CHECK: %[[TMP_214:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_213]]
+ // CHECK: %[[TMP_215:.*]] = mhlo.multiply %[[TMP_212]], %[[TMP_214]]
+ // CHECK: %[[TMP_216:.*]] = mhlo.constant dense<3.3068783068783071E-5>
+ // CHECK: %[[TMP_217:.*]] = mhlo.add %[[TMP_210]], %[[TMP_216]]
+ // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_217]]
+ // CHECK: %[[TMP_219:.*]] = mhlo.multiply %[[TMP_215]], %[[TMP_218]]
+ // CHECK: %[[TMP_220:.*]] = mhlo.constant dense<2.000000e+00>
+ // CHECK: %[[TMP_221:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_220]]
+ // CHECK: %[[TMP_222:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_223:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_222]]
+ // CHECK: %[[TMP_224:.*]] = mhlo.multiply %[[TMP_221]], %[[TMP_223]]
+ // CHECK: %[[TMP_225:.*]] = mhlo.constant dense<-0.0013888888888888889>
+ // CHECK: %[[TMP_226:.*]] = mhlo.add %[[TMP_219]], %[[TMP_225]]
+ // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_226]]
+ // CHECK: %[[TMP_228:.*]] = mhlo.multiply %[[TMP_224]], %[[TMP_227]]
+ // CHECK: %[[TMP_229:.*]] = mhlo.constant dense<5.000000e-01>
+ // CHECK: %[[TMP_230:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]]
+ // CHECK: %[[TMP_231:.*]] = mhlo.constant dense<0.083333333333333329>
+ // CHECK: %[[TMP_232:.*]] = mhlo.add %[[TMP_231]], %[[TMP_228]]
+ // CHECK: %[[TMP_233:.*]] = mhlo.multiply %[[TMP_230]], %[[TMP_232]]
+ // CHECK: %[[TMP_234:.*]] = mhlo.add %[[TMP_229]], %[[TMP_233]]
+ // CHECK: %[[TMP_235:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_234]]
+ // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_127]], %[[TMP_235]]
+ // CHECK: %[[TMP_237:.*]] = "mhlo.abs"(%[[TMP_122]])
+ // CHECK: %[[TMP_238:.*]] = "mhlo.abs"(%[[TMP_120]])
+ // CHECK: %[[TMP_239:.*]] = mhlo.constant dense<4.940660e-324>
+ // CHECK: %[[TMP_240:.*]] = mhlo.multiply %[[TMP_238]], %[[TMP_239]]
+ // CHECK: %[[TMP_241:.*]] = "mhlo.compare"(%[[TMP_237]], %[[TMP_240]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_242:.*]] = "mhlo.select"(%[[TMP_241]], %[[TMP_120]], %[[TMP_236]])
+ // CHECK: %[[TMP_243:.*]] = mhlo.constant dense<0x7FF0000000000000>
+ // CHECK: %[[TMP_244:.*]] = "mhlo.compare"(%[[TMP_5]], %[[TMP_123]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_245:.*]] = "mhlo.select"(%[[TMP_244]], %[[TMP_243]], %[[TMP_242]])
+ // CHECK: %[[TMP_246:.*]] = mhlo.constant dense<0x7FF8000000000000>
+ // CHECK: %[[TMP_247:.*]] = "mhlo.compare"(%[[TMP_5]], %[[TMP_123]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_248:.*]] = "mhlo.select"(%[[TMP_247]], %[[TMP_246]], %[[TMP_245]])
+ // CHECK: %[[TMP_249:.*]] = mhlo.constant dense<0.000000e+00>
+ // CHECK: %[[TMP_250:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_249]]) {comparison_direction = "LE"}
+ // CHECK: %[[TMP_251:.*]] = "mhlo.floor"(%[[TMP_5]])
+ // CHECK: %[[TMP_252:.*]] = "mhlo.compare"(%[[TMP_5]], %[[TMP_251]]) {comparison_direction = "NE"}
+ // CHECK: %[[TMP_253:.*]] = mhlo.and %[[TMP_250]], %[[TMP_252]]
+ // CHECK: %[[TMP_254:.*]] = "mhlo.floor"(%[[ARG1]])
+ // CHECK: %[[TMP_255:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_254]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_256:.*]] = mhlo.and %[[TMP_250]], %[[TMP_255]]
+ // CHECK: %[[TMP_257:.*]] = "mhlo.select"(%[[TMP_256]], %[[TMP_243]], %[[TMP_248]])
+ // CHECK: %[[TMP_258:.*]] = "mhlo.select"(%[[TMP_253]], %[[TMP_246]], %[[TMP_257]])
+ // CHECK: %[[TMP_259:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_89]]
+ // CHECK: %[[TMP_260:.*]] = mhlo.multiply %[[TMP_259]], %[[TMP_258]]
+ // CHECK: %[[TMP_261:.*]] = mhlo.constant dense<0.000000e+00>
+ // CHECK: %[[TMP_262:.*]] = "mhlo.compare"(%[[ARG0]], %[[TMP_261]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_263:.*]] = mhlo.constant dense<5.000000e-01>
+ // CHECK: %[[TMP_264:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_263]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_265:.*]] = "mhlo.negate"(%[[ARG1]])
+ // CHECK: %[[TMP_266:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_267:.*]] = mhlo.subtract %[[ARG1]], %[[TMP_266]]
+ // CHECK: %[[TMP_268:.*]] = "mhlo.select"(%[[TMP_264]], %[[TMP_265]], %[[TMP_267]])
+ // CHECK: %[[TMP_269:.*]] = mhlo.constant dense<0.000000e+00>
+ // CHECK: %[[TMP_270:.*]] = mhlo.constant dense<0.99999999999980993>
+ // CHECK: %[[TMP_271:.*]] = mhlo.constant dense<676.5203681218851>
+ // CHECK: %[[TMP_272:.*]] = mhlo.constant dense<1.000000e+00>
+ // CHECK: %[[TMP_273:.*]] = mhlo.add %[[TMP_268]], %[[TMP_272]]
+ // CHECK: %[[TMP_274:.*]] = mhlo.multiply %[[TMP_273]], %[[TMP_273]]
+ // CHECK: %[[TMP_275:.*]] = mhlo.divide %[[TMP_271]], %[[TMP_274]]
+ // CHECK: %[[TMP_276:.*]] = mhlo.subtract %[[TMP_269]], %[[TMP_275]]
+ // CHECK: %[[TMP_277:.*]] = mhlo.divide %[[TMP_271]], %[[TMP_273]]
+ // CHECK: %[[TMP_278:.*]] = mhlo.add %[[TMP_270]], %[[TMP_277]]
+ // CHECK: %[[TMP_279:.*]] = mhlo.constant dense<-1259.1392167224028>
+ // CHECK: %[[TMP_280:.*]] = mhlo.constant dense<2.000000e+00>
+ // CHECK: %[[TMP_281:.*]] = mhlo.add %[[TMP_268]], %[[TMP_280]]
+ // CHECK: %[[TMP_282:.*]] = mhlo.multiply %[[TMP_281]], %[[TMP_281]]
+ // CHECK: %[[TMP_283:.*]] = mhlo.divide %[[TMP_279]], %[[TMP_282]]
+ // CHECK: %[[TMP_284:.*]] = mhlo.subtract %[[TMP_276]], %[[TMP_283]]
+ // CHECK: %[[TMP_285:.*]] = mhlo.divide %[[TMP_279]], %[[TMP_281]]
+ // CHECK: %[[TMP_286:.*]] = mhlo.add %[[TMP_278]], %[[TMP_285]]
+ // CHECK: %[[TMP_287:.*]] = mhlo.constant dense<771.32342877765313>
+ // CHECK: %[[TMP_288:.*]] = mhlo.constant dense<3.000000e+00>
+ // CHECK: %[[TMP_289:.*]] = mhlo.add %[[TMP_268]], %[[TMP_288]]
+ // CHECK: %[[TMP_290:.*]] = mhlo.multiply %[[TMP_289]], %[[TMP_289]]
+ // CHECK: %[[TMP_291:.*]] = mhlo.divide %[[TMP_287]], %[[TMP_290]]
+ // CHECK: %[[TMP_292:.*]] = mhlo.subtract %[[TMP_284]], %[[TMP_291]]
+ // CHECK: %[[TMP_293:.*]] = mhlo.divide %[[TMP_287]], %[[TMP_289]]
+ // CHECK: %[[TMP_294:.*]] = mhlo.add %[[TMP_286]], %[[TMP_293]]
+ // CHECK: %[[TMP_295:.*]] = mhlo.constant dense<-176.61502916214059>
+ // CHECK: %[[TMP_296:.*]] = mhlo.constant dense<4.000000e+00>
+ // CHECK: %[[TMP_297:.*]] = mhlo.add %[[TMP_268]], %[[TMP_296]]
+ // CHECK: %[[TMP_298:.*]] = mhlo.multiply %[[TMP_297]], %[[TMP_297]]
+ // CHECK: %[[TMP_299:.*]] = mhlo.divide %[[TMP_295]], %[[TMP_298]]
+ // CHECK: %[[TMP_300:.*]] = mhlo.subtract %[[TMP_292]], %[[TMP_299]]
+ // CHECK: %[[TMP_301:.*]] = mhlo.divide %[[TMP_295]], %[[TMP_297]]
+ // CHECK: %[[TMP_302:.*]] = mhlo.add %[[TMP_294]], %[[TMP_301]]
+ // CHECK: %[[TMP_303:.*]] = mhlo.constant dense<12.507343278686905>
+ // CHECK: %[[TMP_304:.*]] = mhlo.constant dense<5.000000e+00>
+ // CHECK: %[[TMP_305:.*]] = mhlo.add %[[TMP_268]], %[[TMP_304]]
+ // CHECK: %[[TMP_306:.*]] = mhlo.multiply %[[TMP_305]], %[[TMP_305]]
+ // CHECK: %[[TMP_307:.*]] = mhlo.divide %[[TMP_303]], %[[TMP_306]]
+ // CHECK: %[[TMP_308:.*]] = mhlo.subtract %[[TMP_300]], %[[TMP_307]]
+ // CHECK: %[[TMP_309:.*]] = mhlo.divide %[[TMP_303]], %[[TMP_305]]
+ // CHECK: %[[TMP_310:.*]] = mhlo.add %[[TMP_302]], %[[TMP_309]]
+ // CHECK: %[[TMP_311:.*]] = mhlo.constant dense<-0.13857109526572012>
+ // CHECK: %[[TMP_312:.*]] = mhlo.constant dense<6.000000e+00>
+ // CHECK: %[[TMP_313:.*]] = mhlo.add %[[TMP_268]], %[[TMP_312]]
+ // CHECK: %[[TMP_314:.*]] = mhlo.multiply %[[TMP_313]], %[[TMP_313]]
+ // CHECK: %[[TMP_315:.*]] = mhlo.divide %[[TMP_311]], %[[TMP_314]]
+ // CHECK: %[[TMP_316:.*]] = mhlo.subtract %[[TMP_308]], %[[TMP_315]]
+ // CHECK: %[[TMP_317:.*]] = mhlo.divide %[[TMP_311]], %[[TMP_313]]
+ // CHECK: %[[TMP_318:.*]] = mhlo.add %[[TMP_310]], %[[TMP_317]]
+ // CHECK: %[[TMP_319:.*]] = mhlo.constant dense<9.9843695780195716E-6>
+ // CHECK: %[[TMP_320:.*]] = mhlo.constant dense<7.000000e+00>
+ // CHECK: %[[TMP_321:.*]] = mhlo.add %[[TMP_268]], %[[TMP_320]]
+ // CHECK: %[[TMP_322:.*]] = mhlo.multiply %[[TMP_321]], %[[TMP_321]]
+ // CHECK: %[[TMP_323:.*]] = mhlo.divide %[[TMP_319]], %[[TMP_322]]
+ // CHECK: %[[TMP_324:.*]] = mhlo.subtract %[[TMP_316]], %[[TMP_323]]
+ // CHECK: %[[TMP_325:.*]] = mhlo.divide %[[TMP_319]], %[[TMP_321]]
+ // CHECK: %[[TMP_326:.*]] = mhlo.add %[[TMP_318]], %[[TMP_325]]
+ // CHECK: %[[TMP_327:.*]] = mhlo.constant dense<1.5056327351493116E-7>
+ // CHECK: %[[TMP_328:.*]] = mhlo.constant dense<8.000000e+00>
+ // CHECK: %[[TMP_329:.*]] = mhlo.add %[[TMP_268]], %[[TMP_328]]
+ // CHECK: %[[TMP_330:.*]] = mhlo.multiply %[[TMP_329]], %[[TMP_329]]
+ // CHECK: %[[TMP_331:.*]] = mhlo.divide %[[TMP_327]], %[[TMP_330]]
+ // CHECK: %[[TMP_332:.*]] = mhlo.subtract %[[TMP_324]], %[[TMP_331]]
+ // CHECK: %[[TMP_333:.*]] = mhlo.divide %[[TMP_327]], %[[TMP_329]]
+ // CHECK: %[[TMP_334:.*]] = mhlo.add %[[TMP_326]], %[[TMP_333]]
+ // CHECK: %[[TMP_335:.*]] = mhlo.constant dense<7.500000e+00>
+ // CHECK: %[[TMP_336:.*]] = mhlo.add %[[TMP_335]], %[[TMP_268]]
+ // CHECK: %[[TMP_337:.*]] = mhlo.constant dense<2.0149030205422647>
+ // CHECK: %[[TMP_338:.*]] = mhlo.divide %[[TMP_268]], %[[TMP_335]]
+ // CHECK: %[[TMP_339:.*]] = "mhlo.log_plus_one"(%[[TMP_338]])
+ // CHECK: %[[TMP_340:.*]] = mhlo.add %[[TMP_337]], %[[TMP_339]]
+ // CHECK: %[[TMP_341:.*]] = mhlo.divide %[[TMP_332]], %[[TMP_334]]
+ // CHECK: %[[TMP_342:.*]] = mhlo.constant dense<7.000000e+00>
+ // CHECK: %[[TMP_343:.*]] = mhlo.divide %[[TMP_342]], %[[TMP_336]]
+ // CHECK: %[[TMP_344:.*]] = mhlo.add %[[TMP_340]], %[[TMP_341]]
+ // CHECK: %[[TMP_345:.*]] = mhlo.subtract %[[TMP_344]], %[[TMP_343]]
+ // CHECK: %[[TMP_346:.*]] = mhlo.constant dense<5.000000e-01>
+ // CHECK: %[[TMP_347:.*]] = mhlo.add %[[ARG1]], %[[TMP_346]]
+ // CHECK: %[[TMP_348:.*]] = "mhlo.floor"(%[[TMP_347]])
+ // CHECK: %[[TMP_349:.*]] = "mhlo.abs"(%[[TMP_348]])
+ // CHECK: %[[TMP_350:.*]] = mhlo.add %[[ARG1]], %[[TMP_349]]
+ // CHECK: %[[TMP_351:.*]] = mhlo.constant dense<3.1415926535897931>
+ // CHECK: %[[TMP_352:.*]] = mhlo.multiply %[[TMP_351]], %[[TMP_350]]
+ // CHECK: %[[TMP_353:.*]] = "mhlo.cosine"(%[[TMP_352]])
+ // CHECK: %[[TMP_354:.*]] = "mhlo.sine"(%[[TMP_352]])
+ // CHECK: %[[TMP_355:.*]] = mhlo.multiply %[[TMP_351]], %[[TMP_353]]
+ // CHECK: %[[TMP_356:.*]] = mhlo.divide %[[TMP_355]], %[[TMP_354]]
+ // CHECK: %[[TMP_357:.*]] = mhlo.subtract %[[TMP_345]], %[[TMP_356]]
+ // CHECK: %[[TMP_358:.*]] = "mhlo.select"(%[[TMP_264]], %[[TMP_357]], %[[TMP_345]])
+ // CHECK: %[[TMP_359:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_269]]) {comparison_direction = "LE"}
+ // CHECK: %[[TMP_360:.*]] = "mhlo.floor"(%[[ARG1]])
+ // CHECK: %[[TMP_361:.*]] = "mhlo.compare"(%[[ARG1]], %[[TMP_360]]) {comparison_direction = "EQ"}
+ // CHECK: %[[TMP_362:.*]] = mhlo.and %[[TMP_359]], %[[TMP_361]]
+ // CHECK: %[[TMP_363:.*]] = mhlo.constant dense<0x7FF8000000000000>
+ // CHECK: %[[TMP_364:.*]] = "mhlo.select"(%[[TMP_362]], %[[TMP_363]], %[[TMP_358]])
+ // CHECK: %[[TMP_365:.*]] = "mhlo.select"(%[[TMP_262]], %[[TMP_364]], %[[TMP_260]])
+ // CHECK: %[[TMP_366:.*]] = "mhlo.floor"(%[[ARG0]])
+ // CHECK: %[[TMP_367:.*]] = "mhlo.compare"(%[[ARG0]], %[[TMP_366]]) {comparison_direction = "NE"}
+ // CHECK: %[[TMP_368:.*]] = "mhlo.compare"(%[[ARG0]], %[[TMP_261]]) {comparison_direction = "LT"}
+ // CHECK: %[[TMP_369:.*]] = mhlo.or %[[TMP_367]], %[[TMP_368]]
+ // CHECK: %[[TMP_370:.*]] = mhlo.constant dense<0x7FF8000000000000>
+ // CHECK: %[[TMP_371:.*]] = "mhlo.select"(%[[TMP_369]], %[[TMP_370]], %[[TMP_365]])
+ // CHECK: return %[[TMP_371]]
+ %1 = chlo.polygamma %lhs, %rhs : tensor<f64>, tensor<f64> -> tensor<f64>
+ return %1 : tensor<f64>
+}
+
+// CHECK-LABEL: @polygamma_f16
+// CHECK-SAME: (%[[LHS:.*]]: tensor<f16>, %[[RHS:.*]]: tensor<f16>)
+func @polygamma_f16(%lhs : tensor<f16>, %rhs : tensor<f16>) -> tensor<f16> {
+ // CHECK: "mhlo.convert"(%[[LHS]]) : (tensor<f16>) -> tensor<f32>
+ // CHECK: "mhlo.convert"(%[[RHS]]) : (tensor<f16>) -> tensor<f32>
+ // CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
+ // CHECK: return %[[RES]]
+ %1 = chlo.polygamma %lhs, %rhs : tensor<f16>, tensor<f16> -> tensor<f16>
+ return %1 : tensor<f16>
}
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
index 34dcfbf..545a781 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
@@ -182,7 +182,7 @@
// CHECK: ^bb0(%[[ARG:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[C1:.*]] = constant 1.{{.*}}e+00
// CHECK: %[[NEG_ARG:.*]] = negf %[[ARG]]
- // CHECK: %[[EXP_NEG_ARG:.*]] = exp %[[NEG_ARG]]
+ // CHECK: %[[EXP_NEG_ARG:.*]] = math.exp %[[NEG_ARG]]
// CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = addf %[[C1]], %[[EXP_NEG_ARG]]
// CHECK: %[[RESULT:.*]] = divf %[[C1]], %[[ONE_ADD_EXP_NEG_ARG]]
// CHECK: linalg.yield %[[RESULT]]
@@ -662,6 +662,19 @@
// -----
+// CHECK-LABEL: func @convert_i1_to_i32
+func @convert_i1_to_i32(%input: tensor<2x2xi1>) -> tensor<2x2xi32> {
+ %result = "mhlo.convert"(%input) : (tensor<2x2xi1>) -> tensor<2x2xi32>
+ return %result : tensor<2x2xi32>
+}
+// CHECK: linalg.init_tensor
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i1, %{{.*}}: i32):
+// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i1 to i32
+// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
+
+// -----
+
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
@@ -683,7 +696,7 @@
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %{{.*}}: i32):
-// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
+// CHECK-NEXT: %[[RESULT:.*]] = sexti %[[OPERAND_IN]] : i16 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
@@ -727,6 +740,34 @@
// -----
+// CHECK-LABEL: func @convert_i32_to_i1
+func @convert_i32_to_i1(%input: tensor<2x2xi32>) -> tensor<2x2xi1> {
+ %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi1>
+ return %result : tensor<2x2xi1>
+}
+// CHECK: linalg.init_tensor
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: i1):
+// CHECK-NEXT: %[[ZERO:.*]] = constant 0 : i32
+// CHECK-NEXT: %[[RESULT:.*]] = cmpi ne, %[[OPERAND_IN]], %[[ZERO]] : i32
+// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
+
+// -----
+
+// CHECK-LABEL: func @convert_f32_to_i1
+func @convert_f32_to_i1(%input: tensor<2x2xf32>) -> tensor<2x2xi1> {
+ %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1>
+ return %result : tensor<2x2xi1>
+}
+// CHECK: linalg.init_tensor
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: i1):
+// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32
+// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
+
+// -----
+
// CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
@@ -834,7 +875,7 @@
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
- // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]]
+ // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = math.powf %[[ARG0]], %[[ARG1]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir
index d18df73..a074763 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir
@@ -158,32 +158,34 @@
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
-// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
-// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
-// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[LHS_RANK]], %[[C0]] : index
+// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
+// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index
// Handle scalar LHS case
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
-// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor<f32>
-// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
-// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
+// CHECK-NEXT: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
+// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
-// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
-// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
+// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK-NEXT: %[[SHAPE_BROADCAST_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+// CHECK-NEXT: %[[RESHAPED_EXTENDED_LHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_LHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_LHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_LHS_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
-// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
-// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
-// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[RHS_RANK]], %[[C0]] : index
+// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
+// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index
// Handle scalar RHS case
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
-// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor<f32>
-// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
+// CHECK-NEXT: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
-// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32>
+// CHECK-NEXT: %[[SHAPE_BROADCAST_RHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+// CHECK-NEXT: %[[RESHAPED_EXTENDED_RHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_RHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_RHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RHS_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// Handle equal shapes case
@@ -197,10 +199,11 @@
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
+// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// Handle rank 1 specialization
-// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir
index 959b8c2..7178c6a 100644
--- a/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-hlo-opt --mhlo-legalize-trigonometric-to-approximation --split-input-file %s | FileCheck %s
func @tanh_f64(%arg0 : f64) -> f64 {
- %res = tanh %arg0 : f64
+ %res = math.tanh %arg0 : f64
return %res : f64
}
@@ -11,7 +11,7 @@
// -----
func @tanh_f32(%arg0 : f32) -> f32 {
- %res = tanh %arg0 : f32
+ %res = math.tanh %arg0 : f32
return %res : f32
}
@@ -66,7 +66,7 @@
// -----
func @tanh_f16(%arg0 : f16) -> f16 {
- %res = tanh %arg0 : f16
+ %res = math.tanh %arg0 : f16
return %res : f16
}
@@ -125,7 +125,7 @@
// CHECK-LABEL: @atan2_f64
func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 {
// CHECK: atan2
- %res = atan2 %arg0, %arg1 : f64
+ %res = math.atan2 %arg0, %arg1 : f64
return %res : f64
}
@@ -134,6 +134,6 @@
// CHECK-LABEL: @atan_f64
func @atan_f64(%arg : f64) -> f64 {
// CHECK: atan
- %res = atan %arg : f64
+ %res = math.atan %arg : f64
return %res : f64
}
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir
index 54aceaf..a46668c 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir
@@ -92,7 +92,7 @@
ins(%1 : memref<100x10xf32>)
outs(%arg2 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
- %2 = exp %arg3 : f32
+ %2 = math.exp %arg3 : f32
linalg.yield %2 : f32
}
dealloc %1 : memref<100x10xf32>
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
index e31369c..f0e0234 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
@@ -10,7 +10,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
-// CHECK-NEXT: %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = math.powf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@@ -115,7 +115,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT: %[[RESULT:.*]] = exp %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = math.exp %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@@ -127,7 +127,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT: %[[RESULT:.*]] = log %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = math.log %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@@ -419,6 +419,18 @@
// -----
+// CHECK-LABEL: func @convert_i1_to_i32
+func @convert_i1_to_i32(%input: memref<2x2xi1>, %result: memref<2x2xi32>) {
+ "lmhlo.convert"(%input, %result) : (memref<2x2xi1>, memref<2x2xi32>) -> ()
+ return
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i1, %[[RESULT_OUT:.*]]: i32):
+// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i1 to i32
+// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
+
+// -----
+
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
@@ -439,7 +451,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]: i32):
-// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
+// CHECK-NEXT: %[[RESULT:.*]] = sexti %[[OPERAND_IN]] : i16 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
@@ -502,6 +514,34 @@
// -----
+// CHECK-LABEL: func @convert_i32_to_i1
+func @convert_i32_to_i1(%input: memref<2x2xi32>, %result: memref<2x2xi1>) {
+ "lmhlo.convert"(%input, %result)
+ : (memref<2x2xi32>, memref<2x2xi1>) -> ()
+ return
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i1):
+// CHECK-NEXT: %[[ZERO:.*]] = constant 0 : i32
+// CHECK-NEXT: %[[RESULT:.*]] = cmpi ne, %[[OPERAND_IN]], %[[ZERO]] : i32
+// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
+
+// -----
+
+// CHECK-LABEL: func @convert_f32_to_i1
+func @convert_f32_to_i1(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
+ "lmhlo.convert"(%input, %result)
+ : (memref<2x2xf32>, memref<2x2xi1>) -> ()
+ return
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i1):
+// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32
+// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
+
+// -----
+
// CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
"lmhlo.convert"(%input, %result)
@@ -522,7 +562,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT: %[[RESULT:.*]] = cos %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = math.cos %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@@ -536,7 +576,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = math.sin %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@@ -612,7 +652,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT: %[[RESULT:.*]] = rsqrt %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = math.rsqrt %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@@ -676,7 +716,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT: %[[RESULT:.*]] = sqrt %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = math.sqrt %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@@ -688,7 +728,7 @@
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
-// CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = math.tanh %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir
index 5651f04..93c4a76 100644
--- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir
@@ -180,6 +180,30 @@
// -----
+// Regression test for b/180052624, where this was improperly marked as an
+// invalid mhlo.broadcast_in_dim op.
+// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
+func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor<?xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.broadcast_in_dim"(%arg0) {
+ broadcast_dimensions = dense<0> : tensor<1xi64>
+ } : (tensor<?xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+// Regression test for b/180052624, where this crashed verification given the
+// unranked operand.
+// CHECK-LABEL: func @broadcast_in_dim_unranked_operand
+func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> {
+ %0 = "mhlo.broadcast_in_dim"(%arg0) {
+ broadcast_dimensions = dense<0> : tensor<1xi64>
+ } : (tensor<*xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
diff --git a/tensorflow/compiler/mlir/init_mlir.cc b/tensorflow/compiler/mlir/init_mlir.cc
index fac9f51..1b415bd 100644
--- a/tensorflow/compiler/mlir/init_mlir.cc
+++ b/tensorflow/compiler/mlir/init_mlir.cc
@@ -17,6 +17,17 @@
#include "tensorflow/core/platform/init_main.h"
+static llvm::cl::extrahelp FlagSplittingHelp(R"(
+The command line parsing is split between the two flag parsing libraries used by
+TensorFlow and LLVM:
+ * Flags before the first '--' are parsed by tensorflow::InitMain while those
+ post are parsed by LLVM's command line parser.
+ * If there is no separator, then no flags are parsed by InitMain and only
+ LLVM command line parser used.
+The above help options reported are for LLVM's parser, run with `--help --` for
+TensorFlow's help.
+)");
+
namespace tensorflow {
InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) {
diff --git a/tensorflow/compiler/mlir/init_mlir.h b/tensorflow/compiler/mlir/init_mlir.h
index 91020c1..81855d6 100644
--- a/tensorflow/compiler/mlir/init_mlir.h
+++ b/tensorflow/compiler/mlir/init_mlir.h
@@ -21,12 +21,11 @@
namespace tensorflow {
-// Initializer to perform both InitLLVM and TF"s InitMain initialization.
+// Initializer to perform both InitLLVM and TF's InitMain initialization.
// InitMain also performs flag parsing and '--' is used to separate flags passed
// to it: Flags before the first '--' are parsed by InitMain and argc and argv
// progressed to the flags post. If there is no separator, then no flags are
// parsed by InitMain and argc/argv left unadjusted.
-// TODO(jpienaar): The way help flag is handled could be improved.
class InitMlir {
public:
InitMlir(int *argc, char ***argv);
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index a760ca1..06fa87f 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -5,6 +5,7 @@
load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
+load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist")
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
@@ -35,7 +36,9 @@
srcs = [
"ir/tfl_op_interfaces.td",
"ir/tfl_ops.td",
+ "ir/tfl_structs.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
@@ -56,6 +59,22 @@
"ir/tfl_ops.cc.inc",
),
(
+ "-gen-dialect-doc",
+ "g3doc/tfl_ops.md",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "ir/tfl_ops.td",
+ td_srcs = [
+ ":tensorflow_lite_ops_td_files",
+ ],
+)
+
+gentbl(
+ name = "tensorflow_lite_structs_inc_gen",
+ compatible_with = get_compatible_with_cloud(),
+ tbl_outs = [
+ (
"-gen-struct-attr-decls",
"ir/tfl_structs.h.inc",
),
@@ -63,13 +82,9 @@
"-gen-struct-attr-defs",
"ir/tfl_structs.cc.inc",
),
- (
- "-gen-dialect-doc",
- "g3doc/tfl_ops.md",
- ),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "ir/tfl_ops.td",
+ td_file = "ir/tfl_structs.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
],
@@ -226,19 +241,23 @@
"ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
+ "ir/tfl_structs.cc.inc",
"runtime_verifiers.inc",
"utils/attribute_utils.cc",
],
hdrs = [
"ir/tfl_ops.h",
+ "ir/tfl_structs.h.inc",
"transforms/passes.h",
"utils/attribute_utils.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
deps = [
":tensorflow_lite_ops_inc_gen",
+ ":tensorflow_lite_structs_inc_gen",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"//third_party/eigen3",
@@ -597,10 +616,7 @@
tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen",
td_file = "ir/tfl_ops.td",
td_srcs = [
- "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
- "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
- "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
- "ir/tfl_op_interfaces.td",
+ ":tensorflow_lite_ops_td_files",
],
)
@@ -1013,3 +1029,22 @@
"@llvm-project//llvm:Support",
],
)
+
+# Python Library to check TensorFlow op compatibility.
+py_library(
+ name = "tensorflow_lite_compatibility_tbl_generated",
+ srcs = [
+ "tensorflow_lite_compatibility_tbl_generated.py",
+ ],
+ srcs_version = "PY3",
+ visibility = internal_visibility_allowlist(),
+)
+
+genrule(
+ name = "tfl_compatibility_tbl_gen",
+ srcs = ["transforms/generated_legalize_tf.inc"],
+ outs = ["tensorflow_lite_compatibility_tbl_generated.py"],
+ cmd = "$(location generate_tfl_compatibility_tbl.sh) $(location transforms/generated_legalize_tf.inc) > \"$@\"",
+ compatible_with = get_compatible_with_cloud(),
+ tools = ["generate_tfl_compatibility_tbl.sh"],
+)
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index 250ff1f..4ab877b 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -165,7 +165,8 @@
case 16:
return tflite::TensorType_INT16;
case 32:
- return tflite::TensorType_INT32;
+ return itype.isUnsigned() ? tflite::TensorType_UINT32
+ : tflite::TensorType_INT32;
case 64:
return itype.isUnsigned() ? tflite::TensorType_UINT64
: tflite::TensorType_INT64;
@@ -182,11 +183,9 @@
type.dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
return GetTFLiteType(q_calibrated_type.getExpressedType());
} else if (type.isa<mlir::TF::ResourceType>()) {
- // Treat tf.resource values as integer values in flatbuffer.
- // TODO(b/146131919): Maybe need to have a detailed design for supporting
- // other resource types beyonds hash table resources and resource
- // variables.
- return tflite::TensorType_INT32;
+ return tflite::TensorType_RESOURCE;
+ } else if (type.isa<mlir::TF::VariantType>()) {
+ return tflite::TensorType_VARIANT;
}
// TFLite export fills FLOAT32 for unknown data types. Returning an error
// for now for safety and this could be revisited when required.
@@ -1712,6 +1711,9 @@
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
+ OpOrArgLocNameMapper default_op_or_arg_name_mapper;
+ if (!op_or_arg_name_mapper)
+ op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
if (!UpdateEntryFunction(module)) return llvm::None;
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
@@ -1944,69 +1946,23 @@
} // namespace
-// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer
-// format. Returns false on success.
-//
+namespace tflite {
// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
// the following:
//
// * Quantization
// * Ops with variable tensors
//
-bool tflite::MlirToFlatBufferTranslateFunction(
- ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
- OpOrArgNameMapper* op_or_arg_name_mapper) {
- return MlirToFlatBufferTranslateFunction(
- module, serialized_flatbuffer, emit_builtin_tflite_ops,
- emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
- op_or_arg_name_mapper);
-}
-
-bool tflite::MlirToFlatBufferTranslateFunction(
- ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
- bool emit_custom_ops) {
- OpOrArgLocNameMapper op_or_arg_name_mapper;
- return MlirToFlatBufferTranslateFunction(
- module, serialized_flatbuffer, emit_builtin_tflite_ops,
- emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
- &op_or_arg_name_mapper);
-}
-
-bool tflite::MlirToFlatBufferTranslateFunction(
- mlir::ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
- const std::unordered_set<std::string>& saved_model_tags) {
- OpOrArgLocNameMapper op_or_arg_name_mapper;
- return MlirToFlatBufferTranslateFunction(
- module, serialized_flatbuffer, emit_builtin_tflite_ops,
- emit_select_tf_ops, emit_custom_ops, saved_model_tags,
- &op_or_arg_name_mapper);
-}
-
-bool tflite::MlirToFlatBufferTranslateFunction(
- mlir::ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
- const std::unordered_set<std::string>& saved_model_tags,
- OpOrArgNameMapper* op_or_arg_name_mapper) {
- std::unordered_set<std::string> select_user_tf_ops;
- return MlirToFlatBufferTranslateFunction(
- module, serialized_flatbuffer, emit_builtin_tflite_ops,
- emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
- op_or_arg_name_mapper);
-}
-
-bool tflite::MlirToFlatBufferTranslateFunction(
- ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
- const std::unordered_set<std::string>& select_user_tf_ops,
- const std::unordered_set<std::string>& saved_model_tags,
- tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
+bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
+ const FlatbufferExportOptions& options,
+ std::string* serialized_flatbuffer) {
auto maybe_translated = Translator::Translate(
- module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
- select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper);
- if (!maybe_translated) return true;
+ module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
+ options.emit_custom_ops, options.select_user_tf_ops,
+ options.saved_model_tags, options.op_or_arg_name_mapper);
+ if (!maybe_translated) return false;
*serialized_flatbuffer = std::move(*maybe_translated);
- return false;
+ return true;
}
+
+} // namespace tflite
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/flatbuffer_export.h
index c47bffb..73b7166 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.h
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.h
@@ -23,43 +23,26 @@
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite {
+// Options for exporting to Flatbuffer.
+struct FlatbufferExportOptions {
+ bool emit_builtin_tflite_ops = false;
+ bool emit_select_tf_ops = false;
+ bool emit_custom_ops = false;
+ // When exporting from SavedModel, this will have the requested tags.
+ std::unordered_set<std::string> saved_model_tags;
+ // TF custom op passed by the user.
+ std::unordered_set<std::string> select_user_tf_ops;
+ // OpOrArgNameMapper to convert location of the op to name in flatbuffer.
+ // If not set, a default mapper will be used.
+ tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr;
+};
// Translates the given MLIR `module` into a FlatBuffer and stores the
-// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to
-// convert location of the op to name in flatbuffer. Returns true if translation
-// fails, otherwise returns false.
+// serialized flatbuffer into the string.
+// Returns true on successful exporting, false otherwise.
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
- std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops,
- bool emit_select_tf_ops,
- bool emit_custom_ops);
-
-// Same as above but takes SavedModel tags of the model.
-bool MlirToFlatBufferTranslateFunction(
- mlir::ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
- const std::unordered_set<std::string>& saved_model_tags);
-
-// Same as the above but with a custom op name mapper.
-bool MlirToFlatBufferTranslateFunction(
- mlir::ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
- tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
-
-// Same as above but takes SavedModel tags of the model.
-bool MlirToFlatBufferTranslateFunction(
- mlir::ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
- const std::unordered_set<std::string>& saved_model_tags,
- tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
-
-// Same as the above but with a list of allowed user's defined ops.
-bool MlirToFlatBufferTranslateFunction(
- mlir::ModuleOp module, std::string* serialized_flatbuffer,
- bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
- const std::unordered_set<std::string>& select_user_tf_ops,
- const std::unordered_set<std::string>& saved_model_tags,
- tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
+ const FlatbufferExportOptions& options,
+ std::string* serialized_flatbuffer);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
index 901199e..1c3788a 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
@@ -158,9 +158,13 @@
op_or_arg_name_mapper =
std::make_unique<tensorflow::OpOrArgLocNameMapper>();
}
- if (tflite::MlirToFlatBufferTranslateFunction(
- module, &serialized_flatbuffer, emit_builtin_tflite_ops,
- emit_select_tf_ops, emit_custom_ops, op_or_arg_name_mapper.get()))
+ tflite::FlatbufferExportOptions options;
+ options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
+ options.emit_custom_ops = emit_custom_ops;
+ options.emit_select_tf_ops = emit_select_tf_ops;
+ options.op_or_arg_name_mapper = op_or_arg_name_mapper.get();
+ if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
+ &serialized_flatbuffer))
return mlir::failure();
output << serialized_flatbuffer;
diff --git a/tensorflow/compiler/mlir/lite/generate_tfl_compatibility_tbl.sh b/tensorflow/compiler/mlir/lite/generate_tfl_compatibility_tbl.sh
new file mode 100755
index 0000000..a6a3455
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/generate_tfl_compatibility_tbl.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+cat <<EOF
+TENSORFLOW_COMPATIBLE_OPS = (
+EOF
+
+# TODO(b/178456916): Leverage existing op compat definitions/specs in the
+# MLIR conversion pipeline in a better way.
+# TODO(b/180352158): Validate generated TF op names.
+grep 'patterns.insert<Legalize' $1 | awk -F'<Legalize|>' '{printf " \"%s\",\n", $2}'
+
+cat <<EOF
+ # Rules at tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+ "Assert",
+ "ConcatV2",
+ "MatMul",
+ "MatrixDiagV2",
+ "MatrixDiagV3",
+ "Pack",
+ "Split",
+ "SplitV",
+ "Unpack",
+ "RandomUniform",
+)
+EOF
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
index e14178d..a8ed682 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
@@ -20,6 +20,31 @@
include "mlir/IR/OpBase.td"
+def TFL_Dialect : Dialect {
+ let name = "tfl";
+
+ let description = [{
+ The TensorFlow Lite dialect.
+
+ This dialect maps to TensorFlow Lite operations.
+
+ Invariants:
+
+ * All values are of Tensor type (in particular, scalars are
+ represented using zero-dimensional tensors);
+ }];
+
+ let cppNamespace = "::mlir::TFL";
+}
+
+// Attributes used for encoding sparse tensors.
+// Please find detailed explanation of these parameters in the TFLite schema.
+def TFL_DT_Dense : StrEnumAttrCase<"DENSE", 0>;
+def TFL_DT_SparseCSR : StrEnumAttrCase<"SPARSE_CSR", 1>;
+
+def TFL_DimensionTypeAttr : StrEnumAttr<
+ "DimensionType", "dimension type", [TFL_DT_Dense, TFL_DT_SparseCSR]>;
+
//===----------------------------------------------------------------------===//
// TFL op interface for stateful operands.
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index f8fb0d4..e55fee4 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -147,6 +147,41 @@
return element_type.isInteger(64) && !element_type.isUnsignedInteger();
}
+// Return true if the value is a splat tensor constant zero.
+bool EqualsZero(Value value) {
+ DenseElementsAttr constant;
+ if (!matchPattern(value, m_Constant(&constant)) || !constant.isSplat()) {
+ return false;
+ }
+
+ Type element_type = value.getType().cast<ShapedType>().getElementType();
+ if (element_type.isa<FloatType>()) {
+ return constant.getSplatValue<APFloat>().isZero();
+ } else {
+ return false;
+ }
+}
+
+// Replaces the bias operand with a "none" type value if the bias value is
+// constant zero.
+// `ConcreteOpType` must be an concrete MLIR op class that has an optional
+// bias operand named 'bias'.
+template <typename ConcreteOpType>
+struct RemoveOptionalZeroBias : public OpRewritePattern<ConcreteOpType> {
+ using OpRewritePattern<ConcreteOpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcreteOpType op,
+ PatternRewriter &rewriter) const override {
+ if (EqualsZero(op.bias())) {
+ auto none_value = rewriter.create<mlir::ConstantOp>(
+ rewriter.getUnknownLoc(), rewriter.getUnitAttr());
+ op.biasMutable().assign(none_value);
+ }
+
+ return success();
+ }
+};
+
// Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
@@ -799,6 +834,131 @@
return mlir::success();
}
+LogicalResult FullyConnectedOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ assert(operands.size() == 3);
+
+ // Folding not implemented with any activation function or any weight type
+ // besides the default.
+ if (fused_activation_function() != "NONE") return failure();
+ if (weights_format() != "DEFAULT") return failure();
+
+ // Bias tensor is optional.
+ const bool has_bias = !(!bias() || bias().getType().isa<NoneType>());
+
+ // Get the tensors.
+ DenseElementsAttr input_tensor, weights_tensor, bias_tensor;
+ if (!matchPattern(input(), m_Constant(&input_tensor)) ||
+ !matchPattern(filter(), m_Constant(&weights_tensor)) ||
+ (has_bias && !matchPattern(bias(), m_Constant(&bias_tensor)))) {
+ return failure();
+ }
+
+ // Get the tensor types.
+ const auto input_type = input_tensor.getType().cast<ShapedType>();
+ const auto weights_type = weights_tensor.getType().cast<ShapedType>();
+ const auto bias_type =
+ has_bias ? bias_tensor.getType().cast<ShapedType>() : ShapedType{};
+
+ const auto output_type = getType(0).cast<ShapedType>();
+
+ // Folding only implemented for float tensors.
+ if (!input_type.getElementType().isF32() ||
+ !weights_type.getElementType().isF32() ||
+ !output_type.getElementType().isF32() ||
+ (has_bias && !bias_type.getElementType().isF32())) {
+ return failure();
+ }
+
+ // Folding only implemented for static shapes
+ if (!input_type.hasStaticShape() || !weights_type.hasStaticShape() ||
+ (has_bias && !bias_type.hasStaticShape())) {
+ return failure();
+ }
+
+ // Folding only implemented for 1D input, 2D weights and 1D bias
+ if (input_type.getShape().size() != 1 ||
+ weights_type.getShape().size() != 2 ||
+ (has_bias && bias_type.getShape().size() != 1)) {
+ return failure();
+ }
+
+ // Get the sizes
+ const auto input_size = input_type.getNumElements();
+ const auto output_size = output_type.getNumElements();
+
+ // Get iterators to the tensors.
+ const auto input_values_it = input_tensor.getValues<float>().begin();
+ const auto weights_values_ptr = weights_tensor.getValues<float>().begin();
+ auto weights_row_it = weights_values_ptr;
+ // The 'else' case could be nullptr, but the types don't match.
+ auto bias_values_it =
+ has_bias ? bias_tensor.getValues<float>().begin() : input_values_it;
+
+ // Do the actual folding, one output at a time.
+ std::vector<float> result_values;
+ result_values.reserve(output_size);
+
+ for (int i = 0; i < output_size; ++i) {
+ // Dot product with Kahan/Neumaier summation to minimize numeric errors.
+ float sum = has_bias ? *bias_values_it : 0.0f;
+ float compensation = 0.0f;
+ for (int j = 0; j < input_size; ++j) {
+ const float addend = input_values_it[j] * weights_row_it[j];
+ const float new_sum = sum + addend;
+ // DO NOT enable -funsafe-math-optimizations here.
+ // There is a test detecting unsafe optimizations.
+ // Unsafe math optimizations can reorder float formulas, and set the
+ // compensation to constant 0. The formula must be evaluated as written
+ // for the algorithm to work.
+ // (Note: -ffast-math is a superset of -funsafe-math-optimizations.)
+ if (std::abs(sum) >= std::abs(addend)) {
+ compensation += (sum - new_sum) + addend;
+ } else {
+ compensation += (addend - new_sum) + sum;
+ }
+ sum = new_sum;
+ }
+ result_values.push_back(sum + compensation);
+ weights_row_it += input_size;
+ bias_values_it++;
+ }
+
+ // Set result tensor
+ const auto folded =
+ DenseElementsAttr::get(output_type, ArrayRef<float>(result_values));
+ results.assign({folded});
+
+ return success();
+}
+
+void FullyConnectedOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<RemoveOptionalZeroBias<FullyConnectedOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Conv2DOp
+//===----------------------------------------------------------------------===//
+
+void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ // TODO(b/180121750): Enable the pattern after the integration tests are
+ // fixed.
+ // results.insert<RemoveOptionalZeroBias<Conv2DOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// DepthwiseConv2DO
+//===----------------------------------------------------------------------===//
+
+void DepthwiseConv2DOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(b/180121750): Enable the pattern after the integration tests are
+ // fixed.
+ // results.insert<RemoveOptionalZeroBias<DepthwiseConv2DOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
@@ -1902,6 +2062,38 @@
return success();
}
+namespace {
+
+// Replaces the optional bias operands with a "none" type value if the bias
+// values are constant zeros.
+struct RemoveLSTMOpZeroBias : public OpRewritePattern<LSTMOp> {
+ using OpRewritePattern<LSTMOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LSTMOp op,
+ PatternRewriter &rewriter) const override {
+ if (EqualsZero(op.input_gate_bias())) {
+ auto none_value = rewriter.create<mlir::ConstantOp>(
+ rewriter.getUnknownLoc(), rewriter.getUnitAttr());
+ op.input_gate_biasMutable().assign(none_value);
+ }
+
+ if (EqualsZero(op.projection_bias())) {
+ auto none_value = rewriter.create<mlir::ConstantOp>(
+ rewriter.getUnknownLoc(), rewriter.getUnitAttr());
+ op.projection_biasMutable().assign(none_value);
+ }
+
+ return success();
+ }
+};
+
+} // namespace
+
+void LSTMOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<RemoveLSTMOpZeroBias>(context);
+}
+
//===----------------------------------------------------------------------===//
// UnidirectionalSequenceLSTMOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 6f9021e..852a113 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -23,25 +23,9 @@
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
+include "tensorflow/compiler/mlir/lite/ir/tfl_structs.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
-def TFL_Dialect : Dialect {
- let name = "tfl";
-
- let description = [{
- The TensorFlow Lite dialect.
-
- This dialect maps to TensorFlow Lite operations.
-
- Invariants:
-
- * All values are of Tensor type (in particular, scalars are
- represented using zero-dimensional tensors);
- }];
-
- let cppNamespace = "::mlir::TFL";
-}
-
//===----------------------------------------------------------------------===//
// TFLite dialect string type - uses the TF string type as implementation
//===----------------------------------------------------------------------===//
@@ -836,33 +820,6 @@
];
}
-// Attributes used for encoding sparse tensors.
-// Please find detailed explanation of these parameters in the TFLite schema.
-def TFL_DT_Dense : StrEnumAttrCase<"DENSE", 0>;
-def TFL_DT_SparseCSR : StrEnumAttrCase<"SPARSE_CSR", 1>;
-
-def TFL_DimensionTypeAttr : StrEnumAttr<
- "DimensionType", "dimension type", [TFL_DT_Dense, TFL_DT_SparseCSR]>;
-
-def DimensionMetadataAttr : StructAttr<"DimensionMetadataAttr", TFL_Dialect, [
- StructFieldAttr<"format", TFL_DimensionTypeAttr>,
- StructFieldAttr<"dense_size", I32Attr>,
- StructFieldAttr<"segments", I32ArrayAttr>,
- StructFieldAttr<"indices", I32ArrayAttr>] > {
- let summary = "Dimension metadata.";
-}
-
-def DimensionMetadataArrayAttr : TypedArrayAttrBase<DimensionMetadataAttr,
- "Array of DimensionMetadata">{}
-
-def SparsityParameterAttr : StructAttr<"SparsityParameterAttr", TFL_Dialect, [
- StructFieldAttr<"traversal_order", I32ArrayAttr>,
- StructFieldAttr<"block_map", I32ArrayAttr>,
- StructFieldAttr<"dim_metadata", DimensionMetadataArrayAttr>]> {
- let summary = "Sparsity parameter.";
- let storageType = [{ TFL::SparsityParameterAttr }];
-}
-
def TFL_SparseConstOp : Op<TFL_Dialect, "pseudo_sparse_const", [
NoSideEffect,
FirstAttrDerivedResultType]> {
@@ -905,6 +862,8 @@
}
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
+ let hasCanonicalizer = 1;
+
let extraClassDeclaration = [{
// AffineQuantizedOpInterface:
int GetChannelDimIndex() { return 0; }
@@ -973,6 +932,8 @@
I32Attr:$depth_multiplier
);
+ let hasCanonicalizer = 1;
+
let extraClassDeclaration = [{
// AffineQuantizedOpInterface:
int GetChannelDimIndex() { return 3; }
@@ -1020,6 +981,10 @@
let hasOptions = 1;
+ let hasCanonicalizer = 1;
+
+ let hasFolder = 1;
+
let extraClassDeclaration = [{
// AffineQuantizedOpInterface:
int GetChannelDimIndex() { return 0; }
@@ -3488,6 +3453,10 @@
let hasOptions = 1;
}
+// If there is a change in supporting more types in the TFLite cast op kernel,
+// the While loop outline pass should be updated since it inserts cast op(s)
+// after the TF -> TFL legalization pass is done.
+// LINT.IfChange
def TFL_CastOp : TFL_Op<"cast", [
NoSideEffect,
SameOperandsAndResultShape,
@@ -3510,6 +3479,7 @@
let hasFolder = 1;
}
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc)
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
@@ -3930,6 +3900,8 @@
let hasOptions = 1;
+ let hasCanonicalizer = 1;
+
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_structs.td b/tensorflow/compiler/mlir/lite/ir/tfl_structs.td
new file mode 100644
index 0000000..f1afd3c
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_structs.td
@@ -0,0 +1,43 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is the struct definition file for TensorFlow.
+
+#ifndef TFL_STRUCT
+#define TFL_STRUCT
+
+include "mlir/IR/OpBase.td"
+include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
+
+def DimensionMetadataAttr : StructAttr<"DimensionMetadataAttr", TFL_Dialect, [
+ StructFieldAttr<"format", TFL_DimensionTypeAttr>,
+ StructFieldAttr<"dense_size", I32Attr>,
+ StructFieldAttr<"segments", I32ArrayAttr>,
+ StructFieldAttr<"indices", I32ArrayAttr>] > {
+ let summary = "Dimension metadata.";
+}
+
+def DimensionMetadataArrayAttr : TypedArrayAttrBase<DimensionMetadataAttr,
+ "Array of DimensionMetadata">{}
+
+def SparsityParameterAttr : StructAttr<"SparsityParameterAttr", TFL_Dialect, [
+ StructFieldAttr<"traversal_order", I32ArrayAttr>,
+ StructFieldAttr<"block_map", I32ArrayAttr>,
+ StructFieldAttr<"dim_metadata", DimensionMetadataArrayAttr>]> {
+ let summary = "Sparsity parameter.";
+ let storageType = [{ TFL::SparsityParameterAttr }];
+}
+
+#endif // TFL_STRUCT
diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
index d0b1fb9..af9fda2 100644
--- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
+++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
@@ -79,6 +79,8 @@
switch (tensor.type) {
case kTfLiteInt32:
return TfLiteTypedTensorString<int32_t>(tensor);
+ case kTfLiteUInt32:
+ return TfLiteTypedTensorString<uint32_t>(tensor);
case kTfLiteInt64:
return TfLiteTypedTensorString<int64_t>(tensor);
case kTfLiteFloat32:
@@ -100,10 +102,10 @@
}
// Load the MLIR module.
- mlir::MLIRContext context;
- context.getDialectRegistry()
- .insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
- mlir::StandardOpsDialect>();
+ mlir::DialectRegistry registry;
+ registry.insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
+ mlir::StandardOpsDialect>();
+ mlir::MLIRContext context(registry);
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
@@ -118,9 +120,12 @@
// Convert to flatbuffer.
std::string serialized_flatbuffer;
- if (tflite::MlirToFlatBufferTranslateFunction(
- module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops,
- emit_select_tf_ops, emit_custom_ops))
+ tflite::FlatbufferExportOptions options;
+ options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
+ options.emit_custom_ops = emit_custom_ops;
+ options.emit_select_tf_ops = emit_select_tf_ops;
+ if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
+ &serialized_flatbuffer))
return 1;
// Create TFLite interpreter & invoke converted program.
diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
index 735e8be..213186f 100644
--- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
+++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
@@ -119,6 +119,8 @@
return DT_INT16;
case toco::IODataType::INT32:
return DT_INT32;
+ case toco::IODataType::UINT32:
+ return DT_UINT32;
case toco::IODataType::INT64:
return DT_INT64;
case toco::IODataType::UINT64:
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
index 139b93c..f8870cb 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
@@ -53,8 +53,9 @@
return kTfLiteError;
}
- MLIRContext context;
- context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
+ DialectRegistry registry;
+ registry.insert<mlir::TFL::TensorFlowLiteDialect>();
+ MLIRContext context(registry);
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);
@@ -107,9 +108,12 @@
// Export the results to the builder
std::string result;
- if (tflite::MlirToFlatBufferTranslateFunction(
- module.get(), &result, /*emit_builtin_tflite_ops=*/true,
- /*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
+ tflite::FlatbufferExportOptions options;
+ options.emit_builtin_tflite_ops = true;
+ options.emit_select_tf_ops = true;
+ options.emit_custom_ops = true;
+ if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
+ &result)) {
error_reporter->Report("Failed to export MLIR to flatbuffer.");
return kTfLiteError;
}
diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
index eed9529..d3482f7 100644
--- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
+++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
@@ -68,9 +68,12 @@
// Export the results to the builder
std::string result;
- if (tflite::MlirToFlatBufferTranslateFunction(
- module.get(), &result, /*emit_builtin_tflite_ops=*/true,
- /*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
+ tflite::FlatbufferExportOptions options;
+ options.emit_builtin_tflite_ops = true;
+ options.emit_select_tf_ops = true;
+ options.emit_custom_ops = true;
+ if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
+ &result)) {
error_reporter->Report("Failed to export MLIR to flatbuffer.");
return kTfLiteError;
}
diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
index 02142c3..d656904 100644
--- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
@@ -204,3 +204,47 @@
// CHECK: "tfl.while"
// CHECK: (tensor<i32>, tensor<i32>, tensor<!tf.resource>) -> (tensor<i32>, tensor<i32>, tensor<!tf.resource>)
}
+
+// CHECK-LABEL: @RemoveFcZeroBias
+func @RemoveFcZeroBias(%arg0: tensor<1x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40xf32> {
+ %0 = "tfl.pseudo_const"() {value = dense<0.0> : tensor<40xf32>} : () -> tensor<40xf32>
+ %1 = "tfl.fully_connected"(%arg0, %arg1, %0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> tensor<1x40xf32>
+// CHECK: "tfl.fully_connected"
+// CHECK-SAME: (tensor<1x37xf32>, tensor<40x37xf32>, none) -> tensor<1x40xf32>
+ return %1 : tensor<1x40xf32>
+}
+
+// CHECK-LABEL: RemoveLstmQuantZeroBias
+func @RemoveLstmQuantZeroBias(
+ %arg0: tensor<1x528xf32>,
+ %arg1: tensor<2048x528xf32>,
+ %arg2: tensor<2048x528xf32>,
+ %arg3: tensor<2048x528xf32>,
+ %arg4: tensor<2048x528xf32>,
+ %arg5: tensor<2048x640xf32>,
+ %arg6: tensor<2048x640xf32>,
+ %arg7: tensor<2048x640xf32>,
+ %arg8: tensor<2048x640xf32>,
+ %arg9: tensor<2048xf32>,
+ %arg10: tensor<2048xf32>,
+ %arg11: tensor<2048xf32>,
+ %arg12: tensor<2048xf32>,
+ %arg13: tensor<640x2048xf32>,
+ %arg14: tensor<640xf32>,
+ %arg15: tensor<2048xf32>,
+ %arg16: tensor<2048xf32>,
+ %arg17: tensor<2048xf32>,
+ %arg18: tensor<2048xf32>,
+ %arg19: tensor<1x640xf32>,
+ %arg20: tensor<1x2048xf32>
+) -> tensor<1x640xf32> {
+ %cst = constant unit
+ %zero = "tfl.pseudo_const"() {value = dense<0.0> : tensor<640xf32>} : () -> tensor<640xf32>
+ %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %zero, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {
+ cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.01 : f32
+ } : (tensor<1x528xf32>, tensor<2048x528xf32>, tensor<2048x528xf32>, tensor<2048x528xf32>, tensor<2048x528xf32>, tensor<2048x640xf32>, tensor<2048x640xf32>, tensor<2048x640xf32>, tensor<2048x640xf32>, none, none, none, tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>, tensor<640x2048xf32>, tensor<640xf32>, tensor<1x640xf32>, tensor<1x2048xf32>, tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>) -> tensor<1x640xf32>
+ return %0 : tensor<1x640xf32>
+// CHECK: %[[NONE:.+]] = constant unit
+// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[NONE]], %[[NONE]], %[[NONE]], %arg9, %arg10, %arg11, %arg12, %arg13, %[[NONE]], %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
+}
+
diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
index 89992e8..236f783 100644
--- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
@@ -726,3 +726,92 @@
// CHECK: %[[CST:.*]] = constant dense<[false, true, true, true]> : tensor<4xi1>
// CHECK: return %[[CST]]
}
+
+// CHECK-LABEL: @ConstantFoldFullyConnectedSmall
+func @ConstantFoldFullyConnectedSmall() -> tensor<3xf32> {
+ %cst_input= constant dense<[2.0, 3.0]> : tensor<2xf32>
+ %cst_weights = constant dense<[[5.0, 7.0], [11.0, 13.0], [17.0, 19.0]]> : tensor<3x2xf32>
+ %cst_bias = constant dense<[23.0, 29.0, 31.0]> : tensor<3xf32>
+
+ %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2xf32>, tensor<3x2xf32>, tensor<3xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+
+ // [54, 90, 122]
+ // CHECK: %[[CST:.*]] = constant dense<[5.400000e+01, 9.000000e+01, 1.220000e+02]> : tensor<3xf32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: @ConstantFoldFullyConnectedLarge
+func @ConstantFoldFullyConnectedLarge() -> tensor<1024xf32> {
+ %cst_input= constant dense<1.0> : tensor<512xf32>
+ %cst_weights = constant dense<2.0> : tensor<1024x512xf32>
+ %cst_bias = constant dense<4.0> : tensor<1024xf32>
+
+ %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<1024xf32>
+
+ return %0 : tensor<1024xf32>
+
+ // 1.0 * 2.0 * 512 + 4.0 = 1028.0
+ // CHECK: %[[CST:.*]] = constant dense<1.028000e+03> : tensor<1024xf32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: @ConstantFoldFullyConnectedNoBias
+func @ConstantFoldFullyConnectedNoBias() -> tensor<1024xf32> {
+ %cst_input= constant dense<1.0> : tensor<512xf32>
+ %cst_weights = constant dense<2.0> : tensor<1024x512xf32>
+ %cst_bias = constant unit
+
+ %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xf32>, none) -> tensor<1024xf32>
+
+ return %0 : tensor<1024xf32>
+
+ // 1.0 * 2.0 * 512 = 1024.0
+ // CHECK: %[[CST:.*]] = constant dense<1.024000e+03> : tensor<1024xf32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: @NoFoldFullyConnectedNonFloat
+func @NoFoldFullyConnectedNonFloat() -> tensor<1024xf32> {
+ %cst_input= constant dense<1.0> : tensor<512xf32>
+ %cst_weights = constant dense<2> : tensor<1024x512xi8>
+ %cst_bias = constant dense<4.0> : tensor<1024xf32>
+
+ %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xi8>, tensor<1024xf32>) -> tensor<1024xf32>
+
+ return %0 : tensor<1024xf32>
+ // CHECK: %[[CST:.*]] = constant dense<1.000000e+00> : tensor<512xf32>
+ // CHECK: %[[CST_0:.*]] = constant dense<2> : tensor<1024x512xi8>
+ // CHECK: %[[CST_1:.*]] = constant dense<4.000000e+00> : tensor<1024xf32>
+ // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xi8>, tensor<1024xf32>) -> tensor<1024xf32>
+ // CHECK: return %[[VAL]] : tensor<1024xf32>
+}
+
+// CHECK-LABEL: @NoFoldFullyConnectedHighRank
+func @NoFoldFullyConnectedHighRank() -> tensor<2x1024xf32> {
+ %cst_input= constant dense<1.0> : tensor<2x512xf32>
+ %cst_weights = constant dense<2.0> : tensor<1024x512xf32>
+ %cst_bias = constant dense<4.0> : tensor<1024xf32>
+
+ %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32>
+
+ return %0 : tensor<2x1024xf32>
+ // CHECK: %[[CST:.*]] = constant dense<1.000000e+00> : tensor<2x512xf32>
+ // CHECK: %[[CST_0:.*]] = constant dense<2.000000e+00> : tensor<1024x512xf32>
+ // CHECK: %[[CST_1:.*]] = constant dense<4.000000e+00> : tensor<1024xf32>
+ // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32>
+ // CHECK: return %[[VAL]] : tensor<2x1024xf32>
+}
+
+// CHECK-LABEL: @ConstantFoldFullyConnectedCheckPrecision
+func @ConstantFoldFullyConnectedCheckPrecision() -> tensor<1xf32> {
+ %cst_input= constant dense<1.0> : tensor<4xf32>
+ %cst_weights = constant dense<[[1.0, 1.0e38, 1.0, -1.0e38]]> : tensor<1x4xf32>
+ %cst_bias = constant dense<0.0> : tensor<1xf32>
+
+ %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4xf32>, tensor<1x4xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ return %0 : tensor<1xf32>
+ // CHECK: %[[CST:.*]] = constant dense<2.000000e+00> : tensor<1xf32>
+ // CHECK: return %[[CST]]
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir
index 2d5852d..4781040 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir
@@ -10,7 +10,7 @@
// CHECK: subgraphs: [ {
// CHECK: tensors: [ {
// CHECK: shape: [ ],
-// CHECK: type: INT32,
+// CHECK: type: RESOURCE,
// CHECK: buffer: 1,
// CHECK: name: "tf.HashTableV2",
// CHECK: quantization: {
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_func.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_func.mlir
index 79f7d43..045ddfd 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_func.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_func.mlir
@@ -1,6 +1,37 @@
-// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
-// CHECK: function argument uses variant type. Currently, the variant type is not natively supported in TFLite. Please consider not using the variant type: 'tensor<!tf.variant<tensor<2xi32>>>'
+// CHECK: {
+// CHECK-NEXT: version: 3,
+// CHECK-NEXT: operator_codes: [ ],
+// CHECK-NEXT: subgraphs: [ {
+// CHECK-NEXT: tensors: [ {
+// CHECK-NEXT: shape: [ ],
+// CHECK-NEXT: type: VARIANT,
+// CHECK-NEXT: buffer: 1,
+// CHECK-NEXT: name: "arg0",
+// CHECK-NEXT: quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT: }
+// CHECK-NEXT: } ],
+// CHECK-NEXT: inputs: [ 0 ],
+// CHECK-NEXT: outputs: [ 0 ],
+// CHECK-NEXT: operators: [ ],
+// CHECK-NEXT: name: "main"
+// CHECK-NEXT: } ],
+// CHECK-NEXT: description: "MLIR Converted.",
+// CHECK-NEXT: buffers: [ {
+// CHECK-EMPTY:
+// CHECK-NEXT: }, {
+// CHECK-EMPTY:
+// CHECK-NEXT: }, {
+// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
+// CHECK-NEXT: } ],
+// CHECK-NEXT: metadata: [ {
+// CHECK-NEXT: name: "min_runtime_version",
+// CHECK-NEXT: buffer: 2
+// CHECK-NEXT: } ],
+// CHECK-NEXT: signature_defs: [ ]
+// CHECK-NEXT: }
func @main(%arg0 : tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf.variant<tensor<2xi32>>> {
return %arg0 : tensor<!tf.variant<tensor<2xi32>>>
}
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_op.mlir
index ab4044e..9ac0739 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_op.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_op.mlir
@@ -1,6 +1,37 @@
-// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
-// CHECK: error: operand result uses variant type. Currently, the variant type is not natively supported in TFLite. Please consider not using the variant type: 'tensor<!tf.variant<tensor<2xi32>>>'
+// CHECK: {
+// CHECK-NEXT: version: 3,
+// CHECK-NEXT: operator_codes: [ ],
+// CHECK-NEXT: subgraphs: [ {
+// CHECK-NEXT: tensors: [ {
+// CHECK-NEXT: shape: [ ],
+// CHECK-NEXT: type: VARIANT,
+// CHECK-NEXT: buffer: 1,
+// CHECK-NEXT: name: "tf.Const",
+// CHECK-NEXT: quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT: }
+// CHECK-NEXT: } ],
+// CHECK-NEXT: inputs: [ ],
+// CHECK-NEXT: outputs: [ 0 ],
+// CHECK-NEXT: operators: [ ],
+// CHECK-NEXT: name: "main"
+// CHECK-NEXT: } ],
+// CHECK-NEXT: description: "MLIR Converted.",
+// CHECK-NEXT: buffers: [ {
+// CHECK-EMPTY:
+// CHECK-NEXT: }, {
+// CHECK-NEXT: data: [ 128, 0, 0, 0, 128, 0, 0, 0 ]
+// CHECK-NEXT: }, {
+// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
+// CHECK-NEXT: } ],
+// CHECK-NEXT: metadata: [ {
+// CHECK-NEXT: name: "min_runtime_version",
+// CHECK-NEXT: buffer: 2
+// CHECK-NEXT: } ],
+// CHECK-NEXT: signature_defs: [ ]
+// CHECK-NEXT: }
func @main() -> tensor<!tf.variant<tensor<2xi32>>> {
%0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<2xi32>>>
return %0 : tensor<!tf.variant<tensor<2xi32>>>
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index 4c725d0..43bafaa 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -1536,3 +1536,131 @@
// CHECK: %[[ADD:.*]] = tfl.add %[[FC_RESULT]], %[[CST2]] {fused_activation_function = "NONE"} : tensor<1x1x1x1x1xf32>
// CHECK: return %[[ADD]] : tensor<1x1x1x1x1xf32>
}
+
+// CHECK-LABEL: ConvertMul1ToIdentity
+func @ConvertMul1ToIdentity(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
+ %cst = constant dense<1.0> : tensor<1x2x3x4xf32>
+ %0 = "tfl.mul"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
+ return %0 : tensor<1x2x3x4xf32>
+ // CHECK: return %arg0
+}
+
+// CHECK-LABEL: DontConvertMul12ToIdentity
+func @DontConvertMul12ToIdentity(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ %cst = constant dense<[1.0, 2.0]> : tensor<2xf32>
+ %0 = "tfl.mul"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+ // CHECK: %cst = constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
+ // CHECK: %0 = tfl.mul %arg0, %cst {fused_activation_function = "NONE"} : tensor<2xf32>
+ // CHECK: return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: DontConvertMul1WithBroadcastToIdentity
+func @DontConvertMul1WithBroadcastToIdentity(%arg0: tensor<2xf32>) -> tensor<2x2xf32> {
+ %cst = constant dense<1.0> : tensor<2x2xf32>
+ %0 = "tfl.mul"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+ // CHECK: %cst = constant dense<1.000000e+00> : tensor<2x2xf32>
+ // CHECK: %0 = "tfl.mul"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ // CHECK: return %0 : tensor<2x2xf32>
+}
+
+// CHECK-LABEL: ConvertConstSelectToIdentity
+func @ConvertConstSelectToIdentity(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) {
+ %cst_true = constant dense<true> : tensor<1x2x3x4xi1>
+ %cst_false = constant dense<false> : tensor<1x2x3x4xi1>
+ %0 = "tfl.select"(%cst_true, %arg0, %arg1) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
+ %1 = "tfl.select_v2"(%cst_true, %arg0, %arg1) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
+ %2 = "tfl.select"(%cst_false, %arg0, %arg1) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
+ %3 = "tfl.select_v2"(%cst_false, %arg0, %arg1) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
+ return %0, %1, %2, %3 : tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>
+ // CHECK: return %arg0, %arg0, %arg1, %arg1
+}
+
+// CHECK-LABEL: DontConvertConstSelectBroadcast
+func @DontConvertConstSelectBroadcast(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2x3xf32> {
+ %cst = constant dense<false> : tensor<2x3xi1>
+ %0 = "tfl.select"(%cst, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2x3xf32>
+ return %0 : tensor<2x3xf32>
+ // CHECK: %0 = "tfl.select"(%cst, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2x3xf32>
+ // CHECK: return %0
+}
+
+// CHECK-LABEL: DontConvertConstSelectMixed
+func @DontConvertConstSelectMixed(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
+ %cst = constant dense<[false, true]> : tensor<2xi1>
+ %0 = "tfl.select"(%cst, %arg0, %arg1) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ %1 = "tfl.select_v2"(%cst, %arg0, %arg1) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ return %0, %1 : tensor<2xf32>, tensor<2xf32>
+ // CHECK: %0 = "tfl.select"(%cst, %arg0, %arg1) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ // CHECK: %1 = "tfl.select_v2"(%cst, %arg0, %arg1) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ // CHECK: return %0, %1
+}
+
+// CHECK-LABEL: RemoveSoftmaxBeforeArgmax
+func @RemoveSoftmaxBeforeArgmax(%arg0: tensor<16x1024xf32>) -> tensor<16xi32> {
+ %cst = constant dense<-1> : tensor<1xi32>
+ %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32>
+ %1 = "tfl.arg_max"(%0, %cst) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ return %1 : tensor<16xi32>
+ // CHECK: %[[CST:.*]] = constant dense<-1> : tensor<1xi32>
+ // CHECK: %[[ARG_MAX:.*]] = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ // CHECK: return %[[ARG_MAX]] : tensor<16xi32>
+}
+
+// CHECK-LABEL: RemoveSoftmaxBeforeArgmin
+func @RemoveSoftmaxBeforeArgmin(%arg0: tensor<16x1024xf32>) -> tensor<16xi32> {
+ %cst = constant dense<-1> : tensor<1xi32>
+ %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32>
+ %1 = "tfl.arg_min"(%0, %cst) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ return %1 : tensor<16xi32>
+ // CHECK: %[[CST:.*]] = constant dense<-1> : tensor<1xi32>
+ // CHECK: %[[ARG_MIN:.*]] = "tfl.arg_min"(%arg0, %[[CST]]) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ // CHECK: return %[[ARG_MIN]] : tensor<16xi32>
+}
+
+// CHECK-LABEL: RemoveLogSoftmaxBeforeArgmax
+func @RemoveLogSoftmaxBeforeArgmax(%arg0: tensor<16x1024xf32>) -> tensor<16xi32> {
+ %cst = constant dense<-1> : tensor<1xi32>
+ %0 = "tfl.log_softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32>
+ %1 = "tfl.arg_max"(%0, %cst) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ return %1 : tensor<16xi32>
+ // CHECK: %[[CST:.*]] = constant dense<-1> : tensor<1xi32>
+ // CHECK: %[[ARG_MAX:.*]] = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ // CHECK: return %[[ARG_MAX]] : tensor<16xi32>
+}
+
+// CHECK-LABEL: RemoveLogSoftmaxBeforeArgmin
+func @RemoveLogSoftmaxBeforeArgmin(%arg0: tensor<16x1024xf32>) -> tensor<16xi32> {
+ %cst = constant dense<-1> : tensor<1xi32>
+ %0 = "tfl.log_softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32>
+ %1 = "tfl.arg_min"(%0, %cst) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ return %1 : tensor<16xi32>
+ // CHECK: %[[CST:.*]] = constant dense<-1> : tensor<1xi32>
+ // CHECK: %[[ARG_MIN:.*]] = "tfl.arg_min"(%arg0, %[[CST]]) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ // CHECK: return %[[ARG_MIN]] : tensor<16xi32>
+}
+
+// CHECK-LABEL: DontRemoveSoftmaxNegativeBetaBeforeArgmax
+func @DontRemoveSoftmaxNegativeBetaBeforeArgmax(%arg0: tensor<16x1024xf32>) -> tensor<16xi32> {
+ %cst = constant dense<-1> : tensor<1xi32>
+ %0 = "tfl.softmax"(%arg0) {beta = -1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32>
+ %1 = "tfl.arg_max"(%0, %cst) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ return %1 : tensor<16xi32>
+ // CHECK: %[[CST:.*]] = constant dense<-1> : tensor<1xi32>
+ // CHECK: %[[SOFTMAX:.*]] = "tfl.softmax"(%arg0) {beta = -1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32>
+ // CHECK: %[[ARG_MAX:.*]] = "tfl.arg_max"(%[[SOFTMAX]], %[[CST]]) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ // CHECK: return %[[ARG_MAX]] : tensor<16xi32>
+}
+
+// CHECK-LABEL: DontRemoveSoftmaxNonLastAxisBeforeArgmax
+func @DontRemoveSoftmaxNonLastAxisBeforeArgmax(%arg0: tensor<16x1024xf32>) -> tensor<16xi32> {
+ %cst = constant dense<0> : tensor<1xi32>
+ %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32>
+ %1 = "tfl.arg_max"(%0, %cst) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ return %1 : tensor<16xi32>
+ // CHECK: %[[CST:.*]] = constant dense<0> : tensor<1xi32>
+ // CHECK: %[[SOFTMAX:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32>
+ // CHECK: %[[ARG_MAX:.*]] = "tfl.arg_max"(%[[SOFTMAX]], %[[CST]]) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32>
+ // CHECK: return %[[ARG_MAX]] : tensor<16xi32>
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 9e0a880..06a92f2 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -66,15 +66,19 @@
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
// OK
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
- // Unsupported training
- %1:6 = "tf.FusedBatchNormV3"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ // Training with non-broadcastable shape
+ %cst = constant dense<0.0> : tensor<4xf32>
+ %1:6 = "tf.FusedBatchNormV3"( %0#0, %cst, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<4xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ // Inference with non-broadcastable shape
+ %2:6 = "tf.FusedBatchNormV3"( %1#0, %cst, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<4xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Use other output
- %2:6 = "tf.FusedBatchNormV3"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ %3:6 = "tf.FusedBatchNormV3"( %2#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
- return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
+ return %3, %3#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
// CHECK-LABEL: fusedBatchNormV3
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
+// CHECK: %[[CONSTANT1:.*]] = constant dense<0.000000e+00> : tensor<4xf32>
// variance + epsilon
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
// rsqrt(variance + epsilon)
@@ -90,11 +94,12 @@
// x * scale * rsqrt(variance + epsilon) +
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
-
-// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
-// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[CONSTANT1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK: %[[BATCHNORM1_b:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[CONSTANT1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_b]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
}
+
func @batchNormWithGlobalNormalization(
%t:tensor<1x10x10x3xf32>, %m:tensor<3xf32>, %v:tensor<3xf32>, %beta:tensor<3xf32>, %gamma:tensor<3xf32>) -> (tensor<1x10x10x3xf32>) {
%0 = "tf.BatchNormWithGlobalNormalization"(%t, %m, %v, %beta, %gamma) {T = "tfdtype$DT_FLOAT", variance_epsilon = 0.001 : f32, scale_after_normalization = false} : (tensor<1x10x10x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<1x10x10x3xf32>)
@@ -779,4 +784,25 @@
// CHECK: "tf.StridedSlice"
}
+func @fused_batch_norm_v3_training(%arg0 : tensor<1x1x6x2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>) -> tensor<1x1x6x2xf32> {
+ %0, %1, %2, %3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", epsilon = 1.000000e-03 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true} : (tensor<1x1x6x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x1x6x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<*xf32>)
+ return %0 : tensor<1x1x6x2xf32>
+ // CHECK-LABEL: fused_batch_norm_v3_training
+ // CHECK: %[[CST:.*]] = constant dense<[0, 1, 2]> : tensor<3xi64>
+ // CHECK: %[[CST0:.*]] = constant dense<0.166666672> : tensor<1xf32>
+ // CHECK: %[[CST1:.*]] = constant dense<1.000000e-03> : tensor<f32>
+ // CHECK: %[[SUM:.*]] = "tf.Sum"(%arg0, %[[CST]]) {keep_dims = false} : (tensor<1x1x6x2xf32>, tensor<3xi64>) -> tensor<2xf32>
+ // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[SUM]], %[[CST0]]) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
+ // CHECK: %[[SQ:.*]] = "tf.SquaredDifference"(%arg0, %[[MUL]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
+ // CHECK: %[[SUM0:.*]] = "tf.Sum"(%[[SQ]], %[[CST]]) {keep_dims = false} : (tensor<1x1x6x2xf32>, tensor<3xi64>) -> tensor<2xf32>
+ // CHECK: %[[MUL0:.*]] = "tf.Mul"(%[[SUM0]], %[[CST0]]) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
+ // CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL0]], %[[CST1]]) : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
+ // CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD]]) : (tensor<2xf32>) -> tensor<2xf32>
+ // CHECK: %[[MUL1:.*]] = "tf.Mul"(%arg1, %[[RSQRT]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ // CHECK: %[[MUL2:.*]] = "tf.Mul"(%arg0, %[[MUL1]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
+ // CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[MUL]], %[[MUL1]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ // CHECK: %[[SUB:.*]] = "tf.Sub"(%arg2, %[[MUL3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ // CHECK: %[[ADD0:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
+ // CHECK: return %[[ADD0]] : tensor<1x1x6x2xf32>
+}
}
diff --git a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir
index 0d7c21d..d6ca4c3 100644
--- a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir
@@ -190,3 +190,33 @@
// CHECK: (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<?xf32>, tensor<i32>)
return %0#1 : tensor<?xf32>
}
+
+// -----
+
+func @unsupportedCast(%arg0: tensor<4x4x3xf32>) -> tensor<*xf32> {
+ %cst = constant dense<0.000000e+00> : tensor<4x2xf32>
+ %cst_0 = constant dense<0.000000e+00> : tensor<4x4x3xf64>
+ %cst_1 = constant dense<[1, 0, 2]> : tensor<3xi32>
+ %cst_2 = constant dense<0.000000e+00> : tensor<4x4x2xf32>
+ %cst_3 = constant dense<4> : tensor<i32>
+ %cst_4 = constant dense<0> : tensor<i32>
+ %cst_5 = constant dense<0.000000e+00> : tensor<4x2xf64>
+ %0 = "tfl.transpose"(%arg0, %cst_1) : (tensor<4x4x3xf32>, tensor<3xi32>) -> tensor<4x4x3xf32>
+ %1:6 = "tfl.while"(%cst_4, %cst_4, %cst_2, %cst, %cst_5, %cst_0) ( {
+ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<*xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x2xf64>, %arg6: tensor<*xf64>): // no predecessors
+ %5 = "tfl.less"(%arg2, %cst_3) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %6 = "tfl.less"(%arg1, %cst_3) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %7 = tfl.logical_and %6, %5 : tensor<i1>
+ "tfl.yield"(%7) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<*xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x2xf64>, %arg6: tensor<*xf64>): // no predecessors
+ "tfl.yield"(%arg1, %arg2, %arg3, %arg4, %arg5, %cst_0) : (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<4x4x3xf64>) -> ()
+ }) {is_stateless = true} : (tensor<i32>, tensor<i32>, tensor<4x4x2xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<4x4x3xf64>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<*xf32>)
+ return %1#2 : tensor<*xf32>
+}
+
+// CHECK-LABEL: func @unsupportedCast(
+
+// CHECK-LABEL: func private @tfl.while_body(
+// CHECK-SAME: %arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<*xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf64>, %arg5: tensor<*xf64>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<*xf64>)
+// CHECK: [[VAL:%.*]] = "tf.Cast"
diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
index e131266..a9a192b 100644
--- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
@@ -172,20 +172,29 @@
// Write MLIR TFLite dialect into FlatBuffer
OpOrArgLocNameMapper op_or_arg_name_mapper;
if (!quant_specs.RunWeightQuantization()) {
- if (tflite::MlirToFlatBufferTranslateFunction(
- module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
- emit_custom_ops, select_user_tf_ops, saved_model_tags,
- &op_or_arg_name_mapper)) {
+ tflite::FlatbufferExportOptions options;
+ options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
+ options.emit_select_tf_ops = emit_select_tf_ops;
+ options.select_user_tf_ops = select_user_tf_ops;
+ options.emit_custom_ops = emit_custom_ops;
+ options.saved_model_tags = saved_model_tags;
+ options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
+ if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
return statusHandler.ConsumeStatus();
}
} else {
// Post-training weight quantization path. Once MLIR has support for this,
// we can remove this else statement.
std::string pre_quantized_result;
- if (tflite::MlirToFlatBufferTranslateFunction(
- module, &pre_quantized_result, emit_builtin_tflite_ops,
- emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
- saved_model_tags, &op_or_arg_name_mapper)) {
+ tflite::FlatbufferExportOptions options;
+ options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
+ options.emit_select_tf_ops = emit_select_tf_ops;
+ options.select_user_tf_ops = select_user_tf_ops;
+ options.emit_custom_ops = emit_custom_ops;
+ options.saved_model_tags = saved_model_tags;
+ options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
+ if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
+ &pre_quantized_result)) {
return statusHandler.ConsumeStatus();
}
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index c12b0dd..cce5801 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -367,8 +367,6 @@
def IsTailOfShape : Constraint<CPred<
"TFL::IsTailOfShape($0.getType(), $1.getType())">>;
-def HaveSameType : Constraint<CPred<"$0.getType(), $1.getType()">>;
-
// Pattern for skipping Tile if it is mainly for broadcasting and the
// Op is already supporting broadcasting.
multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
@@ -668,3 +666,69 @@
(AxesIsLastDimension $axes, $max_input),
(HasOneUse $sub),
(HasOneUse $max)]>;
+
+def HaveSameType : Constraint<CPred<"($0.getType() == $1.getType())">>;
+
+class AllElementsAreF32<string val> : Constraint<CPred<
+ "($0.isa<DenseElementsAttr>() && "
+ "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isF32() && "
+ "std::all_of($0.cast<DenseElementsAttr>().getValues<float>().begin(), "
+ "$0.cast<DenseElementsAttr>().getValues<float>().end(), "
+ "[](float v){ return v == " #val# ";}))">>;
+
+// Optimize X*1 to X
+def OptimizeMul1ToIdentity : Pat<
+ (TFL_MulOp $input,
+ (ConstantOp $constant),
+ TFL_AF_None),
+ (replaceWithValue $input),
+ [(HaveSameType $input, $constant),
+ (AllElementsAreF32<"1.0f"> $constant)]>;
+
+class AllElementsAreBool<string val> : Constraint<CPred<
+ "($0.isa<DenseElementsAttr>() && "
+ "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(1) && "
+ "std::all_of($0.cast<DenseElementsAttr>().getValues<bool>().begin(), "
+ "$0.cast<DenseElementsAttr>().getValues<bool>().end(), "
+ "[](bool v){ return v == " #val# ";}))">>;
+
+// Remove select operators when the result is known in advance.
+foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in {
+ // select(true_tensor, A, B) -> A
+ def Optimize#SelectOp#True : Pat<
+ (SelectOp (ConstantOp $constant),
+ $input1,
+ $input2),
+ (replaceWithValue $input1),
+ [(HaveSameType $input1, $input2),
+ (IsTailOfShape $input1, $constant),
+ (IsTailOfShape $constant, $input1),
+ (AllElementsAreBool<"true"> $constant)]>;
+ // select(false_tensor, A, B) -> B
+ def Optimize#SelectOp#False : Pat<
+ (SelectOp (ConstantOp $constant),
+ $input1,
+ $input2),
+ (replaceWithValue $input2),
+ [(HaveSameType $input1, $input2),
+ (IsTailOfShape $input1, $constant),
+ (IsTailOfShape $constant, $input1),
+ (AllElementsAreBool<"false"> $constant)]>;
+}
+
+// Remove (log-)softmax before arg-minmax as (log-)softmax is monotonic.
+foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in {
+ def RemoveSoftmaxOpBefore#ArgMinMaxOp : Pat<
+ (ArgMinMaxOp (TFL_SoftmaxOp:$softmax $logits, TFL_FloatNonNegative:$beta),
+ (ConstantOp:$const_axes I32ElementsAttr:$axes)),
+ (ArgMinMaxOp $logits, $const_axes),
+ [(HasOneUse $softmax),
+ (AxesIsLastDimension $axes, $logits)]>;
+
+ def RemoveLogSoftmaxOpBefore#ArgMinMaxOp : Pat<
+ (ArgMinMaxOp (TFL_LogSoftmaxOp:$log_softmax $logits),
+ (ConstantOp:$const_axes I32ElementsAttr:$axes)),
+ (ArgMinMaxOp $logits, $const_axes),
+ [(HasOneUse $log_softmax),
+ (AxesIsLastDimension $axes, $logits)]>;
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 5271eba..0be7051 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -926,6 +926,55 @@
// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
+//
+// When is_training is set to true, the given variance and mean are not used.
+// In above calculation, they are replaced by new values. These new mean and
+// variance are calculated as following:
+// rest_size = shape(x)[0] * shape(x)[1] * shape(x)[2]
+// new_mean = sum(x, axis=[0, 1, 2]) / rest_size
+// new_variance = sum(squared_difference(x, new_mean), axis=[0, 1, 2])
+// / rest_size
+//
+// The DDR rule for the is_training equals true case is as following:
+// def : Pattern<
+// (TF_FusedBatchNormV3Op:$root
+// $x, $scale, $offset, $mean, $variance,
+// F32Attr:$epsilon, $exponential_avg_factor,
+// $data_format, FalseBoolAttr:$is_training),
+// [(TF_AddOp
+// (TF_MulOp
+// $x,
+// (TF_MulOp:$multiplier
+// $scale,
+// (TF_RsqrtOp
+// (TF_AddOp
+// (TF_DivOp:$new_variance
+// (TF_SumOp
+// (TF_SquaredDifferenceOp $x, $new_mean),
+// (TF_ConstOp [0,1,2])),
+// $rest_size),
+// (TF_ConstOp $epsilon))))),
+// (TF_SubOp
+// $offset,
+// (TF_MulOp
+// (TF_DivOp:$new_mean
+// (TF_SumOp $x, (TF_ConstOp [0,1,2])),
+// (TF_ProdOp:$rest_size
+// (TF_SliceOp
+// (TF_ShapeOp $x),
+// (TF_ConstOp 0),
+// (TF_ConstOp 3)))),
+// $multiplier))),
+// // We already guaranteed that the last five results have no use so it does
+// // not matter what value we provide here for replacement.
+// /*batch_mean=*/(replaceWithValue $x),
+// /*batch_variance=*/(replaceWithValue $x),
+// /*reserve_space_1=*/(replaceWithValue $x),
+// /*reserve_space_2=*/(replaceWithValue $x),
+// /*reserve_space_3=*/(replaceWithValue $x)],
+// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
+// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
+// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
@@ -940,7 +989,6 @@
// Variables for capturing values and attributes used for creating ops
Operation::operand_range mean(fused_batch_norm->getOperands());
::mlir::FloatAttr exponential_avg_factor;
- ::mlir::StringAttr data_format;
::mlir::TF::FusedBatchNormV3Op root;
Operation::operand_range offset(fused_batch_norm->getOperands());
Operation::operand_range x(fused_batch_norm->getOperands());
@@ -959,6 +1007,9 @@
mean = fused_batch_norm_op.getODSOperands(3);
variance = fused_batch_norm_op.getODSOperands(4);
+ ::mlir::Value mean_value = (*mean.begin());
+ ::mlir::Value variance_value = (*variance.begin());
+
if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
{
@@ -984,25 +1035,9 @@
exponential_avg_factor =
rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
}
- {
- data_format =
- fused_batch_norm_op->getAttrOfType<::mlir::StringAttr>("data_format");
- if (!data_format) data_format = rewriter.getStringAttr("NHWC");
- }
- {
- is_training =
- fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
- if (!is_training) is_training = rewriter.getBoolAttr(true);
-
- if (!((!is_training.getValue()))) {
- return rewriter.notifyMatchFailure(
- fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
- diag << "op 'tf.FusedBatchNormV3' attribute 'is_training' failed "
- "to "
- "satisfy constraint: FalseBoolAttr";
- });
- }
- }
+ if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
+ !TFDataFormatIsNDHWC(fused_batch_norm_op))
+ return failure();
if (!(((*root.getODSResults(1).begin()).use_empty()))) {
return rewriter.notifyMatchFailure(
@@ -1038,8 +1073,140 @@
diag << "entities '' failed to satisfy constraint: has no use";
});
}
- // Rewrite
+
+ is_training =
+ fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
+
+ // We need to make sure input and output shapes are compatible.
+ {
+ int64_t last_dim = -1;
+ auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
+ auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
+ if (!v_type) return true;
+ int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
+ if (v_last_dim == -1) return true;
+ if (last_dim != -1 && v_last_dim != last_dim) return false;
+ last_dim = v_last_dim;
+ return true;
+ };
+
+ if (!is_last_dim_compatible(*x.begin(), last_dim) ||
+ !is_last_dim_compatible(*scale.begin(), last_dim) ||
+ !is_last_dim_compatible(*offset.begin(), last_dim)) {
+ return rewriter.notifyMatchFailure(
+ fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
+ diag << "Shapes of scale and offset should be 1D and "
+ "compatible with x";
+ });
+ }
+
+ if (!is_training.getValue()) {
+ if (!is_last_dim_compatible(mean_value, last_dim) ||
+ !is_last_dim_compatible(variance_value, last_dim)) {
+ return rewriter.notifyMatchFailure(
+ fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
+ diag << "Shapes of mean and variance should be 1D and "
+ "compatible with x";
+ });
+ }
+ }
+
+ // Check if output shape and input shape are compatible.
+ auto x_type = (*x.begin()).getType();
+ auto y_type = (*root.getODSResults(0).begin()).getType();
+ if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
+ return rewriter.notifyMatchFailure(
+ fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
+ diag << "Shapes of x and the first output should be compatible";
+ });
+ }
+ }
+
+ // For training, mean and variance is calculated from input values.
+ if (is_training.getValue()) {
+ auto input_type = fused_batch_norm_op.x()
+ .getType()
+ .dyn_cast_or_null<RankedTensorType>();
+ if (!input_type || input_type.getRank() != 4 ||
+ !input_type.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
+ diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
+ "True is only supported with static input shape";
+ });
+ }
+
+ ::mlir::TF::ConstOp reduce_dim_op;
+ {
+ auto reduce_dim_type =
+ ::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(64));
+ ::mlir::SmallVector<int64_t, 3> reduce_dim_values = {0, 1, 2};
+ reduce_dim_op = rewriter.create<TF::ConstOp>(
+ odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
+ reduce_dim_values));
+ }
+
+ ::mlir::TF::ConstOp rest_size_inv_op;
+ {
+ int64_t rest_size = input_type.getDimSize(0) *
+ input_type.getDimSize(1) * input_type.getDimSize(2);
+ auto rest_size_inv_type =
+ ::mlir::RankedTensorType::get({1}, rewriter.getF32Type());
+ auto rest_size_inv_attr = ::mlir::DenseFPElementsAttr::get(
+ rest_size_inv_type, {1.0f / rest_size});
+ rest_size_inv_op =
+ rewriter.create<::mlir::TF::ConstOp>(odsLoc, rest_size_inv_attr);
+ }
+
+ ::mlir::TF::SumOp sum_op_1;
+ {
+ ::mlir::Value x_value = (*x.begin());
+ sum_op_1 = rewriter.create<TF::SumOp>(
+ odsLoc, x_value, reduce_dim_op,
+ /*keep_dims=*/rewriter.getBoolAttr(false));
+ }
+
+ ::mlir::TF::MulOp mul_op_1;
+ {
+ ::mlir::Value tblgen_value_0 = (*sum_op_1.getODSResults(0).begin());
+ ::mlir::Value tblgen_value_1 =
+ (*rest_size_inv_op.getODSResults(0).begin());
+ mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
+ tblgen_value_1);
+ }
+
+ ::mlir::TF::SquaredDifferenceOp square_diff_op;
+ {
+ ::mlir::Value tblgen_value_0 = (*x.begin());
+ ::mlir::Value tblgen_value_1 = (*mul_op_1.getODSResults(0).begin());
+ // If x has shape of [b, h, w, c], the result of mul_op_1 will have
+ // shape of [c]. Therefore, their shapes are always compatible.
+ square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
+ odsLoc, tblgen_value_0, tblgen_value_1);
+ }
+
+ ::mlir::TF::SumOp sum_op_2;
+ {
+ ::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
+ sum_op_2 = rewriter.create<TF::SumOp>(
+ odsLoc, input_value, reduce_dim_op,
+ /*keep_dims=*/rewriter.getBoolAttr(false));
+ }
+
+ ::mlir::TF::MulOp mul_op_2;
+ {
+ ::mlir::Value tblgen_value_0 = (*sum_op_2.getODSResults(0).begin());
+ ::mlir::Value tblgen_value_1 =
+ (*rest_size_inv_op.getODSResults(0).begin());
+ mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
+ tblgen_value_1);
+ }
+
+ mean_value = (*mul_op_1.getODSResults(0).begin());
+ variance_value = (*mul_op_2.getODSResults(0).begin());
+ } // End is_training equals true if.
+
::llvm::SmallVector<::mlir::Value, 4> replace_values;
::mlir::TF::ConstOp epsilon_const_op;
{
@@ -1049,17 +1216,12 @@
}
::mlir::TF::AddOp add_op_1;
{
- ::mlir::Value tblgen_value_0 = (*variance.begin());
- ::mlir::Value tblgen_value_1 =
+ ::mlir::Value epsilon_value =
(*epsilon_const_op.getODSResults(0).begin());
+ // Multiplying with a constant, no need to check broadcastibility.
add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
- /*x=*/tblgen_value_0,
- /*y=*/tblgen_value_1);
- // We need to make sure the Add operands are broadcastable.
- if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
- add_op_1))) {
- return failure();
- }
+ /*x=*/variance_value,
+ /*y=*/epsilon_value);
}
::mlir::TF::RsqrtOp rsqrt_op;
{
@@ -1073,14 +1235,9 @@
{
::mlir::Value tblgen_value_0 = (*scale.begin());
::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
- // We need to make sure the Add operands are broadcastable.
multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
/*x=*/tblgen_value_0,
/*y=*/tblgen_value_1);
- if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
- multiplier))) {
- return failure();
- }
}
::mlir::TF::MulOp mul_op_1;
{
@@ -1089,23 +1246,13 @@
mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
/*x=*/tblgen_value_0,
/*y=*/tblgen_value_1);
- // We need to make sure the Mul operands are broadcastable.
- if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
- mul_op_1))) {
- return failure();
- }
}
::mlir::TF::MulOp mul_op_2;
{
- ::mlir::Value tblgen_value_0 = (*mean.begin());
- ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
+ ::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
- /*x=*/tblgen_value_0,
- /*y=*/tblgen_value_1);
- if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
- mul_op_2))) {
- return failure();
- }
+ /*x=*/mean_value,
+ /*y=*/multiplier_value);
}
::mlir::TF::SubOp sub_op;
{
@@ -1114,10 +1261,6 @@
sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
/*x=*/tblgen_value_0,
/*y=*/tblgen_value_1);
- if (failed(
- mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(sub_op))) {
- return failure();
- }
}
::mlir::TF::AddOp add_op_2;
{
@@ -1131,11 +1274,6 @@
}
add_op_2 = rewriter.create<::mlir::TF::AddOp>(
odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
- // We need to make sure the Add operands are broadcastable.
- if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
- add_op_2))) {
- return failure();
- }
}
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
index 83d4ac3..f136232 100644
--- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
@@ -30,6 +30,7 @@
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
@@ -72,6 +73,29 @@
return just_call(while_op.body()) && just_call(while_op.cond());
}
+bool IsCompatibleTypeWithTFLCastOp(Type type) {
+ auto elemType = getElementTypeOrSelf(type);
+ // F32 and BF16 types are allowed.
+ if (elemType.isBF16() || elemType.isF32()) return true;
+
+ // I1, I16, I32, I64 types are allowed.
+ if (elemType.isInteger(1) || elemType.isInteger(16) ||
+ elemType.isInteger(32) || elemType.isInteger(64))
+ return true;
+
+ // Complex<F<32>> is allowed.
+ if (elemType.isa<ComplexType>() &&
+ elemType.cast<ComplexType>().getElementType().isF32())
+ return true;
+
+ // QUINT8 and UI8 are allowed.
+ if (elemType.isa<TF::Quint8Type>() ||
+ (elemType.isInteger(8) && elemType.cast<IntegerType>().isUnsigned()))
+ return true;
+
+ return false;
+}
+
void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
OpBuilder builder(&getContext());
// Collect external values used.
@@ -171,8 +195,14 @@
if (value.getType() == type) {
args.push_back(value);
} else {
- auto cast = b.create<CastOp>(yield_op->getLoc(), type, value);
- args.push_back(cast);
+ if (IsCompatibleTypeWithTFLCastOp(value.getType()) &&
+ IsCompatibleTypeWithTFLCastOp(type)) {
+ auto cast = b.create<CastOp>(yield_op->getLoc(), type, value);
+ args.push_back(cast);
+ } else {
+ auto cast = b.create<TF::CastOp>(yield_op->getLoc(), type, value);
+ args.push_back(cast);
+ }
}
}
args.append(new_args.begin(), new_args.end());
diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc
index 811796b..733a5a3 100644
--- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc
+++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc
@@ -41,6 +41,8 @@
return builder.getF64Type();
case tflite::TensorType_INT32:
return builder.getIntegerType(32);
+ case tflite::TensorType_UINT32:
+ return builder.getIntegerType(32, /*isSigned=*/false);
case tflite::TensorType_UINT8:
return builder.getIntegerType(8, /*isSigned=*/false);
case tflite::TensorType_INT64:
@@ -86,6 +88,8 @@
return tensorflow::DT_INT16;
case tflite::TensorType_INT32:
return tensorflow::DT_INT32;
+ case tflite::TensorType_UINT32:
+ return tensorflow::DT_UINT32;
case tflite::TensorType_INT64:
return tensorflow::DT_INT64;
case tflite::TensorType_STRING:
@@ -121,6 +125,8 @@
return tflite::TensorType_INT16;
case tensorflow::DT_INT32:
return tflite::TensorType_INT32;
+ case tensorflow::DT_UINT32:
+ return tflite::TensorType_UINT32;
case tensorflow::DT_INT64:
return tflite::TensorType_INT64;
case tensorflow::DT_UINT64:
diff --git a/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc b/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc
index ac3e59e..a3239aa 100644
--- a/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc
+++ b/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc
@@ -52,7 +52,7 @@
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto,
- bool record_stats) {
+ bool uses_uninitialized_resource_args, bool record_stats) {
switch (GetUserRequest(config_proto)) {
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
return MlirBridgeRolloutPolicy::kEnabledByUser;
diff --git a/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h b/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h
index f029ad8..4fff6b7 100644
--- a/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h
+++ b/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h
@@ -51,7 +51,7 @@
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
const tensorflow::Graph& graph,
absl::optional<tensorflow::ConfigProto> config_proto,
- bool record_stats = false);
+ bool uses_uninitialized_resource_args, bool record_stats = false);
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
index 9cb4428..0723471 100644
--- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
+++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
@@ -166,7 +166,11 @@
// TODO(b/176852151): Remove this after dark launch completed.
// Capture stats relevant to graph properties used in dark launch.
- GetMlirBridgeRolloutPolicy(**graph, config_proto, /*record_stats=*/true);
+ // We set `uses_uninitialized_resource_args` to false here because function
+ // optimization is not affected by uninitialized resource args.
+ GetMlirBridgeRolloutPolicy(**graph, config_proto,
+ /*uses_uninitialized_resource_args=*/false,
+ /*record_stats=*/true);
if (overall_state == MlirOptimizationPassState::Disabled) {
LOG_FIRST_N(INFO, 1) << "None of the MLIR Optimization Passes are enabled "
@@ -182,8 +186,9 @@
<< ", Total: " << registry_->passes().size();
GraphDebugInfo debug_info;
- mlir::MLIRContext context;
- RegisterDialects(context.getDialectRegistry());
+ mlir::DialectRegistry registry;
+ RegisterDialects(registry);
+ mlir::MLIRContext context(registry);
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
@@ -342,8 +347,9 @@
<< " passes)";
GraphDebugInfo debug_info;
- mlir::MLIRContext context;
- RegisterDialects(context.getDialectRegistry());
+ mlir::DialectRegistry registry;
+ RegisterDialects(registry);
+ mlir::MLIRContext context(registry);
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
// Restrict functionalization to TPU nodes to avoid problems in v1 session
diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc
index ffd6bb0..c0d80c6 100644
--- a/tensorflow/compiler/mlir/python/mlir.cc
+++ b/tensorflow/compiler/mlir/python/mlir.cc
@@ -238,8 +238,9 @@
const std::string &pass_pipeline,
bool show_debug_info,
TF_Status *status) {
- mlir::MLIRContext context;
- mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
+ mlir::DialectRegistry registry;
+ mlir::RegisterAllTensorFlowDialects(registry);
+ mlir::MLIRContext context(registry);
mlir::OwningModuleRef module;
{
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
index 6cd49cf..b83f7f9 100644
--- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
+++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
@@ -30,17 +30,20 @@
PYBIND11_MODULE(mlir_wrapper, m) {
m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) {
- mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
- context.getDialectRegistry().loadAll(&context);
+ mlir::DialectRegistry registry;
+ mlir::RegisterAllTensorFlowDialects(registry);
+ context.appendDialectRegistry(registry);
+ context.loadAllAvailableDialects();
});
m.def("verify", [](std::string input) {
llvm::SourceMgr SM = llvm::SourceMgr();
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
llvm::SMLoc());
- mlir::MLIRContext ctx;
- mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
- ctx.getDialectRegistry().loadAll(&ctx);
+ mlir::DialectRegistry registry;
+ mlir::RegisterAllTensorFlowDialects(registry);
+ mlir::MLIRContext ctx(registry);
+ ctx.loadAllAvailableDialects();
auto module = mlir::parseSourceFile(SM, &ctx);
if (!module) {
return false;
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 0796016..028288a 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -698,6 +698,7 @@
":decompose_resource_ops_inc_gen",
":tensorflow",
":tensorflow_types",
+ "//tensorflow/core:framework",
"@llvm-project//mlir:IR",
],
)
diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
index c13dd8a..5405386 100644
--- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
+++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
@@ -77,8 +77,10 @@
namespace {
void RegisterDialects(mlir::MLIRContext& ctx) {
- mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
- ctx.getDialectRegistry().loadAll(&ctx);
+ mlir::DialectRegistry registry;
+ mlir::RegisterAllTensorFlowDialects(registry);
+ ctx.appendDialectRegistry(registry);
+ ctx.loadAllAvailableDialects();
}
Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 9b3aab5..60f727d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -8389,7 +8389,7 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
-def TF_MirrorPadOp : TF_Op<"MirrorPad", [NoSideEffect]> {
+def TF_MirrorPadOp : TF_Op<"MirrorPad", [NoSideEffect, TF_OperandHasRank<1, 2>]> {
let summary = "Pads a tensor with mirrored values.";
let description = [{
@@ -8436,7 +8436,7 @@
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
}
-def TF_MirrorPadGradOp : TF_Op<"MirrorPadGrad", [NoSideEffect]> {
+def TF_MirrorPadGradOp : TF_Op<"MirrorPadGrad", [NoSideEffect, TF_OperandHasRank<1, 2>]> {
let summary = [{
Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor.
}];
@@ -9315,7 +9315,7 @@
let hasFolder = 1;
}
-def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
+def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_OperandHasRank<1, 2>]> {
let summary = "Pads a tensor with zeros.";
let description = [{
@@ -9363,7 +9363,7 @@
}];
}
-def TF_PadV2Op : TF_Op<"PadV2", [NoSideEffect]> {
+def TF_PadV2Op : TF_Op<"PadV2", [NoSideEffect, TF_OperandHasRank<1, 2>]> {
let summary = "Pads a tensor.";
let description = [{
@@ -11823,6 +11823,44 @@
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
}
+def TF_RestoreOp : TF_Op<"Restore", []> {
+ let summary = "Restores a tensor from checkpoint files.";
+
+ let description = [{
+Reads a tensor stored in one or several files. If there are several files (for
+instance because a tensor was saved as slices), `file_pattern` may contain
+wildcard symbols (`*` and `?`) in the filename portion only, not in the
+directory portion.
+
+If a `file_pattern` matches several files, `preferred_shard` can be used to hint
+in which file the requested tensor is likely to be found. This op will first
+open the file at index `preferred_shard` in the list of matching files and try
+to restore tensors from that file. Only if some tensors or tensor slices are
+not found in that first file, then the Op opens all the files. Setting
+`preferred_shard` to match the value passed as the `shard` input
+of a matching `Save` Op may speed up Restore. This attribute only affects
+performance, not correctness. The default value -1 means files are processed in
+order.
+
+See also `RestoreSlice`.
+ }];
+
+ let arguments = (ins
+ Arg<TF_StrTensor, [{Must have a single element. The pattern of the files from
+which we read the tensor.}]>:$file_pattern,
+ Arg<TF_StrTensor, [{Must have a single element. The name of the tensor to be
+restored.}]>:$tensor_name,
+
+ DefaultValuedAttr<I64Attr, "-1">:$preferred_shard
+ );
+
+ let results = (outs
+ Res<TF_Tensor, [{The restored tensor.}]>:$tensor
+ );
+
+ TF_DerivedResultTypeAttr dt = TF_DerivedResultTypeAttr<0>;
+}
+
def TF_RestoreV2Op : TF_Op<"RestoreV2", []> {
let summary = "Restores tensors from a V2 checkpoint.";
@@ -12608,6 +12646,27 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
+def TF_RngReadAndSkipOp : TF_Op<"RngReadAndSkip", []> {
+ let summary = "Advance the counter of a counter-based RNG.";
+
+ let description = [{
+The state of the RNG after
+`rng_read_and_skip(n)` will be the same as that after `uniform([n])`
+(or any other distribution). The actual increment added to the
+counter is an unspecified implementation choice.
+ }];
+
+ let arguments = (ins
+ TF_ResourceTensor:$resource,
+ TF_Int32Tensor:$alg,
+ TF_Uint64Tensor:$delta
+ );
+
+ let results = (outs
+ TF_Int64Tensor:$value
+ );
+}
+
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
let summary = "Rolls the elements of a tensor along an axis.";
@@ -12716,6 +12775,69 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
+def TF_SaveOp : TF_Op<"Save", []> {
+ let summary = "Saves the input tensors to disk.";
+
+ let description = [{
+The size of `tensor_names` must match the number of tensors in `data`. `data[i]`
+is written to `filename` with name `tensor_names[i]`.
+
+See also `SaveSlices`.
+ }];
+
+ let arguments = (ins
+ Arg<TF_StrTensor, [{Must have a single element. The name of the file to which we write
+the tensor.}]>:$filename,
+ Arg<TF_StrTensor, [{Shape `[N]`. The names of the tensors to be saved.}]>:$tensor_names,
+ Arg<Variadic<TF_Tensor>, [{`N` tensors to save.}]>:$data
+ );
+
+ let results = (outs);
+
+ TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<2>;
+}
+
+def TF_SaveSlicesOp : TF_Op<"SaveSlices", []> {
+ let summary = "Saves input tensors slices to disk.";
+
+ let description = [{
+This is like `Save` except that tensors can be listed in the saved file as being
+a slice of a larger tensor. `shapes_and_slices` specifies the shape of the
+larger tensor and the slice that this tensor covers. `shapes_and_slices` must
+have as many elements as `tensor_names`.
+
+Elements of the `shapes_and_slices` input must either be:
+
+* The empty string, in which case the corresponding tensor is
+ saved normally.
+* A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the
+ `dimI` are the dimensions of the larger tensor and `slice-spec`
+ specifies what part is covered by the tensor to save.
+
+`slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1`
+where each `sliceI` is either:
+
+* The string `-` meaning that the slice covers all indices of this dimension
+* `start,length` where `start` and `length` are integers. In that
+ case the slice covers `length` indices starting at `start`.
+
+See also `Save`.
+ }];
+
+ let arguments = (ins
+ Arg<TF_StrTensor, [{Must have a single element. The name of the file to which we write the
+tensor.}]>:$filename,
+ Arg<TF_StrTensor, [{Shape `[N]`. The names of the tensors to be saved.}]>:$tensor_names,
+ Arg<TF_StrTensor, [{Shape `[N]`. The shapes and slice specifications to use when
+saving the tensors.}]>:$shapes_and_slices,
+ Arg<Variadic<TF_Tensor>, [{`N` tensors to save.}]>:$data
+ );
+
+ let results = (outs);
+
+ TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<3>;
+}
+
def TF_SaveV2Op : TF_Op<"SaveV2", []> {
let summary = "Saves tensors in V2 checkpoint format.";
@@ -14910,6 +15032,32 @@
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
+def TF_StatelessRandomUniformFullIntV2Op : TF_Op<"StatelessRandomUniformFullIntV2", [NoSideEffect]> {
+ let summary = [{
+Outputs deterministic pseudorandom random integers from a uniform distribution.
+ }];
+
+ let description = [{
+The generated values are uniform integers covering the whole range of `dtype`.
+
+The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
+ }];
+
+ let arguments = (ins
+ Arg<TF_I32OrI64Tensor, [{The shape of the output tensor.}]>:$shape,
+ Arg<TF_Uint64Tensor, [{Key for the counter-based RNG algorithm (shape uint64[1]).}]>:$key,
+ Arg<TF_Uint64Tensor, [{Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.}]>:$counter,
+ Arg<TF_Int32Tensor, [{The RNG algorithm (shape int32[]).}]>:$alg
+ );
+
+ let results = (outs
+ Res<TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>, [{Random values with specified shape.}]>:$output
+ );
+
+ TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<0>;
+ TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
+}
+
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom random integers from a uniform distribution.
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index f613020..99de2cb 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -394,7 +394,12 @@
0DTensorOf<[I1]>:$cond,
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
- BoolAttr:$is_stateless
+ BoolAttr:$is_stateless,
+ // Used to maintain function name when round-tripping
+ // between functional and regional control flow. This can be removed if
+ // the runtime does not require globally unique then/else branch function names.
+ OptionalAttr<StrAttr>:$_then_func_name,
+ OptionalAttr<StrAttr>:$_else_func_name
);
let results = (outs
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 3bb7861..587fb89 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -1628,43 +1628,47 @@
return failure();
}
- // For operands having dynamic shape.
+ // Output always have `num_dims` rank. All dimensions are initialized to
+ // dynamic size and can be partially inferred.
SmallVector<int64_t, 4> return_shape(num_dims, ShapedType::kDynamicSize);
- if (!input_ty.hasStaticShape() || !filter_ty.hasStaticShape()) {
- inferredReturnTypes.assign(
- {RankedTensorType::get(return_shape, input_ty.getElementType())});
- return success();
- }
-
- // Checks the size of each of the output dimension.
- for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
- const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
- int64_t stride = get_int(strides[dim]);
- tensorflow::int64 expected_output_size;
- tensorflow::int64 pad_low;
- tensorflow::int64 pad_high;
- // Retrieve padding, if defined explicitly.
- if (padding == tensorflow::Padding::EXPLICIT) {
- pad_low = get_int(explicit_padding[2 * dim]);
- pad_high = get_int(explicit_padding[2 * dim + 1]);
- }
- // Calculate the expected_output_size.
- tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
- input_ty.getDimSize(dim), filter_ty.getDimSize(i),
- get_int(dilations[dim]), stride, padding, &expected_output_size,
- &pad_low, &pad_high);
- // Return failure if expected_output_size could not be calculated.
- if (!status.ok()) return failure();
- return_shape[dim] = expected_output_size;
- }
-
- // The remaining dimensions can be obtained using utilities from
+ // Output batch and channel dimension can be obtained using utilities from
// tensorflow/core/util/tensor_format.h.
- return_shape[GetTensorBatchDimIndex(num_dims, format)] =
- input_ty.getShape()[GetTensorBatchDimIndex(num_dims, format)];
- return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
- filter_ty.getShape()[GetFilterTensorOutputChannelsDimIndex(
- num_dims, tensorflow::FORMAT_HWIO)];
+ if (input_ty.hasRank()) {
+ return_shape[GetTensorBatchDimIndex(num_dims, format)] =
+ input_ty.getDimSize(GetTensorBatchDimIndex(num_dims, format));
+ }
+ if (filter_ty.hasRank()) {
+ return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
+ filter_ty.getDimSize(GetFilterTensorOutputChannelsDimIndex(
+ num_dims, tensorflow::FORMAT_HWIO));
+ }
+ // Spatial dimensions can be inferred only when both input and filter are
+ // ranked because we need to get their spatial dimensions.
+ if (input_ty.hasRank() && filter_ty.hasRank()) {
+ // Checks the size of each of the output spatial dimensions.
+ for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
+ const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
+ int64_t stride = get_int(strides[dim]);
+ tensorflow::int64 expected_output_size;
+ tensorflow::int64 pad_low;
+ tensorflow::int64 pad_high;
+ // Retrieve padding, if defined explicitly.
+ if (padding == tensorflow::Padding::EXPLICIT) {
+ pad_low = get_int(explicit_padding[2 * dim]);
+ pad_high = get_int(explicit_padding[2 * dim + 1]);
+ }
+ // Skip if input or filter size is dynamic.
+ if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue;
+ // Calculate the expected_output_size.
+ tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
+ input_ty.getDimSize(dim), filter_ty.getDimSize(i),
+ get_int(dilations[dim]), stride, padding, &expected_output_size,
+ &pad_low, &pad_high);
+ // Return failure if expected_output_size could not be calculated.
+ if (!status.ok()) return failure();
+ return_shape[dim] = expected_output_size;
+ }
+ }
inferredReturnTypes.assign(
{RankedTensorType::get(return_shape, input_ty.getElementType())});
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index eba0646..bed926c 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -68,6 +68,7 @@
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -1332,10 +1333,10 @@
// The rest of the dimension sizes can be calculated when block_shape and
// paddings arguments are constant.
- ElementsAttr block_shape_attr;
- ElementsAttr paddings_attr;
- if (matchPattern(block_shape_val, m_Constant(&block_shape_attr)) &&
- matchPattern(paddings_val, m_Constant(&paddings_attr))) {
+ DenseIntElementsAttr block_shape_attr;
+ DenseIntElementsAttr paddings_attr;
+ if (GetValueAsConstant(block_shape_val, block_shape_attr) &&
+ GetValueAsConstant(paddings_val, paddings_attr)) {
int64_t return_batch = input_shape[0];
for (uint64_t i = 0; i < block_rank; ++i) {
// Propagate dynamic dimension.
@@ -1347,10 +1348,10 @@
continue;
}
int64_t paddings_sum =
- paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt() +
- paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
+ paddings_attr.getValue<APInt>({i, 0}).getSExtValue() +
+ paddings_attr.getValue<APInt>({i, 1}).getSExtValue();
int64_t block_shape_i =
- block_shape_attr.getValue({i}).cast<IntegerAttr>().getInt();
+ block_shape_attr.getValue<APInt>({i}).getSExtValue();
return_batch *= block_shape_i;
return_shape[1 + i] = (paddings_sum + input_shape[i + 1]) / block_shape_i;
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
index 44a3989..fccca8b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
@@ -568,3 +568,17 @@
return
}
+
+// -----
+
+// Test that tf.RngReadAndSkip op is decomposed.
+// CHECK-LABEL: func @decompose_rng_read_and_skip_op
+func @decompose_rng_read_and_skip_op(%resource: tensor<!tf.resource<tensor<3xi64>>>) -> tensor<3xi64> {
+ // We rely on the TensorFlow StatefulRandomOpsTest to check it is lowered
+ // correctly.
+ // CHECK-NOT: tf.RngReadAndSkip
+ %alg = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %delta = "tf.Const"() {value = dense<10> : tensor<ui64>} : () -> tensor<ui64>
+ %0 = "tf.RngReadAndSkip"(%resource, %alg, %delta) : (tensor<!tf.resource<tensor<3xi64>>>, tensor<i32>, tensor<ui64>) -> tensor<3xi64>
+ return %0 : tensor<3xi64>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
index b4738ed..187a335 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
@@ -72,3 +72,33 @@
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
// CHECK: %[[V0]] : tensor<5x7xf32>
}
+
+// CHECK-LABEL: @broadcast_batch_matmul_v2_rhs
+func @broadcast_batch_matmul_v2_rhs(%arg0: tensor<17x17x17xf32>, %arg1: tensor<17x24xf32>) -> tensor<17x17x24xf32> {
+ %cst = constant dense<[17, 17, 24]> : tensor<3xi64>
+ %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<17x24xf32>, tensor<3xi64>) -> tensor<17x17x24xf32>
+ %1 = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
+ return %1 : tensor<17x17x24xf32>
+ // CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x24xf32>) -> tensor<17x17x24xf32>
+ // CHECK: %[[V0]] : tensor<17x17x24xf32>
+}
+
+// CHECK-LABEL: @broadcast_batch_matmul_v2_lhs
+func @broadcast_batch_matmul_v2_lhs(%arg0: tensor<17x17xf32>, %arg1: tensor<17x17x24xf32>) -> tensor<17x17x24xf32> {
+ %cst = constant dense<[17, 17, 17]> : tensor<3xi64>
+ %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
+ %1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
+ return %1 : tensor<17x17x24xf32>
+ // CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
+ // CHECK: %[[V0]] : tensor<17x17x24xf32>
+}
+
+// CHECK-LABEL: @broadcast_batch_matmul_v2_failed
+func @broadcast_batch_matmul_v2_failed(%arg0: tensor<17x17x1xf32>, %arg1: tensor<17x17x24xf32>) -> tensor<17x17x24xf32> {
+ %cst = constant dense<[17, 17, 17]> : tensor<3xi64>
+ %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17x1xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
+ %1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
+ return %1 : tensor<17x17x24xf32>
+ // CHECK: %[[V0:.*]] = "tf.BroadcastTo"
+ // CHECK: "tf.BatchMatMulV2"(%[[V0]], %arg1)
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir
index 7918dd9..69dc04e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir
@@ -20,6 +20,8 @@
// CHECK: "tf.Yield"([[Result1]])
// CHECK: _attr0 = 10
// CHECK-SAME: _attr1 = true
+ // CHECK-SAME: _else_func_name = "testIf1Else"
+ // CHECK-SAME: _then_func_name = "testIf1Then"
// CHECK-NOT: attr2 =
// CHECK-NOT: else_branch
// CHECK-SAME: is_stateless = false
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
index 1733049..e0caf6f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
@@ -1597,6 +1597,21 @@
return %0 : tensor<3x5x1x4xf32>
}
+// CHECK-LABEL: func @convert_dot_general_repeated(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x1024xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024x1024xf32>) -> tensor<1x1x1024xf32> {
+// CHECK: %[[VAL_2:.*]] = constant dense<[1, 1024]> : tensor<2xi64>
+// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<1x1024xf32>
+// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV2"(%[[VAL_3]], %[[VAL_1]]) {adj_x = false, adj_y = false} : {{.*}} -> tensor<1x1024xf32>
+// CHECK: %[[VAL_5:.*]] = constant dense<[1, 1, 1024]> : tensor<3xi64>
+// CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : {{.*}} -> tensor<1x1x1024xf32>
+// CHECK: return %[[VAL_6]] : tensor<1x1x1024xf32>
+// CHECK: }
+func @convert_dot_general_repeated(%arg0: tensor<1x1x1024xf32>, %arg1: tensor<1024x1024xf32>) -> tensor<1x1x1024xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1x1024xf32>
+ return %0 : tensor<1x1x1024xf32>
+}
+
// CHECK-LABEL: func @convert_conv2d(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir
index 7ed34b5..d9321f9 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir
@@ -69,6 +69,9 @@
func @same_predicate_no_returns_merged() {
// CHECK: tf_device.cluster
// CHECK: "tf.IfRegion"
+ // CHECK: _else_func_name = "elseFunc1"
+ // CHECK-SAME: _then_func_name = "thenFunc1"
+
// CHECK-NOT: "tf.IfRegion"
"tf_device.cluster"() ( {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@@ -77,13 +80,13 @@
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
- }) {is_stateless = true} : (tensor<i1>) -> ()
+ }) {is_stateless = true, _else_func_name = "elseFunc1", _then_func_name = "thenFunc1"} : (tensor<i1>) -> ()
"tf.IfRegion"(%0) ( {
%2 = "tf.B"() : () -> (tensor<f32>)
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
- }) {is_stateless = true} : (tensor<i1>) -> ()
+ }) {is_stateless = true, _else_func_name = "elseFunc2", _then_func_name = "thenFunc2"} : (tensor<i1>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir
index 6a5d702..bf34550 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir
@@ -1,11 +1,34 @@
// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file | FileCheck %s
// Simple IfRegion
+// CHECK: func private @test_else_name(%arg0: tensor<*xf32>) -> tensor<*xf32>
+// CHECK-NEXT: "tf.Neg"
+// CHECK: func private @test_then_name(%arg0: tensor<*xf32>) -> tensor<*xf32>
+// CHECK-NEXT: "tf.Abs"
+func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.If"
+ // CHECK-SAME: _attr0 = false
+ // CHECK-NOT: attr1
+ // CHECK-SAME: else_branch = @test_else_name
+ // CHECK-SAME: then_branch = @test_then_name
+ %0 = "tf.IfRegion"(%arg0) ({
+ %1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
+ "tf.Yield"(%1) : (tensor<*xf32>) -> ()
+ }, {
+ %2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
+ "tf.Yield"(%2) : (tensor<*xf32>) -> ()
+ }) {is_stateless = true, _attr0 = false, attr1 = "hello", _then_func_name = "test_then_name", _else_func_name = "test_else_name"} : (tensor<i1>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// Simple IfRegion with empty branch names
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: "tf.Neg"
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: "tf.Abs"
-func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
+func @testSimpleEmptyBranchNames(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.If"
// CHECK-SAME: _attr0 = false
// CHECK-NOT: attr1
@@ -17,7 +40,7 @@
}, {
%2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
- }) {is_stateless = true, _attr0 = false, attr1 = "hello"} : (tensor<i1>) -> tensor<*xf32>
+ }) {is_stateless = true, _attr0 = false, attr1 = "hello", _then_func_name = "", _else_func_name = ""} : (tensor<i1>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index d29d0ef..98fe02e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -1045,4 +1045,105 @@
// CHECK-SAME: tensor<i32>
return %arg0 : tensor<*xi32>
}
-}
+
+ // Test conv2d inferReturnTypes can infer some information when input or
+ // filter does not have fully static shape.
+
+ // CHECK-LABEL: func @conv2d_unranked_input_and_filter
+ func @conv2d_unranked_input_and_filter(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<?x?x?x?xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_unranked_filter
+ func @conv2d_unranked_filter(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<256x?x?x?xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_unranked_filter_and_dynamic_batch
+ func @conv2d_unranked_filter_and_dynamic_batch(%arg0: tensor<?x32x32x3xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<?x?x?x?xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x32x32x3xf32>, tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_unranked_input
+ func @conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<?x?x?x16xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_unranked_input_and_dynamic_channel
+ func @conv2d_unranked_input_and_dynamic_channel(%arg0: tensor<*xf32>, %arg1: tensor<3x3x3x?xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<?x?x?x?xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<3x3x3x?xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_dynamic_batch
+ func @conv2d_dynamic_batch(%arg0: tensor<?x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<?x32x32x16xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_dynamic_channel
+ func @conv2d_dynamic_channel(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x?xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<256x32x32x?xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x?xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_fully_dynamic_spatial_dim
+ func @conv2d_fully_dynamic_spatial_dim(%arg0: tensor<256x?x?x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<256x?x?x16xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x?x?x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_partially_dynamic_spatial_dim
+ func @conv2d_partially_dynamic_spatial_dim(%arg0: tensor<256x?x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<256x?x32x16xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x?x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_dynamic_batch_and_partially_dynamic_spatial_dim
+ func @conv2d_dynamic_batch_and_partially_dynamic_spatial_dim(%arg0: tensor<?x?x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<?x?x32x16xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x?x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @conv2d_dynamic_batch_and_fully_dynamic_spatial_dim
+ func @conv2d_dynamic_batch_and_fully_dynamic_spatial_dim(%arg0: tensor<?x?x?x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
+ // CHECK: "tf.Conv2D"
+ // CHECK-SAME: -> tensor<?x?x?x16xf32>
+ %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x?x?x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: check_walking_identity
+ func @check_walking_identity(%arg0 : tensor<1x192x256x128xf32>) {
+ %0 = "tf.Const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32>
+ %1 = "tf.Const"() {value = dense<2> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+ %2 = "tf.Identity"(%1) {device = ""} : (tensor<2x2xi32>) -> tensor<2x2xi32>
+ // CHECK: SpaceToBatchND{{.*}}-> tensor<4x98x130x128xf32>
+ %3 = "tf.SpaceToBatchND"(%arg0, %0, %2) {device = ""} : (tensor<1x192x256x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?x128xf32>
+ return
+ }
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index e226da5..f0dcfd3 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -308,6 +308,34 @@
// -----
+func @testPadRank1Paddings(%input: tensor<2xi64>) -> tensor<3xi64> {
+ %paddings = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
+ // expected-error @+1 {{failed to verify that operand 1 is 2-D}}
+ %0 = "tf.Pad"(%input, %paddings) : (tensor<2xi64>, tensor<2xi64>) -> tensor<3xi64>
+ return %0 : tensor<3xi64>
+}
+
+// -----
+
+func @testPadV2Rank1Paddings(%input: tensor<2xi64>) -> tensor<3xi64> {
+ %constant = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+ %paddings = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
+ // expected-error @+1 {{failed to verify that operand 1 is 2-D}}
+ %0 = "tf.PadV2"(%input, %paddings, %constant) : (tensor<2xi64>, tensor<2xi64>, tensor<i64>) -> tensor<3xi64>
+ return %0 : tensor<3xi64>
+}
+
+// -----
+
+func @testMirrorPadRank1Paddings(%input: tensor<2xi64>) -> tensor<3xi64> {
+ %paddings = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
+ // expected-error @+1 {{failed to verify that operand 1 is 2-D}}
+ %0 = "tf.MirrorPad"(%input, %paddings) { mode = "SYMMETRIC" }: (tensor<2xi64>, tensor<2xi64>) -> tensor<3xi64>
+ return %0 : tensor<3xi64>
+}
+
+// -----
+
// CHECK-LABEL: func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>)
func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<100x100xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>) {
%shape1 = constant dense<100> : tensor<2xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py
index ab786ac..67563fa 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py
@@ -27,16 +27,19 @@
# CHECK: func {{.*}} tf_saved_model.exported_names = ["key_1"]
# CHECK: "tf.If"
-# CHECK-SAME: else_branch = @[[else:[a-zA-Z_0-9]+]]
-# CHECK-SAME: then_branch = @[[then:[a-zA-Z_0-9]+]]
+# CHECK-SAME: else_branch = @[[else_1:"key_1/[a-zA-Z_0-9]+"]]
+# CHECK-SAME: then_branch = @[[then_1:"key_1/[a-zA-Z_0-9]+"]]
# CHECK: func {{.*}} tf_saved_model.exported_names = ["key_2"]
# CHECK: "tf.If"
-# CHECK-SAME: else_branch = @[[else]]
-# CHECK-SAME: then_branch = @[[then]]
+# CHECK-SAME: else_branch = @[[else_2:"key_2/[a-zA-Z_0-9]+"]]
+# CHECK-SAME: then_branch = @[[then_2:"key_2/[a-zA-Z_0-9]+"]]
-# CHECK: func private @[[else]](
-# CHECK: func private @[[then]](
+# CHECK: func private @[[else_1]](
+# CHECK: func private @[[then_1]](
+
+# CHECK: func private @[[else_2]](
+# CHECK: func private @[[then_2]](
def Test():
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
index 55bb3d1..164b523 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
@@ -572,6 +572,8 @@
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
// CHECK-NOT: "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf.Yield"() : () -> ()
+ // CHECK: _else_func_name = "test_else_name"
+ // CHECK-SAME _then_func_name = "test_then_name"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
@@ -595,7 +597,7 @@
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
- }) { is_stateless = false} : (tensor<i1>) -> ()
+ }) { is_stateless = false, _then_func_name = "test_then_name", _else_func_name = "test_else_name"} : (tensor<i1>) -> ()
%5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
index 6fe06f9..8f0acc6 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
@@ -59,7 +59,10 @@
llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
PassManager bridge(module.getContext());
::tensorflow::applyTensorflowAndCLOptions(bridge);
- if (enable_logging) EnableLogging(&bridge);
+ if (enable_logging || VLOG_IS_ON(1)) {
+ tensorflow::DumpMlirOpToFile("tpu_bridge_before", module);
+ if (VLOG_IS_ON(2)) EnableLogging(&bridge);
+ }
// Populate a passmanager with the list of passes that implement the bridge.
pipeline_builder(bridge);
@@ -72,6 +75,8 @@
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
LogicalResult result = bridge.run(module);
(void)result;
+ if (enable_logging || VLOG_IS_ON(1))
+ tensorflow::DumpMlirOpToFile("tpu_bridge_after", module);
return diag_handler.ConsumeStatus();
}
} // namespace
@@ -163,7 +168,10 @@
bool enable_logging,
bool enable_inliner) {
PassManager bridge(module.getContext());
- if (enable_logging) EnableLogging(&bridge);
+ if (enable_logging || VLOG_IS_ON(1)) {
+ tensorflow::DumpMlirOpToFile("standard_pipeline_before", module);
+ if (VLOG_IS_ON(2)) EnableLogging(&bridge);
+ }
StandardPipelineOptions pipeline_options;
pipeline_options.enable_inliner.setValue(enable_inliner);
@@ -171,6 +179,8 @@
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
LogicalResult result = bridge.run(module);
(void)result;
+ if (enable_logging || VLOG_IS_ON(1))
+ tensorflow::DumpMlirOpToFile("standard_pipeline_after", module);
return diag_handler.ConsumeStatus();
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc
index 7701d96..4e0f3eb 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+#include "tensorflow/core/framework/rng_alg.h"
namespace mlir {
namespace TF {
@@ -68,11 +69,147 @@
.front();
}
+// Decompose tf.RngReadAndSkip.
+//
+// For Philox, the resource variable holds a tensor<3xi64> with the state:
+// [counter_lo, counter_hi, key]
+//
+// RngReadAndSkip increments the 128 bit counter value by 256 * delta and
+// returns the original state value.
+//
+// For Threefry, the resource variable holds a tensor<2xi64> with the state:
+// [counter, key]
+//
+// RngReadAndSkip increments the 64 bit counter value by 256 * delta and
+// returns a tensor<3xi64> value [counter, key, 0].
+class DecomposeRngReadAndSkipOp : public RewritePattern {
+ public:
+ explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
+ : RewritePattern(RngReadAndSkipOp::getOperationName(),
+ {
+ AddV2Op::getOperationName(),
+ AssignVariableOp::getOperationName(),
+ CastOp::getOperationName(),
+ ConstOp::getOperationName(),
+ LessOp::getOperationName(),
+ MulOp::getOperationName(),
+ PadOp::getOperationName(),
+ PackOp::getOperationName(),
+ ReadVariableOp::getOperationName(),
+ SelectV2Op::getOperationName(),
+ UnpackOp::getOperationName(),
+ },
+ 1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto rng_op = cast<RngReadAndSkipOp>(op);
+
+ DenseIntElementsAttr alg_constant;
+ if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
+ return rewriter.notifyMatchFailure(
+ op, "unable to determine algorithm statically");
+ }
+
+ if (alg_constant.getNumElements() != 1) {
+ return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
+ }
+
+ uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
+ tensorflow::Algorithm alg;
+ if (tensorflow::RNG_ALG_PHILOX == alg_value) {
+ alg = tensorflow::RNG_ALG_PHILOX;
+ } else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
+ alg = tensorflow::RNG_ALG_THREEFRY;
+ } else {
+ return rewriter.notifyMatchFailure(op, "unsupported alg");
+ }
+
+ Type state_element_type = rewriter.getI64Type();
+ RankedTensorType op_type = RankedTensorType::get(
+ {tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
+ state_element_type);
+ if (op_type != rng_op.getType()) {
+ return rewriter.notifyMatchFailure(op, "unexpected op type");
+ }
+
+ if (!HasResourceSubtype(rng_op.resource())) {
+ return rewriter.notifyMatchFailure(op, "missing resource subtype");
+ }
+
+ int counter_size = tensorflow::GetCounterSize(alg);
+ int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
+ RankedTensorType res_type =
+ RankedTensorType::get({state_size}, state_element_type);
+ if (res_type != GetResourceSubtype(rng_op.resource())) {
+ return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
+ }
+
+ Location loc = op->getLoc();
+
+ // Read the state value from the resource.
+ Value state =
+ rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
+
+ // Extract the key and counter from the state.
+ RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
+ auto unpacked = rewriter.create<UnpackOp>(
+ loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
+ Value key = unpacked.getResult(counter_size);
+
+ SmallVector<Value, 4> counter;
+ for (int i = 0; i < counter_size; ++i) {
+ counter.push_back(unpacked.getResult(i));
+ }
+
+ // Set the increment to 256 * delta.
+ Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
+ RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
+ Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
+ Value increment =
+ rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
+
+ // Increment the counter.
+ SmallVector<Value, 4> pack_args;
+ RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
+ Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
+ Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
+ for (int i = 0; i < counter_size; ++i) {
+ Value word = counter[i];
+ Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
+ Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
+ Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
+ pack_args.push_back(new_word);
+
+ Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
+ increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
+ }
+
+ // Save the new state value to the resource.
+ pack_args.push_back(key);
+ Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
+ rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
+
+ // Pad the original state as necessary to fill the output shape.
+ int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
+ Type i64 = rewriter.getI64Type();
+ RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
+ std::vector<int64_t> paddings_values = {0, pad};
+ Value paddings = rewriter.create<ConstOp>(
+ loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
+ Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
+
+ rewriter.replaceOp(op, output);
+ return success();
+ }
+};
+
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
} // namespace
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
+ patterns->insert<DecomposeRngReadAndSkipOp>(context);
populateWithGenerated(context, *patterns);
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
index cba31b1..8d777c2 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
@@ -44,7 +44,14 @@
template <typename Op>
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
- LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
+ LogicalResult RewriteOp(
+ Operation* op, PatternRewriter& rewriter,
+ const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
+ SmallVectorImpl<int64_t>&)>&
+ get_broadcasted_shape) const;
+
+ LogicalResult RewriteBatchMatMulV2Op(Operation* op,
+ PatternRewriter& rewriter) const;
};
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
@@ -55,26 +62,78 @@
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
- return RewriteOp(op, rewriter);
+ return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
// incompatible_shape_error is `true` (what is also checked by the verifier).
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
+ if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
return failure();
}
+LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
+ Operation* op, PatternRewriter& rewriter) const {
+ auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
+ if (!matmul_op) return failure();
+
+ // Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
+ // shape of op's first/left-hand-side operand and `shape_y` is the shape of
+ // op's second/right-hand-side operand.
+ const auto get_broadcasted_shape =
+ [&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
+ SmallVectorImpl<int64_t>& result_shape) {
+ if (shape_x.size() < 2 || shape_y.size() < 2) {
+ return false;
+ }
+
+ // Checks outer dimensions (i.e., the dimensions higher than 2D) are
+ // broadcastable. If true, then get the broadcasted shape for outer
+ // dimension.
+ if (!OpTrait::util::getBroadcastedShape(
+ shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
+ return false;
+ }
+
+ const int x_row =
+ matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
+ const int x_col =
+ !matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
+
+ const int y_row =
+ matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
+ const int y_col =
+ !matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
+
+ // Checks that matrix multiply can perform a valid contraction.
+ if (x_col != y_row) {
+ result_shape.clear();
+ return false;
+ }
+
+ result_shape.push_back(x_row);
+ result_shape.push_back(y_col);
+ return true;
+ };
+
+ return RewriteOp(op, rewriter, get_broadcasted_shape);
+}
+
template <typename Op>
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
Operation* op, PatternRewriter& rewriter) const {
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
- if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
+ if (eq_op && eq_op.incompatible_shape_error())
+ return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
return failure();
}
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
- Operation* op, PatternRewriter& rewriter) const {
+ Operation* op, PatternRewriter& rewriter,
+ const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
+ SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
+ const {
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
return failure();
@@ -102,12 +161,16 @@
.dyn_cast_or_null<RankedTensorType>();
if (!argument_type || !argument_type.hasStaticShape()) continue;
+ // Get the unbroadcasted shapes in the operand order.
+ std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
+ operand_shapes[i] = broadcast_arg_type.getShape();
+ operand_shapes[1 - i] = argument_type.getShape();
+
// Check that the input of the broadcast and the other operand is broadcast
// compatible.
llvm::SmallVector<int64_t, 4> broadcasted_shape;
- if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(),
- argument_type.getShape(),
- broadcasted_shape))
+ if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
+ broadcasted_shape))
continue;
// Check that an implicit broadcast between the operand of the broadcast and
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
index 400e942..12da234 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc
@@ -95,8 +95,11 @@
// Transform a functional IfOp to a region based IfRegionOp.
LogicalResult ConvertIfOp(IfOp if_op) {
Value cond = ConvertConditionToBoolean(if_op, if_op.cond());
- auto if_region = OpBuilder(if_op).create<TF::IfRegionOp>(
- if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless());
+ OpBuilder builder(if_op);
+ auto if_region = builder.create<TF::IfRegionOp>(
+ if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless(),
+ builder.getStringAttr(if_op.then_function().getName()),
+ builder.getStringAttr(if_op.else_function().getName()));
CopyDeviceAndUnderscoredAttributes(if_op, if_region);
CreateCall(if_op, if_op.then_function(),
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
index 9ffdb8a..371908c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
@@ -24,7 +24,6 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
@@ -382,12 +381,12 @@
}
// A struct to hold axes and sizes for a set of dimensions.
-struct DimensionSetVector {
- llvm::ArrayRef<int64_t> AxesArray() const { return axes.getArrayRef(); }
- llvm::ArrayRef<int64_t> SizesArray() const { return sizes.getArrayRef(); }
+struct DimensionVector {
+ llvm::ArrayRef<int64_t> AxesArray() const { return axes; }
+ llvm::ArrayRef<int64_t> SizesArray() const { return sizes; }
- llvm::SmallSetVector<int64_t, 4> axes;
- llvm::SmallSetVector<int64_t, 4> sizes;
+ llvm::SmallVector<int64_t, 4> axes;
+ llvm::SmallVector<int64_t, 4> sizes;
};
// A struct to hold information about dimensions of dot_general operands.
@@ -397,34 +396,32 @@
DenseIntElementsAttr contracting_dimensions) {
const int rank = type.getRank();
for (const int dim : batch_dimensions.getValues<int64_t>()) {
- batch_dimensions_.axes.insert(dim);
- batch_dimensions_.sizes.insert(type.getDimSize(dim));
+ batch_dimensions_.axes.push_back(dim);
+ batch_dimensions_.sizes.push_back(type.getDimSize(dim));
}
for (const int dim : contracting_dimensions.getValues<int64_t>()) {
- contracting_dimensions_.axes.insert(dim);
- contracting_dimensions_.sizes.insert(type.getDimSize(dim));
+ contracting_dimensions_.axes.push_back(dim);
+ contracting_dimensions_.sizes.push_back(type.getDimSize(dim));
}
for (int dim = 0; dim < rank; ++dim) {
- if (contracting_dimensions_.axes.count(dim) > 0 ||
- batch_dimensions_.axes.count(dim) > 0) {
+ if (llvm::count(contracting_dimensions_.axes, dim) > 0 ||
+ llvm::count(batch_dimensions_.axes, dim) > 0) {
continue;
}
- out_dimensions_.axes.insert(dim);
- out_dimensions_.sizes.insert(type.getDimSize(dim));
+ out_dimensions_.axes.push_back(dim);
+ out_dimensions_.sizes.push_back(type.getDimSize(dim));
}
}
- const DimensionSetVector &batch_dimensions() const {
- return batch_dimensions_;
- }
- const DimensionSetVector &contracting_dimensions() const {
+ const DimensionVector &batch_dimensions() const { return batch_dimensions_; }
+ const DimensionVector &contracting_dimensions() const {
return contracting_dimensions_;
}
// Out dimensions are any dimensions that are neither batch nor contracting
// dimensions, hence will be propagated to output shape.
- const DimensionSetVector &out_dimensions() const { return out_dimensions_; }
+ const DimensionVector &out_dimensions() const { return out_dimensions_; }
// Returns the total dimension size after flattening all contracting
// dimensions.
@@ -442,11 +439,11 @@
}
private:
- DimensionSetVector batch_dimensions_;
- DimensionSetVector contracting_dimensions_;
+ DimensionVector batch_dimensions_;
+ DimensionVector contracting_dimensions_;
// Out dimensions are any dimensions that are neither batch nor contracting
// dimensions, hence will be propagated to output shape.
- DimensionSetVector out_dimensions_;
+ DimensionVector out_dimensions_;
};
Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs,
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc
index 4a2b6b8..b9a61a1 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc
@@ -217,7 +217,8 @@
auto new_if_op = builder.create<TF::IfRegionOp>(
destination.getLoc(), merged_return_types, destination.cond(),
- destination.is_stateless() && source.is_stateless());
+ destination.is_stateless() && source.is_stateless(),
+ destination._then_func_nameAttr(), destination._else_func_nameAttr());
new_if_op.then_branch().push_back(new Block);
new_if_op.else_branch().push_back(new Block);
// Replace internal usages of merged if ops.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
index 90ba1e4..a84ceb2 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
@@ -46,6 +46,9 @@
namespace {
+constexpr char kElseFuncNameAttr[] = "_else_func_name";
+constexpr char kThenFuncNameAttr[] = "_then_func_name";
+
struct RegionControlFlowToFunctional
: public TF::RegionControlFlowToFunctionalPassBase<
RegionControlFlowToFunctional> {
@@ -307,11 +310,25 @@
// Create 2 new functions with the input signature matching this order,
// and outline the `then` and `else` regions by moving the bodies of these
// regions into these functions. Replace tf.yield with a regular return.
- then_name = GetName(if_region, "_then");
+ if (if_region->hasAttrOfType<StringAttr>(kThenFuncNameAttr) &&
+ !if_region._then_func_nameAttr().getValue().empty()) {
+ then_name =
+ mapper.GetUniqueName(if_region._then_func_nameAttr().getValue())
+ .str();
+ } else {
+ then_name = GetName(if_region, "_then");
+ }
ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values,
worklist, /*extern_values_passthrough=*/false);
- else_name = GetName(if_region, "_else");
+ if (if_region->hasAttrOfType<StringAttr>(kElseFuncNameAttr) &&
+ !if_region._else_func_nameAttr().getValue().empty()) {
+ else_name =
+ mapper.GetUniqueName(if_region._else_func_nameAttr().getValue())
+ .str();
+ } else {
+ else_name = GetName(if_region, "_else");
+ }
ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values,
worklist, /*extern_values_passthrough=*/false);
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
index 96b821a..56e799a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
@@ -122,24 +122,28 @@
/*device_ordinal=*/builder.getI64IntegerAttr(0));
}
-// Creates a IfRegionOp with `predicate` and then/else region with yield op and
-// an empty block.
-TF::IfRegionOp CloneEmptyIfWithPredicate(Value predicate, bool is_stateless,
- Location loc, OpBuilder& builder) {
+// Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
+// with yield op and an empty block.
+TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,
+ OpBuilder& builder) {
auto host_side_if = builder.create<TF::IfRegionOp>(
- loc, llvm::SmallVector<Type, 4>{}, predicate, is_stateless);
+ if_region.getLoc(), llvm::SmallVector<Type, 4>{}, if_region.cond(),
+ if_region.is_stateless(), if_region._then_func_nameAttr(),
+ if_region._else_func_nameAttr());
// Create empty then branch region.
auto& then_branch = host_side_if.then_branch();
then_branch.push_back(new Block);
builder.setInsertionPointToEnd(&then_branch.front());
- builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
+ builder.create<TF::YieldOp>(if_region.getLoc(),
+ /*operands=*/ArrayRef<Value>{});
// Create empty else branch region.
auto& else_branch = host_side_if.else_branch();
else_branch.push_back(new Block);
builder.setInsertionPointToEnd(&else_branch.front());
- builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
+ builder.create<TF::YieldOp>(if_region.getLoc(),
+ /*operands=*/ArrayRef<Value>{});
return host_side_if;
}
// Creates a WhileRegionOp cond and body regions with yield op and
@@ -357,8 +361,7 @@
if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(op)) {
if (!HasOutsideCompilationNested(op)) return;
OpBuilder builder(if_op);
- auto host_if = CloneEmptyIfWithPredicate(
- if_op.cond(), if_op.is_stateless(), if_op.getLoc(), builder);
+ auto host_if = CloneEmptyIfWithPredicate(if_op, builder);
MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(),
host_if.then_branch().front().getTerminator(),
compilation_key, device_ordinal, communication_key_index);
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index a8ece6a..2d0654b 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -146,7 +146,8 @@
// Load dialects involved in the conversion
mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
- registry.loadAll(&context);
+ context.appendDialectRegistry(registry);
+ context.loadAllAvailableDialects();
}
// This class is used to generate new MLIR function name strings that are both
@@ -1479,6 +1480,13 @@
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes) {
+ // Store the arg/return attributes as a list rather than uniqueuing during
+ // construction.
+ llvm::SmallVector<mlir::NamedAttrList, 4> arg_attrs;
+ arg_attrs.resize(func.getNumArguments());
+ llvm::SmallVector<mlir::NamedAttrList, 4> ret_attrs;
+ ret_attrs.resize(func.getNumResults());
+
auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) {
for (const auto& node_attr : node->attrs()) {
const auto& key = node_attr.first;
@@ -1494,9 +1502,10 @@
ConvertAttributeValue(node_attr.second));
std::string dialect_attribute = "tf." + key;
if (is_arg) {
- func.setArgAttr(index, dialect_attribute, converted_attr);
+ arg_attrs[index].set(dialect_attribute, converted_attr);
} else {
func.setResultAttr(index, dialect_attribute, converted_attr);
+ ret_attrs[index].set(dialect_attribute, converted_attr);
}
}
return Status::OK();
@@ -1527,9 +1536,8 @@
control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
if (!arg_node.node->requested_device().empty())
- func.setArgAttr(
- i, "tf.device",
- builder_.getStringAttr(arg_node.node->requested_device()));
+ arg_attrs[i].set("tf.device", builder_.getStringAttr(
+ arg_node.node->requested_device()));
if (arg_node.node->IsArg()) {
TF_RETURN_IF_ERROR(
@@ -1546,9 +1554,8 @@
auto* inst = node_values_[ret.node->id()];
if (ret.node->IsRetval()) {
if (!ret.node->requested_device().empty())
- func.setResultAttr(
- ret_and_idx.index(), "tf.device",
- builder_.getStringAttr(ret.node->requested_device()));
+ ret_attrs[ret_and_idx.index()].set(
+ "tf.device", builder_.getStringAttr(ret.node->requested_device()));
TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(),
/*is_arg=*/false));
// Lookup the instruction inside the island
@@ -1585,6 +1592,16 @@
builder_.setInsertionPointToEnd(bb);
builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
graph_op.getResults());
+
+ func.setAllArgAttrs(
+ llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) {
+ return list.getDictionary(context_);
+ })));
+ func.setAllResultAttrs(
+ llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) {
+ return list.getDictionary(context_);
+ })));
+
return Status::OK();
}
@@ -3486,7 +3503,8 @@
// Moves the functions in `sub_module` to `module_` and skips the duplicate
// functions.
- void MoveConvertedFunctionsToModule(mlir::ModuleOp sub_module);
+ Status MoveConvertedFunctionsToModule(absl::string_view name,
+ mlir::ModuleOp sub_module);
GraphImportConfig::InputArrays ParseInputArrays(
llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
@@ -3526,14 +3544,34 @@
return results;
}
-void SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
- mlir::ModuleOp sub_module) {
- // Iterate through all functions and insert the ones that do not already exist
- // in `module_`.
+Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
+ absl::string_view name, mlir::ModuleOp sub_module) {
+ mlir::Builder builder(sub_module.getContext());
+ mlir::SymbolTable sub_module_symbol_table(sub_module);
+
+ // Prefix private functions with the unique signature name, so that it cannot
+ // collide with private functions used in the other signatures.
for (auto func : sub_module.getOps<mlir::FuncOp>()) {
- if (symbol_table_.lookup(func.getName())) continue;
+ if (mlir::tf_saved_model::IsExported(func)) continue;
+
+ std::string new_sym_name = absl::StrCat(name, "/", func.sym_name().str());
+ if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
+ func, new_sym_name, sub_module)))
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "SavedModelSignatureDefImporterLite: failed to assign a unique "
+ "name to the private function used in a signature: ",
+ func.sym_name().str()));
+
+ mlir::SymbolTable::setSymbolName(func, new_sym_name);
+ }
+
+ // Copy all functions used by this signature to the final MLIR module.
+ for (auto func : sub_module.getOps<mlir::FuncOp>()) {
+ DCHECK(symbol_table_.lookup(func.sym_name()) == nullptr);
symbol_table_.insert(func.clone());
}
+
+ return Status::OK();
}
Status SavedModelSignatureDefImporterLite::ConvertInitializer(
@@ -3574,9 +3612,7 @@
"__tf_saved_model_session_initializer_", target_node_name)}));
// Move the converted functions to top level MLIR module.
- MoveConvertedFunctionsToModule(*sub_module);
-
- return Status::OK();
+ return MoveConvertedFunctionsToModule(target_node_name, *sub_module);
}
StatusOr<mlir::OwningModuleRef>
@@ -3646,9 +3682,7 @@
}
// Move the converted functions to top level MLIR module.
- MoveConvertedFunctionsToModule(*sub_module);
-
- return Status::OK();
+ return MoveConvertedFunctionsToModule(sig_def_key, *sub_module);
}
GraphImportConfig::InputArrays
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h
index bd81cae..e028069 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h
@@ -47,6 +47,23 @@
});
}
+// Forward declare these passthrough ops.
+// TODO(jpienaar): Remove these and use trait instead.
+class IdentityOp;
+class IdentityNOp;
+
+// Returns if a value corresponds to a constant, returns the matched constant
+// as an attribute.
+template <typename AttrT>
+bool GetValueAsConstant(Value val, AttrT &attr) {
+ while (auto result = val.dyn_cast<OpResult>()) {
+ Operation *op = result.getOwner();
+ if (!isa<IdentityOp>(op) && !isa<IdentityNOp>(op)) break;
+ val = op->getOperand(result.getResultNumber());
+ }
+ return matchPattern(val, m_Constant(&attr));
+}
+
} // namespace TF
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index cc17993..59c647e 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -345,7 +345,9 @@
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
custom_legalization_passes);
- if (VLOG_IS_ON(1)) {
+ if (VLOG_IS_ON(1))
+ tensorflow::DumpMlirOpToFile("legalize_hlo_before", module_op);
+ if (VLOG_IS_ON(2)) {
// Print the whole module after each pass which requires disabling
// multi-threading as well.
module_op.getContext()->disableMultithreading();
@@ -364,7 +366,7 @@
}
if (VLOG_IS_ON(1))
- tensorflow::DumpMlirOpToFile("mlir_compile_legalize_hlo", module_op);
+ tensorflow::DumpMlirOpToFile("legalize_hlo_after", module_op);
return Status::OK();
}
@@ -406,8 +408,8 @@
// Use arg_shapes to improve the mlir type information of `main` in module_op.
TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
- if (VLOG_IS_ON(1))
- tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
+ if (VLOG_IS_ON(2))
+ tensorflow::DumpMlirOpToFile("compile_mlir_shape_refiner", module_op);
if (!*shape_representation_fn)
*shape_representation_fn = IdentityShapeRepresentationFn();
@@ -422,8 +424,8 @@
llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
- if (VLOG_IS_ON(1))
- tensorflow::DumpMlirOpToFile("mlir_compile_before_build_hlo_tf", module_op);
+ if (VLOG_IS_ON(2))
+ tensorflow::DumpMlirOpToFile("build_hlo_tf_before", module_op);
XlaHelpers::ShapeRepresentationFn shape_representation_fn;
TF_RETURN_IF_ERROR(
@@ -434,8 +436,8 @@
returns, device_type,
custom_legalization_passes));
- if (VLOG_IS_ON(1))
- tensorflow::DumpMlirOpToFile("mlir_compile_after_build_hlo_tf", module_op);
+ if (VLOG_IS_ON(2))
+ tensorflow::DumpMlirOpToFile("build_hlo_tf_after", module_op);
return Status::OK();
}
@@ -491,8 +493,9 @@
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
- mlir::MLIRContext mlir_context;
- RegisterDialects(mlir_context.getDialectRegistry());
+ mlir::DialectRegistry mlir_registry;
+ RegisterDialects(mlir_registry);
+ mlir::MLIRContext mlir_context(mlir_registry);
mlir::OwningModuleRef mlir_module;
TF_RETURN_IF_ERROR(
@@ -599,8 +602,12 @@
mlir::TF::StandardPipelineOptions tf_options;
mlir::TF::CreateTFStandardPipeline(pm, tf_options);
+ if (VLOG_IS_ON(1))
+ tensorflow::DumpMlirOpToFile("compile_graph_setup_before", module_op);
mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
+ if (VLOG_IS_ON(1))
+ tensorflow::DumpMlirOpToFile("compile_graph_setup_after", module_op);
return Status::OK();
}
@@ -646,7 +653,9 @@
const GraphDebugInfo& debug_info,
mlir::MLIRContext* context,
mlir::OwningModuleRef* module) {
- RegisterDialects(context->getDialectRegistry());
+ mlir::DialectRegistry registry;
+ RegisterDialects(registry);
+ context->appendDialectRegistry(registry);
GraphImportConfig config;
config.graph_as_function = true;
config.control_outputs = control_rets;
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc
index ac6bc63..336f3a4 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc
@@ -371,7 +371,9 @@
}
auto str_attr = attr.cast<mlir::StringAttr>();
- RegisterMlirInputDialects(context->getDialectRegistry());
+ mlir::DialectRegistry registry;
+ RegisterMlirInputDialects(registry);
+ context->appendDialectRegistry(registry);
mlir::OwningModuleRef module_ref;
auto status =
DeserializeMlirModule(str_attr.getValue().str(), context, &module_ref);
diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
index a5edc9f..4006c31 100644
--- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
@@ -92,7 +92,7 @@
StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) {
mlir_ctx->allowUnregisteredDialects(/*allow=*/true);
// Load dialects involved in the conversion
- mlir::DialectRegistry& registry = mlir_ctx->getDialectRegistry();
+ mlir::DialectRegistry registry;
// clang-format off
registry.insert<mlir::StandardOpsDialect,
mlir::scf::SCFDialect,
@@ -102,7 +102,8 @@
mlir::tf_executor::TensorFlowExecutorDialect,
mlir::TFR::TFRDialect>();
// clang-format on
- registry.loadAll(mlir_ctx);
+ mlir_ctx->appendDialectRegistry(registry);
+ mlir_ctx->loadAllAvailableDialects();
// Load the TFR functions in a mlir::ModuleOp
auto memory_buffer = llvm::MemoryBuffer::getMemBuffer(
diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc
index 59ef8c1..6c6bcf0 100644
--- a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc
+++ b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc
@@ -33,12 +33,12 @@
PYBIND11_MODULE(tfr_wrapper, m) {
m.def("verify", [](std::string input) {
- mlir::MLIRContext ctx;
- auto& registry = ctx.getDialectRegistry();
+ mlir::DialectRegistry registry;
registry.insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
mlir::TFR::TFRDialect>();
- ctx.getDialectRegistry().loadAll(&ctx);
+ mlir::MLIRContext ctx(registry);
+ ctx.loadAllAvailableDialects();
llvm::SourceMgr source_mgr = llvm::SourceMgr();
source_mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
index 01ded63..87871d8 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
@@ -32,6 +32,7 @@
name = "kernel_creator",
srcs = ["kernel_creator.cc"],
hdrs = ["kernel_creator.h"],
+ compatible_with = get_compatible_with_cloud(),
copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]),
deps = [
"//tensorflow/compiler/mlir/hlo",
@@ -70,8 +71,11 @@
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:NVVMDialect",
+ "@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:ROCDLDialect",
+ "@llvm-project//mlir:ROCDLToLLVMIRTranslation",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToGPUPass",
"@llvm-project//mlir:SCFToStandard",
@@ -82,6 +86,7 @@
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TargetLLVMIR",
"@llvm-project//mlir:Transforms",
],
)
@@ -113,6 +118,7 @@
"@llvm-project//mlir:ExecutionEngineUtils",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TargetLLVMIR",
+ "@llvm-project//mlir:LLVMIRModuleTranslation",
] + if_llvm_system_z_available([
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
]) + if_llvm_aarch64_available([
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc
index e01ac78..9a12fa5 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc
@@ -118,6 +118,22 @@
}
}
+//===----------------------------------------------------------------------===//
+// MinimumBroadcastShapesOp
+//===----------------------------------------------------------------------===//
+template <>
+LogicalResult Verify<MinimumBroadcastShapesOp>(MinimumBroadcastShapesOp op) {
+ // Check that the number of operands matches the number of outputs.
+ unsigned result_shapes_count = op.results().size();
+ unsigned operand_shapes_count = op.shapes().size();
+ if (operand_shapes_count != result_shapes_count) {
+ return op.emitOpError()
+ << "number of operand shapes " << operand_shapes_count
+ << " does not match number of result shapes " << result_shapes_count;
+ }
+ return success();
+}
+
} // namespace tf_framework
} // namespace kernel_gen
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td
index 67a4c75..2b0bd68 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td
@@ -168,4 +168,47 @@
let assemblyFormat = "$ctx `,` $error_code `,` $msg attr-dict";
}
+//===----------------------------------------------------------------------===//
+// MinimumBroadcastShapesOp
+//===----------------------------------------------------------------------===//
+def TFFramework_MinimumBroadcastShapesOp :
+ TFFramework_Op<"minimum_broadcast_shapes", [NoSideEffect]> {
+ let summary = "Minimizes the rank of two or more shapes to be broadcasted";
+ let description = [{
+ Given two or more 1D tensors representing shapes, returns one 1D tensor for
+ each operand, where operand `i` corresponds to output `i`.
+
+ The returned tensors have the property that they specify a shape which is a
+ reshape of the corresponding input shape, and the broadcasted output shape
+ (using shape::BroadcastOp) of the returned shapes is a reshape of the
+ broadcasted output shape of the input shapes. Among all possibilities with
+ this property, the one is chosen which minimizes the rank of each returned
+ shape.
+
+ The general idea of this op is that it can be used for ops which have a
+ broadcasting semantic to operate on shapes with a possibly smaller rank
+ while preserving equivalence of the computed values. After computing the
+ result of the op using reshaped operands, the result can be reshaped to the
+ result that would have been originally computed.
+
+ Here is an example with two input shapes:
+
+ ```mlir
+ tf_framework.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1],
+ [1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3]
+ ```
+
+ The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the
+ broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are
+ reshapes of each other, and also each output is a reshape of the
+ corresponding input.
+ }];
+
+ let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes);
+ let results = (outs Variadic<1DTensorOf<[Index]>>:$results);
+
+ let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)";
+
+}
+
#endif // TF_FRAMEWORK_OPS
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
index 4784203..a96cca5 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
@@ -32,6 +32,7 @@
#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
@@ -45,6 +46,9 @@
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
+#include "mlir/Target/LLVMIR.h" // from @llvm-project
+#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project
+#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Transforms/Bufferize.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
@@ -66,6 +70,7 @@
namespace kernel_gen {
namespace {
+using mlir::Value;
using mlir::scf::ParallelOp;
using tensorflow::Status;
using xla::InternalError;
@@ -73,6 +78,43 @@
constexpr llvm::StringRef kGpuBinaryAttrName = "gpu.binary";
+/// Check if the size of the allocation is less than the given size. The
+/// transformation is only applied to small buffers since large buffers could
+/// exceed the stack space.
+bool IsSmallAlloc(Value alloc) {
+ constexpr unsigned kMaximumSizeInBytes = 64;
+ constexpr unsigned kBitwidthOfIndexType = 64;
+ constexpr unsigned kMaxRankOfAllocatedMemRef = 1;
+
+ auto type = alloc.getType().dyn_cast<mlir::ShapedType>();
+ if (!type || !alloc.getDefiningOp<mlir::AllocOp>()) return false;
+ if (!type.hasStaticShape()) {
+ // Check if the dynamic shape dimension of the alloc is produced by RankOp
+ // or SelectOp(_, RankOp, RankOp).
+ // If this is the case, it is likely to be small. Furthermore, the dimension
+ // is limited to the maximum rank of the allocated memref to avoid large
+ // values by multiplying several small values.
+ if (type.getRank() <= kMaxRankOfAllocatedMemRef) {
+ for (Value alloc_arg : alloc.getDefiningOp()->getOperands()) {
+ if (auto select = alloc_arg.getDefiningOp<mlir::SelectOp>()) {
+ if (!select.true_value().getDefiningOp<mlir::RankOp>() ||
+ !select.false_value().getDefiningOp<mlir::RankOp>())
+ return false;
+ } else if (!alloc_arg.getDefiningOp<mlir::RankOp>()) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+ // For index types, use the provided size, as the type does not know.
+ unsigned int bitwidth = type.getElementType().isIndex()
+ ? kBitwidthOfIndexType
+ : type.getElementTypeBitWidth();
+ return type.getNumElements() * bitwidth <= kMaximumSizeInBytes * 8;
+}
+
// TODO(herhut): Remove this once leftover tensor_to_memref are handled in core.
struct RemoveUnusedTensorToMemrefOperations
: public mlir::PassWrapper<RemoveUnusedTensorToMemrefOperations,
@@ -152,7 +194,8 @@
};
Status LowerTFtoLoops(mlir::ModuleOp module, llvm::ArrayRef<int64_t> tile_sizes,
- llvm::ArrayRef<int64_t> unroll_factors) {
+ llvm::ArrayRef<int64_t> unroll_factors,
+ bool cpu_codegen) {
mlir::PassManager pm(module.getContext());
applyTensorflowAndCLOptions(pm);
@@ -197,10 +240,14 @@
// recognized as such.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
- // Collapse and tile parallel loops.
- pm.addNestedPass<mlir::FuncOp>(std::make_unique<CollapseParallelLoopsTo1D>());
- pm.addNestedPass<mlir::FuncOp>(
- std::make_unique<TileLoops>(tile_sizes, unroll_factors));
+ if (!cpu_codegen) {
+ // Collapse and tile parallel loops. Collapsing shouldn't provide benefits
+ // to CPU and tiling is handled by vectorization.
+ pm.addNestedPass<mlir::FuncOp>(
+ std::make_unique<CollapseParallelLoopsTo1D>());
+ pm.addNestedPass<mlir::FuncOp>(
+ std::make_unique<TileLoops>(tile_sizes, unroll_factors));
+ }
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
if (failed(pm.run(module))) {
@@ -220,10 +267,6 @@
mlir::kernel_gen::transforms::CreateMapParallelLoopsPass());
}
- // Now lower the shape computations, bufferize all remaining ops and insert
- // deallocs.
- pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferHoistingPass());
- pm.addNestedPass<mlir::FuncOp>(mlir::createCopyRemovalPass());
// Expand memref_reshape to its ranked form so that we can propagate
// scalars and avoid allocation.
pm.addNestedPass<mlir::FuncOp>(mlir::createStdExpandOpsPass());
@@ -245,13 +288,20 @@
// Longer term, this should be handled by proper device placement.
pm.addPass(mlir::kernel_gen::tf_framework::
CreateEmbedTFFrameworkFunctionAndAllocPass());
+ // Now lower the shape computations, bufferize all remaining ops and insert
+ // deallocs.
pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass());
- pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(64));
+ // TODO(herhut): Enable once no-longer broken.
+ // This depends on https://bugs.llvm.org/show_bug.cgi?id=49142 being fixed.
+ // pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferHoistingPass());
+ pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(
+ [](Value alloc) { return IsSmallAlloc(alloc); }));
// TODO(herhut): Depends on https://bugs.llvm.org/show_bug.cgi?id=48385.
// We also cannot properly free temporaries until
// https://llvm.discourse.group/t/remove-tight-coupling-of-the-bufferdeallocation-pass-to-std-and-linalg-operations/2162
// is resolved.
// pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
+ // pm.addNestedPass<mlir::FuncOp>(mlir::createCopyRemovalPass());
// Apply the mapping and go to GPU. We cannot do this earlier due to missing
// interfaces on the GPU dialect.
// TODO(b/174830459): Move up once implemented.
@@ -346,6 +396,7 @@
bool enable_ftz) {
mlir::PassManager pm(module.getContext());
applyTensorflowAndCLOptions(pm);
+ mlir::registerLLVMDialectTranslation(*module->getContext());
auto& kernel_pm = pm.nest<mlir::gpu::GPUModuleOp>();
// Remove debug information to ensure we do not create debug PTX.
@@ -381,12 +432,20 @@
llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
bool embed_memref_prints, bool generate_fatbin, bool print_ptx,
bool enable_ftz, bool cpu_codegen) {
- auto& registry = context.getDialectRegistry();
+ mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
registry.insert<mlir::chlo::HloClientDialect, mlir::mhlo::MhloDialect>();
+ registry.insert<mlir::NVVM::NVVMDialect, mlir::ROCDL::ROCDLDialect>();
+ registry.addDialectInterface<mlir::NVVM::NVVMDialect,
+ mlir::NVVMDialectLLVMIRTranslationInterface>();
+ registry.addDialectInterface<mlir::ROCDL::ROCDLDialect,
+ mlir::ROCDLDialectLLVMIRTranslationInterface>();
+ mlir::registerLLVMDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
- TF_RETURN_IF_ERROR(LowerTFtoLoops(module.get(), tile_sizes, unroll_factors));
+ TF_RETURN_IF_ERROR(
+ LowerTFtoLoops(module.get(), tile_sizes, unroll_factors, cpu_codegen));
TF_RETURN_IF_ERROR(
LowerLoopsToGPUorCPU(module.get(), embed_memref_prints, cpu_codegen));
if (!cpu_codegen) {
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir
index 1d3d5e4..d1e7082 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir
@@ -5,3 +5,12 @@
%buf = tf_framework.alloc(%ctx, %size) : memref<?x10x?xi8>
return
}
+
+// -----
+
+func @minimum_broadcast_shapes(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>) {
+ // expected-error @+1{{number of operand shapes 2 does not match number of result shapes 1}}
+ %0 = tf_framework.minimum_broadcast_shapes %lhs, %rhs :
+ tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+ return
+}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir
index 7ba69dc..6e67195 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir
@@ -46,3 +46,11 @@
tf_framework.null_context : !tf_framework.op_kernel_context
return
}
+
+// CHECK-LABEL: func @minimum_broadcast_shapes
+func @minimum_broadcast_shapes(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>)
+ -> (tensor<?xindex>, tensor<?xindex>) {
+ %0, %1 = tf_framework.minimum_broadcast_shapes %lhs, %rhs :
+ tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
+ return %0, %1 : tensor<?xindex>, tensor<?xindex>
+}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
index 61f4efe..10751f5 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
@@ -35,6 +35,7 @@
#include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/LLVMIR.h" // from @llvm-project
+#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
#include "tensorflow/compiler/xla/util.h"
@@ -72,6 +73,7 @@
xla::StatusOr<std::string> EmitToBinary(mlir::ModuleOp module) {
// Translate the module.
llvm::LLVMContext llvm_context;
+ mlir::registerLLVMDialectTranslation(*module->getContext());
std::unique_ptr<llvm::Module> llvm_module =
mlir::translateModuleToLLVMIR(module, llvm_context);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
index 21b4324..6b12053 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
@@ -1,4 +1,3 @@
-load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//third_party/mlir:tblgen.bzl", "gentbl")
load(
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
@@ -19,6 +18,7 @@
name = "tf_framework_legalize_to_llvm",
srcs = ["tf_framework_legalize_to_llvm.cc"],
hdrs = ["rewriters.h"],
+ compatible_with = get_compatible_with_cloud(),
deps = [
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
"@llvm-project//llvm:Support",
@@ -35,6 +35,7 @@
name = "bufferize",
srcs = ["bufferize.cc"],
hdrs = ["rewriters.h"],
+ compatible_with = get_compatible_with_cloud(),
deps = [
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@@ -49,6 +50,7 @@
name = "embed_tf_framework",
srcs = ["embed_tf_framework.cc"],
hdrs = ["rewriters.h"],
+ compatible_with = get_compatible_with_cloud(),
deps = [
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
"@llvm-project//mlir:IR",
@@ -87,6 +89,7 @@
"tf_kernel_to_llvm_pass.cc",
],
hdrs = ["passes.h"],
+ compatible_with = get_compatible_with_cloud(),
copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]),
deps = [
":bufferize",
@@ -118,11 +121,12 @@
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeToStandard",
"@llvm-project//mlir:ShapeTransforms",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",
- "@llvm-project//mlir:TargetNVVMIR",
- "@llvm-project//mlir:TargetROCDLIR",
+ "@llvm-project//mlir:TargetLLVMIR",
+ "@llvm-project//mlir:LLVMIRModuleTranslation",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:Transforms",
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
index 45ba5e0..22e2066 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
+#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
@@ -113,7 +114,8 @@
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<complex::ComplexDialect, lmhlo::LmhloDialect,
- StandardOpsDialect, tensor::TensorDialect>();
+ StandardOpsDialect, tensor::TensorDialect,
+ math::MathDialect>();
target.addIllegalDialect<mhlo::MhloDialect>();
CustomBufferizeTypeConverter converter;
@@ -161,10 +163,11 @@
void runOnOperation() override {
auto& context = getContext();
ConversionTarget target(context);
- target.addLegalDialect<
- complex::ComplexDialect, scf::SCFDialect, StandardOpsDialect,
- tensor::TensorDialect, tf_framework::TFFrameworkDialect, AffineDialect,
- shape::ShapeDialect, lmhlo::LmhloDialect, linalg::LinalgDialect>();
+ target.addLegalDialect<complex::ComplexDialect, scf::SCFDialect,
+ StandardOpsDialect, tensor::TensorDialect,
+ tf_framework::TFFrameworkDialect, AffineDialect,
+ shape::ShapeDialect, lmhlo::LmhloDialect,
+ linalg::LinalgDialect, math::MathDialect>();
target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
index cb81002..184f34f 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
@@ -15,8 +15,8 @@
#include "llvm/Transforms/Utils/Cloning.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
-#include "mlir/Target/NVVMIR.h" // from @llvm-project
-#include "mlir/Target/ROCDLIR.h" // from @llvm-project
+#include "mlir/Target/LLVMIR.h" // from @llvm-project
+#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
@@ -90,9 +90,9 @@
}
llvm::LLVMContext llvmContext;
+ auto llvmModule = mlir::translateModuleToLLVMIR(gpu_module, llvmContext);
#if TENSORFLOW_USE_ROCM
- auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext);
if (!llvmModule) {
return InternalError("Could not translate MLIR module to ROCDL IR");
}
@@ -143,7 +143,6 @@
return tensorflow::se::BundleGpuAsm(images, tensorflow::RocmRoot());
#elif GOOGLE_CUDA
- auto llvmModule = mlir::translateModuleToNVVMIR(gpu_module, llvmContext);
if (!llvmModule) {
return InternalError("Could not translate MLIR module to NVVM");
}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
index 7743b03..25088d5 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
@@ -17,6 +17,7 @@
// structured control flow and descriptors.
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project
+#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project
@@ -50,6 +51,7 @@
target.addIllegalDialect<shape::ShapeDialect>();
target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalDialect<math::MathDialect>();
target.addLegalDialect<tensor::TensorDialect>();
// Don't mark the primary Cstr/Assuming ops as illegal, so they can be
// lowered at a later time to assertions.
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc
index 4a764af..7eeaf43 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc
@@ -24,6 +24,7 @@
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
+#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/Transforms/Passes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
@@ -263,9 +264,9 @@
// Set target.
ConversionTarget target(*ctx);
target.addLegalDialect<LLVM::LLVMDialect>();
- target
- .addIllegalDialect<StandardOpsDialect, complex::ComplexDialect,
- gpu::GPUDialect, tf_framework::TFFrameworkDialect>();
+ target.addIllegalDialect<StandardOpsDialect, complex::ComplexDialect,
+ gpu::GPUDialect, tf_framework::TFFrameworkDialect,
+ math::MathDialect>();
target.addIllegalOp<LLVM::DialectCastOp>();
// Mark modules as legal.
target.addLegalOp<ModuleOp, ModuleTerminatorOp, gpu::GPUModuleOp>();
diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
index a8cea22..535f1dd 100644
--- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
+++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
@@ -81,6 +81,13 @@
return %0 : tensor<13x21x3xf32>
}
+// CHECK-LABEL: test_gather
+// CHECK: tosa.gather
+func @test_gather(%arg0: tensor<100x25xf32>, %arg1: tensor<1x20xi32>) -> tensor<20x25x3xf32> {
+ %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<100x25xf32>, tensor<1x20xi32>) -> tensor<20x25x3xf32>
+ return %0 : tensor<20x25x3xf32>
+}
+
// -----
// CHECK-LABEL: test_sub
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
index 4314591..30cfc59 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
@@ -49,6 +49,6 @@
def : Pat<(TFL_GatherOp $params,
$indices,
$axis),
- (Tosa_GatherOp $params,
- $indices,
+ (Tosa_GatherOp $indices,
+ $params,
$axis)>;
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 9ff44c0..90d74ef 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -134,7 +134,8 @@
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing) {
+ output_operand_aliasing,
+ const Literal* literal) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
@@ -142,6 +143,8 @@
shape, builder_));
TF_RET_CHECK(output_operand_aliasing.empty())
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
+ TF_RET_CHECK(literal == nullptr)
+ << "MLIR CustomCallOp does not support literal yet";
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index cc95b58..2935089 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -137,7 +137,8 @@
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing) override;
+ output_operand_aliasing,
+ const Literal* literal) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
index 5a07d93..0eeec56 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
@@ -9,7 +9,7 @@
// CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: [[LHSSHAPE:%.*]] = shape.shape_of [[LHS]] : tensor<1x4x2xf32>
// CHECK: [[RHSSHAPE:%.*]] = shape.shape_of [[RHS]] : tensor<3x2x4xf32>
-// CHECK: [[CM2:%.*]] = constant -2 : i32
+// CHECK: [[CM2:%.*]] = constant -2 : index
// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]])
// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]])
// CHECK: [[BCASTHEAD:%.*]] = shape.broadcast [[LHSHEAD]], [[RHSHEAD]]
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index a3165ba..f0e66c6 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -2774,7 +2774,7 @@
Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
Value const_neg2 =
- rewriter->create<ConstantOp>(loc, rewriter->getI32IntegerAttr(-2));
+ rewriter->create<ConstantOp>(loc, rewriter->getIndexAttr(-2));
auto lhs_splitted =
rewriter->create<shape::SplitAtOp>(loc, lhs_shape, const_neg2);
auto rhs_splitted =
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index ae0c18a..d34f40c 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -241,6 +241,7 @@
TypeID::get<TF::StatelessRandomNormalV2Op>(),
TypeID::get<TF::StatelessRandomUniformOp>(),
TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
+ TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
TypeID::get<TF::StatelessRandomUniformV2Op>(),
TypeID::get<TF::StatelessRandomUniformIntOp>(),
TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index b97dd79..0ad1e7f 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1277,6 +1277,7 @@
name = "stateful_random_ops_test",
size = "medium",
srcs = ["stateful_random_ops_test.py"],
+ enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 10,
tags = [
diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py
index 239b99d..f310b21 100644
--- a/tensorflow/compiler/tests/stateful_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateful_random_ops_test.py
@@ -31,6 +31,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import gen_stateful_random_ops
@@ -76,6 +77,7 @@
gen.uniform_full_int(shape=(3,))
@parameterized.parameters(ALGS)
+ @test_util.disable_mlir_bridge("TODO(b/180412889): Crashes with MLIR bridge.")
def testDefun(self, alg):
"""Test for defun."""
with ops.device(xla_device_name()):
@@ -248,6 +250,36 @@
shape=shape, dtype=dtype))
self.assertAllEqual(cpu, xla)
+ def testXLAEqualsCPUAroundCounterOverflow(self):
+ """Tests XLA and CPU kernels generate the same integers in overflow case.
+
+ Specifically this tests the case where the counter is incremented past
+ what can fit within 64 bits of the 128 bit Philox counter.
+ """
+ dtype = dtypes.uint64
+ seed = 2**64 - 10
+ shape = [315, 49]
+ if compat.forward_compatible(2020, 10, 25):
+ with ops.device("/device:CPU:0"):
+ cpu_gen = random.Generator.from_seed(
+ seed=seed, alg=random.RNG_ALG_PHILOX)
+ with ops.device(xla_device_name()):
+ xla_gen = random.Generator.from_seed(
+ seed=seed, alg=random.RNG_ALG_PHILOX)
+ # Repeat multiple times to make sure that the state after
+ # number-generation are the same between CPU and XLA.
+ for _ in range(5):
+ with ops.device("/device:CPU:0"):
+ # Test both number-generation and skip
+ cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
+ cpu_gen.skip(100)
+ with ops.device(xla_device_name()):
+ xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
+ xla_gen.skip(100)
+ self.assertAllEqual(cpu, xla)
+ self.assertAllEqual(cpu_gen.state, xla_gen.state)
+ self.assertAllEqual(cpu, xla)
+
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
# The random-number generator, if working correctly, should produce the
@@ -352,6 +384,8 @@
mean_atol=2e-3, median_atol=4e-3,
variance_rtol=1e-2 if dtype == dtypes.bfloat16 else 5e-3)
+ @test_util.disable_mlir_bridge(
+ "b/180412086: MLIR bridge gives wrong error messages.")
def testErrors(self):
"""Tests that proper errors are raised.
"""
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 3132959..5d55175 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -558,6 +558,7 @@
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework",
"@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:Support",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index cec7b9a..56eb5b3 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -275,8 +275,10 @@
srcs = ["if_while_utils.cc"],
hdrs = ["if_while_utils.h"],
deps = [
+ "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
index 22e5586..065feb0 100644
--- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
@@ -13,12 +13,14 @@
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -60,6 +62,17 @@
errors::InvalidArgument("Invalid/missing argument expression"));
if (ctx->expected_output_dtype(0) == DT_VARIANT) {
ctx->SetTensorListOutput(0, arg.handle());
+ } else if (arg.value_bound().has_value()) {
+ // The argument has a bound attached to it, call SetBound op on the
+ // argument.
+ xla::XlaBuilder* builder = ctx->builder();
+ auto input_op = arg.AsXlaOp(builder);
+ xla::Literal bound = HostTensorToLiteral(*arg.value_bound()).ValueOrDie();
+ ctx->SetOutput(
+ 0, xla::CustomCall(builder, "SetBound", {input_op},
+ builder->GetShape(input_op).ValueOrDie(), "",
+ false, {}, &bound));
+ return;
} else {
ctx->SetOutputExpression(0, arg);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
index 82d8eb8..8ee5197 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
@@ -16,6 +16,8 @@
#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
namespace tensorflow {
@@ -38,11 +40,28 @@
xla::StatusOr<absl::optional<Tensor>> maybe_constant =
expression.ResolveConstant(ctx->compiler()->client());
if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) {
- arg->kind = XlaCompiler::Argument::kConstant;
- arg->type = expression.dtype();
- arg->constant_value = std::move(maybe_constant.ValueOrDie().value());
- arg->shape = expression.GetShape().ValueOrDie();
- resolved_constant_idxs.push_back(i);
+ xla::StatusOr<Tensor> values_are_dynamic =
+ expression.ResolveDynamism(ctx->compiler()->client());
+ bool all_values_are_static = false;
+ if (!values_are_dynamic.ok()) {
+ // Conservatiely assume all values are dynamic.
+ all_values_are_static = true;
+ } else {
+ xla::Literal literal =
+ HostTensorToLiteral(values_are_dynamic.ValueOrDie()).ValueOrDie();
+ all_values_are_static = literal.IsAll(0);
+ }
+
+ if (all_values_are_static) {
+ arg->kind = XlaCompiler::Argument::kConstant;
+ arg->type = expression.dtype();
+ arg->constant_value = std::move(maybe_constant.ValueOrDie().value());
+ arg->shape = expression.GetShape().ValueOrDie();
+ resolved_constant_idxs.push_back(i);
+ } else {
+ arg->value_bound.emplace(
+ std::move(maybe_constant.ValueOrDie().value()));
+ }
}
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 99f5010..22ade14 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -25,6 +25,7 @@
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -97,10 +98,11 @@
bound_shape.DebugString()));
int64 bound;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
-
- xla::XlaOp result = xla::CustomCall(
- ctx->builder(), "SetBound", {ctx->Input("input")},
- ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound));
+ xla::Literal bound_literal = xla::LiteralUtil::CreateR0<int32>(bound);
+ xla::XlaOp result =
+ xla::CustomCall(ctx->builder(), "SetBound", {ctx->Input("input")},
+ ctx->InputXlaShape("input").ValueOrDie(), "", false, {},
+ &bound_literal);
ctx->SetOutput(0, result);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index d5e7577..0da48c3 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -68,11 +68,12 @@
errors::InvalidArgument(
"Strides have to be one when inputs are not constant."));
}
- // Infer static output shape, reconsile unknown dimension with input dim
+ // Infer static output shape, reconcile unknown dimension with input dim
// size.
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
if (partial_final_shape.dim_size(i) == -1) {
- // Use input shape shape_spec.
+ // Use input shape to update unknown dimension of partial shape -- if a
+ // dimension is unknown, we use input shape as bound.
partial_final_shape.set_dim(
i,
input_shape.dim_size(shape_spec.output_to_processing_mapping[i]));
@@ -87,7 +88,8 @@
", output shape must be a compile-time constant"));
for (int64 i = 0; i < partial_processing_shape.dims(); ++i) {
if (partial_processing_shape.dim_size(i) == -1) {
- // Use input shape shape_spec.
+ // Use input shape to update unknown dimension of partial shape -- if a
+ // dimension is unknown, we use input shape as bound.
partial_processing_shape.set_dim(i, input_shape.dim_size(i));
}
}
@@ -408,15 +410,23 @@
"one when inputs are not constant."));
}
- auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
- zero = xla::Broadcast(zero, input_shape.dim_sizes());
xla::XlaOp grad = ctx->Input(4);
xla::Shape grad_shape = ctx->InputXlaShape(4).ValueOrDie();
- // Undo any new/shrink axes.
VLOG(1) << "xla grad shape" << grad_shape;
+ VLOG(1) << "xla final_shape" << final_shape;
VLOG(1) << "input_shape" << input_shape.DebugString();
- std::vector<xla::XlaOp> begins(processing_shape.dims(),
- xla::Zero(ctx->builder(), xla::S32));
+ auto input_sizes = input_shape.dim_sizes();
+ // For unknown output dim the bound of the output shape is input. Pad and
+ // double the size of input shape to leave enough buffer to avoid OOB
+ // dynamic update slice.
+ auto input_sizes_padded = input_shape.dim_sizes();
+ bool need_padding = false;
+ for (int64 i = 0; i < processing_shape.dims(); ++i) {
+ if (processing_shape.dim_size(i) == -1) {
+ input_sizes_padded[i] *= 2;
+ need_padding = true;
+ }
+ }
for (int64 i = 0; i < grad_shape.rank(); ++i) {
// Use grad shape, which is known, to update unknown processing shape.
// Grad shape is the output of the ValidateStridedSliceOp function in
@@ -425,26 +435,44 @@
processing_shape.set_dim(shape_spec.output_to_processing_mapping[i],
grad_shape.dimensions(i));
}
-
- // Similarly, use output_to_sparse_mapping to find out corresponding
- // begin dim of the output, as indices for dynamic update slice.
- int64 begin_dim = shape_spec.output_to_sparse_mapping[i];
- if (begin_dim != -1) {
- auto begin_index =
- xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
- auto begin_index_scalar = xla::Reshape(
- xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index);
- begins[shape_spec.output_to_sparse_mapping[i]] = begin_index_scalar;
- }
}
- VLOG(1) << "processing_shape" << processing_shape.DebugString();
- TensorShape full_processing_shape;
- OP_REQUIRES(ctx, processing_shape.AsTensorShape(&full_processing_shape),
- errors::InvalidArgument(
- "Processing shape ", processing_shape.DebugString(),
- " can't be fully inferred from grad shape"));
- grad = xla::Reshape(grad, full_processing_shape.dim_sizes());
+
+ std::vector<xla::XlaOp> begins;
+ begins.reserve(processing_shape.dims());
+ for (int64 i = 0; i < input_shape.dims(); ++i) {
+ bool begin_mask = (1 << i) & shape_spec.begin_dense_mask;
+ // Similarly, use processing_to_sparse_mapping to find out corresponding
+ // begin dim of the gradient, as indices for dynamic update slice.
+ int64 begin_dim = shape_spec.processing_to_sparse_mapping[i];
+ xla::XlaOp begin_index;
+ auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
+ if (begin_mask) {
+ begin_index = zero;
+ } else {
+ xla::XlaOp dim_size = xla::Slice(ctx->Input(0), {i}, {i + 1}, {1});
+ dim_size = xla::Reshape(dim_size, {});
+ begin_index =
+ xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
+ begin_index = xla::Reshape(begin_index, {});
+ auto index_negative = xla::Lt(begin_index, zero);
+ auto wrapped_index = xla::Add(dim_size, begin_index);
+ // Wrap negative indices around.
+ begin_index = xla::Select(index_negative, wrapped_index, begin_index);
+ }
+ begins.push_back(begin_index);
+ }
+ auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
+
+ zero = xla::Broadcast(zero, input_sizes_padded);
+ grad = xla::Reshape(grad, processing_shape.dim_sizes());
grad = xla::DynamicUpdateSlice(zero, grad, begins);
+ if (need_padding) {
+ // We padded the input shape to avoid OOB when DUS. Now slice out the
+ // padding in the final result.
+ std::vector<int64> strides(input_shape.dims(), 1);
+ std::vector<int64> start_indices(input_shape.dims(), 0);
+ grad = xla::Slice(grad, start_indices, input_sizes, strides);
+ }
ctx->SetOutput(0, grad);
}
void Compile(XlaOpKernelContext* ctx) override {
diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
index 273718e..f58001f 100644
--- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
+++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
@@ -36,15 +36,32 @@
namespace {
-// Checks if the module has any TPU devices in its device list.
-bool HasTPUDevice(mlir::ModuleOp op) {
- mlir::TF::RuntimeDevices devices;
- if (failed(GetDevicesFromOp(op.getOperation(), &devices))) return false;
+constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
- for (const auto& device : devices.device_names()) {
- if (device.has_type && device.type == "TPU") return true;
- }
- return false;
+bool HasTPUDevice(mlir::ModuleOp module) {
+ mlir::TF::RuntimeDevices devices;
+ if (failed(GetDevicesFromOp(module.getOperation(), &devices))) return false;
+ return absl::c_any_of(
+ devices.device_names(),
+ [](const tensorflow::DeviceNameUtils::ParsedName& device) {
+ return device.has_type && device.type == "TPU";
+ });
+}
+
+bool HasTPUOp(mlir::ModuleOp module) {
+ auto walk_result = module.walk([&](mlir::Operation* op) {
+ auto replicate_attr =
+ op->getAttrOfType<mlir::StringAttr>(kTPUReplicateAttr);
+ if (replicate_attr) return mlir::WalkResult::interrupt();
+ return mlir::WalkResult::advance();
+ });
+ return walk_result.wasInterrupted();
+}
+
+// Checks that the module has both - TPU devices in its device list and contains
+// TPU ops (identifed by `_tpu_replicate` attribute on ops).
+bool HasTPUDevicesAndOps(mlir::ModuleOp module) {
+ return HasTPUDevice(module) && HasTPUOp(module);
}
bool HasTPUDevice(const DeviceSet& device_set) {
@@ -74,8 +91,10 @@
return MlirOptimizationPassState::Disabled;
}
- MlirBridgeRolloutPolicy policy =
- GetMlirBridgeRolloutPolicy(graph, config_proto);
+ // We set `uses_uninitialized_resource_args` to false here because the first
+ // phase of the bridge is not affected by uninitialized resource args.
+ MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
+ graph, config_proto, /*uses_uninitialized_resource_args=*/false);
switch (policy) {
case MlirBridgeRolloutPolicy::kEnabledByUser:
return MlirOptimizationPassState::Enabled;
@@ -104,9 +123,9 @@
return Status::OK();
}
- // Skip MLIR TPU Bridge if no TPU devices found.
- if (!HasTPUDevice(module)) {
- VLOG(1) << "Skipping MLIR TPU Bridge, no TPU devices found";
+ // Skip MLIR TPU Bridge if no TPU devices or TPU ops found.
+ if (!HasTPUDevicesAndOps(module)) {
+ VLOG(1) << "Skipping MLIR TPU Bridge, no TPU devices or TPU ops found";
return Status::OK();
}
@@ -127,8 +146,10 @@
// Do not run the bridge if it's enabled by the graph analysis,
// only run if it's enabled by the user explicitly.
- MlirBridgeRolloutPolicy policy =
- GetMlirBridgeRolloutPolicy(graph, config_proto);
+ // We set `uses_uninitialized_resource_args` to false here because the first
+ // phase of the bridge is not affected by uninitialized resource args.
+ MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
+ graph, config_proto, /*uses_uninitialized_resource_args=*/false);
return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
}
@@ -146,9 +167,10 @@
return Status::OK();
}
- // Skip MLIR TPU Bridge if no TPU devices found.
- if (!HasTPUDevice(module)) {
- VLOG(1) << "Skipping MLIR TPU Bridge V1 Compat, no TPU devices found";
+ // Skip MLIR TPU Bridge if no TPU devices or TPU ops found.
+ if (!HasTPUDevicesAndOps(module)) {
+ VLOG(1) << "Skipping MLIR TPU Bridge V1 Compat, no TPU devices or TPU ops "
+ "found";
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_argument.cc b/tensorflow/compiler/tf2xla/xla_argument.cc
index fe31025..8b91dd3 100644
--- a/tensorflow/compiler/tf2xla/xla_argument.cc
+++ b/tensorflow/compiler/tf2xla/xla_argument.cc
@@ -15,6 +15,8 @@
#include "tensorflow/compiler/tf2xla/xla_argument.h"
+#include "llvm/ADT/STLExtras.h"
+
namespace tensorflow {
bool XlaArgument::operator==(const XlaArgument& other) const {
@@ -50,4 +52,10 @@
return constant_value.tensor_data() == other.constant_value.tensor_data();
}
+bool AnyUninitializedResourceArg(absl::Span<const XlaArgument> args) {
+ return llvm::any_of(args, [](const XlaArgument& arg) {
+ return arg.kind == XlaArgument::kResource && arg.type == DT_INVALID;
+ });
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h
index c304c47..87509c0 100644
--- a/tensorflow/compiler/tf2xla/xla_argument.h
+++ b/tensorflow/compiler/tf2xla/xla_argument.h
@@ -16,6 +16,8 @@
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
+#include <optional>
+
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
@@ -75,6 +77,9 @@
// host-memory tensor.
Tensor constant_value;
+ // The upper bounds of the value.
+ absl::optional<Tensor> value_bound;
+
// The name of this argument, used for debugging.
string name;
@@ -119,6 +124,9 @@
string ShapeHumanString() const;
};
+// Returns true if any of `args` is an uninitialized resource variable.
+bool AnyUninitializedResourceArg(absl::Span<const XlaArgument> args);
+
} // end namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index c04d1e0..7c41e77 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -804,8 +804,9 @@
}
VLOG(1) << "====================================================";
- MlirBridgeRolloutPolicy policy =
- GetMlirBridgeRolloutPolicy(*graph, config_proto);
+ MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
+ *graph, config_proto,
+ /*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
@@ -1158,6 +1159,10 @@
xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
} else {
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
+ if (arg.value_bound) {
+ // Propagate upper bound to arg_expression.
+ arg_expression.set_value_bound(arg.value_bound.value());
+ }
}
break;
case XlaCompiler::Argument::kTensorList: {
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index 40b154b..498b3f8 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -170,7 +170,9 @@
TF_ASSIGN_OR_RETURN(bool is_constant,
handle().builder()->IsConstant(handle()));
- if (!is_constant) return {absl::nullopt};
+ if (!is_constant) {
+ return {absl::nullopt};
+ }
if (!client)
return errors::InvalidArgument("client is required to resolve constant");
diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h
index fd6b311..408afef 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.h
+++ b/tensorflow/compiler/tf2xla/xla_expression.h
@@ -94,6 +94,13 @@
return constant_value_;
}
+ // Set the bound of the expression.
+ void set_value_bound(Tensor tensor) {
+ value_bound_.emplace(std::move(tensor));
+ }
+
+ // Return the bound of the expression, if available.
+ absl::optional<Tensor> value_bound() const { return value_bound_; }
XlaResource* resource() const { return resource_; }
// Returns a human-readable summary of the expression.
@@ -138,6 +145,9 @@
// The value of the constant, if available.
absl::optional<Tensor> constant_value_;
+ // The bound of the expression, if available.
+ absl::optional<Tensor> value_bound_;
+
// The resource, if kind_ == kResource. Not owned.
XlaResource* resource_ = nullptr;
};
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 36db01b..d5cc955 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -397,9 +397,8 @@
const std::unordered_set<string>* compile_time_constant_inputs;
- if (GetNodeAttr(node_def, kXlaCompileTimeConstantInputsAttr,
- &compile_time_constant_inputs_vect_from_attr)
- .ok()) {
+ if (TryGetNodeAttr(node_def, kXlaCompileTimeConstantInputsAttr,
+ &compile_time_constant_inputs_vect_from_attr)) {
absl::c_copy(compile_time_constant_inputs_vect_from_attr,
std::inserter(compile_time_constant_inputs_from_attr,
compile_time_constant_inputs_from_attr.end()));
diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h
index a85d551..8ca0332 100644
--- a/tensorflow/compiler/xla/array.h
+++ b/tensorflow/compiler/xla/array.h
@@ -561,6 +561,7 @@
index *= sizes_[i];
index += indexes[i];
}
+ DCHECK_LT(index, this->num_elements());
return index;
}
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 35cd1c2..b8bfb7e 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1882,7 +1882,8 @@
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing) {
+ output_operand_aliasing,
+ const Literal* literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
@@ -1915,7 +1916,7 @@
}
return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
- output_operand_aliasing);
+ output_operand_aliasing, literal);
});
}
@@ -1925,7 +1926,8 @@
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing) {
+ output_operand_aliasing,
+ const Literal* literal) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
@@ -1936,6 +1938,9 @@
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
}
}
+ if (literal != nullptr) {
+ *instr.mutable_literal() = literal->ToProto();
+ }
instr.set_custom_call_has_side_effect(has_side_effect);
for (const auto& pair : output_operand_aliasing) {
auto aliasing = instr.add_custom_call_output_operand_aliasing();
@@ -1956,7 +1961,8 @@
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing) {
+ output_operand_aliasing,
+ const Literal* literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1968,6 +1974,9 @@
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
+ if (literal != nullptr) {
+ *instr.mutable_literal() = literal->ToProto();
+ }
if (operand_shapes_with_layout.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument(
@@ -3786,6 +3795,8 @@
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
InstrIsSetBound(instr_proto)) {
int32 constant_value = -1;
+ HloInstructionProto const_instr;
+
if (instr_proto->opcode() ==
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
// At this point, BuildConstantSubGraph should never encounter a
@@ -3804,18 +3815,14 @@
constant_value =
static_cast<int32>(operand_proto->shape().dimensions(dimension));
}
+ Literal literal = LiteralUtil::CreateR0(constant_value);
+ *const_instr.mutable_literal() = literal.ToProto();
+ *const_instr.mutable_shape() = literal.shape().ToProto();
} else {
- TF_RET_CHECK(
- absl::SimpleAtoi(instr_proto->backend_config(), &constant_value));
+ *const_instr.mutable_literal() = instr_proto->literal();
+ *const_instr.mutable_shape() = instr_proto->shape();
}
-
- Literal literal = LiteralUtil::CreateR0(constant_value);
-
- HloInstructionProto const_instr;
- *const_instr.mutable_shape() = literal.shape().ToProto();
- *const_instr.mutable_literal() = literal.ToProto();
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
-
const_instr.set_id(handle);
*const_instr.mutable_name() =
GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
@@ -3866,7 +3873,6 @@
}
}
*module->add_computations() = std::move(entry);
-
return std::move(computation);
}
@@ -4459,10 +4465,11 @@
absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing) {
+ output_operand_aliasing,
+ const Literal* literal) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
- has_side_effect, output_operand_aliasing);
+ has_side_effect, output_operand_aliasing, literal);
}
XlaOp CustomCallWithComputation(
@@ -4470,11 +4477,12 @@
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing) {
+ output_operand_aliasing,
+ const Literal* literal) {
return builder->CustomCall(call_target_name, operands, computation, shape,
opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
- has_side_effect, output_operand_aliasing);
+ has_side_effect, output_operand_aliasing, literal);
}
XlaOp CustomCallWithLayout(
@@ -4483,10 +4491,11 @@
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing) {
+ output_operand_aliasing,
+ const Literal* literal) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
- output_operand_aliasing);
+ output_operand_aliasing, literal);
}
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index cb212e1..cc1806b 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -655,7 +655,8 @@
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing);
+ output_operand_aliasing,
+ const Literal* literal);
// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
@@ -666,7 +667,8 @@
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing);
+ output_operand_aliasing,
+ const Literal* literal);
XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
@@ -675,7 +677,8 @@
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing);
+ output_operand_aliasing,
+ const Literal* literal);
XlaOp Reduce(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
@@ -1214,20 +1217,23 @@
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing);
+ output_operand_aliasing,
+ const Literal* literal);
friend XlaOp CustomCallWithComputation(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing);
+ output_operand_aliasing,
+ const Literal* literal);
friend XlaOp CustomCallWithLayout(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing);
+ output_operand_aliasing,
+ const Literal* literal);
friend XlaOp Complex(XlaOp real, XlaOp imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(XlaOp operand);
@@ -2025,7 +2031,8 @@
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing = {});
+ output_operand_aliasing = {},
+ const Literal* literal = nullptr);
// Overload which constructs a custom call that applies an Xla computation.
XlaOp CustomCallWithComputation(
@@ -2033,7 +2040,8 @@
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing = {});
+ output_operand_aliasing = {},
+ const Literal* literal = nullptr);
// Overload which constructs a custom call with fixed layouts. The operands will
// have the layouts specified by |operand_shapes_with_layout| when provided to
@@ -2046,7 +2054,8 @@
absl::Span<const Shape> operand_shapes_with_layout,
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
- output_operand_aliasing = {});
+ output_operand_aliasing = {},
+ const Literal* literal = nullptr);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index d15f78c..57a2ec1 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -27,6 +27,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/index_util.h"
@@ -71,6 +72,22 @@
}
}
+string CompactOneline(const string& input) {
+ string result;
+ std::vector<string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
+ bool first = true;
+ // Concatenate elements in "v" with spaces separating them, but ignoring
+ // empty entries.
+ for (const auto& s : v) {
+ if (s.empty()) {
+ continue;
+ }
+ absl::StrAppend(&result, (first ? "" : " "), s);
+ first = false;
+ }
+ return result;
+}
+
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
// able to transparently access the raw 16-bit value contained within.
template <typename T>
@@ -1281,6 +1298,10 @@
return absl::StrJoin(pieces, "");
}
+string LiteralBase::ToStringOneline() const {
+ return CompactOneline(ToString());
+}
+
string LiteralBase::ToStringWithoutShape() const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
@@ -1289,6 +1310,10 @@
return absl::StrJoin(pieces, "");
}
+string LiteralBase::ToStringWithoutShapeOneline() const {
+ return CompactOneline(ToStringWithoutShape());
+}
+
string LiteralBase::ToStringWithLayout() const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 1ee7161..4147436 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -94,10 +94,18 @@
// element Literals.
string ToString() const;
+ // Similar to ToString, but return the result in a compact
+ // one-line form.
+ string ToStringOneline() const;
+
// Returns a string representation of the literal value which does *not*
// include the shape string.
string ToStringWithoutShape() const;
+ // Similar to ToStringWithoutShape, but return the result in a compact
+ // one-line form.
+ string ToStringWithoutShapeOneline() const;
+
// Returns a string representation of the literal value which includes the
// shape string with its layout.does *not* include the shape string.
string ToStringWithLayout() const;
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
index e7ce064..9423658 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
+++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
@@ -513,7 +513,8 @@
: parent_(other.parent_),
type_(other.type_),
state_(other.state_),
- buffer_or_(std::move(other.buffer_or_)) {
+ status_(std::move(other.status_)),
+ buffer_(std::move(other.buffer_)) {
// Preserve the invariant that status is invalid if buffer == nullptr.
other.SetState(kMoved);
}
@@ -521,16 +522,23 @@
void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
CHECK(!ok());
- buffer_or_ = std::move(buffer_or);
- SetState(buffer_or_.ok() ? kValid : kError);
+ if (buffer_or.ok()) {
+ buffer_ = buffer_or.ValueOrDie();
+ SetState(kValid);
+ } else {
+ status_ = buffer_or.status();
+ buffer_ = nullptr;
+ SetState(kError);
+ }
// Check the invariant holds.
- CHECK(!ok() || buffer_or_.ValueOrDie() != nullptr);
+ CHECK(!ok() || buffer_ != nullptr);
}
PjRtStreamExecutorBuffer::ScopedHold::ForClosure
PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
CHECK(ok());
- ForClosure for_closure(parent_, type_, state_, std::move(buffer_or_));
+ ForClosure for_closure(parent_, type_, state_, std::move(status_),
+ std::move(buffer_));
SetState(kReleased);
return for_closure;
}
@@ -1274,7 +1282,7 @@
absl::MutexLock lock(&mu_);
ScopedHold hold(this, type);
AcquireHoldLocked(&hold);
- if (type == ScopedHold::kDonation && !hold.status().ok()) {
+ if (type == ScopedHold::kDonation && !hold.ok()) {
donation_semaphore_.Release(1);
}
return hold;
@@ -2029,12 +2037,16 @@
const int partition = addressable_device_logical_ids_[i].partition;
auto& statusor = results[i];
if (!statusor.ok()) {
- return AppendStatus(
- statusor.status(),
- absl::StrFormat("while running replica %d and partition %d of a "
- "replicated computation (other "
- "replicas may have failed as well).",
- replica, partition));
+ if (num_addressable_devices == 1) {
+ return statusor.status();
+ } else {
+ return AppendStatus(
+ statusor.status(),
+ absl::StrFormat("while running replica %d and partition %d of a "
+ "replicated computation (other "
+ "replicas may have failed as well).",
+ replica, partition));
+ }
}
wrapped_results[i] = std::move(statusor.ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
index 34f819d..d56b19a 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
@@ -367,7 +367,7 @@
case kDonated:
return InvalidArgument("Buffer has been donated");
case kError:
- return buffer_or_.status();
+ return status_;
default:
CHECK(false) << "Unexpected state value " << state_;
}
@@ -377,8 +377,8 @@
// Access to the underlying device buffer storage. Requires this->ok().
const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
CHECK_EQ(state_, kValid);
- CHECK_NE(buffer_or_.ValueOrDie(), nullptr);
- return buffer_or_.ValueOrDie();
+ CHECK_NE(buffer_, nullptr);
+ return buffer_;
}
TrackedDeviceBuffer* operator->() const { return buffer().get(); }
const TrackedDeviceBuffer& operator*() const { return *buffer(); }
@@ -420,9 +420,8 @@
// Helper struct that makes it possible to move a ScopedHold through a
// closure.
- using ForClosure =
- std::tuple<PjRtStreamExecutorBuffer*, Type, State,
- StatusOr<std::shared_ptr<TrackedDeviceBuffer>>>;
+ using ForClosure = std::tuple<PjRtStreamExecutorBuffer*, Type, State,
+ Status, std::shared_ptr<TrackedDeviceBuffer>>;
ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
: parent_(parent), type_(type), state_(kUninitialized) {}
@@ -430,15 +429,16 @@
: parent_(std::get<0>(closure_helper)),
type_(std::get<1>(closure_helper)),
state_(std::get<2>(closure_helper)),
- buffer_or_(std::get<3>(closure_helper)) {
+ status_(std::get<3>(closure_helper)),
+ buffer_(std::get<4>(closure_helper)) {
// Check the buffer is not in an error state.
- CHECK(buffer_or_.ValueOrDie() != nullptr);
+ CHECK(status_.ok() && buffer_ != nullptr);
}
// Sets buffer state.
void SetState(State state) { state_ = state; }
- // Sets buffer_or_. Called by parent_ to initialize the hold.
+ // Sets buffer_ and status_. Called by parent_ to initialize the hold.
void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
// Releases the contents of *this, so *this can subsequently be
// deleted without releasing the parent's hold. Should be passed to the
@@ -450,9 +450,10 @@
const Type type_;
// There is an invariant that if ok() then
- // buffer_or_.ValueOrDie() != nullptr.
+ // buffer_.ValueOrDie() != nullptr.
State state_;
- StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
+ Status status_;
+ std::shared_ptr<TrackedDeviceBuffer> buffer_;
};
PjRtStreamExecutorBuffer(Shape on_device_shape,
diff --git a/tensorflow/compiler/xla/python/pmap_lib.cc b/tensorflow/compiler/xla/python/pmap_lib.cc
index 2f6f909..f1be1ca 100644
--- a/tensorflow/compiler/xla/python/pmap_lib.cc
+++ b/tensorflow/compiler/xla/python/pmap_lib.cc
@@ -350,7 +350,6 @@
py::class_<Chunked> chunked(pmap_lib, "Chunked");
chunked.def(py::init<std::vector<int>>())
.def_readonly("chunks", &Chunked::chunks)
- .def_readonly("num_chunks", &Chunked::chunks)
.def("__repr__",
[](const Chunked& chuncked) {
return absl::StrCat("Chunked(",
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ba65a29..44128b90 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3961,6 +3961,7 @@
deps = [
":algebraic_simplifier",
":computation_layout",
+ ":dynamic_padder",
":hlo",
":hlo_parser",
":layout_assignment",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 6c84573..b6e233e 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -3195,6 +3195,30 @@
}
}
+ if (Cast<HloCompareInstruction>(compare)->type() ==
+ Comparison::Type::kUnsigned) {
+ // X u< 0 -> false
+ if (compare->comparison_direction() == ComparisonDirection::kLt &&
+ IsAll(rhs, 0)) {
+ return ReplaceInstruction(compare, MakeScalarLike(compare, false));
+ }
+ // X u>= 0 -> true
+ if (compare->comparison_direction() == ComparisonDirection::kGe &&
+ IsAll(rhs, 0)) {
+ return ReplaceInstruction(compare, MakeScalarLike(compare, true));
+ }
+ // 0 u> X -> false
+ if (compare->comparison_direction() == ComparisonDirection::kGt &&
+ IsAll(lhs, 0)) {
+ return ReplaceInstruction(compare, MakeScalarLike(compare, false));
+ }
+ // 0 u<= X -> true
+ if (compare->comparison_direction() == ComparisonDirection::kLe &&
+ IsAll(lhs, 0)) {
+ return ReplaceInstruction(compare, MakeScalarLike(compare, true));
+ }
+ }
+
if (compare->comparison_direction() == ComparisonDirection::kLt &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 88f45e8..0ef62b4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -6259,6 +6259,121 @@
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
}
+TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ zero = u32[] constant(0)
+ param = u32[] parameter(0)
+ ROOT compare = pred[] compare(param, zero), direction=LT
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::ConstantScalar(false)));
+}
+
+TEST_F(AlgebraicSimplifierTest, CompareLeZero) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ zero = u32[] constant(0)
+ param = u32[] parameter(0)
+ ROOT compare = pred[] compare(param, zero), direction=LE
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Le(m::Parameter(0), m::ConstantEffectiveScalar(0))));
+}
+
+TEST_F(AlgebraicSimplifierTest, CompareGeZero) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ zero = u32[] constant(0)
+ param = u32[] parameter(0)
+ ROOT compare = pred[] compare(param, zero), direction=GE
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::ConstantScalar(true)));
+}
+
+TEST_F(AlgebraicSimplifierTest, CompareGtZero) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ zero = u32[] constant(0)
+ param = u32[] parameter(0)
+ ROOT compare = pred[] compare(param, zero), direction=GT
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Gt(m::Parameter(0), m::ConstantEffectiveScalar(0))));
+}
+
+TEST_F(AlgebraicSimplifierTest, CompareZeroGt) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ zero = u32[] constant(0)
+ param = u32[] parameter(0)
+ ROOT compare = pred[] compare(zero, param), direction=GT
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::ConstantScalar(false)));
+}
+
+TEST_F(AlgebraicSimplifierTest, CompareZeroGe) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ zero = u32[] constant(0)
+ param = u32[] parameter(0)
+ ROOT compare = pred[] compare(zero, param), direction=GE
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Ge(m::ConstantEffectiveScalar(0), m::Parameter(0))));
+}
+
+TEST_F(AlgebraicSimplifierTest, CompareZeroLe) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ zero = u32[] constant(0)
+ param = u32[] parameter(0)
+ ROOT compare = pred[] compare(zero, param), direction=LE
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::ConstantScalar(true)));
+}
+
+TEST_F(AlgebraicSimplifierTest, CompareZeroLt) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ zero = u32[] constant(0)
+ param = u32[] parameter(0)
+ ROOT compare = pred[] compare(zero, param), direction=LT
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Lt(m::ConstantEffectiveScalar(0), m::Parameter(0))));
+}
+
TEST_F(AlgebraicSimplifierTest, CompareSame) {
const char* kModuleStr = R"(
HloModule m
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index dda4f55..1ac7983 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -1160,6 +1160,7 @@
"@llvm-project//llvm:Linker",
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LLVMIRModuleTranslation",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TargetLLVMIR",
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 643de6c..48e372c 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -160,32 +160,42 @@
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
std::vector<llvm::VecDesc> result = {
- {"tanhf", runtime::kTanhV4F32SymbolName, 4},
- {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4},
+ {"tanhf", runtime::kTanhV4F32SymbolName, llvm::ElementCount::getFixed(4)},
+ {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName,
+ llvm::ElementCount::getFixed(4)},
- {"tanhf", runtime::kTanhV8F32SymbolName, 8},
- {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8},
+ {"tanhf", runtime::kTanhV8F32SymbolName, llvm::ElementCount::getFixed(8)},
+ {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName,
+ llvm::ElementCount::getFixed(8)},
- {"tanhf", runtime::kTanhV16F32SymbolName, 16},
- {"llvm.tanh.f32", runtime::kTanhV16F32SymbolName, 16},
+ {"tanhf", runtime::kTanhV16F32SymbolName,
+ llvm::ElementCount::getFixed(16)},
+ {"llvm.tanh.f32", runtime::kTanhV16F32SymbolName,
+ llvm::ElementCount::getFixed(16)},
- {"expf", runtime::kExpV4F32SymbolName, 4},
- {"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4},
+ {"expf", runtime::kExpV4F32SymbolName, llvm::ElementCount::getFixed(4)},
+ {"llvm.exp.f32", runtime::kExpV4F32SymbolName,
+ llvm::ElementCount::getFixed(4)},
- {"expf", runtime::kExpV8F32SymbolName, 8},
- {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8},
+ {"expf", runtime::kExpV8F32SymbolName, llvm::ElementCount::getFixed(8)},
+ {"llvm.exp.f32", runtime::kExpV8F32SymbolName,
+ llvm::ElementCount::getFixed(8)},
- {"expf", runtime::kExpV16F32SymbolName, 16},
- {"llvm.exp.f32", runtime::kExpV16F32SymbolName, 16},
+ {"expf", runtime::kExpV16F32SymbolName, llvm::ElementCount::getFixed(16)},
+ {"llvm.exp.f32", runtime::kExpV16F32SymbolName,
+ llvm::ElementCount::getFixed(16)},
- {"logf", runtime::kLogV4F32SymbolName, 4},
- {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4},
+ {"logf", runtime::kLogV4F32SymbolName, llvm::ElementCount::getFixed(4)},
+ {"llvm.log.f32", runtime::kLogV4F32SymbolName,
+ llvm::ElementCount::getFixed(4)},
- {"logf", runtime::kLogV8F32SymbolName, 8},
- {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8},
+ {"logf", runtime::kLogV8F32SymbolName, llvm::ElementCount::getFixed(8)},
+ {"llvm.log.f32", runtime::kLogV8F32SymbolName,
+ llvm::ElementCount::getFixed(8)},
- {"logf", runtime::kLogV16F32SymbolName, 16},
- {"llvm.log.f32", runtime::kLogV16F32SymbolName, 16},
+ {"logf", runtime::kLogV16F32SymbolName, llvm::ElementCount::getFixed(16)},
+ {"llvm.log.f32", runtime::kLogV16F32SymbolName,
+ llvm::ElementCount::getFixed(16)},
};
return result;
}
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index a4566b1..6b8b156 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -33,9 +33,9 @@
namespace xla {
namespace cpu {
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
+ PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
+ absl::string_view /*name*/) {
string function_name;
bool cast_result_to_fp16 = false;
switch (prim_type) {
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index a002df2..3a06466 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -37,7 +37,8 @@
protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) override;
+ llvm::Value* rhs,
+ absl::string_view name) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 608a743..7827f1a 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1532,7 +1532,7 @@
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
HloInstruction* arg, absl::Span<const int64> dimensions,
- unsigned element_alignment) {
+ llvm::Align element_alignment) {
ShardedVector accumulator;
accumulator.reserve(accumulator_type.size());
for (auto accumulator_shard_type : accumulator_type) {
@@ -1608,7 +1608,7 @@
void IrEmitter::EmitShardedVectorStore(
llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
- const int alignment, const llvm_ir::IrArray& containing_array) {
+ llvm::Align alignment, const llvm_ir::IrArray& containing_array) {
for (int i = 0; i < value_to_store.size(); i++) {
auto store_address_typed =
BitCast(store_address,
@@ -1666,9 +1666,9 @@
bool is_reduction_over_minor_dimension = absl::c_linear_search(
dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
- unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
+ llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
- MinimumAlignmentForPrimitiveType(reduce->shape().element_type()));
+ MinimumAlignmentForPrimitiveType(reduce->shape().element_type())));
if (is_reduction_over_minor_dimension) {
// TODO(sanjoy): Implement vectorized reduction over the minor dimension.
@@ -2583,8 +2583,8 @@
const llvm_ir::IrArray& source_array) {
unsigned primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
- unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
- primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
+ llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
+ primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)));
llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 49490ef..f7762df 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -379,7 +379,7 @@
// "store_address".
void EmitShardedVectorStore(llvm::Value* store_address,
const ShardedVector& value_to_store,
- const int alignment,
+ llvm::Align alignment,
const llvm_ir::IrArray& containing_array);
using ReductionGenerator = std ::function<llvm::Value*(
@@ -399,7 +399,7 @@
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
HloInstruction* arg, absl::Span<const int64> dimensions,
- unsigned element_alignment);
+ llvm::Align element_alignment);
// Tries to emit a fast concatenate operation using memcpy. Returns true if
// successful, and false on failure. On failure, sets "failure_reason" to a
diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
index ee0ffd6..43a48a1 100644
--- a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
@@ -24,6 +24,7 @@
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/LLVMIR.h" // from @llvm-project
+#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index 48aa32f..cc903da 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -219,7 +219,8 @@
pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
}
return b()->CreateAlignedLoad(
- pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
+ pointer, llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)),
+ name());
}
llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
@@ -227,7 +228,8 @@
pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
}
return b()->CreateAlignedLoad(
- pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
+ pointer, llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)),
+ name());
}
void VectorSupportLibrary::StoreVector(llvm::Value* value,
@@ -236,8 +238,9 @@
if (pointer->getType() != vector_pointer_type()) {
pointer = b()->CreateBitCast(pointer, vector_pointer_type());
}
- b()->CreateAlignedStore(value, pointer,
- ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
+ b()->CreateAlignedStore(
+ value, pointer,
+ llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
}
void VectorSupportLibrary::StoreScalar(llvm::Value* value,
@@ -246,8 +249,9 @@
if (pointer->getType() != scalar_pointer_type()) {
pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
}
- b()->CreateAlignedStore(value, pointer,
- ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
+ b()->CreateAlignedStore(
+ value, pointer,
+ llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
}
llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index 2328ad9..c41691e 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -1141,15 +1141,27 @@
return ForEachOperandDynamicDimension(
hlo,
[&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
- int64 /*operand_index*/, HloInstruction* dynamic_size) {
+ int64 operand_index, HloInstruction* dynamic_size) {
if (hlo->shape().dimensions(dimension) !=
hlo->operand(0)->shape().dimensions(dimension)) {
return Unimplemented(
- "Dynamic dimension propagation on DynamicSlice where a partial "
- "dimension is selected %s",
+ "Dynamic dimension propagation on DynamicUpdateSlice where a "
+ "partial dimension is selected %s",
hlo->ToString());
}
+ if (operand_index == 1 &&
+ hlo->operand(1)->shape().dimensions(dimension) <
+ hlo->operand(0)->shape().dimensions(dimension)) {
+ // DUS(input=[A], update=[<=B])
+ //
+ // If update dim is smaller than input dim (B < A) , then we are doing
+ // a partial update, no need to set the output dynamic dimension.
+ //
+ // The dynamic shape in `update` doesn't change output dynamic shape.
+ return Status::OK();
+ }
+
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
return Status::OK();
@@ -1659,8 +1671,10 @@
void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
HloInstruction* replace, HloInstruction* with) {
- CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32)));
- CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32)));
+ CHECK(Shape::Equal().IgnoreLayout()(replace->shape(),
+ ShapeUtil::MakeScalarShape(S32)));
+ CHECK(Shape::Equal().IgnoreLayout()(with->shape(),
+ ShapeUtil::MakeScalarShape(S32)));
for (auto& kv : dynamic_mapping_) {
if (kv.second == replace) {
kv.second = with;
diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc
index ab94695..dfe5a33 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder.cc
@@ -1282,6 +1282,109 @@
return true;
}
+StatusOr<bool> RewriteDynamicUpdateSlice(
+ HloInstruction* hlo,
+ DynamicDimensionInference* dynamic_dimension_inference) {
+ HloDynamicUpdateSliceInstruction* dus =
+ Cast<HloDynamicUpdateSliceInstruction>(hlo);
+ HloComputation* comp = hlo->parent();
+ // Suppose we have a base area that we want to update:
+ // +------------------------+
+ // | |
+ // | base |
+ // | |
+ // +------------------------+
+ //
+ // A partial update with dynamic padding looks like this:
+ //
+ // +------+-------+
+ // |update|padding|
+ // +------+-------+
+ //
+ // We don't want the padding to overwrite the base area:
+ //
+ // +------------------------+
+ // | +------+-------+
+ // |<-begin->|update|padding| (what we want to avoid)
+ // | +------+-------+
+ // +------------------------+
+ //
+ // Instead we want to keep the base area untouched except for the update
+ // region:
+ //
+ // +------------------------+
+ // | +------+ |
+ // |<-begin->|update| base | (what we want)
+ // | +------+ |
+ // +------------------------+
+ //
+ // We do this by dynamic slicing the base area out first with the same begin
+ // index:
+ //
+ // +--------------+
+ // <-begin-> | base |
+ // +--------------+
+ //
+ // Then replace the update's padding part with base:
+ //
+ // +------+-------+
+ // |update| base |
+ // +------+-------+
+ //
+ // Then do the DUS.
+
+ HloInstruction* update = dus->mutable_operand(1);
+ HloInstruction* base = dus->mutable_operand(0);
+ std::vector<HloInstruction*> dynamic_dims_in_partial_update(
+ update->shape().rank(), nullptr);
+ bool needs_rewrite = false;
+ for (int64 i = 0; i < update->shape().rank(); ++i) {
+ if (update->shape().dimensions(i) < base->shape().dimensions(i)) {
+ HloInstruction* dynamic_dim =
+ dynamic_dimension_inference->GetDynamicSize(update, {}, i);
+
+ if (dynamic_dim != nullptr) {
+ dynamic_dims_in_partial_update[i] = dynamic_dim;
+ needs_rewrite = true;
+ }
+ }
+ }
+
+ if (!needs_rewrite) {
+ return false;
+ }
+ std::vector<HloInstruction*> indices;
+ indices.reserve(dus->operand_count() - 2);
+ for (int64 i = 2; i < dus->operand_count(); ++i) {
+ indices.push_back(dus->mutable_operand(i));
+ }
+ HloInstruction* base_slice =
+ comp->AddInstruction(HloInstruction::CreateDynamicSlice(
+ update->shape(), base, indices, update->shape().dimensions()));
+
+ for (int64 i = 0; i < dynamic_dims_in_partial_update.size(); ++i) {
+ HloInstruction* dynamic_dim = dynamic_dims_in_partial_update[i];
+ if (dynamic_dim != nullptr) {
+ Shape mask_shape_int = ShapeUtil::ChangeElementType(update->shape(), S32);
+ Shape mask_shape_pred =
+ ShapeUtil::ChangeElementType(update->shape(), PRED);
+ // Generate mask using iota and dynamic_dim.
+ HloInstruction* iota =
+ comp->AddInstruction(HloInstruction::CreateIota(mask_shape_int, i));
+ HloInstruction* broadcast_dim = comp->AddInstruction(
+ HloInstruction::CreateBroadcast(mask_shape_int, dynamic_dim, {}));
+ HloInstruction* pred = comp->AddInstruction(HloInstruction::CreateCompare(
+ mask_shape_pred, iota, broadcast_dim, ComparisonDirection::kLt));
+ // Update `update` to include base.
+ update = comp->AddInstruction(HloInstruction::CreateTernary(
+ update->shape(), HloOpcode::kSelect, pred, update, base_slice));
+ }
+ }
+ TF_RETURN_IF_ERROR(dus->ReplaceOperandWith(1, update));
+
+ return true;
+}
+
StatusOr<bool> RewriteDynamicReshape(
HloInstruction* reshape,
DynamicDimensionInference* dynamic_dimension_inference) {
@@ -1732,6 +1835,12 @@
continue;
}
+ if (inst->opcode() == HloOpcode::kDynamicUpdateSlice) {
+ TF_ASSIGN_OR_RETURN(changed, RewriteDynamicUpdateSlice(
+ inst, &dynamic_dimension_inference));
+ continue;
+ }
+
if (inst->opcode() == HloOpcode::kDynamicReshape) {
TF_ASSIGN_OR_RETURN(
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 31ca1ab..1c417bc 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -446,7 +446,7 @@
primitive_util::BitWidth(to_type));
}
case HloOpcode::kExp:
- return EmitExp(op->shape().element_type(), operand_value);
+ return EmitExp(op->shape().element_type(), operand_value, "");
case HloOpcode::kExpm1:
return EmitExpm1(op->shape().element_type(), operand_value);
case HloOpcode::kLog:
@@ -528,7 +528,8 @@
// log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
- TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a));
+ TF_ASSIGN_OR_RETURN(llvm::Value * angle,
+ EmitAtan2(component_type, b, a, ""));
TF_ASSIGN_OR_RETURN(llvm::Value * abs,
EmitComplexAbs(component_type, operand_value));
TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
@@ -543,7 +544,8 @@
auto a_plus_one = FAdd(a, one);
auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
- TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
+ TF_ASSIGN_OR_RETURN(auto angle,
+ EmitAtan2(component_type, b, a_plus_one, ""));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
@@ -566,7 +568,8 @@
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
TF_ASSIGN_OR_RETURN(
- auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
+ auto exp_a,
+ EmitExp(component_type, EmitExtractReal(operand_value), ""));
TF_ASSIGN_OR_RETURN(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
@@ -576,7 +579,8 @@
case HloOpcode::kExpm1: {
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
TF_ASSIGN_OR_RETURN(
- auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
+ auto exp_a,
+ EmitExp(component_type, EmitExtractReal(operand_value), ""));
TF_ASSIGN_OR_RETURN(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
@@ -597,7 +601,7 @@
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
- TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
+ TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
@@ -619,7 +623,7 @@
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
- TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
+ TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
@@ -828,15 +832,15 @@
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
- return FAdd(lhs_value, rhs_value);
+ return FAdd(lhs_value, rhs_value, op->name());
case HloOpcode::kSubtract:
- return FSub(lhs_value, rhs_value);
+ return FSub(lhs_value, rhs_value, op->name());
case HloOpcode::kMultiply:
- return FMul(lhs_value, rhs_value);
+ return FMul(lhs_value, rhs_value, op->name());
case HloOpcode::kDivide:
- return FDiv(lhs_value, rhs_value);
+ return FDiv(lhs_value, rhs_value, op->name());
case HloOpcode::kRemainder:
- return FRem(lhs_value, rhs_value);
+ return FRem(lhs_value, rhs_value, op->name());
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
@@ -848,32 +852,34 @@
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
- rhs_value, b_);
+ rhs_value, b_, op->name());
case ComparisonDirection::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
- rhs_value, b_);
+ rhs_value, b_, op->name());
case ComparisonDirection::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
- rhs_value, b_);
+ rhs_value, b_, op->name());
case ComparisonDirection::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
- rhs_value, b_);
+ rhs_value, b_, op->name());
case ComparisonDirection::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
- rhs_value, b_);
+ rhs_value, b_, op->name());
case ComparisonDirection::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
- rhs_value, b_);
+ rhs_value, b_, op->name());
}
}
case HloOpcode::kMaximum:
- return EmitFloatMax(lhs_value, rhs_value);
+ return EmitFloatMax(lhs_value, rhs_value, op->name());
case HloOpcode::kMinimum:
- return EmitFloatMin(lhs_value, rhs_value);
+ return EmitFloatMin(lhs_value, rhs_value, op->name());
case HloOpcode::kPower:
- return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
+ return EmitPow(op->shape().element_type(), lhs_value, rhs_value,
+ op->name());
case HloOpcode::kAtan2:
- return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
+ return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value,
+ op->name());
default:
return Unimplemented("binary floating point op '%s'",
HloOpcodeString(op->opcode()));
@@ -901,8 +907,8 @@
llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
- llvm::Value* max = EmitFloatMax(abs_real, abs_imag);
- llvm::Value* min = EmitFloatMin(abs_real, abs_imag);
+ llvm::Value* max = EmitFloatMax(abs_real, abs_imag, "");
+ llvm::Value* min = EmitFloatMin(abs_real, abs_imag, "");
llvm::Value* div = FDiv(min, max);
llvm::Value* div_sq = FMul(div, div);
@@ -939,7 +945,7 @@
TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
TF_ASSIGN_OR_RETURN(llvm::Value * pow,
EmitPow(prim_type, one_p_div_sq,
- llvm::ConstantFP::get(max->getType(), .25)));
+ llvm::ConstantFP::get(max->getType(), .25), ""));
llvm::Value* result = FMul(sqrt_max, pow);
// When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
// In such cases, we return `min` instead of `result`.
@@ -983,7 +989,7 @@
llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
- TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
+ TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
llvm::Value* angle = FMul(t, c);
@@ -1039,7 +1045,7 @@
llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
- TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
+ TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
llvm::Value* angle = FMul(t, c);
@@ -1116,13 +1122,13 @@
auto half_c = FMul(one_half, c);
TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
- EmitPow(component_type, aa_p_bb, half_c));
+ EmitPow(component_type, aa_p_bb, half_c, ""));
auto neg_d = FNeg(d);
- TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
+ TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, ""));
auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
- EmitExp(component_type, neg_d_arg_lhs));
+ EmitExp(component_type, neg_d_arg_lhs, ""));
auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
auto half_d = FMul(one_half, d);
@@ -1314,13 +1320,15 @@
}
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
- llvm::Value* rhs_value) {
- return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
+ llvm::Value* rhs_value,
+ absl::string_view name) {
+ return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
- llvm::Value* rhs_value) {
- return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
+ llvm::Value* rhs_value,
+ absl::string_view name) {
+ return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
@@ -1404,9 +1412,10 @@
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
- llvm::Value* value) {
+ llvm::Value* value,
+ absl::string_view name) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
- {value->getType()}, b_);
+ {value->getType()}, b_, name);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
@@ -1417,7 +1426,7 @@
auto half = llvm::ConstantFP::get(type, 0.5);
// When the exponent is large, the naive evaluation of e^(x) - 1 is more
// accurate than the Taylor series.
- TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
+ TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value, ""));
auto for_large_x = FSub(exp_x, one);
// The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
// We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
@@ -1438,9 +1447,10 @@
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) {
+ llvm::Value* rhs,
+ absl::string_view name) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
- {lhs->getType()}, b_);
+ {lhs->getType()}, b_, name);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
@@ -1450,15 +1460,15 @@
auto abs_value =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
- EmitPow(prim_type, abs_value, third));
+ EmitPow(prim_type, abs_value, third, ""));
auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
{abs_res, value}, {type}, b_);
return signed_res;
}
-StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) {
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(
+ PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/,
+ absl::string_view /*name*/) {
return Unimplemented("atan2");
}
@@ -1728,7 +1738,7 @@
operand_to_generator.at(hlo->operand(2))(index));
PrimitiveType prim_type = hlo->shape().element_type();
if (primitive_util::IsFloatingPointType(prim_type)) {
- return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
+ return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), "");
} else if (primitive_util::IsIntegralType(prim_type)) {
bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
return EmitIntegralMin(
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 7eff80d..675b7e4 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -105,10 +105,12 @@
llvm::Value* rhs_value);
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
- llvm::Value* rhs_value);
+ llvm::Value* rhs_value,
+ absl::string_view name);
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
- llvm::Value* rhs_value);
+ llvm::Value* rhs_value,
+ absl::string_view name);
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed);
@@ -117,7 +119,8 @@
bool is_signed);
virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
- llvm::Value* lhs, llvm::Value* rhs);
+ llvm::Value* lhs, llvm::Value* rhs,
+ absl::string_view name);
virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
llvm::Value* value);
@@ -141,13 +144,15 @@
llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
- llvm::Value* value);
+ llvm::Value* value,
+ absl::string_view name);
virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
- llvm::Value* lhs, llvm::Value* rhs);
+ llvm::Value* lhs, llvm::Value* rhs,
+ absl::string_view name);
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value);
diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h
index b7350b4..cb8d2b2 100644
--- a/tensorflow/compiler/xla/service/fusion_queue.h
+++ b/tensorflow/compiler/xla/service/fusion_queue.h
@@ -45,6 +45,11 @@
HloInstruction* original_producer,
HloInstruction* original_consumer) {}
+ // A callback passed to the queue implementation when a proposed fusion does
+ // not happen.
+ virtual void NotFusingInstruction(HloInstruction* producer,
+ HloInstruction* consumer) {}
+
// A callback passed to the queue implementation to notify the removal of an
// instruction.
virtual void RemoveInstruction(HloInstruction* instruction) = 0;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 5a8a117..18f5204 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -218,7 +218,6 @@
":gpu_constants",
":gpu_executable",
":ir_emission_utils",
- ":nccl_all_reduce_thunk",
":thunk",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -257,9 +256,7 @@
":hlo_to_ir_bindings",
":ir_emission_utils",
":launch_dimensions",
- ":nccl_all_gather_thunk",
- ":nccl_all_reduce_thunk",
- ":nccl_all_to_all_thunk",
+ ":nccl_collective_thunks",
":parallel_loop_emitter",
":target_util",
":thunk",
@@ -436,37 +433,42 @@
actual = if_rocm("@local_config_rocm//rocm:rccl", ":empty"),
)
-# First level of nested select. NCCL requires both if_cuda and if_nccl.
-filegroup(
- name = "nccl_collective_thunk_src",
- srcs = if_nccl(
- ["nccl_collective_thunk.cc"],
- ["nccl_collective_thunk_dummy.cc"],
- ),
-)
-
tf_cuda_library(
- name = "nccl_collective_thunk",
- srcs = if_cuda_or_rocm(
- [":nccl_collective_thunk_src"],
- ["nccl_collective_thunk_dummy.cc"],
- ),
- hdrs = ["nccl_collective_thunk.h"],
+ name = "nccl_collective_thunks",
+ srcs = [
+ "nccl_all_gather_thunk.cc",
+ "nccl_all_reduce_thunk.cc",
+ "nccl_all_to_all_thunk.cc",
+ "nccl_collective_thunk.cc",
+ ],
+ hdrs = [
+ "nccl_all_gather_thunk.h",
+ "nccl_all_reduce_thunk.h",
+ "nccl_all_to_all_thunk.h",
+ "nccl_collective_thunk.h",
+ ],
+ copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_nccl(["-DGOOGLE_XCCL=1"]),
deps = [
+ ":buffer_allocations",
":thunk",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/compiler/xla/service:global_device_id",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
+ "//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/xla:hlo_utils",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:attribute_exporter",
+ "//tensorflow/stream_executor/gpu:gpu_activation_header",
"@llvm-project//mlir:IR",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cuda_activation",
@@ -483,138 +485,6 @@
# First level of nested select. NCCL requires both if_cuda and if_nccl.
filegroup(
- name = "nccl_all_gather_thunk_src",
- srcs = if_nccl(
- ["nccl_all_gather_thunk.cc"],
- ["nccl_all_gather_thunk_dummy.cc"],
- ),
-)
-
-tf_cuda_library(
- name = "nccl_all_gather_thunk",
- srcs = if_cuda_or_rocm(
- [":nccl_all_gather_thunk_src"],
- ["nccl_all_gather_thunk_dummy.cc"],
- ),
- hdrs = ["nccl_all_gather_thunk.h"],
- deps = [
- ":buffer_allocations",
- ":gpu_executable_run_options",
- ":hlo_execution_profiler",
- ":nccl_collective_thunk",
- ":thunk",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/strings:str_format",
- "//tensorflow/compiler/xla/service:buffer_assignment",
- "//tensorflow/compiler/xla/service:collective_ops_utils",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_casting_utils",
- "//tensorflow/compiler/xla/service:pattern_matcher",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto_cc",
- "//tensorflow/core:lib",
- "//tensorflow/compiler/mlir/hlo:lhlo",
- "//tensorflow/compiler/mlir/xla:hlo_utils",
- "//tensorflow/compiler/mlir/xla:type_to_shape",
- "//tensorflow/compiler/mlir/xla:attribute_exporter",
- ] + if_nccl([
- ":virtual_nccl",
- ":virtual_nccl_utils",
- ":virtual_rccl",
- ]),
-)
-
-# First level of nested select. NCCL requires both if_cuda and if_nccl.
-filegroup(
- name = "nccl_all_reduce_thunk_src",
- srcs = if_nccl(
- ["nccl_all_reduce_thunk.cc"],
- ["nccl_all_reduce_thunk_dummy.cc"],
- ),
-)
-
-tf_cuda_library(
- name = "nccl_all_reduce_thunk",
- srcs = if_cuda_or_rocm(
- [":nccl_all_reduce_thunk_src"],
- ["nccl_all_reduce_thunk_dummy.cc"],
- ),
- hdrs = ["nccl_all_reduce_thunk.h"],
- deps = [
- ":buffer_allocations",
- ":gpu_executable_run_options",
- ":hlo_execution_profiler",
- ":nccl_collective_thunk",
- ":thunk",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/strings:str_format",
- "//tensorflow/compiler/xla/service:buffer_assignment",
- "//tensorflow/compiler/xla/service:collective_ops_utils",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_casting_utils",
- "//tensorflow/compiler/xla/service:pattern_matcher",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto_cc",
- "//tensorflow/core:lib",
- "//tensorflow/compiler/mlir/hlo:lhlo",
- "//tensorflow/compiler/mlir/xla:hlo_utils",
- "//tensorflow/compiler/mlir/xla:type_to_shape",
- "//tensorflow/compiler/mlir/xla:attribute_exporter",
- ] + if_nccl([
- ":virtual_nccl",
- ":virtual_nccl_utils",
- ":virtual_rccl",
- ]),
-)
-
-# First level of nested select. NCCL requires both if_cuda and if_nccl.
-filegroup(
- name = "nccl_all_to_all_thunk_src",
- srcs = if_nccl(
- ["nccl_all_to_all_thunk.cc"],
- ["nccl_all_to_all_thunk_dummy.cc"],
- ),
-)
-
-tf_cuda_library(
- name = "nccl_all_to_all_thunk",
- srcs = if_cuda_or_rocm(
- [":nccl_all_to_all_thunk_src"],
- ["nccl_all_to_all_thunk_dummy.cc"],
- ),
- hdrs = ["nccl_all_to_all_thunk.h"],
- deps = [
- ":buffer_allocations",
- ":gpu_executable_run_options",
- ":hlo_execution_profiler",
- ":nccl_collective_thunk",
- ":thunk",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/strings:str_format",
- "//tensorflow/compiler/xla/service:buffer_assignment",
- "//tensorflow/compiler/xla/service:collective_ops_utils",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_casting_utils",
- "//tensorflow/compiler/xla/service:pattern_matcher",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto_cc",
- "//tensorflow/core:lib",
- "//tensorflow/compiler/mlir/hlo:lhlo",
- "//tensorflow/compiler/mlir/xla:hlo_utils",
- "//tensorflow/compiler/mlir/xla:type_to_shape",
- "//tensorflow/compiler/mlir/xla:attribute_exporter",
- ] + if_nccl([
- ":virtual_nccl",
- ":virtual_nccl_utils",
- ":virtual_rccl",
- ]),
-)
-
-# First level of nested select. NCCL requires both if_cuda and if_nccl.
-filegroup(
name = "nccl_test_utils_src",
srcs = if_nccl(
["nccl_test_utils.cc"],
@@ -740,7 +610,7 @@
":hlo_execution_profiler",
":infeed_manager",
":ir_emission_utils",
- ":nccl_all_reduce_thunk", # fixdeps: keep
+ ":nccl_collective_thunks",
":outfeed_manager",
":launch_dimensions",
":stream_assignment",
@@ -1386,7 +1256,7 @@
":ir_emitter",
":launch_dimensions",
":multi_output_fusion",
- ":nccl_all_gather_thunk",
+ ":nccl_collective_thunks",
":reduction_degenerate_dim_remover",
":reduction_dimension_grouper",
":reduction_layout_normalizer",
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index e72c128..c97f5e2 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -78,7 +78,8 @@
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
- absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
+ absl::string_view name) {
// Device functions dont have f16 math functions, so we convert the operands
// to f32 before calling the function and then convert the result back to f16.
bool cast_result_to_fp16 = false;
@@ -109,7 +110,7 @@
const string& munged_callee =
ObtainDeviceFunctionName(funcid, output_type, b());
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
- converted_input_types, output_type)
+ converted_input_types, output_type, name)
.ValueOrDie();
if (cast_result_to_fp16) {
result = FPCast(result, b()->getHalfTy());
@@ -142,7 +143,8 @@
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
const string& callee_name, absl::Span<llvm::Value* const> operands,
- absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
+ absl::string_view name) {
// Binary math functions transform are of type [T] -> T.
for (PrimitiveType input_type : input_types) {
if (output_type != input_type) {
@@ -154,7 +156,7 @@
return EmitDeviceFunctionCall(
callee_name, operands, input_types, output_type,
- {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
+ {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
@@ -221,8 +223,8 @@
prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
- llvm::Value* value) {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
+ PrimitiveType prim_type, llvm::Value* value, absl::string_view /*name*/) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
prim_type);
}
@@ -235,9 +237,10 @@
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) {
+ llvm::Value* rhs,
+ absl::string_view name) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
- {prim_type, prim_type}, prim_type);
+ {prim_type, prim_type}, prim_type, name);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
@@ -252,11 +255,11 @@
{prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
+ PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
+ absl::string_view name) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
- {prim_type, prim_type}, prim_type);
+ {prim_type, prim_type}, prim_type, name);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 0303ea4..06cb4e4 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -64,8 +64,8 @@
StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
llvm::Value* value) override;
- StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
- llvm::Value* value) override;
+ StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, llvm::Value* value,
+ absl::string_view name) override;
StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
llvm::Value* value) override;
@@ -77,10 +77,12 @@
llvm::Value* value) override;
StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) override;
+ llvm::Value* rhs,
+ absl::string_view name) override;
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) override;
+ llvm::Value* rhs,
+ absl::string_view name) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;
@@ -118,13 +120,15 @@
// return value of the function.
StatusOr<llvm::Value*> EmitDeviceMathCall(
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
- absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
+ absl::string_view name = "");
// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
// Returns the IR value that represents the return value of the function.
StatusOr<llvm::Value*> EmitMathCall(
const string& callee_name, absl::Span<llvm::Value* const> operands,
- absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
+ absl::string_view name = "");
const HloModuleConfig& hlo_module_config_;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils_test.cc
index 9eec224..f064c05 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils_test.cc
@@ -27,8 +27,9 @@
namespace gpu {
TEST(IrEmissionUtilsTest, TestOperandPartitionNoAlias) {
- mlir::MLIRContext context;
- mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
+ mlir::DialectRegistry registry;
+ mlir::mhlo::registerAllMhloDialects(registry);
+ mlir::MLIRContext context(registry);
auto module = mlir::parseSourceString(R"(
func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {
@@ -43,8 +44,9 @@
}
TEST(IrEmissionUtilsTest, TestOperandPartitionWithAlias0) {
- mlir::MLIRContext context;
- mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
+ mlir::DialectRegistry registry;
+ mlir::mhlo::registerAllMhloDialects(registry);
+ mlir::MLIRContext context(registry);
auto module = mlir::parseSourceString(R"(
func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {
@@ -59,8 +61,9 @@
}
TEST(IrEmissionUtilsTest, TestOperandPartitionWithAlias1) {
- mlir::MLIRContext context;
- mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
+ mlir::DialectRegistry registry;
+ mlir::mhlo::registerAllMhloDialects(registry);
+ mlir::MLIRContext context(registry);
auto module = mlir::parseSourceString(R"(
func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 613696a..69d3faa 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -91,7 +91,8 @@
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : hlo->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
- return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
+ return GetIrArray(*operand, *hlo)
+ .EmitReadArrayElement(index, &b_, operand->name());
};
}
return EmitTargetElementLoop(
@@ -688,7 +689,8 @@
fused_emitter->BindGenerator(
fusion->fused_parameter(i),
[this, operand, fusion](llvm_ir::IrArray::Index index) {
- return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_);
+ return GetIrArray(*operand, *fusion)
+ .EmitReadArrayElement(index, &b_, operand->name());
});
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 29126fb..349aee3 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1060,11 +1060,14 @@
}
Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
+ TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(custom_call));
+ return EmitCustomCallFromMlir(input);
+}
+
+Status IrEmitterUnnested::EmitCustomCallFromMlir(MlirEmitterInput input) {
using mlir::dyn_cast;
using mlir::isa;
- TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(custom_call));
-
if (auto call = dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) {
if (call.call_target_name() == "PadToStatic") {
return EmitPadToStaticFromMlir(input);
@@ -1100,7 +1103,7 @@
#endif // GOOGLE_CUDA
return Unimplemented("No registered implementation for custom call to \"%s\"",
- custom_call->custom_call_target());
+ MlirToString(input.op));
}
Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
@@ -1897,10 +1900,12 @@
GetNestedComputer());
FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ auto fused_operand = fused_computation->parameter_instruction(i);
operand_fused_emitter.BindGenerator(
- fused_computation->parameter_instruction(i),
- [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
- return ir_arrays[i].EmitReadArrayElement(index, &b_);
+ fused_operand, [this, &ir_arrays, i, fused_operand](
+ const llvm_ir::IrArray::Index& index) {
+ return ir_arrays[i].EmitReadArrayElement(
+ index, &b_, fused_operand->name());
});
}
TF_ASSIGN_OR_RETURN(
@@ -1939,10 +1944,12 @@
GetNestedComputer());
FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ auto fused_operand = fused_computation->parameter_instruction(i);
scatter_fused_emitter.BindGenerator(
- fused_computation->parameter_instruction(i),
- [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
- return ir_arrays[i].EmitReadArrayElement(index, &b_);
+ fused_operand, [this, &ir_arrays, i, fused_operand](
+ const llvm_ir::IrArray::Index& index) {
+ return ir_arrays[i].EmitReadArrayElement(
+ index, &b_, fused_operand->name());
});
}
@@ -2046,10 +2053,12 @@
/*is_fusion=*/true));
for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ auto fused_operand = fused_computation->parameter_instruction(i);
fused_emitter.BindGenerator(
- fused_computation->parameter_instruction(i),
- [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
- return ir_arrays[i].EmitReadArrayElement(index, &b_);
+ fused_operand, [this, &ir_arrays, i,
+ fused_operand](const llvm_ir::IrArray::Index& index) {
+ return ir_arrays[i].EmitReadArrayElement(index, &b_,
+ fused_operand->name());
});
}
@@ -2127,7 +2136,10 @@
Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(reduce));
+ return EmitReduceFromMlir(mlir_input);
+}
+Status IrEmitterUnnested::EmitReduceFromMlir(MlirEmitterInput mlir_input) {
if (GetHloOutputs(mlir_input.op).size() == 1 &&
IsReductionFromOrToContiguousDimensions(mlir_input.op)) {
return EmitReductionFromOrToContiguousDimensions(mlir_input);
@@ -2836,20 +2848,8 @@
}
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
- MlirEmitterInput result;
-
- TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_->EmitOp(sort));
- result.op = sort_op;
- const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
- auto& slice = result.extra_slice.emplace();
- TF_ASSIGN_OR_RETURN(slice.buffer_slice,
- buffer_assignment.GetUniqueSlice(sort, {}));
- slice.written = true;
- slice.shape = sort->shape();
-
- result.thunk_info = GetThunkInfo(sort);
-
- return EmitSortFromMlir(result);
+ TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(sort));
+ return EmitSortFromMlir(mlir_input);
}
Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
@@ -4171,8 +4171,9 @@
};
} else {
auto array = operand_arrays[i];
- gen = [this, array](llvm_ir::IrArray::Index index) {
- return array.EmitReadArrayElement(index, &b_);
+ auto name = fused_computation->parameter_instruction(i)->name();
+ gen = [this, array, name](const llvm_ir::IrArray::Index& index) {
+ return array.EmitReadArrayElement(index, &b_, name);
};
}
fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
@@ -5627,10 +5628,12 @@
CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
for (int i = 0; i < fused_computation->num_parameters(); i++) {
auto ir_array = ir_arrays[i];
+ auto fused_operand = fused_computation->parameter_instruction(i);
fused_emitter->BindGenerator(
- fused_computation->parameter_instruction(i),
- [this, ir_array](llvm_ir::IrArray::Index index) {
- return ir_array.EmitReadArrayElement(index, &b_);
+ fused_operand, [this, ir_array,
+ fused_operand](const llvm_ir::IrArray::Index& index) {
+ return ir_array.EmitReadArrayElement(index, &b_,
+ fused_operand->name());
});
}
result_ir_arrays = absl::MakeSpan(ir_arrays).subspan(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 3028889..9c6e6f0 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -172,6 +172,7 @@
Status HandleConditional(HloInstruction* conditional) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
+ Status EmitCustomCallFromMlir(MlirEmitterInput input);
Status EmitConvolutionThunkFromMlir(MlirEmitterInput input);
Status EmitGemmThunkFromMlir(MlirEmitterInput input);
Status EmitBatchNormThunkFromMlir(MlirEmitterInput input);
@@ -187,6 +188,7 @@
absl::optional<int> unroll_factor_override = {});
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleReduce(HloInstruction* reduce) override;
+ Status EmitReduceFromMlir(MlirEmitterInput mlir_input);
Status HandleSelectAndScatter(HloInstruction* instruction) override;
Status EmitSelectAndScatterFromMlir(MlirEmitterInput mlir_input);
Status HandleTuple(HloInstruction* tuple) override;
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
index 31f25e8..6abcded 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
@@ -23,13 +23,6 @@
#include <vector>
#include "absl/strings/str_format.h"
-#if GOOGLE_CUDA
-#include "third_party/nccl/nccl.h"
-#elif TENSORFLOW_USE_ROCM
-#include "rocm/include/rccl/rccl.h"
-#endif
-#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h"
@@ -48,7 +41,7 @@
auto operands_are_supported = [hlo]() {
return absl::c_all_of(hlo->operands(), [](HloInstruction* operand) {
return LayoutUtil::IsDenseArray(operand->shape()) &&
- ToNcclDataType(operand->shape().element_type()).ok();
+ IsTypeSupportedByNccl(operand->shape().element_type());
});
};
return (Cast<HloAllGatherInstruction>(hlo)->all_gather_dimension() == 0) &&
@@ -60,7 +53,7 @@
absl::c_all_of(op.operands(), [](mlir::Value operand) {
Shape shape = TypeToShape(operand.getType());
return LayoutUtil::IsDenseArray(shape) &&
- ToNcclDataType(shape.element_type()).ok();
+ IsTypeSupportedByNccl(shape.element_type());
});
return op.all_gather_dimension() == 0 && operands_are_supported;
}
@@ -76,6 +69,7 @@
Status NcclAllGatherThunk::RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) {
+#if XLA_ENABLE_XCCL
int device_ordinal = params.stream->parent()->device_ordinal();
VLOG(3) << "Performing all-gather from device ordinal: " << device_ordinal;
@@ -109,10 +103,11 @@
VLOG(3) << "Done performing all-gather for ordinal: " << device_ordinal;
return Status::OK();
-}
-
-const NcclCollectiveConfig& NcclAllGatherThunk::config() const {
- return config_.config;
+#else // XLA_ENABLE_XCCL
+ return Unimplemented(
+ "NCCL support is not available: this binary was not built with a CUDA "
+ "compiler, which is necessary to build the NCCL source library.");
+#endif // XLA_ENABLE_XCCL
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h
index e0ef126..24ca205 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h
@@ -50,7 +50,7 @@
Status RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) override;
- const NcclCollectiveConfig& config() const override;
+ const NcclCollectiveConfig& config() const override { return config_.config; }
private:
const NcclAllGatherConfig config_;
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk_dummy.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk_dummy.cc
deleted file mode 100644
index d761ef7..0000000
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk_dummy.cc
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-
-namespace xla {
-namespace gpu {
-
-NcclAllGatherThunk::NcclAllGatherThunk(
- ThunkInfo thunk_info, mlir::lmhlo::AllGatherOp op, int64 replica_count,
- std::vector<NcclAllGatherThunk::Buffer> buffers)
- : NcclCollectiveThunk(Thunk::kNcclAllGather, thunk_info), config_{} {}
-
-/* static */ bool NcclAllGatherThunk::CanImplement(const HloInstruction* hlo) {
- return false;
-}
-
-/* static */ bool NcclAllGatherThunk::CanImplement(
- mlir::lmhlo::AllGatherOp op) {
- return false;
-}
-
-Status NcclAllGatherThunk::RunNcclCollective(const ExecuteParams&, ncclComm_t) {
- return Unimplemented(
- "NCCL support is not available: this binary was not built with a CUDA "
- "compiler, which is necessary to build the NCCL source library.");
-}
-
-const NcclCollectiveConfig& NcclAllGatherThunk::config() const {
- // This function will never be called.
- const NcclCollectiveConfig* config = nullptr;
- return *config;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
index 6b45100..5c891ac 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
@@ -23,25 +23,14 @@
#include <vector>
#include "absl/strings/str_format.h"
-#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
-#include "mlir/IR/BuiltinOps.h" // from @llvm-project
-#include "mlir/IR/Value.h" // from @llvm-project
-#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
-#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#if GOOGLE_CUDA
-#include "third_party/nccl/nccl.h"
-#elif TENSORFLOW_USE_ROCM
-#include "rocm/include/rccl/rccl.h"
-#endif
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
-#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
+#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace gpu {
@@ -109,12 +98,11 @@
}
/*static*/ bool NcclAllReduceThunk::CanImplement(mlir::lmhlo::AllReduceOp op) {
- if (!op.IsCrossReplica()) return false;
bool operands_are_supported =
absl::c_all_of(op.operands(), [](mlir::Value operand) {
Shape shape = TypeToShape(operand.getType());
return LayoutUtil::IsDenseArray(shape) &&
- ToNcclDataType(shape.element_type()).ok();
+ IsTypeSupportedByNccl(shape.element_type());
});
return operands_are_supported && MatchReductionComputation(op).has_value();
}
@@ -130,6 +118,7 @@
Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) {
+#if XLA_ENABLE_XCCL
int device_ordinal = params.stream->parent()->device_ordinal();
VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
@@ -165,10 +154,11 @@
VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
return Status::OK();
-}
-
-const NcclCollectiveConfig& NcclAllReduceThunk::config() const {
- return config_.config;
+#else // XLA_ENABLE_XCCL
+ return Unimplemented(
+ "NCCL support is not available: this binary was not built with a CUDA "
+ "compiler, which is necessary to build the NCCL source library.");
+#endif // XLA_ENABLE_XCCL
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h
index b2ef4b3..1f88a62 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h
@@ -48,7 +48,7 @@
Status RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) override;
- const NcclCollectiveConfig& config() const override;
+ const NcclCollectiveConfig& config() const override { return config_.config; }
private:
const NcclAllReduceConfig config_;
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk_dummy.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk_dummy.cc
deleted file mode 100644
index eb6597a..0000000
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk_dummy.cc
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-
-namespace xla {
-namespace gpu {
-
-NcclAllReduceThunk::NcclAllReduceThunk(ThunkInfo thunk_info,
- mlir::lmhlo::AllReduceOp op,
- int64 replica_count,
- std::vector<Buffer> buffers)
- : NcclCollectiveThunk(Thunk::kNcclAllReduce, thunk_info), config_{} {}
-
-/* static */ bool NcclAllReduceThunk::CanImplement(
- mlir::lmhlo::AllReduceOp op) {
- return false;
-}
-
-Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams&, ncclComm_t) {
- return Unimplemented(
- "NCCL support is not available: this binary was not built with a CUDA "
- "compiler, which is necessary to build the NCCL source library.");
-}
-
-const NcclCollectiveConfig& NcclAllReduceThunk::config() const {
- // This function will never be called.
- const NcclCollectiveConfig* config = nullptr;
- return *config;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
index 1d5957c..f0027ef 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
@@ -23,13 +23,7 @@
#include <vector>
#include "absl/strings/str_format.h"
-#if GOOGLE_CUDA
-#include "third_party/nccl/nccl.h"
-#elif TENSORFLOW_USE_ROCM
-#include "rocm/include/rccl/rccl.h"
-#endif
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -51,7 +45,7 @@
absl::c_all_of(op.operands(), [](mlir::Value operand) {
Shape shape = TypeToShape(operand.getType());
return LayoutUtil::IsDenseArray(shape) &&
- ToNcclDataType(shape.element_type()).ok();
+ IsTypeSupportedByNccl(shape.element_type());
});
return op.split_dimension().getValueOr(0) == 0 && operands_are_supported;
}
@@ -67,6 +61,7 @@
Status NcclAllToAllThunk::RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) {
+#if XLA_ENABLE_XCCL
int device_ordinal = params.stream->parent()->device_ordinal();
VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
@@ -139,10 +134,11 @@
VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal;
return Status::OK();
-}
-
-const NcclCollectiveConfig& NcclAllToAllThunk::config() const {
- return config_.config;
+#else // XLA_ENABLE_XCCL
+ return Unimplemented(
+ "NCCL support is not available: this binary was not built with a CUDA "
+ "compiler, which is necessary to build the NCCL source library.");
+#endif // XLA_ENABLE_XCCL
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h
index 28c552a..f1d66f0 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h
@@ -48,7 +48,7 @@
Status RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) override;
- const NcclCollectiveConfig& config() const override;
+ const NcclCollectiveConfig& config() const override { return config_.config; }
private:
const NcclAllToAllConfig config_;
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk_dummy.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk_dummy.cc
deleted file mode 100644
index 16d418a..0000000
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk_dummy.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-
-namespace xla {
-namespace gpu {
-
-NcclAllToAllThunk::NcclAllToAllThunk(
- ThunkInfo thunk_info, mlir::lmhlo::AllToAllOp op, int64 replica_count,
- std::vector<NcclAllToAllThunk::Buffer> buffers)
- : NcclCollectiveThunk(Thunk::kNcclAllToAll, thunk_info), config_{} {}
-
-/* static */ bool NcclAllToAllThunk::CanImplement(mlir::lmhlo::AllToAllOp op) {
- return false;
-}
-
-Status NcclAllToAllThunk::RunNcclCollective(const ExecuteParams&, ncclComm_t) {
- return Unimplemented(
- "NCCL support is not available: this binary was not built with a CUDA "
- "compiler, which is necessary to build the NCCL source library.");
-}
-
-const NcclCollectiveConfig& NcclAllToAllThunk::config() const {
- // This function will never be called.
- const NcclCollectiveConfig* config = nullptr;
- return *config;
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
index 970458f..2ffb4e9 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
@@ -27,7 +27,6 @@
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
-#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
@@ -54,33 +53,16 @@
NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) =
default;
-NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction* hlo,
- int64 replica_count) {
- NcclCollectiveConfig config;
- config.operand_count = hlo->operands().size();
- config.operand_element_type.reserve(config.operand_count);
- for (int i = 0; i < config.operand_count; i++) {
- config.operand_element_type.push_back(
- hlo->operand(i)->shape().element_type());
- }
- config.replica_count = replica_count;
- config.replica_groups = hlo->replica_groups();
-
- if (hlo->channel_id().has_value()) {
- config.collective_op_kind = RendezvousKey::kCrossModule;
- config.op_id = *hlo->channel_id();
- } else {
- config.collective_op_kind = RendezvousKey::kCrossReplica;
- config.op_id = static_cast<int64>(hlo->GetModule()->unique_id());
- }
- return config;
-}
-
/* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
- return true; // Skylark selects this source file if NCCL is enabled.
+#if XLA_ENABLE_XCCL
+ return true;
+#else
+ return false;
+#endif
}
Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
+#if XLA_ENABLE_XCCL
VLOG(1) << absl::StreamFormat("Starting %s.", ThunkKindToString(kind()));
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
@@ -122,6 +104,29 @@
TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
return Status::OK();
+#else // XLA_ENABLE_XCCL
+ return Unimplemented(
+ "NCCL support is not available: this binary was not built with a CUDA "
+ "compiler, which is necessary to build the NCCL source library.");
+#endif // XLA_ENABLE_XCCL
+}
+
+bool IsTypeSupportedByNccl(PrimitiveType element_type) {
+ switch (element_type) {
+ case S8:
+ case PRED:
+ case U8:
+ case S32:
+ case U32:
+ case S64:
+ case U64:
+ case F16:
+ case F32:
+ case F64:
+ return true;
+ default:
+ return false;
+ }
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
index f0304e0..0b17b01 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
@@ -27,6 +27,29 @@
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
+// Common place for all collective thunks to source nccl/rccl headers.
+// Also, all the RunNcclCollective() functions for various thunks should
+// use XLA_ENABLE_XCCL to guard use NCCL/RCCL usage (and not use GOOGLE_XCCL).
+#if GOOGLE_XCCL
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#define XLA_ENABLE_XCCL 1
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#endif // GOOGLE_XCCL
+
+#if XLA_ENABLE_XCCL
+#if GOOGLE_CUDA
+#include "third_party/nccl/nccl.h"
+#elif TENSORFLOW_USE_ROCM
+#include "rocm/include/rccl/rccl.h"
+#else
+#error "Neither CUDA nor ROCm enabled but NCCL/RCCL enabled"
+#endif
+
+// Also include this file required by all collective thunks.
+#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
+
+#endif // XLA_ENABLE_XCCL
+
struct ncclComm;
using ncclComm_t = ncclComm*;
@@ -50,8 +73,6 @@
int64 op_id;
};
-NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction* hlo,
- int64 replica_count);
template <typename OpT>
NcclCollectiveConfig GetNcclCollectiveConfigForMlir(OpT op,
@@ -107,6 +128,10 @@
virtual const NcclCollectiveConfig& config() const = 0;
};
+// Returns if the given data type is supported by NCCL.
+// Note: Keep this in sync with ToNcclDataType().
+bool IsTypeSupportedByNccl(PrimitiveType element_type);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk_dummy.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk_dummy.cc
deleted file mode 100644
index fc5ea04..0000000
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk_dummy.cc
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-
-namespace xla {
-namespace gpu {
-
-struct NcclClique {};
-
-NcclCollectiveConfig::NcclCollectiveConfig() = default;
-NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig &&) = default;
-NcclCollectiveConfig::~NcclCollectiveConfig() = default;
-NcclCollectiveConfig &NcclCollectiveConfig::operator=(NcclCollectiveConfig &&) =
- default;
-
-NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction *hlo,
- int64 replica_count) {
- return NcclCollectiveConfig();
-}
-
-/* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
- return false; // Skylark selects this source file if NCCL is disabled.
-}
-
-Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams &) {
- return Unimplemented(
- "NCCL support is not available: this binary was not built with a CUDA "
- "compiler, which is necessary to build the NCCL source library.");
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/target_util.cc b/tensorflow/compiler/xla/service/gpu/target_util.cc
index 31b590a..9784199 100644
--- a/tensorflow/compiler/xla/service/gpu/target_util.cc
+++ b/tensorflow/compiler/xla/service/gpu/target_util.cc
@@ -194,7 +194,7 @@
const string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::Span<const llvm::Attribute::AttrKind> attributes,
- llvm::IRBuilder<>* b) {
+ llvm::IRBuilder<>* b, absl::string_view name) {
std::vector<llvm::Type*> ir_input_types;
llvm::Module* module = b->GetInsertBlock()->getModule();
for (PrimitiveType input_type : input_types) {
@@ -217,7 +217,7 @@
callee->addFnAttr(attribute);
}
- return b->CreateCall(callee, llvm_ir::AsArrayRef(operands));
+ return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data());
}
llvm::CallInst* EmitCallToTargetIntrinsic(
diff --git a/tensorflow/compiler/xla/service/gpu/target_util.h b/tensorflow/compiler/xla/service/gpu/target_util.h
index 2bdaea7..115609d 100644
--- a/tensorflow/compiler/xla/service/gpu/target_util.h
+++ b/tensorflow/compiler/xla/service/gpu/target_util.h
@@ -69,7 +69,7 @@
const std::string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_type, PrimitiveType output_type,
absl::Span<const llvm::Attribute::AttrKind> attributes,
- llvm::IRBuilder<>* b);
+ llvm::IRBuilder<>* b, absl::string_view name = "");
// Emits a call to the specified target intrinsic with the given operands.
// Overloaded intrinsics (for example, "minnum") must include a type
diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_scatter.hlo
index 9cdc0af..1492d6c 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/fused_scatter.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/fused_scatter.hlo
@@ -150,9 +150,9 @@
// CHECK: %[[VAL_136:.*]] = icmp ult i32 %[[VAL_134]], 3
// CHECK: %[[VAL_137:.*]] = and i1 true, %[[VAL_136]]
// CHECK: br i1 %[[VAL_137]], label %[[VAL_138:.*]], label %[[VAL_131]]
-// CHECK: scatter.in_bounds-after2: ; preds = %[[VAL_138]], %[[VAL_129]]
+// CHECK: scatter.in_bounds-after3: ; preds = %[[VAL_138]], %[[VAL_129]]
// CHECK: br label %[[VAL_130]]
-// CHECK: scatter.in_bounds-true1: ; preds = %[[VAL_129]]
+// CHECK: scatter.in_bounds-true2: ; preds = %[[VAL_129]]
// CHECK: %[[VAL_139:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_119]], i32 0, i32 %[[VAL_135]], i32 %[[VAL_126]]
// CHECK: %[[VAL_140:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_116]] to i32*
// CHECK: %[[VAL_141:.*]] = getelementptr inbounds i32, i32* %[[VAL_140]], i32 %[[VAL_123]]
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduce_unnested.hlo b/tensorflow/compiler/xla/service/gpu/tests/reduce_unnested.hlo
index 987831b..ec4e6b7 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduce_unnested.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduce_unnested.hlo
@@ -239,11 +239,11 @@
// CHECK: %[[VAL_141:.*]] = udiv i32 %[[VAL_95]], 32
// CHECK: %[[VAL_142:.*]] = icmp eq i32 %[[VAL_97]], 0
// CHECK: br i1 %[[VAL_142]], label %[[VAL_143:.*]], label %[[VAL_144:.*]]
-// CHECK: intra_warp_reduce_write-after66: ; preds = %[[VAL_143]], %[[VAL_123]]
+// CHECK: intra_warp_reduce_write-after129: ; preds = %[[VAL_143]], %[[VAL_123]]
// CHECK: call void @llvm.nvvm.barrier0()
// CHECK: %[[VAL_145:.*]] = icmp eq i32 %[[VAL_141]], 0
// CHECK: br i1 %[[VAL_145]], label %[[VAL_146:.*]], label %[[VAL_51]]
-// CHECK: inter_warp_reduce-after68: ; preds = %[[VAL_147:.*]], %[[VAL_144]]
+// CHECK: inter_warp_reduce-after131: ; preds = %[[VAL_147:.*]], %[[VAL_144]]
// CHECK: br label %[[VAL_50]]
// CHECK: output_is_full_tile-true: ; preds = %[[VAL_85]]
// CHECK: %[[VAL_148:.*]] = add i32 %[[VAL_72]], %[[VAL_89]]
@@ -459,37 +459,37 @@
// CHECK: %[[VAL_324:.*]] = add i32 %[[VAL_75]], 1
// CHECK: %[[VAL_325:.*]] = icmp ult i32 %[[VAL_323]], %[[VAL_71]]
// CHECK: br i1 %[[VAL_325]], label %[[VAL_326:.*]], label %[[VAL_327:.*]]
-// CHECK: output_x_in_tile-after13: ; preds = %[[VAL_326]], %[[VAL_322]]
+// CHECK: output_x_in_tile-after48: ; preds = %[[VAL_326]], %[[VAL_322]]
// CHECK: %[[VAL_328:.*]] = add i32 64, %[[VAL_74]]
// CHECK: %[[VAL_329:.*]] = add i32 %[[VAL_75]], 64
// CHECK: %[[VAL_330:.*]] = icmp ult i32 %[[VAL_328]], %[[VAL_71]]
// CHECK: br i1 %[[VAL_330]], label %[[VAL_331:.*]], label %[[VAL_332:.*]]
-// CHECK: output_x_in_tile-after16: ; preds = %[[VAL_331]], %[[VAL_327]]
+// CHECK: output_x_in_tile-after55: ; preds = %[[VAL_331]], %[[VAL_327]]
// CHECK: %[[VAL_333:.*]] = add i32 65, %[[VAL_74]]
// CHECK: %[[VAL_334:.*]] = add i32 %[[VAL_75]], 65
// CHECK: %[[VAL_335:.*]] = icmp ult i32 %[[VAL_333]], %[[VAL_71]]
// CHECK: br i1 %[[VAL_335]], label %[[VAL_336:.*]], label %[[VAL_337:.*]]
-// CHECK: output_x_in_tile-after19: ; preds = %[[VAL_336]], %[[VAL_332]]
+// CHECK: output_x_in_tile-after62: ; preds = %[[VAL_336]], %[[VAL_332]]
// CHECK: %[[VAL_338:.*]] = add i32 128, %[[VAL_74]]
// CHECK: %[[VAL_339:.*]] = add i32 %[[VAL_75]], 128
// CHECK: %[[VAL_340:.*]] = icmp ult i32 %[[VAL_338]], %[[VAL_71]]
// CHECK: br i1 %[[VAL_340]], label %[[VAL_341:.*]], label %[[VAL_342:.*]]
-// CHECK: output_x_in_tile-after22: ; preds = %[[VAL_341]], %[[VAL_337]]
+// CHECK: output_x_in_tile-after69: ; preds = %[[VAL_341]], %[[VAL_337]]
// CHECK: %[[VAL_343:.*]] = add i32 129, %[[VAL_74]]
// CHECK: %[[VAL_344:.*]] = add i32 %[[VAL_75]], 129
// CHECK: %[[VAL_345:.*]] = icmp ult i32 %[[VAL_343]], %[[VAL_71]]
// CHECK: br i1 %[[VAL_345]], label %[[VAL_346:.*]], label %[[VAL_347:.*]]
-// CHECK: output_x_in_tile-after25: ; preds = %[[VAL_346]], %[[VAL_342]]
+// CHECK: output_x_in_tile-after76: ; preds = %[[VAL_346]], %[[VAL_342]]
// CHECK: %[[VAL_348:.*]] = add i32 192, %[[VAL_74]]
// CHECK: %[[VAL_349:.*]] = add i32 %[[VAL_75]], 192
// CHECK: %[[VAL_350:.*]] = icmp ult i32 %[[VAL_348]], %[[VAL_71]]
// CHECK: br i1 %[[VAL_350]], label %[[VAL_351:.*]], label %[[VAL_352:.*]]
-// CHECK: output_x_in_tile-after28: ; preds = %[[VAL_351]], %[[VAL_347]]
+// CHECK: output_x_in_tile-after83: ; preds = %[[VAL_351]], %[[VAL_347]]
// CHECK: %[[VAL_353:.*]] = add i32 193, %[[VAL_74]]
// CHECK: %[[VAL_354:.*]] = add i32 %[[VAL_75]], 193
// CHECK: %[[VAL_355:.*]] = icmp ult i32 %[[VAL_353]], %[[VAL_71]]
// CHECK: br i1 %[[VAL_355]], label %[[VAL_356:.*]], label %[[VAL_93]]
-// CHECK: output_x_in_tile-after31: ; preds = %[[VAL_356]], %[[VAL_352]]
+// CHECK: output_x_in_tile-after90: ; preds = %[[VAL_356]], %[[VAL_352]]
// CHECK: br label %[[VAL_81]]
// CHECK: output_x_in_tile-true: ; preds = %[[VAL_92]]
// CHECK: %[[VAL_357:.*]] = mul nuw nsw i32 %[[VAL_319]], 1
@@ -516,7 +516,7 @@
// CHECK: %[[VAL_375:.*]] = getelementptr inbounds float, float* %[[VAL_33]], i32 0
// CHECK: call void @region_2_9(float* %[[VAL_375]], float* %[[VAL_34]], float* %[[VAL_375]])
// CHECK: br label %[[VAL_322]]
-// CHECK: output_x_in_tile-true12: ; preds = %[[VAL_322]]
+// CHECK: output_x_in_tile-true47: ; preds = %[[VAL_322]]
// CHECK: %[[VAL_376:.*]] = mul nuw nsw i32 %[[VAL_324]], 1
// CHECK: %[[VAL_377:.*]] = add nuw nsw i32 0, %[[VAL_376]]
// CHECK: %[[VAL_378:.*]] = mul nuw nsw i32 %[[VAL_317]], 32
@@ -541,7 +541,7 @@
// CHECK: %[[VAL_394:.*]] = getelementptr inbounds float, float* %[[VAL_33]], i32 0
// CHECK: call void @region_2_9(float* %[[VAL_394]], float* %[[VAL_34]], float* %[[VAL_394]])
// CHECK: br label %[[VAL_327]]
-// CHECK: output_x_in_tile-true15: ; preds = %[[VAL_327]]
+// CHECK: output_x_in_tile-true54: ; preds = %[[VAL_327]]
// CHECK: %[[VAL_395:.*]] = mul nuw nsw i32 %[[VAL_329]], 1
// CHECK: %[[VAL_396:.*]] = add nuw nsw i32 0, %[[VAL_395]]
// CHECK: %[[VAL_397:.*]] = mul nuw nsw i32 %[[VAL_317]], 32
@@ -566,7 +566,7 @@
// CHECK: %[[VAL_413:.*]] = getelementptr inbounds float, float* %[[VAL_33]], i32 0
// CHECK: call void @region_2_9(float* %[[VAL_413]], float* %[[VAL_34]], float* %[[VAL_413]])
// CHECK: br label %[[VAL_332]]
-// CHECK: output_x_in_tile-true18: ; preds = %[[VAL_332]]
+// CHECK: output_x_in_tile-true61: ; preds = %[[VAL_332]]
// CHECK: %[[VAL_414:.*]] = mul nuw nsw i32 %[[VAL_334]], 1
// CHECK: %[[VAL_415:.*]] = add nuw nsw i32 0, %[[VAL_414]]
// CHECK: %[[VAL_416:.*]] = mul nuw nsw i32 %[[VAL_317]], 32
@@ -591,7 +591,7 @@
// CHECK: %[[VAL_432:.*]] = getelementptr inbounds float, float* %[[VAL_33]], i32 0
// CHECK: call void @region_2_9(float* %[[VAL_432]], float* %[[VAL_34]], float* %[[VAL_432]])
// CHECK: br label %[[VAL_337]]
-// CHECK: output_x_in_tile-true21: ; preds = %[[VAL_337]]
+// CHECK: output_x_in_tile-true68: ; preds = %[[VAL_337]]
// CHECK: %[[VAL_433:.*]] = mul nuw nsw i32 %[[VAL_339]], 1
// CHECK: %[[VAL_434:.*]] = add nuw nsw i32 0, %[[VAL_433]]
// CHECK: %[[VAL_435:.*]] = mul nuw nsw i32 %[[VAL_317]], 32
@@ -616,7 +616,7 @@
// CHECK: %[[VAL_451:.*]] = getelementptr inbounds float, float* %[[VAL_33]], i32 0
// CHECK: call void @region_2_9(float* %[[VAL_451]], float* %[[VAL_34]], float* %[[VAL_451]])
// CHECK: br label %[[VAL_342]]
-// CHECK: output_x_in_tile-true24: ; preds = %[[VAL_342]]
+// CHECK: output_x_in_tile-true75: ; preds = %[[VAL_342]]
// CHECK: %[[VAL_452:.*]] = mul nuw nsw i32 %[[VAL_344]], 1
// CHECK: %[[VAL_453:.*]] = add nuw nsw i32 0, %[[VAL_452]]
// CHECK: %[[VAL_454:.*]] = mul nuw nsw i32 %[[VAL_317]], 32
@@ -641,7 +641,7 @@
// CHECK: %[[VAL_470:.*]] = getelementptr inbounds float, float* %[[VAL_33]], i32 0
// CHECK: call void @region_2_9(float* %[[VAL_470]], float* %[[VAL_34]], float* %[[VAL_470]])
// CHECK: br label %[[VAL_347]]
-// CHECK: output_x_in_tile-true27: ; preds = %[[VAL_347]]
+// CHECK: output_x_in_tile-true82: ; preds = %[[VAL_347]]
// CHECK: %[[VAL_471:.*]] = mul nuw nsw i32 %[[VAL_349]], 1
// CHECK: %[[VAL_472:.*]] = add nuw nsw i32 0, %[[VAL_471]]
// CHECK: %[[VAL_473:.*]] = mul nuw nsw i32 %[[VAL_317]], 32
@@ -666,7 +666,7 @@
// CHECK: %[[VAL_489:.*]] = getelementptr inbounds float, float* %[[VAL_33]], i32 0
// CHECK: call void @region_2_9(float* %[[VAL_489]], float* %[[VAL_34]], float* %[[VAL_489]])
// CHECK: br label %[[VAL_352]]
-// CHECK: output_x_in_tile-true30: ; preds = %[[VAL_352]]
+// CHECK: output_x_in_tile-true89: ; preds = %[[VAL_352]]
// CHECK: %[[VAL_490:.*]] = mul nuw nsw i32 %[[VAL_354]], 1
// CHECK: %[[VAL_491:.*]] = add nuw nsw i32 0, %[[VAL_490]]
// CHECK: %[[VAL_492:.*]] = mul nuw nsw i32 %[[VAL_317]], 32
@@ -731,13 +731,13 @@
// CHECK: %[[VAL_528:.*]] = load float, float* %[[VAL_513]], align 4
// CHECK: %[[VAL_529:.*]] = atomicrmw fadd float* %[[VAL_105]], float %[[VAL_528]] seq_cst
// CHECK: br label %[[VAL_124]]
-// CHECK: intra_warp_reduce_write-true65: ; preds = %[[VAL_123]]
+// CHECK: intra_warp_reduce_write-true128: ; preds = %[[VAL_123]]
// CHECK: %[[VAL_530:.*]] = getelementptr inbounds [1 x [32 x float]], [1 x [32 x float]] addrspace(3)* @shared_cache_1, i32 0, i32 0, i32 %[[VAL_141]]
// CHECK: %[[VAL_531:.*]] = addrspacecast float addrspace(3)* %[[VAL_530]] to float*
// CHECK: %[[VAL_532:.*]] = load float, float* %[[VAL_130]], align 4
// CHECK: store float %[[VAL_532]], float* %[[VAL_531]], align 4
// CHECK: br label %[[VAL_144]]
-// CHECK: inter_warp_reduce-true67: ; preds = %[[VAL_144]]
+// CHECK: inter_warp_reduce-true130: ; preds = %[[VAL_144]]
// CHECK: %[[VAL_533:.*]] = getelementptr inbounds [1 x [32 x float]], [1 x [32 x float]] addrspace(3)* @shared_cache_1, i32 0, i32 0, i32 %[[VAL_97]]
// CHECK: %[[VAL_534:.*]] = addrspacecast float addrspace(3)* %[[VAL_533]] to float*
// CHECK: store float %[[VAL_55]], float* %[[VAL_15]], align 4
@@ -765,9 +765,9 @@
// CHECK: call void @region_2_9(float* %[[VAL_536]], float* %[[VAL_10]], float* %[[VAL_536]])
// CHECK: %[[VAL_547:.*]] = icmp eq i32 %[[VAL_95]], 0
// CHECK: br i1 %[[VAL_547]], label %[[VAL_548:.*]], label %[[VAL_147]]
-// CHECK: reduction_atomic_update-after81: ; preds = %[[VAL_549:.*]], %[[VAL_146]]
+// CHECK: reduction_atomic_update-after144: ; preds = %[[VAL_549:.*]], %[[VAL_146]]
// CHECK: br label %[[VAL_51]]
-// CHECK: reduction_atomic_update-true80: ; preds = %[[VAL_146]]
+// CHECK: reduction_atomic_update-true143: ; preds = %[[VAL_146]]
// CHECK: %[[VAL_550:.*]] = load float, float* %[[VAL_534]], align 4
// CHECK: %[[VAL_551:.*]] = bitcast float* %[[VAL_129]] to i32*
// CHECK: %[[VAL_552:.*]] = bitcast i32* %[[VAL_8]] to float*
diff --git a/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo
index 8e4e8bf..2968e6e 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo
@@ -13,49 +13,47 @@
// CHECK: %[[VAL_1:.*]] = alloca i8, align 1
// CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_0]], i64 0
// CHECK: %[[VAL_3:.*]] = bitcast i8* %[[VAL_2]] to [2 x [3 x float]]*
-// CHECK: %[[VAL_4:.*]] = getelementptr inbounds i8, i8* %[[VAL_0]], i64 0
-// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_4]] to [2 x [3 x float]]*
-// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
+// CHECK: %[[VAL_4:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
+// CHECK: %[[VAL_5:.*]] = zext i32 %[[VAL_4]] to i64
+// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK: %[[VAL_7:.*]] = zext i32 %[[VAL_6]] to i64
-// CHECK: %[[VAL_8:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
-// CHECK: %[[VAL_9:.*]] = zext i32 %[[VAL_8]] to i64
-// CHECK: %[[VAL_10:.*]] = mul nuw nsw i64 %[[VAL_7]], 4
-// CHECK: %[[VAL_11:.*]] = add nuw nsw i64 %[[VAL_10]], %[[VAL_9]]
-// CHECK: %[[VAL_12:.*]] = icmp ult i64 %[[VAL_11]], 4
-// CHECK: call void @llvm.assume(i1 %[[VAL_12]])
-// CHECK: %[[VAL_13:.*]] = udiv i64 %[[VAL_11]], 1
-// CHECK: %[[VAL_14:.*]] = urem i64 %[[VAL_13]], 2
-// CHECK: %[[VAL_15:.*]] = udiv i64 %[[VAL_11]], 2
-// CHECK: %[[VAL_16:.*]] = icmp ult i64 %[[VAL_11]], 4
-// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]]
-// CHECK: sort.in_bounds-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]]
+// CHECK: %[[VAL_8:.*]] = mul nuw nsw i64 %[[VAL_5]], 4
+// CHECK: %[[VAL_9:.*]] = add nuw nsw i64 %[[VAL_8]], %[[VAL_7]]
+// CHECK: %[[VAL_10:.*]] = icmp ult i64 %[[VAL_9]], 4
+// CHECK: call void @llvm.assume(i1 %[[VAL_10]])
+// CHECK: %[[VAL_11:.*]] = udiv i64 %[[VAL_9]], 1
+// CHECK: %[[VAL_12:.*]] = urem i64 %[[VAL_11]], 2
+// CHECK: %[[VAL_13:.*]] = udiv i64 %[[VAL_9]], 2
+// CHECK: %[[VAL_14:.*]] = icmp ult i64 %[[VAL_9]], 4
+// CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]]
+// CHECK: sort.in_bounds-after: ; preds = %[[VAL_17:.*]], %[[VAL_18:.*]]
// CHECK: ret void
-// CHECK: sort.in_bounds-true: ; preds = %[[VAL_20]]
-// CHECK: %[[VAL_21:.*]] = mul i64 %[[VAL_14]], 2
-// CHECK: %[[VAL_22:.*]] = xor i64 %[[VAL_21]], 1
-// CHECK: %[[VAL_23:.*]] = icmp slt i64 %[[VAL_21]], %[[VAL_22]]
-// CHECK: %[[VAL_24:.*]] = icmp slt i64 %[[VAL_22]], 3
-// CHECK: %[[VAL_25:.*]] = and i1 %[[VAL_23]], %[[VAL_24]]
-// CHECK: br i1 %[[VAL_25]], label %[[VAL_26:.*]], label %[[VAL_19]]
-// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_27:.*]], %[[VAL_17]]
-// CHECK: br label %[[VAL_18]]
-// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_17]]
-// CHECK: %[[VAL_28:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_22]]
-// CHECK: %[[VAL_29:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_21]]
-// CHECK: call void @region_0_4(float* %[[VAL_28]], float* %[[VAL_29]], i8* %[[VAL_1]])
-// CHECK: %[[VAL_30:.*]] = load i8, i8* %[[VAL_1]], align 1
-// CHECK: %[[VAL_31:.*]] = icmp ne i8 %[[VAL_30]], 0
-// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_27]]
-// CHECK: is_smaller_than-after: ; preds = %[[VAL_32]], %[[VAL_26]]
-// CHECK: br label %[[VAL_19]]
-// CHECK: is_smaller_than-true: ; preds = %[[VAL_26]]
-// CHECK: %[[VAL_33:.*]] = load float, float* %[[VAL_28]], align 4
-// CHECK: %[[VAL_34:.*]] = load float, float* %[[VAL_29]], align 4
-// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_21]]
-// CHECK: store float %[[VAL_33]], float* %[[VAL_35]], align 4
-// CHECK: %[[VAL_36:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_22]]
-// CHECK: store float %[[VAL_34]], float* %[[VAL_36]], align 4
-// CHECK: br label %[[VAL_27]]
+// CHECK: sort.in_bounds-true: ; preds = %[[VAL_18]]
+// CHECK: %[[VAL_19:.*]] = mul i64 %[[VAL_12]], 2
+// CHECK: %[[VAL_20:.*]] = xor i64 %[[VAL_19]], 1
+// CHECK: %[[VAL_21:.*]] = icmp slt i64 %[[VAL_19]], %[[VAL_20]]
+// CHECK: %[[VAL_22:.*]] = icmp slt i64 %[[VAL_20]], 3
+// CHECK: %[[VAL_23:.*]] = and i1 %[[VAL_21]], %[[VAL_22]]
+// CHECK: br i1 %[[VAL_23]], label %[[VAL_24:.*]], label %[[VAL_17]]
+// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_25:.*]], %[[VAL_15]]
+// CHECK: br label %[[VAL_16]]
+// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_15]]
+// CHECK: %[[VAL_26:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_20]]
+// CHECK: %[[VAL_27:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_19]]
+// CHECK: call void @region_0_4(float* %[[VAL_26]], float* %[[VAL_27]], i8* %[[VAL_1]])
+// CHECK: %[[VAL_28:.*]] = load i8, i8* %[[VAL_1]], align 1
+// CHECK: %[[VAL_29:.*]] = icmp ne i8 %[[VAL_28]], 0
+// CHECK: br i1 %[[VAL_29]], label %[[VAL_30:.*]], label %[[VAL_25]]
+// CHECK: is_smaller_than-after: ; preds = %[[VAL_30]], %[[VAL_24]]
+// CHECK: br label %[[VAL_17]]
+// CHECK: is_smaller_than-true: ; preds = %[[VAL_24]]
+// CHECK: %[[VAL_31:.*]] = load float, float* %[[VAL_26]], align 4
+// CHECK: %[[VAL_32:.*]] = load float, float* %[[VAL_27]], align 4
+// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_19]]
+// CHECK: store float %[[VAL_31]], float* %[[VAL_33]], align 4
+// CHECK: %[[VAL_34:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_20]]
+// CHECK: store float %[[VAL_32]], float* %[[VAL_34]], align 4
+// CHECK: br label %[[VAL_25]]
// CHECK: }
// CHECK: ; Function Attrs: nounwind readnone
// CHECK: declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
@@ -82,48 +80,46 @@
// CHECK: %[[VAL_1:.*]] = alloca i8, align 1
// CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_0]], i64 0
// CHECK: %[[VAL_3:.*]] = bitcast i8* %[[VAL_2]] to [2 x [3 x float]]*
-// CHECK: %[[VAL_4:.*]] = getelementptr inbounds i8, i8* %[[VAL_0]], i64 0
-// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_4]] to [2 x [3 x float]]*
-// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
+// CHECK: %[[VAL_4:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
+// CHECK: %[[VAL_5:.*]] = zext i32 %[[VAL_4]] to i64
+// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK: %[[VAL_7:.*]] = zext i32 %[[VAL_6]] to i64
-// CHECK: %[[VAL_8:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
-// CHECK: %[[VAL_9:.*]] = zext i32 %[[VAL_8]] to i64
-// CHECK: %[[VAL_10:.*]] = mul nuw nsw i64 %[[VAL_7]], 4
-// CHECK: %[[VAL_11:.*]] = add nuw nsw i64 %[[VAL_10]], %[[VAL_9]]
-// CHECK: %[[VAL_12:.*]] = icmp ult i64 %[[VAL_11]], 4
-// CHECK: call void @llvm.assume(i1 %[[VAL_12]])
-// CHECK: %[[VAL_13:.*]] = udiv i64 %[[VAL_11]], 1
-// CHECK: %[[VAL_14:.*]] = urem i64 %[[VAL_13]], 2
-// CHECK: %[[VAL_15:.*]] = udiv i64 %[[VAL_11]], 2
-// CHECK: %[[VAL_16:.*]] = icmp ult i64 %[[VAL_11]], 4
-// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]]
-// CHECK: sort.in_bounds-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]]
+// CHECK: %[[VAL_8:.*]] = mul nuw nsw i64 %[[VAL_5]], 4
+// CHECK: %[[VAL_9:.*]] = add nuw nsw i64 %[[VAL_8]], %[[VAL_7]]
+// CHECK: %[[VAL_10:.*]] = icmp ult i64 %[[VAL_9]], 4
+// CHECK: call void @llvm.assume(i1 %[[VAL_10]])
+// CHECK: %[[VAL_11:.*]] = udiv i64 %[[VAL_9]], 1
+// CHECK: %[[VAL_12:.*]] = urem i64 %[[VAL_11]], 2
+// CHECK: %[[VAL_13:.*]] = udiv i64 %[[VAL_9]], 2
+// CHECK: %[[VAL_14:.*]] = icmp ult i64 %[[VAL_9]], 4
+// CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]]
+// CHECK: sort.in_bounds-after: ; preds = %[[VAL_17:.*]], %[[VAL_18:.*]]
// CHECK: ret void
-// CHECK: sort.in_bounds-true: ; preds = %[[VAL_20]]
-// CHECK: %[[VAL_21:.*]] = xor i64 %[[VAL_14]], 3
-// CHECK: %[[VAL_22:.*]] = icmp slt i64 %[[VAL_14]], %[[VAL_21]]
-// CHECK: %[[VAL_23:.*]] = icmp slt i64 %[[VAL_21]], 3
-// CHECK: %[[VAL_24:.*]] = and i1 %[[VAL_22]], %[[VAL_23]]
-// CHECK: br i1 %[[VAL_24]], label %[[VAL_25:.*]], label %[[VAL_19]]
-// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_26:.*]], %[[VAL_17]]
-// CHECK: br label %[[VAL_18]]
-// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_17]]
-// CHECK: %[[VAL_27:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_21]]
-// CHECK: %[[VAL_28:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_14]]
-// CHECK: call void @region_0_4(float* %[[VAL_27]], float* %[[VAL_28]], i8* %[[VAL_1]])
-// CHECK: %[[VAL_29:.*]] = load i8, i8* %[[VAL_1]], align 1
-// CHECK: %[[VAL_30:.*]] = icmp ne i8 %[[VAL_29]], 0
-// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_26]]
-// CHECK: is_smaller_than-after: ; preds = %[[VAL_31]], %[[VAL_25]]
-// CHECK: br label %[[VAL_19]]
-// CHECK: is_smaller_than-true: ; preds = %[[VAL_25]]
-// CHECK: %[[VAL_32:.*]] = load float, float* %[[VAL_27]], align 4
-// CHECK: %[[VAL_33:.*]] = load float, float* %[[VAL_28]], align 4
-// CHECK: %[[VAL_34:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_14]]
-// CHECK: store float %[[VAL_32]], float* %[[VAL_34]], align 4
-// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_21]]
-// CHECK: store float %[[VAL_33]], float* %[[VAL_35]], align 4
-// CHECK: br label %[[VAL_26]]
+// CHECK: sort.in_bounds-true: ; preds = %[[VAL_18]]
+// CHECK: %[[VAL_19:.*]] = xor i64 %[[VAL_12]], 3
+// CHECK: %[[VAL_20:.*]] = icmp slt i64 %[[VAL_12]], %[[VAL_19]]
+// CHECK: %[[VAL_21:.*]] = icmp slt i64 %[[VAL_19]], 3
+// CHECK: %[[VAL_22:.*]] = and i1 %[[VAL_20]], %[[VAL_21]]
+// CHECK: br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_17]]
+// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_24:.*]], %[[VAL_15]]
+// CHECK: br label %[[VAL_16]]
+// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_15]]
+// CHECK: %[[VAL_25:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_19]]
+// CHECK: %[[VAL_26:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_12]]
+// CHECK: call void @region_0_4(float* %[[VAL_25]], float* %[[VAL_26]], i8* %[[VAL_1]])
+// CHECK: %[[VAL_27:.*]] = load i8, i8* %[[VAL_1]], align 1
+// CHECK: %[[VAL_28:.*]] = icmp ne i8 %[[VAL_27]], 0
+// CHECK: br i1 %[[VAL_28]], label %[[VAL_29:.*]], label %[[VAL_24]]
+// CHECK: is_smaller_than-after: ; preds = %[[VAL_29]], %[[VAL_23]]
+// CHECK: br label %[[VAL_17]]
+// CHECK: is_smaller_than-true: ; preds = %[[VAL_23]]
+// CHECK: %[[VAL_30:.*]] = load float, float* %[[VAL_25]], align 4
+// CHECK: %[[VAL_31:.*]] = load float, float* %[[VAL_26]], align 4
+// CHECK: %[[VAL_32:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_12]]
+// CHECK: store float %[[VAL_30]], float* %[[VAL_32]], align 4
+// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_19]]
+// CHECK: store float %[[VAL_31]], float* %[[VAL_33]], align 4
+// CHECK: br label %[[VAL_24]]
// CHECK: }
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) %[[VAL_0:.*]]) {
@@ -131,49 +127,47 @@
// CHECK: %[[VAL_1:.*]] = alloca i8, align 1
// CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_0]], i64 0
// CHECK: %[[VAL_3:.*]] = bitcast i8* %[[VAL_2]] to [2 x [3 x float]]*
-// CHECK: %[[VAL_4:.*]] = getelementptr inbounds i8, i8* %[[VAL_0]], i64 0
-// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_4]] to [2 x [3 x float]]*
-// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
+// CHECK: %[[VAL_4:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
+// CHECK: %[[VAL_5:.*]] = zext i32 %[[VAL_4]] to i64
+// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK: %[[VAL_7:.*]] = zext i32 %[[VAL_6]] to i64
-// CHECK: %[[VAL_8:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
-// CHECK: %[[VAL_9:.*]] = zext i32 %[[VAL_8]] to i64
-// CHECK: %[[VAL_10:.*]] = mul nuw nsw i64 %[[VAL_7]], 4
-// CHECK: %[[VAL_11:.*]] = add nuw nsw i64 %[[VAL_10]], %[[VAL_9]]
-// CHECK: %[[VAL_12:.*]] = icmp ult i64 %[[VAL_11]], 4
-// CHECK: call void @llvm.assume(i1 %[[VAL_12]])
-// CHECK: %[[VAL_13:.*]] = udiv i64 %[[VAL_11]], 1
-// CHECK: %[[VAL_14:.*]] = urem i64 %[[VAL_13]], 2
-// CHECK: %[[VAL_15:.*]] = udiv i64 %[[VAL_11]], 2
-// CHECK: %[[VAL_16:.*]] = icmp ult i64 %[[VAL_11]], 4
-// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]]
-// CHECK: sort.in_bounds-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]]
+// CHECK: %[[VAL_8:.*]] = mul nuw nsw i64 %[[VAL_5]], 4
+// CHECK: %[[VAL_9:.*]] = add nuw nsw i64 %[[VAL_8]], %[[VAL_7]]
+// CHECK: %[[VAL_10:.*]] = icmp ult i64 %[[VAL_9]], 4
+// CHECK: call void @llvm.assume(i1 %[[VAL_10]])
+// CHECK: %[[VAL_11:.*]] = udiv i64 %[[VAL_9]], 1
+// CHECK: %[[VAL_12:.*]] = urem i64 %[[VAL_11]], 2
+// CHECK: %[[VAL_13:.*]] = udiv i64 %[[VAL_9]], 2
+// CHECK: %[[VAL_14:.*]] = icmp ult i64 %[[VAL_9]], 4
+// CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]]
+// CHECK: sort.in_bounds-after: ; preds = %[[VAL_17:.*]], %[[VAL_18:.*]]
// CHECK: ret void
-// CHECK: sort.in_bounds-true: ; preds = %[[VAL_20]]
-// CHECK: %[[VAL_21:.*]] = mul i64 %[[VAL_14]], 2
-// CHECK: %[[VAL_22:.*]] = xor i64 %[[VAL_21]], 1
-// CHECK: %[[VAL_23:.*]] = icmp slt i64 %[[VAL_21]], %[[VAL_22]]
-// CHECK: %[[VAL_24:.*]] = icmp slt i64 %[[VAL_22]], 3
-// CHECK: %[[VAL_25:.*]] = and i1 %[[VAL_23]], %[[VAL_24]]
-// CHECK: br i1 %[[VAL_25]], label %[[VAL_26:.*]], label %[[VAL_19]]
-// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_27:.*]], %[[VAL_17]]
-// CHECK: br label %[[VAL_18]]
-// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_17]]
-// CHECK: %[[VAL_28:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_22]]
-// CHECK: %[[VAL_29:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_21]]
-// CHECK: call void @region_0_4(float* %[[VAL_28]], float* %[[VAL_29]], i8* %[[VAL_1]])
-// CHECK: %[[VAL_30:.*]] = load i8, i8* %[[VAL_1]], align 1
-// CHECK: %[[VAL_31:.*]] = icmp ne i8 %[[VAL_30]], 0
-// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_27]]
-// CHECK: is_smaller_than-after: ; preds = %[[VAL_32]], %[[VAL_26]]
-// CHECK: br label %[[VAL_19]]
-// CHECK: is_smaller_than-true: ; preds = %[[VAL_26]]
-// CHECK: %[[VAL_33:.*]] = load float, float* %[[VAL_28]], align 4
-// CHECK: %[[VAL_34:.*]] = load float, float* %[[VAL_29]], align 4
-// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_21]]
-// CHECK: store float %[[VAL_33]], float* %[[VAL_35]], align 4
-// CHECK: %[[VAL_36:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_22]]
-// CHECK: store float %[[VAL_34]], float* %[[VAL_36]], align 4
-// CHECK: br label %[[VAL_27]]
+// CHECK: sort.in_bounds-true: ; preds = %[[VAL_18]]
+// CHECK: %[[VAL_19:.*]] = mul i64 %[[VAL_12]], 2
+// CHECK: %[[VAL_20:.*]] = xor i64 %[[VAL_19]], 1
+// CHECK: %[[VAL_21:.*]] = icmp slt i64 %[[VAL_19]], %[[VAL_20]]
+// CHECK: %[[VAL_22:.*]] = icmp slt i64 %[[VAL_20]], 3
+// CHECK: %[[VAL_23:.*]] = and i1 %[[VAL_21]], %[[VAL_22]]
+// CHECK: br i1 %[[VAL_23]], label %[[VAL_24:.*]], label %[[VAL_17]]
+// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_25:.*]], %[[VAL_15]]
+// CHECK: br label %[[VAL_16]]
+// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_15]]
+// CHECK: %[[VAL_26:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_20]]
+// CHECK: %[[VAL_27:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_19]]
+// CHECK: call void @region_0_4(float* %[[VAL_26]], float* %[[VAL_27]], i8* %[[VAL_1]])
+// CHECK: %[[VAL_28:.*]] = load i8, i8* %[[VAL_1]], align 1
+// CHECK: %[[VAL_29:.*]] = icmp ne i8 %[[VAL_28]], 0
+// CHECK: br i1 %[[VAL_29]], label %[[VAL_30:.*]], label %[[VAL_25]]
+// CHECK: is_smaller_than-after: ; preds = %[[VAL_30]], %[[VAL_24]]
+// CHECK: br label %[[VAL_17]]
+// CHECK: is_smaller_than-true: ; preds = %[[VAL_24]]
+// CHECK: %[[VAL_31:.*]] = load float, float* %[[VAL_26]], align 4
+// CHECK: %[[VAL_32:.*]] = load float, float* %[[VAL_27]], align 4
+// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_19]]
+// CHECK: store float %[[VAL_31]], float* %[[VAL_33]], align 4
+// CHECK: %[[VAL_34:.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* %[[VAL_3]], i64 0, i64 %[[VAL_13]], i64 %[[VAL_20]]
+// CHECK: store float %[[VAL_32]], float* %[[VAL_34]], align 4
+// CHECK: br label %[[VAL_25]]
// CHECK: }
ENTRY main {
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index f71cf05..5e59081 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1128,7 +1128,7 @@
}
if (!instr->metadata().source_file().empty() &&
instr->metadata().source_line() != 0) {
- lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(),
+ lines.push_back(StrFormat("source: %s:%d", instr->metadata().source_file(),
instr->metadata().source_line()));
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index c3951d5..9b7679e 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -328,7 +328,9 @@
instruction = CreateConstant(std::move(literal));
// Literal's shape may have no/different tiling info.
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
- instruction->shape(), shape));
+ instruction->shape(), shape))
+ << instruction->shape().ToString(true) << " vs "
+ << shape.ToString(true);
*instruction->mutable_shape() = shape;
} else {
instruction = absl::make_unique<HloConstantInstruction>(shape);
@@ -578,6 +580,12 @@
if (proto.has_window()) {
custom_call_instr->set_window(proto.window());
}
+ if (proto.has_literal()) {
+ TF_ASSIGN_OR_RETURN(
+ auto literal,
+ Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
+ custom_call_instr->set_literal(std::move(literal));
+ }
if (proto.has_convolution_dimension_numbers()) {
custom_call_instr->set_convolution_dimension_numbers(
proto.convolution_dimension_numbers());
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index f2a7fe1..7a77e86 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1328,19 +1328,7 @@
options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
// lines. Compact this into one line by stripping out white space.
- string tmp = literal().ToStringWithoutShape();
- std::replace(tmp.begin(), tmp.end(), '\n', ' ');
- std::vector<string> v = absl::StrSplit(tmp, ' ');
- bool first = true;
- // Concatenate elements in "v" with spaces separating them, but ignoring
- // empty entries.
- for (const auto& s : v) {
- if (s.empty()) {
- continue;
- }
- StrAppend(&operands, (first ? "" : " "), s);
- first = false;
- }
+ operands = literal_->ToStringWithoutShapeOneline();
} else {
// Do not show large constants or tuples.
operands = "{...}";
@@ -2441,6 +2429,9 @@
}
}
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
+ if (literal_.has_value()) {
+ *proto.mutable_literal() = literal_->ToProto();
+ }
for (const auto& pair : output_to_operand_aliasing_) {
auto aliasing = proto.add_custom_call_output_operand_aliasing();
aliasing->set_operand_index(pair.second.first);
@@ -2495,6 +2486,9 @@
if (custom_call_has_side_effect_) {
extra.push_back("custom_call_has_side_effect=true");
}
+ if (literal_.has_value()) {
+ extra.push_back(StrCat("literal=(", literal_->ToStringOneline(), ")"));
+ }
if (!output_to_operand_aliasing_.empty()) {
std::vector<string> pair_strings;
for (const auto& pair : output_to_operand_aliasing_) {
@@ -2571,6 +2565,13 @@
return false;
}
}
+ if (HasLiteral() == casted_other.HasLiteral()) {
+ if (HasLiteral() && literal() == casted_other.literal()) {
+ return false;
+ }
+ } else {
+ return true;
+ }
// Note: backend_config comparison is done in Identical, which is the
// intended/exposed way to compare computations, and so not repeated here.
@@ -2593,6 +2594,9 @@
if (convolution_dimension_numbers_ != nullptr) {
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
}
+ if (HasLiteral()) {
+ cloned->set_literal(literal().Clone());
+ }
cloned->set_feature_group_count(feature_group_count_);
cloned->set_batch_group_count(batch_group_count_);
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index bacbce1..4df82e1 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1466,6 +1466,13 @@
padding_type_ = padding_type;
}
+ // Returns the literal associated with this instruction.
+ const Literal& literal() const { return *literal_; }
+ // Set the value of literal to a new one.
+ void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); }
+ // Returns whether there is literal associated with this instruction.
+ bool HasLiteral() const { return literal_.has_value(); }
+
const PrecisionConfig& precision_config() const { return precision_config_; }
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
@@ -1532,6 +1539,7 @@
// output_to_operand_aliasing().
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_to_operand_aliasing_;
+ absl::optional<Literal> literal_;
};
class HloPadInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 3341864..6f5a877 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -253,6 +253,7 @@
bool ParseInstructionRhs(HloComputation::Builder* builder,
const std::string& name, LocTy name_loc);
bool ParseControlPredecessors(HloInstruction* instruction);
+ bool ParseLiteral(Literal* literal);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
@@ -307,6 +308,7 @@
kInt32,
kFloat,
kString,
+ kLiteral,
kBracedInt64List,
kBracedInt64ListList,
kHloComputation,
@@ -2268,6 +2270,9 @@
attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
&padding_type};
+
+ optional<Literal> literal;
+ attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
@@ -2357,6 +2362,9 @@
custom_call_instr->set_output_to_operand_aliasing(
std::move(*output_to_operand_aliasing));
}
+ if (literal.has_value()) {
+ custom_call_instr->set_literal(std::move(*literal));
+ }
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
@@ -3048,6 +3056,14 @@
return true;
}
+bool HloParserImpl::ParseLiteral(Literal* literal) {
+ Shape literal_shape;
+ if (!ParseShape(&literal_shape)) {
+ return false;
+ }
+ return ParseLiteral(literal, literal_shape);
+}
+
// literal
// ::= tuple
// ::= non_tuple
@@ -3830,6 +3846,21 @@
->emplace(std::move(aliasing_output_operand_pairs));
return true;
}
+ case AttrTy::kLiteral: {
+ if (!ParseToken(TokKind::kLparen, "expects '(' before literal")) {
+ return false;
+ }
+ Literal result;
+ if (!ParseLiteral(&result)) {
+ return false;
+ }
+ if (!ParseToken(TokKind::kRparen, "expects ')' after literal")) {
+ return false;
+ }
+ static_cast<optional<Literal>*>(attr_out_ptr)
+ ->emplace(std::move(result));
+ return true;
+ }
}
}();
if (!success) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 4dac92b..696f809 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -401,6 +401,32 @@
)"
},
+
+// CustomCall with literal.
+{
+"CustomCallWithLiteral",
+R"(HloModule custom_call
+
+ENTRY %CustomCall () -> f32[1,2,3] {
+ %constant = f32[1]{0} constant({12345})
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[1] {0.1})
+}
+
+)"
+},
+
+// CustomCall with literal R0.
+{
+"CustomCallWithLiteralR0",
+R"(HloModule custom_call
+
+ENTRY %CustomCall () -> f32[1,2,3] {
+ %constant = f32[1]{0} constant({12345})
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[] 0.1)
+}
+
+)"
+},
// reduce window
{
"ReduceWindow",
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 25b2df0..5a8e1a9 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -180,7 +180,7 @@
RecordPassEndMetadata(*hlo, pass_name, pass_changed);
changed |= pass_changed;
if (pass_changed) {
- VLOG(3) << " Pass caused changes" << pass->name();
+ VLOG(3) << " Pass caused changes " << pass->name();
}
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
if (!pass->IsPassPipeline()) {
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 1c27f81..ac2f8cc 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -557,6 +557,7 @@
}
if (fusion_instruction == nullptr) {
+ fusion_queue->NotFusingInstruction(operand, instruction);
continue;
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index d51bf70..f6d9f5c 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -144,10 +144,15 @@
bool ReusesOperandElements(const HloInstruction* consumer,
int64 operand_index);
- private:
// The set of producers whose consumers we cannot fuse into.
using HloInstructionSet = std::unordered_set<HloInstruction*>;
+ // Computes the set of nodes that we do not want to fuse into any of their
+ // consumers based on a global analysis of the HLO graph.
+ virtual HloInstructionSet ComputeGloballyUnfusible(
+ absl::Span<HloInstruction* const> post_order);
+
+ private:
HloInstruction* AddFusionInstruction(HloInstruction* producer,
HloInstruction* consumer);
@@ -163,11 +168,6 @@
absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
result_cache);
- // Computes the set of nodes that we do not want to fuse into any of their
- // consumers based on a global analysis of the HLO graph.
- HloInstructionSet ComputeGloballyUnfusible(
- absl::Span<HloInstruction* const> post_order);
-
// Used to determine if an HLO is expensive. Expensive operations will not be
// duplicated.
std::function<bool(const HloInstruction& instruction)> is_expensive_;
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 4882c5d..72561cb 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1949,9 +1949,19 @@
computation->root_instruction()));
computation->set_root_instruction(new_root);
} else {
- // Use the specified shape including tiling info in layout.
- *(computation->root_instruction()->mutable_shape()) =
- constraints.ResultLayout()->shape();
+ // Copy the specified tiling info.
+ auto assign_tiling = [&constraints](xla::Shape* subshape,
+ const xla::ShapeIndex& index) {
+ if (subshape->IsArray()) {
+ const Shape& result_shape = ShapeUtil::GetSubshape(
+ constraints.ResultLayout()->shape(), index);
+ subshape->mutable_layout()->mutable_tiles()->assign(
+ result_shape.layout().tiles().begin(),
+ result_shape.layout().tiles().end());
+ }
+ };
+ xla::ShapeUtil::ForEachMutableSubshape(
+ computation->root_instruction()->mutable_shape(), assign_tiling);
}
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 304a80c..987ed90 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -1422,5 +1422,29 @@
ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
}
+TEST_F(LayoutAssignmentTest, DynamicRoot) {
+ const char* module_str = R"(
+HloModule test_module
+
+ENTRY entry_computation {
+ param = f32[1,<=16]{0,1} parameter(0)
+ ROOT abs = f32[1,<=16]{0,1} abs(param)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
+ computation_layout.mutable_result_layout()->ClearDynamicShape();
+
+ AssignLayouts(m.get(), &computation_layout);
+
+ const HloInstruction* abs = FindInstruction(m.get(), "abs");
+ ExpectLayoutIs(abs->operand(0)->shape(), {0, 1});
+ ExpectLayoutIs(abs->shape(), {0, 1});
+ EXPECT_TRUE(abs->shape().is_dynamic_dimension(1));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index 0a26a2b..135fd4e 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -102,7 +102,7 @@
global,
llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
return IrArray(shape_constant, constant->shape())
- .EmitReadArrayElement(index, b_);
+ .EmitReadArrayElement(index, b_, constant->name());
};
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index a21e7fa..beb06c3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -504,7 +504,7 @@
bool use_linear_index) const {
llvm::Value* element_address =
EmitArrayElementAddress(index, b, name, use_linear_index);
- llvm::LoadInst* load = b->CreateLoad(element_address);
+ llvm::LoadInst* load = b->CreateLoad(element_address, name.data());
AnnotateLoadStoreInstructionWithMetadata(load);
return load;
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index a00156a..9632bc6 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -83,36 +83,39 @@
llvm::CallInst* EmitCallToIntrinsic(
llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
- absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
+ absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
+ absl::string_view name) {
llvm::Module* module = ModuleFromIRBuilder(b);
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
module, intrinsic_id, AsArrayRef(overloaded_types));
- return b->CreateCall(intrinsic, AsArrayRef(operands));
+ return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
}
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
- llvm::IRBuilder<>* b, bool enable_fast_min_max) {
+ llvm::IRBuilder<>* b, bool enable_fast_min_max,
+ absl::string_view name) {
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
- return b->CreateSelect(cmp, lhs_value, rhs_value);
+ return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
} else {
auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
- return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
+ return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
}
}
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
- llvm::IRBuilder<>* b, bool enable_fast_min_max) {
+ llvm::IRBuilder<>* b, bool enable_fast_min_max,
+ absl::string_view name) {
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
- return b->CreateSelect(cmp, lhs_value, rhs_value);
+ return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
} else {
auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
- return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
+ return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
}
}
@@ -351,12 +354,14 @@
llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
llvm::Value* lhs_value, llvm::Value* rhs_value,
- llvm::IRBuilder<>* b) {
+ llvm::IRBuilder<>* b, absl::string_view name) {
llvm::Value* comparison_result;
if (lhs_value->getType()->isIntegerTy()) {
- comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value);
+ comparison_result =
+ b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
} else {
- comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value);
+ comparison_result =
+ b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
}
// comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
// arrays. So we extend it to i8 so that it's addressable.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 3a3b4b7..1171d29 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -103,17 +103,20 @@
// overloaded type.
llvm::CallInst* EmitCallToIntrinsic(
llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
- absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b);
+ absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
+ absl::string_view name = "");
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
- llvm::IRBuilder<>* b, bool enable_fast_min_max);
+ llvm::IRBuilder<>* b, bool enable_fast_min_max,
+ absl::string_view name = "");
// Emit float min. Emit minnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
- llvm::IRBuilder<>* b, bool enable_fast_min_max);
+ llvm::IRBuilder<>* b, bool enable_fast_min_max,
+ absl::string_view name = "");
// Convenience methods for emitting a GEP instruction that indexes into a buffer
// (1-dimensional array), equivalent to array[index]. The type is automatically
@@ -214,7 +217,7 @@
// and then converts the result to i8 so that it is addressable.
llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
llvm::Value* lhs, llvm::Value* rhs,
- llvm::IRBuilder<>* b);
+ llvm::IRBuilder<>* b, absl::string_view name = "");
// Emits a call that logs the given value with the given tag as a prefix.
// The provided tag and value are passed to a runtime logging call that is
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 10c3d60..a065065 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -254,12 +254,6 @@
cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
}
-float MemorySpaceAssignmentCostAnalysis::
- GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const {
- return bytes /
- cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
-}
-
float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
const HloInstruction& instruction,
absl::optional<int64> operand_in_alternate_mem,
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 7bffcc2..f6104d6 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -123,12 +123,6 @@
absl::optional<int64> operand_in_alternate_mem = absl::nullopt,
bool output_in_alternate_mem = false) const;
- // Returns the elapsed time in seconds that other BufferIntervals are slowed
- // down, due to the prefetching of current bytes. Assuming other
- // BufferIntervals needs default memory bandwidth, and only current
- // BufferInterval is prefetched.
- float GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const;
-
// Returns the estimated elapsed duration of the instruction in seconds. It
// assumes all operands and outputs of the instruction are in the default
// memory.
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 73da11b..ddc6090 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2689,11 +2689,19 @@
auto result_shape = operand_shape;
- // If any of the operand shape and update shape is dynamic, update the result
- // dimension to dynamic.
+ // If any of the operand shape is dynamic, the result dimension is also
+ // dynamic.
+ // If update shape is dynamic, only propagate dynamic dimension to result if
+ // the update is a full update (update_shape[i] == operand_shape[i]).
for (int64 i = 0; i < update_shape.rank(); ++i) {
- if (update_shape.is_dynamic_dimension(i) ||
- operand_shape.is_dynamic_dimension(i)) {
+ if (operand_shape.is_dynamic_dimension(i)) {
+ result_shape.set_dynamic_dimension(i, true);
+ }
+
+ if (update_shape.is_dynamic_dimension(i) &&
+ update_shape.dimensions(i) == operand_shape.dimensions(i)) {
+ // When update/replace a full dimension, propagate dynamic dimension to
+ // the result.
result_shape.set_dynamic_dimension(i, true);
}
}
diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD
index 4c7dddc..b126904 100644
--- a/tensorflow/compiler/xla/service/spmd/BUILD
+++ b/tensorflow/compiler/xla/service/spmd/BUILD
@@ -73,6 +73,8 @@
deps = [
":spmd_partitioner",
"//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
index 85ebc31..49ae866 100644
--- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
@@ -92,37 +92,259 @@
namespace {
-StatusOr<HloInstruction*> PartitionBaseCase(
- PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
- const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
- int64 num_partitions,
- const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*,
- const Window& conv_window)>& create_sharded_dot,
- const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
- int64 lhs_batch_partitions, int64 rhs_batch_partitions,
- int64 output_batch_partitions, int64 lhs_contracting_partitions,
- int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
- int64 rhs_non_contracting_partitions,
- int64 output_lhs_non_contracting_partitions,
- int64 output_rhs_non_contracting_partitions,
- const SpmdPartitionerOptions& options, SpmdBuilder* b,
- std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
- windowed_dot_general_loops,
- bool may_reshard_without_detecting_match) {
- const HloSharding& lhs_sharding = lhs.sharding();
- const HloSharding& rhs_sharding = rhs.sharding();
- if (lhs_sharding.ReplicateOnLastTileDim() ||
- rhs_sharding.ReplicateOnLastTileDim() ||
- output_sharding.ReplicateOnLastTileDim()) {
- return nullptr;
+enum class WindowedEinsumOperand { LHS, RHS };
+
+struct WindowedEinsumConfig {
+ WindowedEinsumOperand windowed_op;
+ bool windowed_at_contracting_dims;
+ bool windowed_at_batch_dims;
+ bool operands_sharded_at_contracting_dims;
+};
+
+struct DotDimensionIndexMapping {
+ std::vector<int64> lhs_to_rhs_indices;
+ std::vector<int64> lhs_to_output_indices;
+ std::vector<int64> rhs_to_lhs_indices;
+ std::vector<int64> rhs_to_output_indices;
+ std::vector<int64> output_to_lhs_indices;
+ std::vector<int64> output_to_rhs_indices;
+};
+
+void UpdateDDNums(DotDimensionNumbers* new_ddnums, int64 reshaped_dim,
+ bool lhs) {
+ auto update_dims =
+ [&reshaped_dim](tensorflow::protobuf::RepeatedField<int64>* dims) {
+ for (int64 i = 0; i < dims->size(); ++i) {
+ auto dim = dims->at(i);
+ if (reshaped_dim <= dim) {
+ dims->Set(i, dim + 1);
+ }
+ }
+ if (absl::c_linear_search(*dims, reshaped_dim)) {
+ dims->Add(reshaped_dim);
+ }
+ };
+
+ if (lhs) {
+ update_dims(new_ddnums->mutable_lhs_contracting_dimensions());
+ update_dims(new_ddnums->mutable_lhs_batch_dimensions());
+ } else { // rhs
+ update_dims(new_ddnums->mutable_rhs_contracting_dimensions());
+ update_dims(new_ddnums->mutable_rhs_batch_dimensions());
}
- std::vector<int64> lhs_to_rhs_indices(lhs.base_shape().rank(), -1);
- std::vector<int64> lhs_to_output_indices(lhs.base_shape().rank(), -1);
- std::vector<int64> rhs_to_lhs_indices(rhs.base_shape().rank(), -1);
- std::vector<int64> rhs_to_output_indices(rhs.base_shape().rank(), -1);
- std::vector<int64> output_to_lhs_indices(output_base_shape.rank(), -1);
- std::vector<int64> output_to_rhs_indices(output_base_shape.rank(), -1);
+}
+
+Window GenNewWindow(const HloInstruction* original_dot,
+ const HloInstruction* dot_lhs,
+ const HloInstruction* dot_rhs, int64 lhs_concat_dim,
+ int64 rhs_concat_dim, bool windowed_at_contracting_dims,
+ bool windowed_at_batch_dims) {
+ auto new_window = original_dot->window();
+ const ConvolutionDimensionNumbers& conv_dnums =
+ original_dot->convolution_dimension_numbers();
+ if (lhs_concat_dim != -1) {
+ for (int64 i = 0; i < conv_dnums.input_spatial_dimensions_size(); ++i) {
+ if (conv_dnums.input_spatial_dimensions(i) == lhs_concat_dim) {
+ auto wd = new_window.mutable_dimensions(i);
+ auto lhs_size = dot_lhs->shape().dimensions(lhs_concat_dim + 1);
+ if (windowed_at_contracting_dims) {
+ wd->set_size(lhs_size);
+ }
+ if (windowed_at_batch_dims) {
+ wd->set_size(lhs_size);
+ wd->set_padding_low(0);
+ wd->set_padding_high(0);
+ wd->set_stride(std::max<int64>(1, lhs_size - 1));
+ wd->set_window_dilation(1);
+ wd->set_base_dilation(lhs_size);
+ wd->set_window_reversal(false);
+ }
+ }
+ }
+ }
+ if (rhs_concat_dim != -1) {
+ for (int64 i = 0; i < conv_dnums.kernel_spatial_dimensions_size(); ++i) {
+ if (conv_dnums.kernel_spatial_dimensions(i) == rhs_concat_dim &&
+ !windowed_at_contracting_dims && !windowed_at_batch_dims &&
+ lhs_concat_dim == -1) {
+ auto wd = new_window.mutable_dimensions(i);
+ auto rhs_size = dot_rhs->shape().dimensions(rhs_concat_dim + 1);
+ wd->set_size(rhs_size);
+ wd->set_padding_low(rhs_size - 1);
+ wd->set_padding_high(rhs_size - 1);
+ }
+ }
+ }
+ // Add the extra dimension to window.
+ WindowDimension* new_dim = new_window.add_dimensions();
+ if (windowed_at_contracting_dims) {
+ new_dim->set_size(2);
+ new_dim->set_padding_low(0);
+ new_dim->set_padding_high(0);
+ new_dim->set_stride(1);
+ new_dim->set_window_dilation(1);
+ new_dim->set_base_dilation(1);
+ new_dim->set_window_reversal(false);
+ } else if (windowed_at_batch_dims) {
+ new_dim->set_size(2);
+ new_dim->set_padding_low(0);
+ new_dim->set_padding_high(0);
+ new_dim->set_stride(1); // std::max<int64>(1, 2 - 1)
+ new_dim->set_window_dilation(1);
+ new_dim->set_base_dilation(2);
+ new_dim->set_window_reversal(false);
+ } else {
+ if (lhs_concat_dim != -1) {
+ new_dim->set_size(1);
+ new_dim->set_padding_low(0);
+ new_dim->set_padding_high(0);
+ new_dim->set_stride(1);
+ new_dim->set_window_dilation(1);
+ new_dim->set_base_dilation(1);
+ new_dim->set_window_reversal(false);
+ }
+ if (rhs_concat_dim != -1) {
+ new_dim->set_size(2); // rhs_size
+ new_dim->set_padding_low(1); // rhs_size - 1
+ new_dim->set_padding_high(1); // rhs_size - 1
+ new_dim->set_stride(1);
+ new_dim->set_window_dilation(1);
+ new_dim->set_base_dilation(1);
+ new_dim->set_window_reversal(true);
+ }
+ }
+
+ VLOG(2) << "new_window: " << new_window.ShortDebugString();
+ return new_window;
+}
+
+ConvolutionDimensionNumbers GenNewConvDNums(
+ const HloInstruction* original_dot, const HloInstruction* dot_lhs,
+ const HloInstruction* dot_rhs, int64 lhs_concat_dim, int64 rhs_concat_dim,
+ bool windowed_at_contracting_dims, bool windowed_at_batch_dims,
+ const std::vector<int64>& lhs_to_output_indices,
+ const std::vector<int64>& rhs_to_output_indices,
+ const Shape& new_dot_shape) {
+ // Generate the new conv dimension numbers.
+ const ConvolutionDimensionNumbers& dnums =
+ original_dot->convolution_dimension_numbers();
+ // Handle the LHS dimension numbers.
+ int64 input_batch_dimension = dnums.input_batch_dimension();
+ int64 input_feature_dimension = dnums.input_feature_dimension();
+ std::vector<int64> input_spatial_dimensions(
+ dnums.input_spatial_dimensions().begin(),
+ dnums.input_spatial_dimensions().end());
+ if (lhs_concat_dim != -1) {
+ if (lhs_concat_dim <= input_batch_dimension) {
+ input_batch_dimension++;
+ }
+ if (lhs_concat_dim <= input_feature_dimension) {
+ input_feature_dimension++;
+ }
+ for (int64 i = 0; i < input_spatial_dimensions.size(); ++i) {
+ if (lhs_concat_dim <= input_spatial_dimensions[i]) {
+ input_spatial_dimensions[i]++;
+ }
+ }
+ input_spatial_dimensions.push_back(lhs_concat_dim);
+ }
+ if (rhs_concat_dim != -1 && !windowed_at_contracting_dims &&
+ !windowed_at_batch_dims) {
+ input_spatial_dimensions.push_back(dot_lhs->shape().dimensions_size() - 1);
+ }
+ // Handle the RHS dimension numbers.
+ int64 kernel_input_feature_dimension = dnums.kernel_input_feature_dimension();
+ int64 kernel_output_feature_dimension =
+ dnums.kernel_output_feature_dimension();
+ std::vector<int64> kernel_spatial_dimensions(
+ dnums.kernel_spatial_dimensions().begin(),
+ dnums.kernel_spatial_dimensions().end());
+ if (rhs_concat_dim != -1) {
+ if (rhs_concat_dim <= kernel_input_feature_dimension) {
+ kernel_input_feature_dimension++;
+ }
+ if (rhs_concat_dim <= kernel_output_feature_dimension) {
+ kernel_output_feature_dimension++;
+ }
+ for (int64 i = 0; i < kernel_spatial_dimensions.size(); ++i) {
+ if (rhs_concat_dim <= kernel_spatial_dimensions[i]) {
+ kernel_spatial_dimensions[i]++;
+ }
+ }
+ kernel_spatial_dimensions.push_back(rhs_concat_dim);
+ }
+ if (lhs_concat_dim != -1 && !windowed_at_contracting_dims &&
+ !windowed_at_batch_dims) {
+ kernel_spatial_dimensions.push_back(dot_rhs->shape().dimensions_size() - 1);
+ }
+ // Handle the Output dimension numbers.
+ int64 output_batch_dimension = dnums.output_batch_dimension();
+ int64 output_feature_dimension = dnums.output_feature_dimension();
+ std::vector<int64> output_spatial_dimensions(
+ dnums.output_spatial_dimensions().begin(),
+ dnums.output_spatial_dimensions().end());
+ if (!windowed_at_contracting_dims) {
+ auto output_slice_dim = lhs_concat_dim != -1
+ ? lhs_to_output_indices[lhs_concat_dim]
+ : rhs_to_output_indices[rhs_concat_dim];
+ if (output_slice_dim <= output_batch_dimension) {
+ output_batch_dimension++;
+ }
+ if (output_slice_dim <= output_feature_dimension) {
+ output_feature_dimension++;
+ }
+ for (int64 i = 0; i < output_spatial_dimensions.size(); ++i) {
+ if (output_slice_dim <= output_spatial_dimensions[i]) {
+ output_spatial_dimensions[i]++;
+ }
+ }
+ output_spatial_dimensions.push_back(output_slice_dim);
+ } else {
+ output_spatial_dimensions.push_back(new_dot_shape.dimensions_size() - 1);
+ }
+ // Construct the new dot dimension numbers.
+ ConvolutionDimensionNumbers new_dnums;
+ new_dnums.set_input_batch_dimension(input_batch_dimension);
+ new_dnums.set_input_feature_dimension(input_feature_dimension);
+ for (auto dim : input_spatial_dimensions) {
+ new_dnums.add_input_spatial_dimensions(dim);
+ }
+ new_dnums.set_kernel_input_feature_dimension(kernel_input_feature_dimension);
+ new_dnums.set_kernel_output_feature_dimension(
+ kernel_output_feature_dimension);
+ for (auto dim : kernel_spatial_dimensions) {
+ new_dnums.add_kernel_spatial_dimensions(dim);
+ }
+ new_dnums.set_output_batch_dimension(output_batch_dimension);
+ new_dnums.set_output_feature_dimension(output_feature_dimension);
+ for (auto dim : output_spatial_dimensions) {
+ new_dnums.add_output_spatial_dimensions(dim);
+ }
+
+ return new_dnums;
+}
+
+int64 FirstShardingDimWithPartitionOfSize(int64 num_partitions,
+ const HloSharding& sharding) {
+ int64 sharding_dim = -1;
+ for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
+ if (sharding.tile_assignment().dim(i) == num_partitions) {
+ sharding_dim = i;
+ break;
+ }
+ }
+ return sharding_dim;
+}
+
+DotDimensionIndexMapping ComputeDimensionIndexMapping(
+ const DotConvDimsMapping& dims_mapping, int64 lhs_rank, int64 rhs_rank,
+ int64 output_rank) {
+ std::vector<int64> lhs_to_rhs_indices(lhs_rank, -1);
+ std::vector<int64> lhs_to_output_indices(lhs_rank, -1);
+ std::vector<int64> rhs_to_lhs_indices(rhs_rank, -1);
+ std::vector<int64> rhs_to_output_indices(rhs_rank, -1);
+ std::vector<int64> output_to_lhs_indices(output_rank, -1);
+ std::vector<int64> output_to_rhs_indices(output_rank, -1);
auto populate_indices_mapping =
[&](const DotConvDimsMapping::DimsMapping& mapping) {
if (mapping.lhs >= 0) {
@@ -153,24 +375,146 @@
for (const auto& mapping : dims_mapping.conv_spatial_dims) {
populate_indices_mapping(mapping);
}
+ return DotDimensionIndexMapping{lhs_to_rhs_indices, lhs_to_output_indices,
+ rhs_to_lhs_indices, rhs_to_output_indices,
+ output_to_lhs_indices, output_to_rhs_indices};
+}
+
+absl::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
+ int64 num_partitions, int64 output_lhs_non_contracting_partitions,
+ int64 output_rhs_non_contracting_partitions,
+ int64 rhs_contracting_partitions, int64 rhs_non_contracting_partitions,
+ int64 rhs_batch_partitions, int64 lhs_contracting_partitions,
+ int64 lhs_non_contracting_partitions, int64 lhs_batch_partitions,
+ int64 output_sharding_dim, int64 rhs_shape_size, int64 lhs_shape_size,
+ int64 output_shape_size, int64 einsum_threshold_mib,
+ const absl::optional<HloSharding>& output_sharding_transposed_to_match_lhs,
+ const absl::optional<HloSharding>& output_sharding_transposed_to_match_rhs,
+ const HloSharding& lhs_sharding, const HloSharding& rhs_sharding) {
+ if (output_lhs_non_contracting_partitions == num_partitions &&
+ output_sharding_transposed_to_match_lhs == lhs_sharding &&
+ rhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
+ if (rhs_contracting_partitions == num_partitions) {
+ return WindowedEinsumConfig{
+ /*windowed_op=*/WindowedEinsumOperand::RHS,
+ /*windowed_at_contracting_dims*/ true,
+ /*windowed_at_batch_dims=*/false,
+ /*operands_sharded_at_contracting_dims=*/false};
+ }
+ if (rhs_non_contracting_partitions == num_partitions) {
+ return WindowedEinsumConfig{
+ /*windowed_op=*/WindowedEinsumOperand::RHS,
+ /*windowed_at_contracting_dims*/ false,
+ /*windowed_at_batch_dims=*/false,
+ /*operands_sharded_at_contracting_dims=*/false};
+ }
+ if (rhs_batch_partitions == num_partitions) {
+ return WindowedEinsumConfig{
+ /*windowed_op=*/WindowedEinsumOperand::RHS,
+ /*windowed_at_contracting_dims*/ false,
+ /*windowed_at_batch_dims=*/true,
+ /*operands_sharded_at_contracting_dims=*/false};
+ }
+ }
+ if (output_rhs_non_contracting_partitions == num_partitions &&
+ output_sharding_transposed_to_match_rhs == rhs_sharding &&
+ lhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
+ if (lhs_contracting_partitions == num_partitions) {
+ return WindowedEinsumConfig{
+ /*windowed_op=*/WindowedEinsumOperand::LHS,
+ /*windowed_at_contracting_dims*/ true,
+ /*windowed_at_batch_dims=*/false,
+ /*operands_sharded_at_contracting_dims=*/false};
+ }
+ if (lhs_non_contracting_partitions == num_partitions) {
+ return WindowedEinsumConfig{
+ /*windowed_op=*/WindowedEinsumOperand::LHS,
+ /*windowed_at_contracting_dims*/ false,
+ /*windowed_at_batch_dims=*/false,
+ /*operands_sharded_at_contracting_dims=*/false};
+ }
+ if (lhs_batch_partitions == num_partitions) {
+ return WindowedEinsumConfig{
+ /*windowed_op=*/WindowedEinsumOperand::LHS,
+ /*windowed_at_contracting_dims*/ false,
+ /*windowed_at_batch_dims=*/true,
+ /*operands_sharded_at_contracting_dims=*/false};
+ }
+ }
+ if (lhs_contracting_partitions == rhs_contracting_partitions &&
+ lhs_contracting_partitions == num_partitions &&
+ output_sharding_dim > -1 &&
+ output_shape_size >= einsum_threshold_mib * 1024 * 1024) {
+ if (output_lhs_non_contracting_partitions == num_partitions) {
+ return WindowedEinsumConfig{
+ /*windowed_op=*/WindowedEinsumOperand::RHS,
+ /*windowed_at_contracting_dims*/ false,
+ /*windowed_at_batch_dims=*/false,
+ /*operands_sharded_at_contracting_dims=*/true};
+ }
+ if (output_rhs_non_contracting_partitions == num_partitions) {
+ return WindowedEinsumConfig{
+ /*windowed_op=*/WindowedEinsumOperand::LHS,
+ /*windowed_at_contracting_dims*/ false,
+ /*windowed_at_batch_dims=*/false,
+ /*operands_sharded_at_contracting_dims=*/true};
+ }
+ }
+ return absl::nullopt;
+}
+
+StatusOr<HloInstruction*> PartitionBaseCase(
+ PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
+ const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
+ int64 num_partitions,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot,
+ const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
+ int64 lhs_batch_partitions, int64 rhs_batch_partitions,
+ int64 output_batch_partitions, int64 lhs_contracting_partitions,
+ int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
+ int64 rhs_non_contracting_partitions,
+ int64 output_lhs_non_contracting_partitions,
+ int64 output_rhs_non_contracting_partitions,
+ const SpmdPartitionerOptions& options, SpmdBuilder* b,
+ std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
+ windowed_dot_general_loops,
+ bool may_reshard_without_detecting_match) {
+ const HloSharding& lhs_sharding = lhs.sharding();
+ const HloSharding& rhs_sharding = rhs.sharding();
+ if (lhs_sharding.ReplicateOnLastTileDim() ||
+ rhs_sharding.ReplicateOnLastTileDim() ||
+ output_sharding.ReplicateOnLastTileDim()) {
+ return nullptr;
+ }
+ DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
+ dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
+ output_base_shape.rank());
auto lhs_sharding_transposed_to_match_rhs =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
- lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices);
+ lhs_sharding, indices_map.lhs_to_rhs_indices,
+ indices_map.rhs_to_lhs_indices);
auto rhs_sharding_transposed_to_match_lhs =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
- rhs_sharding, rhs_to_lhs_indices, lhs_to_rhs_indices);
+ rhs_sharding, indices_map.rhs_to_lhs_indices,
+ indices_map.lhs_to_rhs_indices);
auto lhs_sharding_transposed_to_match_output =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
- lhs_sharding, lhs_to_output_indices, output_to_lhs_indices);
+ lhs_sharding, indices_map.lhs_to_output_indices,
+ indices_map.output_to_lhs_indices);
auto rhs_sharding_transposed_to_match_output =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
- rhs_sharding, rhs_to_output_indices, output_to_rhs_indices);
+ rhs_sharding, indices_map.rhs_to_output_indices,
+ indices_map.output_to_rhs_indices);
auto output_sharding_transposed_to_match_lhs =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
- output_sharding, output_to_lhs_indices, lhs_to_output_indices);
+ output_sharding, indices_map.output_to_lhs_indices,
+ indices_map.lhs_to_output_indices);
auto output_sharding_transposed_to_match_rhs =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
- output_sharding, output_to_rhs_indices, rhs_to_output_indices);
+ output_sharding, indices_map.output_to_rhs_indices,
+ indices_map.rhs_to_output_indices);
// LHS and RHS are partitioned the same way and only partitioned in batch
// dimensions.
@@ -238,29 +582,28 @@
}
}
- int64 output_sharding_dim = -1;
- for (int64 i = 0; i < output_sharding.tile_assignment().num_dimensions();
- ++i) {
- if (output_sharding.tile_assignment().dim(i) == num_partitions) {
- output_sharding_dim = i;
- break;
- }
- }
+ const int64 output_sharding_dim =
+ FirstShardingDimWithPartitionOfSize(num_partitions, output_sharding);
// Try to emit windowed DotGeneral when one operand is partitioned in the same
// way as the output along non-contracting dimensions, but the other operand
// is tiled in other dimensions. Or both operands are partitioned in the same
// way along contracting dimensions, but the output is partitioned along
// non-contracting dimensions.
auto emit_windowed_dot_general =
- [&](int64 matching_operand, int64 windowing_operand,
- bool windowed_at_contracting_dims, bool windowed_at_batch_dims,
- bool operands_sharded_at_contracting_dims)
+ [&](const WindowedEinsumConfig& einsum_config)
-> StatusOr<HloInstruction*> {
- CHECK_EQ(matching_operand + windowing_operand, 1);
- CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims);
+ CHECK(!einsum_config.windowed_at_batch_dims ||
+ !einsum_config.windowed_at_contracting_dims);
+ const bool windowed_at_batch_dims = einsum_config.windowed_at_batch_dims;
+ const bool windowed_at_contracting_dims =
+ einsum_config.windowed_at_contracting_dims;
+ const bool operands_sharded_at_contracting_dims =
+ einsum_config.operands_sharded_at_contracting_dims;
auto unpadded_result_buffer_shape =
MakePartitionedShape(output_base_shape, output_sharding);
auto padded_result_buffer_shape = unpadded_result_buffer_shape;
+ const bool windowed_op_is_lhs =
+ einsum_config.windowed_op == WindowedEinsumOperand::LHS;
// For windowing at batch/non-contracting dims, we produce the result one
// partition at a time, so we need to pad the shape in case of uneven
// partitioning in order to make dynamic-update-slice in-bound.
@@ -268,13 +611,13 @@
!operands_sharded_at_contracting_dims) {
padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning(
padded_result_buffer_shape,
- windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output
- : *rhs_sharding_transposed_to_match_output);
+ windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
+ : *rhs_sharding_transposed_to_match_output);
}
// Mask the padding area of the windowed operand with zero if there is
// uneven partitioning.
if (windowed_at_contracting_dims) {
- auto& to_mask = windowing_operand == 0 ? lhs : rhs;
+ auto& to_mask = windowed_op_is_lhs ? lhs : rhs;
to_mask =
to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(output_base_shape.element_type()))));
@@ -290,8 +633,8 @@
(!(options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
operands_sharded_at_contracting_dims)
? CreateZero(padded_result_buffer_shape, b)
- : windowing_operand == 0 ? lhs.hlo()
- : rhs.hlo();
+ : windowed_op_is_lhs ? lhs.hlo()
+ : rhs.hlo();
if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0 &&
!operands_sharded_at_contracting_dims) {
@@ -356,72 +699,27 @@
cw_data_partition_id = body_b.AddInstruction(
HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
cw_data_partition_id, partition_count));
- auto ccw_dot_lhs = l;
- auto ccw_dot_rhs = r;
- auto cw_dot_lhs = windowing_operand == 0 ? extra_inout : l;
- auto cw_dot_rhs = windowing_operand == 0 ? r : extra_inout;
- if (windowed_at_contracting_dims || windowed_at_batch_dims ||
- operands_sharded_at_contracting_dims) {
- // Slice the matching operand according to the partitioned dimensions on
- // the windowed operand or the output.
- auto slice_operand = matching_operand == 0 ? l : r;
-
- // We do this by treating the matching operand as replicated, and
- // resharding it to match the windowed operand or the output.
- auto gen_slice =
- [&](HloInstruction* data_partition_id) -> HloInstruction* {
- slice_operand->set_sharding(HloSharding::Replicate());
- auto state = lhs.state();
- state.b = &body_b;
- state.partition_id = data_partition_id;
- state.reshard_cache->per_hlo_cache.erase(slice_operand);
- const HloSharding* slice_sharding;
- if (operands_sharded_at_contracting_dims) {
- slice_sharding = windowing_operand == 0
- ? &*output_sharding_transposed_to_match_rhs
- : &*output_sharding_transposed_to_match_lhs;
- } else {
- slice_sharding = windowing_operand == 0
- ? &*lhs_sharding_transposed_to_match_rhs
- : &*rhs_sharding_transposed_to_match_lhs;
- }
- auto slice =
- PartitionedHlo(slice_operand, slice_operand->shape(), state)
- .Reshard(*slice_sharding)
- .hlo();
- slice_operand->clear_sharding();
- return slice;
- };
-
- auto ccw_slice = gen_slice(ccw_data_partition_id);
- auto cw_slice = gen_slice(cw_data_partition_id);
- if (matching_operand == 0) {
- ccw_dot_lhs = ccw_slice;
- cw_dot_lhs = cw_slice;
- } else {
- ccw_dot_rhs = ccw_slice;
- cw_dot_rhs = cw_slice;
- }
- }
-
+ // Calculate concat dim.
const HloSharding* slice_sharding;
if (operands_sharded_at_contracting_dims) {
- slice_sharding = windowing_operand == 0
+ slice_sharding = windowed_op_is_lhs
? &*output_sharding_transposed_to_match_rhs
: &*output_sharding_transposed_to_match_lhs;
} else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
- slice_sharding = windowing_operand == 0
+ slice_sharding = windowed_op_is_lhs
? &*lhs_sharding_transposed_to_match_rhs
: &*rhs_sharding_transposed_to_match_lhs;
} else {
- slice_sharding = windowing_operand == 0
+ slice_sharding = windowed_op_is_lhs
? &*lhs_sharding_transposed_to_match_output
: &*rhs_sharding_transposed_to_match_output;
}
+ CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
+ num_partitions);
int64 slice_sharding_dim = -1;
for (int64 i = 0; i < slice_sharding->tile_assignment().num_dimensions();
++i) {
- if (slice_sharding->tile_assignment().dim(i) == num_partitions) {
+ if (slice_sharding->tile_assignment().dim(i) > 1) {
slice_sharding_dim = i;
break;
}
@@ -429,47 +727,275 @@
int64 lhs_concat_dim = -1;
int64 rhs_concat_dim = -1;
if (operands_sharded_at_contracting_dims) {
- if (windowing_operand == 0) {
+ if (windowed_op_is_lhs) {
rhs_concat_dim = slice_sharding_dim;
} else {
lhs_concat_dim = slice_sharding_dim;
}
} else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
- lhs_concat_dim = windowing_operand == 0
- ? rhs_to_lhs_indices[slice_sharding_dim]
- : slice_sharding_dim;
- rhs_concat_dim = windowing_operand == 0
- ? slice_sharding_dim
- : lhs_to_rhs_indices[slice_sharding_dim];
+ lhs_concat_dim =
+ windowed_op_is_lhs
+ ? indices_map.rhs_to_lhs_indices[slice_sharding_dim]
+ : slice_sharding_dim;
+ rhs_concat_dim =
+ windowed_op_is_lhs
+ ? slice_sharding_dim
+ : indices_map.lhs_to_rhs_indices[slice_sharding_dim];
} else {
- if (windowing_operand == 0) {
- lhs_concat_dim = output_to_lhs_indices[slice_sharding_dim];
+ if (windowed_op_is_lhs) {
+ lhs_concat_dim =
+ indices_map.output_to_lhs_indices[slice_sharding_dim];
} else {
- rhs_concat_dim = output_to_rhs_indices[slice_sharding_dim];
+ rhs_concat_dim =
+ indices_map.output_to_rhs_indices[slice_sharding_dim];
}
}
- auto dot_lhs = ccw_dot_lhs;
- auto dot_rhs = ccw_dot_rhs;
- if (lhs_concat_dim != -1) {
+ DotDimensionNumbers new_ddnums;
+ if (original_hlo->opcode() == HloOpcode::kDot) {
+ new_ddnums = original_hlo->dot_dimension_numbers();
+ }
+
+ auto dot_lhs = l;
+ auto dot_rhs = r;
+ auto original_dot_lhs = l;
+ auto original_dot_rhs = r;
+ if (windowed_at_contracting_dims || windowed_at_batch_dims ||
+ operands_sharded_at_contracting_dims) {
+ // Slice the matching operand according to the partitioned dimensions
+ // on the windowed operand or the output.
+ auto slice_operand = !windowed_op_is_lhs ? l : r;
+
+ // Pad the sharding dim first (then the concat dim) for correctness.
+ auto sharding_dim_size =
+ slice_operand->shape().dimensions(slice_sharding_dim);
+ if (sharding_dim_size % num_partitions != 0) {
+ slice_operand = PadBaseShapeBeforeUnevenTiledSharding(
+ slice_operand, *slice_sharding, &body_b);
+ }
+
+ // We do this by treating the matching operand as replicated, and
+ // resharding it to match the windowed operand or the output.
+ auto gen_slice = [&](HloInstruction* data_partition_id,
+ bool ccw) -> HloInstruction* {
+ std::vector<int64> new_dims;
+ for (int64 i = 0; i < slice_operand->shape().dimensions_size(); ++i) {
+ if (i == slice_sharding_dim) {
+ new_dims.push_back(1);
+ }
+ new_dims.push_back(slice_operand->shape().dimensions(i));
+ }
+ auto reshaped_slice_operand =
+ body_b.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(slice_operand->shape().element_type(),
+ new_dims),
+ slice_operand));
+ auto min = body_b.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::MinValue(
+ reshaped_slice_operand->shape().element_type())));
+ std::vector<int64> min_padding(
+ reshaped_slice_operand->shape().rank());
+ auto padded_slice_operand = reshaped_slice_operand;
+ auto padded_shape = padded_slice_operand->shape();
+ int64 padding_dim = slice_sharding_dim;
+ padded_shape.set_dimensions(padding_dim, 2);
+ if (ccw) {
+ // ccw pad high
+ PaddingConfig ccw_pad_config =
+ window_util::MakeSymmetricPadding(min_padding);
+ ccw_pad_config.mutable_dimensions(padding_dim)
+ ->set_edge_padding_low(0);
+ ccw_pad_config.mutable_dimensions(padding_dim)
+ ->set_edge_padding_high(1);
+ padded_slice_operand =
+ body_b.AddInstruction(HloInstruction::CreatePad(
+ padded_shape, padded_slice_operand, min, ccw_pad_config));
+ } else {
+ // cw pad low
+ PaddingConfig cw_pad_config =
+ window_util::MakeSymmetricPadding(min_padding);
+ cw_pad_config.mutable_dimensions(padding_dim)
+ ->set_edge_padding_low(1);
+ cw_pad_config.mutable_dimensions(padding_dim)
+ ->set_edge_padding_high(0);
+ padded_slice_operand =
+ body_b.AddInstruction(HloInstruction::CreatePad(
+ padded_shape, padded_slice_operand, min, cw_pad_config));
+ }
+
+ padded_slice_operand->set_sharding(HloSharding::Replicate());
+ auto state = lhs.state();
+ state.b = &body_b;
+ state.partition_id = data_partition_id;
+ state.reshard_cache->per_hlo_cache.erase(padded_slice_operand);
+ auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding(
+ slice_operand->shape(), reshaped_slice_operand->shape(),
+ *slice_sharding);
+ auto padded_slice =
+ PartitionedHlo(padded_slice_operand,
+ padded_slice_operand->shape(), state)
+ .Reshard(*padded_slice_sharding)
+ .hlo();
+ padded_slice_operand->clear_sharding();
+ return padded_slice;
+ };
+
+ auto ccw_slice = gen_slice(ccw_data_partition_id, true);
+ auto cw_slice = gen_slice(cw_data_partition_id, false);
+ auto slice = body_b.AddInstruction(HloInstruction::CreateBinary(
+ ccw_slice->shape(), HloOpcode::kMaximum, ccw_slice, cw_slice));
+ // Reshape. The reshaped slice will not be used to produce the final
+ // result, but used as a hint for the shape inference.
+ std::vector<int64> reshaped_slice_dims;
+ for (int64 i = 0; i < slice->shape().dimensions_size(); ++i) {
+ auto dim_size = slice->shape().dimensions(i);
+ if (i == (slice_sharding_dim + 1)) {
+ reshaped_slice_dims.push_back(dim_size * 2);
+ } else if (i != slice_sharding_dim) {
+ reshaped_slice_dims.push_back(dim_size);
+ }
+ }
+ auto reshaped_slice =
+ body_b.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(slice->shape().element_type(),
+ reshaped_slice_dims),
+ slice));
+
+ if (!windowed_op_is_lhs) {
+ dot_lhs = slice;
+ original_dot_lhs = reshaped_slice;
+ if (original_hlo->opcode() == HloOpcode::kDot) {
+ UpdateDDNums(&new_ddnums, slice_sharding_dim, true);
+ }
+ } else {
+ dot_rhs = slice;
+ original_dot_rhs = reshaped_slice;
+ if (original_hlo->opcode() == HloOpcode::kDot) {
+ UpdateDDNums(&new_ddnums, slice_sharding_dim, false);
+ }
+ }
+ }
+
+ auto ccw_dot_lhs = l;
+ auto ccw_dot_rhs = r;
+ auto cw_dot_lhs = windowed_op_is_lhs ? extra_inout : l;
+ auto cw_dot_rhs = windowed_op_is_lhs ? r : extra_inout;
+ if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
auto lhs_concat_shape = ccw_dot_lhs->shape();
lhs_concat_shape.set_dimensions(
lhs_concat_dim,
ccw_dot_lhs->shape().dimensions(lhs_concat_dim) * 2);
dot_lhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
lhs_concat_shape, {ccw_dot_lhs, cw_dot_lhs}, lhs_concat_dim));
+ original_dot_lhs = dot_lhs;
+
+ // Reshape
+ std::vector<int64> reshaped_dims(dot_lhs->shape().dimensions().begin(),
+ dot_lhs->shape().dimensions().end());
+ reshaped_dims[lhs_concat_dim] /= 2;
+ reshaped_dims.insert(reshaped_dims.begin() + lhs_concat_dim, 2);
+ dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(dot_lhs->shape().element_type(),
+ reshaped_dims),
+ dot_lhs));
+
+ if (original_hlo->opcode() == HloOpcode::kDot) {
+ UpdateDDNums(&new_ddnums, lhs_concat_dim, true);
+ }
}
- if (rhs_concat_dim != -1) {
+ if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
auto rhs_concat_shape = ccw_dot_rhs->shape();
rhs_concat_shape.set_dimensions(
rhs_concat_dim,
ccw_dot_rhs->shape().dimensions(rhs_concat_dim) * 2);
dot_rhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
rhs_concat_shape, {ccw_dot_rhs, cw_dot_rhs}, rhs_concat_dim));
+ original_dot_rhs = dot_rhs;
+
+ // Reshape
+ std::vector<int64> reshaped_dims(dot_rhs->shape().dimensions().begin(),
+ dot_rhs->shape().dimensions().end());
+ reshaped_dims[rhs_concat_dim] /= 2;
+ reshaped_dims.insert(reshaped_dims.begin() + rhs_concat_dim, 2);
+ dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(dot_rhs->shape().element_type(),
+ reshaped_dims),
+ dot_rhs));
+
+ if (original_hlo->opcode() == HloOpcode::kDot) {
+ UpdateDDNums(&new_ddnums, rhs_concat_dim, false);
+ }
}
- TF_ASSIGN_OR_RETURN(
- auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
+ // The generated original dot will not be used.
+ TF_ASSIGN_OR_RETURN(auto original_dot,
+ create_sharded_dot(original_dot_lhs, original_dot_rhs,
+ &body_b, conv_window));
+ VLOG(2) << original_dot->ToString();
+
+ // Generate the correct shape of the new dot/conv.
+ auto original_sharded_dot_shape = original_dot->shape();
+ auto new_dot_shape = original_sharded_dot_shape;
+ std::vector<int64> new_dims(new_dot_shape.dimensions().begin(),
+ new_dot_shape.dimensions().end());
+ if (!windowed_at_contracting_dims) {
+ auto slice_dim =
+ lhs_concat_dim != -1
+ ? indices_map.lhs_to_output_indices[lhs_concat_dim]
+ : indices_map.rhs_to_output_indices[rhs_concat_dim];
+ new_dims[slice_dim] /= 2;
+ new_dims.insert(new_dims.begin() + slice_dim, 2);
+ } else {
+ new_dims.push_back(1);
+ }
+ new_dot_shape =
+ ShapeUtil::MakeShape(original_hlo->shape().element_type(), new_dims);
+
+ HloInstruction* dot;
+ if (original_hlo->opcode() == HloOpcode::kDot) {
+ dot = body_b.AddInstruction(HloInstruction::CreateDot(
+ new_dot_shape, dot_lhs, dot_rhs, new_ddnums,
+ original_hlo->precision_config()));
+ } else {
+ if (!windowed_at_contracting_dims && !windowed_at_batch_dims) {
+ if (lhs_concat_dim != -1) {
+ std::vector<int64> new_dims(dot_rhs->shape().dimensions().begin(),
+ dot_rhs->shape().dimensions().end());
+ new_dims.push_back(1);
+ dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(dot_rhs->shape().element_type(), new_dims),
+ dot_rhs));
+ }
+ if (rhs_concat_dim != -1) {
+ std::vector<int64> new_dims(dot_lhs->shape().dimensions().begin(),
+ dot_lhs->shape().dimensions().end());
+ new_dims.push_back(1);
+ dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(dot_lhs->shape().element_type(), new_dims),
+ dot_lhs));
+ }
+ }
+
+ dot = body_b.AddInstruction(HloInstruction::CreateConvolve(
+ new_dot_shape, dot_lhs, dot_rhs,
+ original_dot->feature_group_count(),
+ original_dot->batch_group_count(),
+ GenNewWindow(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
+ rhs_concat_dim, windowed_at_contracting_dims,
+ windowed_at_batch_dims),
+ GenNewConvDNums(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
+ rhs_concat_dim, windowed_at_contracting_dims,
+ windowed_at_batch_dims,
+ indices_map.lhs_to_output_indices,
+ indices_map.rhs_to_output_indices, new_dot_shape),
+ original_dot->precision_config()));
+ }
+ VLOG(2) << dot->ToString();
+
+ // Reshape to the original sharded dot shape.
+ dot = body_b.AddInstruction(
+ HloInstruction::CreateReshape(original_sharded_dot_shape, dot));
+
if (windowed_at_contracting_dims) {
// Accumulate the partial output to the result buffer.
o = body_b.AddInstruction(
@@ -479,9 +1005,10 @@
// dimensions, so we need a dynamic-update-slice to save the partial
// output in the result buffer.
auto slice_shape = dot->shape();
- auto slice_dim = lhs_concat_dim != -1
- ? lhs_to_output_indices[lhs_concat_dim]
- : rhs_to_output_indices[rhs_concat_dim];
+ auto slice_dim =
+ lhs_concat_dim != -1
+ ? indices_map.lhs_to_output_indices[lhs_concat_dim]
+ : indices_map.rhs_to_output_indices[rhs_concat_dim];
slice_shape.set_dimensions(slice_dim,
dot->shape().dimensions(slice_dim) / 2);
std::vector<int64> ccw_start_indices(dot->shape().rank(), 0);
@@ -503,13 +1030,13 @@
} else {
auto ccw_offsets = MakePartitionOffsets(
o->shape(),
- windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output
- : *rhs_sharding_transposed_to_match_output,
+ windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
+ : *rhs_sharding_transposed_to_match_output,
ccw_data_partition_id, &body_b);
auto cw_offsets = MakePartitionOffsets(
o->shape(),
- windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output
- : *rhs_sharding_transposed_to_match_output,
+ windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
+ : *rhs_sharding_transposed_to_match_output,
cw_data_partition_id, &body_b);
o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
o->shape(), o, ccw_dot, ccw_offsets));
@@ -545,7 +1072,7 @@
operands_sharded_at_contracting_dims) {
// Slice the matching operand according to the partitioned dimensions on
// the windowed operand or the output.
- auto slice_operand = matching_operand == 0 ? l : r;
+ auto slice_operand = !windowed_op_is_lhs ? l : r;
// We do this by treating the matching operand as replicated, and
// resharding it to match the windowed operand or the output.
slice_operand->set_sharding(HloSharding::Replicate());
@@ -555,11 +1082,11 @@
state.reshard_cache->per_hlo_cache.erase(slice_operand);
const HloSharding* slice_sharding;
if (operands_sharded_at_contracting_dims) {
- slice_sharding = windowing_operand == 0
+ slice_sharding = windowed_op_is_lhs
? &*output_sharding_transposed_to_match_rhs
: &*output_sharding_transposed_to_match_lhs;
} else {
- slice_sharding = windowing_operand == 0
+ slice_sharding = windowed_op_is_lhs
? &*lhs_sharding_transposed_to_match_rhs
: &*rhs_sharding_transposed_to_match_lhs;
}
@@ -568,7 +1095,7 @@
.Reshard(*slice_sharding)
.hlo();
slice_operand->clear_sharding();
- if (matching_operand == 0) {
+ if (!windowed_op_is_lhs) {
dot_lhs = slice;
} else {
dot_rhs = slice;
@@ -587,8 +1114,8 @@
// output in the result buffer.
auto offsets = MakePartitionOffsets(
o->shape(),
- windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output
- : *rhs_sharding_transposed_to_match_output,
+ windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
+ : *rhs_sharding_transposed_to_match_output,
data_partition_id, &body_b);
o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
o->shape(), o, dot, offsets));
@@ -632,7 +1159,7 @@
auto next_l = l;
auto next_r = r;
auto ccw_cp_input = operands_sharded_at_contracting_dims ? o
- : windowing_operand == 0 ? l
+ : windowed_op_is_lhs ? l
: r;
auto ccw_cp_output =
lhs.state()
@@ -641,7 +1168,7 @@
(*lhs.state().next_channel_id)++);
if (operands_sharded_at_contracting_dims) {
o = ccw_cp_output;
- } else if (windowing_operand == 0) {
+ } else if (windowed_op_is_lhs) {
next_l = ccw_cp_output;
} else {
next_r = ccw_cp_output;
@@ -667,7 +1194,7 @@
auto second_next_l = next_l;
auto second_next_r = next_r;
ccw_cp_input = operands_sharded_at_contracting_dims ? o
- : windowing_operand == 0 ? next_l
+ : windowed_op_is_lhs ? next_l
: next_r;
ccw_cp_output =
lhs.state()
@@ -676,7 +1203,7 @@
(*lhs.state().next_channel_id)++);
if (operands_sharded_at_contracting_dims) {
o = ccw_cp_output;
- } else if (windowing_operand == 0) {
+ } else if (windowed_op_is_lhs) {
second_next_l = ccw_cp_output;
} else {
second_next_r = ccw_cp_output;
@@ -749,13 +1276,13 @@
// Even number iteration.
auto next_l = l;
auto next_r = r;
- auto cp_input = windowing_operand == 0 ? l : r;
+ auto cp_input = windowed_op_is_lhs ? l : r;
auto cp_output = lhs.state()
.collective_ops_creator
.create_cross_partition_collective_permute(
&body_b, cp_input, sd_pairs,
(*lhs.state().next_channel_id)++);
- if (windowing_operand == 0) {
+ if (windowed_op_is_lhs) {
next_l = cp_output;
} else {
next_r = cp_output;
@@ -771,13 +1298,13 @@
// Odd number iteration.
auto second_next_l = next_l;
auto second_next_r = next_r;
- cp_input = windowing_operand == 0 ? next_l : next_r;
+ cp_input = windowed_op_is_lhs ? next_l : next_r;
cp_output = lhs.state()
.collective_ops_creator
.create_cross_partition_collective_permute(
&body_b, cp_input, sd_pairs,
(*lhs.state().next_channel_id)++);
- if (windowing_operand == 0) {
+ if (windowed_op_is_lhs) {
second_next_l = cp_output;
} else {
second_next_r = cp_output;
@@ -824,7 +1351,7 @@
auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
0,
operands_sharded_at_contracting_dims ? o->shape()
- : windowing_operand == 0 ? l->shape()
+ : windowed_op_is_lhs ? l->shape()
: r->shape(),
"window"));
std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
@@ -842,27 +1369,27 @@
ncp_b.AddInstruction(HloInstruction::CreateParameter(
0,
operands_sharded_at_contracting_dims ? o->shape()
- : windowing_operand == 0 ? l->shape()
+ : windowed_op_is_lhs ? l->shape()
: r->shape(),
"window"));
}
conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
operands_sharded_at_contracting_dims ? o->shape()
- : windowing_operand == 0 ? l->shape()
+ : windowed_op_is_lhs ? l->shape()
: r->shape(),
has_more,
operands_sharded_at_contracting_dims ? o
- : windowing_operand == 0 ? l
+ : windowed_op_is_lhs ? l
: r,
module->AddEmbeddedComputation(cp_b.Build()),
operands_sharded_at_contracting_dims ? o
- : windowing_operand == 0 ? l
+ : windowed_op_is_lhs ? l
: r,
module->AddEmbeddedComputation(ncp_b.Build())));
}
if (operands_sharded_at_contracting_dims) {
o = conditional;
- } else if (windowing_operand == 0) {
+ } else if (windowed_op_is_lhs) {
l = conditional;
} else {
r = conditional;
@@ -895,7 +1422,7 @@
b->AddInstruction(HloInstruction::CreateTuple(
{lhs.hlo(), rhs.hlo(), result_buffer, extra_buffer, iteration}))));
windowed_dot_general_loops->push_back(
- {while_loop, windowing_operand, windowed_at_contracting_dims,
+ {while_loop, windowed_op_is_lhs ? 0 : 1, windowed_at_contracting_dims,
windowed_at_batch_dims, operands_sharded_at_contracting_dims});
auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
result_buffer->shape(), while_loop, 2));
@@ -937,45 +1464,21 @@
}
return result;
};
- if (output_lhs_non_contracting_partitions == num_partitions &&
- output_sharding_transposed_to_match_lhs == lhs_sharding &&
- ShapeSizeInBytes(rhs.base_shape()) >=
- options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
- if (rhs_contracting_partitions == num_partitions) {
- return emit_windowed_dot_general(0, 1, true, false, false);
- }
- if (rhs_non_contracting_partitions == num_partitions) {
- return emit_windowed_dot_general(0, 1, false, false, false);
- }
- if (rhs_batch_partitions == num_partitions) {
- return emit_windowed_dot_general(0, 1, false, true, false);
- }
- }
- if (output_rhs_non_contracting_partitions == num_partitions &&
- output_sharding_transposed_to_match_rhs == rhs_sharding &&
- ShapeSizeInBytes(lhs.base_shape()) >=
- options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
- if (lhs_contracting_partitions == num_partitions) {
- return emit_windowed_dot_general(1, 0, true, false, false);
- }
- if (lhs_non_contracting_partitions == num_partitions) {
- return emit_windowed_dot_general(1, 0, false, false, false);
- }
- if (lhs_batch_partitions == num_partitions) {
- return emit_windowed_dot_general(1, 0, false, true, false);
- }
- }
- if (lhs_contracting_partitions == rhs_contracting_partitions &&
- lhs_contracting_partitions == num_partitions &&
- output_sharding_dim > -1 &&
- ShapeSizeInBytes(output_base_shape) >=
- options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
- if (output_lhs_non_contracting_partitions == num_partitions) {
- return emit_windowed_dot_general(0, 1, false, false, true);
- }
- if (output_rhs_non_contracting_partitions == num_partitions) {
- return emit_windowed_dot_general(1, 0, false, false, true);
- }
+ absl::optional<WindowedEinsumConfig> e_config =
+ GetWindowedEinsumConfiguration(
+ num_partitions, output_lhs_non_contracting_partitions,
+ output_rhs_non_contracting_partitions, rhs_contracting_partitions,
+ rhs_non_contracting_partitions, rhs_batch_partitions,
+ lhs_contracting_partitions, lhs_non_contracting_partitions,
+ lhs_batch_partitions, output_sharding_dim,
+ ShapeSizeInBytes(rhs.base_shape()),
+ ShapeSizeInBytes(lhs.base_shape()),
+ ShapeSizeInBytes(output_base_shape),
+ options.threshold_for_windowed_einsum_mib,
+ output_sharding_transposed_to_match_lhs,
+ output_sharding_transposed_to_match_rhs, lhs_sharding, rhs_sharding);
+ if (e_config) {
+ return emit_windowed_dot_general(*e_config);
}
{
@@ -1374,6 +1877,105 @@
.hlo();
}
+GroupedSharding GetNonContractingPartitionGroupedShardingForMatchedOperand(
+ bool lhs_matching, const HloSharding& matching_sharding,
+ const HloSharding& output_sharding,
+ absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims) {
+ std::vector<int64> matching_sharding_dims =
+ matching_sharding.tile_assignment().dimensions();
+ std::vector<int64> matching_dims;
+ std::vector<int64> output_dims;
+ // Make sure the partitioning on matching's non-contracting dimensions
+ // defines the same device groups for both matching and output.
+ for (const auto& dim : partitioned_dims) {
+ int64 md = lhs_matching ? dim.lhs : dim.rhs;
+ matching_sharding_dims[md] =
+ output_sharding.tile_assignment().dim(dim.output);
+ matching_dims.push_back(md);
+ output_dims.push_back(dim.output);
+ }
+ GroupedSharding output_grouped =
+ GroupShardingOnDims(output_sharding, output_dims);
+ Array<int64> reshaped_matching_tiling = matching_sharding.tile_assignment();
+ reshaped_matching_tiling.Reshape(matching_sharding_dims);
+ return AlignGroupsWith(
+ GroupShardingOnDims(
+ matching_sharding.ReplicateOnLastTileDim()
+ ? HloSharding::PartialTile(reshaped_matching_tiling)
+ : HloSharding::Tile(reshaped_matching_tiling),
+ matching_dims),
+ output_grouped);
+}
+
+absl::optional<GroupedSharding>
+GetNonContractingPartitionGroupedShardingForOtherOperand(
+ bool lhs_matching, const Shape& output_base_shape, const Shape& other_shape,
+ int64 other_contracting_partitions, int64 other_non_contracting_partitions,
+ int64 matching_contracting_partitions,
+ int64 output_other_non_contracting_partitions,
+ const HloSharding& other_sharding, const HloSharding& output_sharding,
+ absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,
+ absl::Span<const DotConvDimsMapping::DimsMapping>
+ other_non_contracting_dims,
+ absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims) {
+ int64 group_count = 1;
+ std::vector<int64> output_dims;
+ for (const auto& dim : matching_partitioned_dims) {
+ output_dims.push_back(dim.output);
+ group_count *= output_sharding.tile_assignment().dim(dim.output);
+ }
+ GroupedSharding output_grouped =
+ GroupShardingOnDims(output_sharding, output_dims);
+ std::vector<int64> other_group_dims;
+ if (other_sharding.ReplicateOnLastTileDim() &&
+ other_sharding.tile_assignment().dimensions().back() % group_count == 0) {
+ other_group_dims.push_back(
+ other_sharding.tile_assignment().num_dimensions() - 1);
+ } else {
+ const bool may_replicate_other_contracting_dims =
+ (other_contracting_partitions == group_count &&
+ other_non_contracting_partitions ==
+ output_other_non_contracting_partitions);
+ const bool may_replicate_other_non_contracting_dims =
+ group_count == other_non_contracting_partitions &&
+ matching_contracting_partitions == other_contracting_partitions;
+ if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
+ other_sharding, output_grouped.device_groups)) {
+ other_group_dims = std::move(*found_dims);
+ } else if (may_replicate_other_contracting_dims &&
+ (!may_replicate_other_non_contracting_dims ||
+ ShapeUtil::ByteSizeOf(other_shape)) <=
+ ShapeUtil::ByteSizeOf(MakePartitionedShape(
+ output_base_shape, output_sharding))) {
+ for (const auto& dim : other_contracting_dims) {
+ other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
+ }
+ } else if (may_replicate_other_non_contracting_dims) {
+ for (const auto& dim : other_non_contracting_dims) {
+ other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
+ }
+ } else {
+ return absl::nullopt;
+ }
+ }
+ if (other_group_dims.size() == 1 &&
+ other_group_dims[0] ==
+ other_sharding.tile_assignment().num_dimensions() - 1) {
+ return AlignGroupsWith(
+ GroupShardingOnDims(
+ other_sharding, {other_group_dims[0]},
+ {other_sharding.tile_assignment().dimensions().back() /
+ group_count}),
+ output_grouped, /*ignore_group_order=*/true);
+
+ } else if (!other_sharding.IsReplicated()) {
+ return AlignGroupsWith(
+ GroupShardingOnDims(other_sharding, other_group_dims), output_grouped,
+ /*ignore_group_order=*/true);
+ }
+ return absl::nullopt;
+}
+
StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
int64 matching_contracting_partitions, int64 other_contracting_partitions,
@@ -1399,71 +2001,34 @@
}
});
- auto matching_sharding_dims =
- matching.sharding().tile_assignment().dimensions();
- std::vector<int64> matching_dims;
std::vector<int64> output_dims;
- int64 group_count = 1;
- // Make sure the partitioning on matching's non-contracting dimensions
- // defines the same device groups for both matching and output.
for (const auto& dim : partitioned_non_contracting_dims) {
- int64 md = lhs_matching ? dim.lhs : dim.rhs;
- matching_sharding_dims[md] =
- output_sharding.tile_assignment().dim(dim.output);
- matching_dims.push_back(md);
output_dims.push_back(dim.output);
- group_count *= output_sharding.tile_assignment().dim(dim.output);
}
- auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
- auto reshaped_matching_tiling = matching.sharding().tile_assignment();
- reshaped_matching_tiling.Reshape(matching_sharding_dims);
- auto matching_grouped = AlignGroupsWith(
- GroupShardingOnDims(
- matching.sharding().ReplicateOnLastTileDim()
- ? HloSharding::PartialTile(reshaped_matching_tiling)
- : HloSharding::Tile(reshaped_matching_tiling),
- matching_dims),
- output_grouped);
+ GroupedSharding output_grouped =
+ GroupShardingOnDims(output_sharding, output_dims);
+ GroupedSharding matching_grouped =
+ GetNonContractingPartitionGroupedShardingForMatchedOperand(
+ lhs_matching, matching.sharding(), output_sharding,
+ partitioned_non_contracting_dims);
if (require_matching_devices_to_group &&
matching.sharding() != UngroupSharding(matching_grouped)) {
return nullptr;
}
+ absl::optional<GroupedSharding> other_grouped =
+ GetNonContractingPartitionGroupedShardingForOtherOperand(
+ lhs_matching, output_base_shape, other.hlo()->shape(),
+ other_contracting_partitions, other_non_contracting_partitions,
+ matching_contracting_partitions,
+ output_other_non_contracting_partitions, other.sharding(),
+ output_sharding, partitioned_non_contracting_dims,
+ lhs_matching ? dims_mapping.rhs_non_contracting_dims
+ : dims_mapping.lhs_non_contracting_dims,
+ dims_mapping.contracting_dims);
- std::vector<int64> other_group_dims;
- if (other.sharding().ReplicateOnLastTileDim() &&
- other.sharding().tile_assignment().dimensions().back() % group_count ==
- 0) {
- other_group_dims.push_back(other.base_shape().rank());
- } else {
- const bool may_replicate_other_contracting_dims =
- (other_contracting_partitions == group_count &&
- other_non_contracting_partitions ==
- output_other_non_contracting_partitions);
- const bool may_replicate_other_non_contracting_dims =
- group_count == other_non_contracting_partitions &&
- matching_contracting_partitions == other_contracting_partitions;
- if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
- other.sharding(), output_grouped.device_groups)) {
- other_group_dims = std::move(*found_dims);
- } else if (may_replicate_other_contracting_dims &&
- (!may_replicate_other_non_contracting_dims ||
- ShapeUtil::ByteSizeOf(other.hlo()->shape()) <=
- ShapeUtil::ByteSizeOf(MakePartitionedShape(
- output_base_shape, output_sharding)))) {
- for (const auto& dim : dims_mapping.contracting_dims) {
- other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
- }
- } else if (may_replicate_other_non_contracting_dims) {
- for (const auto& dim : lhs_matching
- ? dims_mapping.rhs_non_contracting_dims
- : dims_mapping.lhs_non_contracting_dims) {
- other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
- }
- } else {
- other = other.Replicate();
- }
+ if (!other_grouped) {
+ other = other.Replicate();
}
-
matching = matching.Reshard(UngroupSharding(matching_grouped));
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
matching.state(), matching_grouped.device_groups, b);
@@ -1475,32 +2040,23 @@
per_group_partitioner_state);
auto partially_replicated_other = other.hlo();
- if (other_group_dims.size() == 1 &&
- other_group_dims[0] == other.base_shape().rank()) {
+ if (other_grouped && other_grouped->group_dims.size() == 1 &&
+ other_grouped->group_dims[0] == other.base_shape().rank()) {
// Group on replication dim.
- auto grouped = AlignGroupsWith(
- GroupShardingOnDims(
- other.sharding(), {other_group_dims[0]},
- {other.sharding().tile_assignment().dimensions().back() /
- group_count}),
- output_grouped, /*ignore_group_order=*/true);
- other = other.Reshard(UngroupSharding(grouped));
+ other = other.Reshard(UngroupSharding(*other_grouped));
partially_replicated_other = other.hlo();
top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding());
- partially_replicated_other->set_sharding(grouped.sharding);
+ partially_replicated_other->set_sharding(other_grouped->sharding);
} else if (!other.sharding().IsReplicated()) {
- auto other_grouped =
- AlignGroupsWith(GroupShardingOnDims(other.sharding(), other_group_dims),
- output_grouped, /*ignore_group_order=*/true);
- other = other.Reshard(UngroupSharding(other_grouped));
+ other = other.Reshard(UngroupSharding(*other_grouped));
partially_replicated_other =
other
.Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
- other.sharding(), other_grouped.group_dims))
+ other.sharding(), other_grouped->group_dims))
.hlo();
top_level_sharding_to_reset.emplace_back(
partially_replicated_other, partially_replicated_other->sharding());
- partially_replicated_other->set_sharding(other_grouped.sharding);
+ partially_replicated_other->set_sharding(other_grouped->sharding);
}
auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
per_group_partitioner_state);
@@ -1728,6 +2284,147 @@
return new_dims_mapping;
}
+bool LhsIsBestMatchForNonContractingPartitioning(
+ const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
+ const PartitionedHlo& rhs, const Shape& output_base_shape,
+ const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
+ int64 num_partitions, int64 lhs_non_contracting_partitions,
+ int64 rhs_non_contracting_partitions, int64 lhs_contracting_partitions,
+ int64 rhs_contracting_partitions,
+ int64 output_lhs_non_contracting_partitions,
+ int64 output_rhs_non_contracting_partitions, int64 lhs_batch_partitions,
+ int64 rhs_batch_partitions, bool may_group_on_lhs_non_contracting,
+ bool may_group_on_rhs_non_contracting) {
+ // If both match output non-contracting dimensions, choose the one which
+ // will result in smaller replication of the other operand.
+ bool lhs_matching = may_group_on_lhs_non_contracting &&
+ (!may_group_on_rhs_non_contracting ||
+ lhs_non_contracting_partitions *
+ ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <
+ rhs_non_contracting_partitions *
+ ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
+ // If both grouping are available and the option to choose faster windowed
+ // einsums vs saving memory is enabled then try to determine which of the
+ // operands will generate the least amount of iterations for the windowed
+ // einsum when matched (if a windowed einsum is gonna be generated at all).
+ if (may_group_on_lhs_non_contracting && may_group_on_rhs_non_contracting &&
+ options.choose_faster_windowed_einsum_over_mem) {
+ const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
+ dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
+ output_base_shape.rank());
+ auto subsequent_einsum_iterations_estimate =
+ [&](bool assume_lhs_match) -> absl::optional<int64> {
+ const std::vector<DotConvDimsMapping::DimsMapping>&
+ matching_non_contracting_dims =
+ assume_lhs_match ? dims_mapping.lhs_non_contracting_dims
+ : dims_mapping.rhs_non_contracting_dims;
+ const std::vector<DotConvDimsMapping::DimsMapping>&
+ other_non_contracting_dims =
+ assume_lhs_match ? dims_mapping.rhs_non_contracting_dims
+ : dims_mapping.lhs_non_contracting_dims;
+ const std::vector<int64>& output_to_matching_indices =
+ assume_lhs_match ? indices_map.output_to_lhs_indices
+ : indices_map.output_to_rhs_indices;
+ const std::vector<int64>& output_to_other_indices =
+ assume_lhs_match ? indices_map.output_to_rhs_indices
+ : indices_map.output_to_lhs_indices;
+ const std::vector<int64>& matching_to_output_indices =
+ assume_lhs_match ? indices_map.lhs_to_output_indices
+ : indices_map.rhs_to_output_indices;
+ const std::vector<int64>& other_to_output_indices =
+ assume_lhs_match ? indices_map.rhs_to_output_indices
+ : indices_map.lhs_to_output_indices;
+ const HloSharding& matching_sharding =
+ assume_lhs_match ? lhs.sharding() : rhs.sharding();
+ const HloSharding& other_sharding =
+ assume_lhs_match ? rhs.sharding() : lhs.sharding();
+ const PartitionedHlo& matching_partitioned = assume_lhs_match ? lhs : rhs;
+ const PartitionedHlo& other_partitioned = assume_lhs_match ? rhs : lhs;
+ const int64 matching_non_contracting_partitions =
+ assume_lhs_match ? lhs_non_contracting_partitions
+ : rhs_non_contracting_partitions;
+ const int64 other_non_contracting_partitions =
+ assume_lhs_match ? rhs_non_contracting_partitions
+ : lhs_non_contracting_partitions;
+ const int64 matching_contracting_partitions =
+ assume_lhs_match ? lhs_contracting_partitions
+ : rhs_contracting_partitions;
+ const int64 other_contracting_partitions =
+ assume_lhs_match ? rhs_contracting_partitions
+ : lhs_contracting_partitions;
+ const int64 output_matching_non_contracting_partitions =
+ assume_lhs_match ? output_lhs_non_contracting_partitions
+ : output_rhs_non_contracting_partitions;
+ const int64 output_other_non_contracting_partitions =
+ assume_lhs_match ? output_rhs_non_contracting_partitions
+ : output_lhs_non_contracting_partitions;
+ const int64 matching_batch_partitions =
+ assume_lhs_match ? lhs_batch_partitions : rhs_batch_partitions;
+ const int64 other_batch_partitions =
+ assume_lhs_match ? rhs_batch_partitions : lhs_batch_partitions;
+ std::vector<int64> output_dims;
+ output_dims.reserve(matching_non_contracting_dims.size());
+ for (const DotConvDimsMapping::DimsMapping& dim :
+ matching_non_contracting_dims) {
+ output_dims.push_back(dim.output);
+ }
+ GroupedSharding output_grouped =
+ GroupShardingOnDims(output_sharding, output_dims);
+ GroupedSharding matching_grouped =
+ GetNonContractingPartitionGroupedShardingForMatchedOperand(
+ assume_lhs_match, matching_sharding, output_sharding,
+ matching_non_contracting_dims);
+ absl::optional<GroupedSharding> other_grouped =
+ GetNonContractingPartitionGroupedShardingForOtherOperand(
+ assume_lhs_match, output_base_shape,
+ other_partitioned.hlo()->shape(), other_contracting_partitions,
+ other_non_contracting_partitions, matching_contracting_partitions,
+ output_other_non_contracting_partitions, other_sharding,
+ output_sharding, matching_non_contracting_dims,
+ other_non_contracting_dims, dims_mapping.contracting_dims);
+ absl::optional<HloSharding> output_sharding_transposed_to_match_matching =
+ hlo_sharding_util::TransposeShardingWithCollapsedDims(
+ output_grouped.sharding, output_to_matching_indices,
+ matching_to_output_indices);
+ absl::optional<HloSharding> output_sharding_transposed_to_match_other =
+ hlo_sharding_util::TransposeShardingWithCollapsedDims(
+ output_grouped.sharding, output_to_other_indices,
+ other_to_output_indices);
+ const int64 new_num_partitions =
+ num_partitions / matching_non_contracting_partitions;
+ const int64 output_sharding_dim = FirstShardingDimWithPartitionOfSize(
+ new_num_partitions, output_grouped.sharding);
+ absl::optional<WindowedEinsumConfig> e_config =
+ GetWindowedEinsumConfiguration(
+ new_num_partitions, output_matching_non_contracting_partitions,
+ output_other_non_contracting_partitions, 1,
+ other_non_contracting_partitions, other_batch_partitions,
+ matching_contracting_partitions, 1, matching_batch_partitions,
+ output_sharding_dim,
+ ShapeSizeInBytes(other_partitioned.base_shape()),
+ ShapeSizeInBytes(matching_partitioned.base_shape()) /
+ matching_non_contracting_partitions,
+ ShapeSizeInBytes(
+ GetPerGroupBaseShape(output_grouped, output_base_shape)),
+ options.threshold_for_windowed_einsum_mib,
+ output_sharding_transposed_to_match_matching,
+ output_sharding_transposed_to_match_other,
+ matching_grouped.sharding, other_grouped->sharding);
+ return e_config ? new_num_partitions
+ : absl::optional<int64>(absl::nullopt);
+ };
+ absl::optional<int64> lhs_matching_iterations =
+ subsequent_einsum_iterations_estimate(true);
+ absl::optional<int64> rhs_matching_iterations =
+ subsequent_einsum_iterations_estimate(false);
+ if (lhs_matching_iterations && rhs_matching_iterations &&
+ *lhs_matching_iterations != *rhs_matching_iterations) {
+ lhs_matching = *lhs_matching_iterations < *rhs_matching_iterations;
+ }
+ }
+ return lhs_matching;
+}
+
// Recursive partitioning function. If there are partial dimensions matching in
// the operands and output, group the devices and recursively partition the
// in-group dot.
@@ -1929,15 +2626,14 @@
rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
rhs_non_contracting_partitions > 1;
if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
- // If both match output non-contracting dimensions, choose the one which
- // will result in smaller replication of the other operand.
- const bool lhs_matching =
- may_group_on_lhs_non_contracting &&
- (!may_group_on_rhs_non_contracting ||
- lhs_non_contracting_partitions *
- ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <
- rhs_non_contracting_partitions *
- ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
+ const bool lhs_matching = LhsIsBestMatchForNonContractingPartitioning(
+ dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
+ num_partitions, lhs_non_contracting_partitions,
+ rhs_non_contracting_partitions, lhs_contracting_partitions,
+ rhs_contracting_partitions, output_lhs_non_contracting_partitions,
+ output_rhs_non_contracting_partitions, lhs_batch_partitions,
+ rhs_batch_partitions, may_group_on_lhs_non_contracting,
+ may_group_on_rhs_non_contracting);
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDotGroupOnNonContracting(
diff --git a/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc b/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc
index 013d155..67a7076 100644
--- a/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc
@@ -165,7 +165,9 @@
CHECK(pgather_sharding.has_value());
pgather->set_sharding(*pgather_sharding);
VLOG(5) << "[Gather partitioning]: Partitioned as index only";
- return pgather;
+ return PartitionedHlo(pgather, gather->shape(), operand.state())
+ .Reshard(gather->sharding())
+ .hlo();
}
}
return nullptr;
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
index 51187ad..052c2e6 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
@@ -1594,7 +1594,7 @@
dim->set_window_reversal(false);
dim->set_padding_low(-hlo->slice_starts(i));
dim->set_padding_high(hlo->slice_limits(i) -
- hlo->operand(0)->shape().dimensions(i));
+ operand.base_shape().dimensions(i));
dim->set_base_dilation(1);
}
@@ -1726,14 +1726,14 @@
const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
- auto final_sort = b_.AddInstruction(HloInstruction::CreateSort(
+ HloInstruction* final_sort = b_.AddInstruction(HloInstruction::CreateSort(
final_topk_shape, sort_dim,
{replicated_slice_input, replicated_slice_index}, sort->to_apply(),
sort->is_stable()));
final_sort->set_sharding(HloSharding::Replicate()
.GetTupleSharding(final_sort->shape())
.ValueOrDie());
- PartitionedHlo replicated_sort(final_sort, final_topk_shape,
+ PartitionedHlo replicated_sort(final_sort, final_sort->shape(),
MakePartitioningState());
SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
@@ -2477,14 +2477,14 @@
auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
tuple.hlo(), hlo->tuple_index()));
- SetPartitionedHlo(hlo, [&]() {
- const auto source_sharding = tuple.sharding().GetSubSharding(
- tuple.base_shape(), {hlo->tuple_index()});
- gte->set_sharding(source_sharding);
- PartitionedHlo source_partitioned_gte(gte, hlo->shape(),
- MakePartitioningState());
- return source_partitioned_gte.Reshard(hlo->sharding()).hlo();
- });
+ const auto source_sharding =
+ tuple.sharding().GetSubSharding(tuple.base_shape(), {hlo->tuple_index()});
+ gte->set_sharding(source_sharding);
+ PartitionedHlo source_partitioned_gte(
+ gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
+ MakePartitioningState());
+ source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
+ SetPartitionedHlo(hlo, source_partitioned_gte);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
index f92cd0b..ad4f1ce 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
@@ -60,6 +60,9 @@
// memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE
// pass to CSE some all-gathers which are relatively close to each other.
bool cache_all_gather = true;
+ // When making a compromise between windowed einsum speed and memory usage
+ // prefer the former if true.
+ bool choose_faster_windowed_einsum_over_mem = false;
};
// Class to wrap the computation builder to capture information during SPMD
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
index f44da51..ab19d00 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
@@ -15,6 +15,8 @@
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
@@ -35,12 +37,15 @@
public:
StatusOr<std::unique_ptr<HloModule>> PartitionComputation(
absl::string_view hlo_module, int64 num_devices,
- bool conv_halo_exchange_always_on_lhs = true) {
+ bool conv_halo_exchange_always_on_lhs = true,
+ bool choose_faster_windowed_einsum = false) {
// Some tests (BackpropFilter convs) set this flag false to test two
// different paths of the implementation.
SpmdPartitionerOptions options;
options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs;
options.allow_module_signature_change = true;
+ options.choose_faster_windowed_einsum_over_mem =
+ choose_faster_windowed_einsum;
auto collective_ops_creator =
GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1);
// Do not use all-gather for pattern-matching purpose, as the partitioner
@@ -6954,6 +6959,267 @@
_, op::AllReduce(op::Select(_, _, gather)), _, _, _, _)));
}
+TEST_F(SpmdPartitioningTest, SortTopKNonSortDimension) {
+ absl::string_view hlo_string = R"(
+HloModule module
+
+%compare-greater-than.42077 (p.0.lhs.42078: f32[],
+ p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
+ %p.0.lhs.42078 = f32[] parameter(0)
+ %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
+ %constant.45054 = s32[] constant(0)
+ %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
+ s32[] %constant.45054), direction=LT
+ %constant.45278 = u32[] constant(2147483647)
+ %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
+ %subtract.337 = u32[] subtract(u32[] %constant.45278,
+ u32[] %bitcast-convert.136)
+ %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
+ %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
+ s32[] %bitcast-convert.135)
+ %p.0.rhs.42079 = f32[] parameter(1)
+ %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
+ %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
+ s32[] %constant.45054), direction=LT
+ %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
+ %subtract.338 = u32[] subtract(u32[] %constant.45278,
+ u32[] %bitcast-convert.139)
+ %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
+ %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
+ s32[] %bitcast-convert.138)
+ %compare.135 = pred[] compare(s32[] %select.282,
+ s32[] %select.283), direction=GT
+ %compare.428 = pred[] compare(s32[] %select.283,
+ s32[] %select.282), direction=GT
+ %compare.429 = pred[] compare(pred[] %compare.135,
+ pred[] %compare.428), direction=EQ
+ %p.1.lhs.42080 = s32[] parameter(2)
+ %p.1.rhs.42081 = s32[] parameter(3)
+ %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
+ s32[] %p.1.rhs.42081), direction=LT
+ ROOT %select.579 = pred[] select(pred[] %compare.429,
+ pred[] %compare.430, pred[] %compare.135)
+}
+
+ENTRY %module {
+ %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
+ sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
+ %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
+ sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
+ %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
+ f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
+ dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
+ sharding={{devices=[2,1,4]0,1,2,3,4,5,6,7},
+ {devices=[2,1,4]0,1,2,3,4,5,6,7}}
+ output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
+ sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
+ %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
+ slice={[0:2], [0:64], [0:2]},
+ sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
+ output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
+ sharding={replicated}
+ %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
+ slice={[0:2], [0:64], [0:2]},
+ sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
+ ROOT output.t = (f32[2,64,2]{2,1,0},
+ s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
+ sharding={{replicated}, {replicated}}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ PartitionComputation(hlo_string, /*num_devices=*/8));
+
+ const HloInstruction* sort = FindInstruction(module.get(), "sort");
+ EXPECT_NE(sort, nullptr);
+ auto sort_match =
+ AllOf(op::Shape("(f32[2,64,32128], s32[2,64,32128])"), op::Sort(_, _));
+ EXPECT_THAT(sort, sort_match);
+}
+
+TEST_F(SpmdPartitioningTest, SortTopKPropagateBaseShape) {
+ absl::string_view hlo_string = R"(
+HloModule module
+
+%compare-greater-than.42077 (p.0.lhs.42078: f32[],
+ p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
+ %p.0.lhs.42078 = f32[] parameter(0)
+ %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
+ %constant.45054 = s32[] constant(0)
+ %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
+ s32[] %constant.45054), direction=LT
+ %constant.45278 = u32[] constant(2147483647)
+ %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
+ %subtract.337 = u32[] subtract(u32[] %constant.45278,
+ u32[] %bitcast-convert.136)
+ %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
+ %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
+ s32[] %bitcast-convert.135)
+ %p.0.rhs.42079 = f32[] parameter(1)
+ %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
+ %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
+ s32[] %constant.45054), direction=LT
+ %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
+ %subtract.338 = u32[] subtract(u32[] %constant.45278,
+ u32[] %bitcast-convert.139)
+ %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
+ %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
+ s32[] %bitcast-convert.138)
+ %compare.135 = pred[] compare(s32[] %select.282,
+ s32[] %select.283), direction=GT
+ %compare.428 = pred[] compare(s32[] %select.283,
+ s32[] %select.282), direction=GT
+ %compare.429 = pred[] compare(pred[] %compare.135,
+ pred[] %compare.428), direction=EQ
+ %p.1.lhs.42080 = s32[] parameter(2)
+ %p.1.rhs.42081 = s32[] parameter(3)
+ %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
+ s32[] %p.1.rhs.42081), direction=LT
+ ROOT %select.579 = pred[] select(pred[] %compare.429,
+ pred[] %compare.430, pred[] %compare.135)
+}
+
+ENTRY %module {
+ %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
+ sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
+ %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
+ sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
+ %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
+ f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
+ dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
+ sharding={{devices=[1,1,8]0,1,2,3,4,5,6,7},
+ {devices=[1,1,8]0,1,2,3,4,5,6,7}}
+ output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
+ sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
+ %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
+ slice={[0:2], [0:64], [0:2]},
+ sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
+ output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
+ sharding={replicated}
+ %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
+ slice={[0:2], [0:64], [0:2]},
+ sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
+ ROOT output.t = (f32[2,64,2]{2,1,0},
+ s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
+ sharding={{replicated}, {replicated}}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ PartitionComputation(hlo_string, /*num_devices=*/8));
+
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ auto all_reduce_val =
+ AllOf(op::Shape("f32[2,64,2]"),
+ op::Slice(op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _))));
+ auto all_reduce_idx =
+ AllOf(op::Shape("s32[2,64,2]"),
+ op::Slice(op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _))));
+ auto tuple = AllOf(op::Shape("(f32[2,64,2], s32[2,64,2])"),
+ op::Tuple(all_reduce_val, all_reduce_idx));
+ EXPECT_THAT(root, tuple);
+}
+
+TEST_F(SpmdPartitioningTest, GatherIndexOnlyCorrectReplacement) {
+ absl::string_view hlo_string = R"(
+HloModule module
+
+ENTRY %module {
+ %parameter.0 = bf16[1,8,6,6]{3,2,1,0} parameter(0),
+ sharding={replicated}
+ %parameter.1 = s32[2,4]{1,0} parameter(1),
+ sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+ %gather.100 = bf16[2,1,8,1,6]{4,3,2,1,0} gather(
+ bf16[1,8,6,6]{3,2,1,0} %parameter.0, s32[2,4]{1,0} %parameter.1),
+ offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1,2,3},
+ index_vector_dim=1, slice_sizes={1,8,1,6},
+ sharding={devices=[2,1,4,1,1]0,1,2,3,4,5,6,7}
+ %constant.45590 = s32[] constant(0), sharding={replicated}
+ %broadcast.54515 = s32[2,64,1,1]{3,2,1,0} broadcast(s32[] %constant.45590),
+ dimensions={},
+ sharding={devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+ ROOT %reshape.4243 = bf16[2,8,6]{2,1,0} reshape(
+ bf16[2,1,8,1,6]{4,3,2,1,0} %gather.100),
+ sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ PartitionComputation(hlo_string, /*num_devices=*/8));
+
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ auto param0 = AllOf(op::Shape("bf16[1,8,6,6]"), op::Parameter());
+ auto param1 = AllOf(op::Shape("s32[1,4]"), op::Parameter());
+ auto reshape = AllOf(
+ op::Shape("bf16[1,2,6]"),
+ op::Reshape(op::DynamicSlice(op::Gather(param0, param1), _, _, _, _, _)));
+ EXPECT_THAT(root, reshape);
+}
+
+TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint) {
+ absl::string_view hlo_string = R"(
+HloModule module
+
+ENTRY %module {
+ %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
+ sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
+ %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
+ sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
+ %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
+ convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
+ bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
+ window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
+ rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
+ sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
+ ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
+ reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
+ sharding={replicated}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module,
+ PartitionComputation(hlo_string, /*num_devices=*/8,
+ /*conv_halo_exchange_always_on_lhs =*/true,
+ /*choose_faster_windowed_einsum =*/false));
+ const HloInstruction* while_inst = FindInstruction(module.get(), "while");
+ EXPECT_NE(while_inst, nullptr);
+ const HloComputation* cond_comp = while_inst->while_condition();
+ const HloInstruction* root = cond_comp->root_instruction();
+ EXPECT_THAT(root, op::Compare(_, op::Constant()));
+ const HloConstantInstruction* iterations =
+ Cast<HloConstantInstruction>(root->operand(1));
+ EXPECT_TRUE(iterations->literal().GetFirstInteger());
+ EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
+}
+
+TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations) {
+ absl::string_view hlo_string = R"(
+HloModule module
+
+ENTRY %module {
+ %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
+ sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
+ %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
+ sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
+ %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
+ convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
+ bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
+ window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
+ rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
+ sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
+ ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
+ reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
+ sharding={replicated}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto module,
+ PartitionComputation(hlo_string, /*num_devices=*/8,
+ /*conv_halo_exchange_always_on_lhs =*/true,
+ /*choose_faster_windowed_einsum =*/true));
+ const HloInstruction* while_inst = FindInstruction(module.get(), "while");
+ EXPECT_NE(while_inst, nullptr);
+ const HloComputation* cond_comp = while_inst->while_condition();
+ const HloInstruction* root = cond_comp->root_instruction();
+ EXPECT_THAT(root, op::Compare(_, op::Constant()));
+ const HloConstantInstruction* iterations =
+ Cast<HloConstantInstruction>(root->operand(1));
+ EXPECT_TRUE(iterations->literal().GetFirstInteger());
+ EXPECT_EQ(*iterations->literal().GetFirstInteger(), 2);
+}
+
} // namespace
} // namespace spmd
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
index cae327c..e8b16cd 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
@@ -1231,7 +1231,8 @@
}
// Check if partitioned at sort dimension.
- for (int64 dim : sort->dimensions()) {
+ for (int64 dim = 0; dim < sort->shape().tuple_shapes(0).dimensions_size();
+ ++dim) {
if (sharding.tile_assignment().dim(dim) > 1) {
if (dim != sort_dim) {
return absl::nullopt;
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 2ea7ca5..2f0f9c6 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -1810,9 +1810,7 @@
// dot(reshape(A), reshape(transpose(reshape(Const)))),
// and then fold the reshape and transpose on the Const side.
// We can compare performance with and without algsimp pass to see the impact.
-void DOT_ReorderContracting(int num_iters) {
- tensorflow::testing::StopTiming();
-
+void DOT_ReorderContracting(::testing::benchmark::State& state) {
se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
se::StreamExecutorMemoryAllocator allocator(platform, executors);
@@ -1864,16 +1862,13 @@
}
const int64 total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3;
- tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) *
- total_bytes * sizeof(float));
- tensorflow::testing::UseRealTime();
- tensorflow::testing::StartTiming();
- for (int i = 0; i < num_iters; ++i) {
+ for (auto s : state) {
ASSERT_IS_OK(executable->Run({&buffer0}, options));
}
+ state.SetBytesProcessed(state.iterations() * total_bytes * sizeof(float));
}
-BENCHMARK(DOT_ReorderContracting);
+BENCHMARK(DOT_ReorderContracting)->UseRealTime();
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 0bf5d0e..0317c41 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -86,7 +86,7 @@
// Command-line opts to this tool. See main() for descriptions of these
// fields.
struct Options {
- Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {}
+ Options() {}
bool NeedsRealData() const { return !use_fake_data && !compile_only; }
@@ -106,7 +106,7 @@
bool print_result = true;
int num_runs = 1;
- int intra_op_thread_pool_size;
+ int intra_op_thread_pool_size = -1;
bool compile_only = false;
};
@@ -173,7 +173,7 @@
if (!fake_xfeed_shape.empty()) {
xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
} else if (generate_fake_xfeed) {
- CHECK_LT(xfeed_instrs.size(), 2)
+ QCHECK_LT(xfeed_instrs.size(), 2)
<< "--generate_fake_" << xfeed_name
<< " only works if the model has 0 or 1 " << xfeed_name << " ops.";
if (xfeed_instrs.empty()) {
@@ -196,7 +196,7 @@
<< " ops, but this model has " << xfeed_instrs.size()
<< " of them:";
log_xfeed_instrs();
- LOG(FATAL) << "Can't run model with --generate_fake_infeed.";
+ LOG(QFATAL) << "Can't run model with --generate_fake_infeed.";
}
} else if (!xfeed_instrs.empty()) {
LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
@@ -314,8 +314,11 @@
if (xla_hlo_profile && is_final_result) {
LOG(INFO) << "\n\n***** Final run below ******";
}
+ int thread_pool_size = opts.intra_op_thread_pool_size < 0
+ ? tensorflow::port::MaxParallelism()
+ : opts.intra_op_thread_pool_size;
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
- opts.intra_op_thread_pool_size);
+ thread_pool_size);
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
pool.NumThreads());
@@ -366,10 +369,10 @@
LOG(ERROR) << "Encountered bad proto";
}
}
- CHECK(!snapshots.empty())
+ QCHECK(!snapshots.empty())
<< "No proto is successfully parsed from the file - the file possibly "
"has a mismatched compression option, format, etc.";
- CHECK(!opts.NeedsRealData())
+ QCHECK(!opts.NeedsRealData())
<< "Without --use_fake_data or --compile_only, you must pass an "
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
return snapshots;
@@ -387,7 +390,7 @@
if (s.code() == tensorflow::error::NOT_FOUND) {
return s;
}
- CHECK(!opts.NeedsRealData())
+ QCHECK(!opts.NeedsRealData())
<< "Without --use_fake_data or --compile_only, you must pass an "
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
diff --git a/tensorflow/core/api_def/base_api/api_def_Relu.pbtxt b/tensorflow/core/api_def/base_api/api_def_Relu.pbtxt
index f8460ec..fb2e27e 100644
--- a/tensorflow/core/api_def/base_api/api_def_Relu.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Relu.pbtxt
@@ -4,7 +4,7 @@
description: <<END
See: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
Example usage:
->>> tf.nn.relu([-2., 0., -0., 3.]).numpy()
-array([ 0., 0., -0., 3.], dtype=float32)
+>>> tf.nn.relu([-2., 0., 3.]).numpy()
+array([0., 0., 3.], dtype=float32)
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV2.pbtxt
index e21f56b..4f42650 100644
--- a/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV2.pbtxt
@@ -48,33 +48,33 @@
For example:
```
-# tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
-y, idx, count = unique_with_counts(x)
+x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
+y, idx, count = UniqueWithCountsV2(x, axis = [0])
y ==> [1, 2, 4, 7, 8]
idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
count ==> [2, 1, 3, 1, 2]
```
-For an `2-D` tensor `x` with `axis = 0`:
+For a `2-D` tensor `x` with `axis = 0`:
```
-# tensor 'x' is [[1, 0, 0],
-# [1, 0, 0],
-# [2, 0, 0]]
-y, idx, count = unique_with_counts(x, axis=0)
+x = tf.constant([[1, 0, 0],
+ [1, 0, 0],
+ [2, 0, 0]])
+y, idx, count = UniqueWithCountsV2(x, axis=[0])
y ==> [[1, 0, 0],
[2, 0, 0]]
idx ==> [0, 0, 1]
count ==> [2, 1]
```
-For an `2-D` tensor `x` with `axis = 1`:
+For a `2-D` tensor `x` with `axis = 1`:
```
-# tensor 'x' is [[1, 0, 0],
-# [1, 0, 0],
-# [2, 0, 0]]
-y, idx, count = unique_with_counts(x, axis=1)
+x = tf.constant([[1, 0, 0],
+ [1, 0, 0],
+ [2, 0, 0]])
+y, idx, count = UniqueWithCountsV2(x, axis=[1])
y ==> [[1, 0],
[1, 0],
[2, 0]]
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 7fe6e00..56e1f33 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -1067,9 +1067,24 @@
CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
- local_device_manager_.Reset(device_mgr);
+ if (device_mgr != local_device_manager_.Get()) {
+ if (local_device_manager_.Owned()) {
+ old_local_device_managers_.push_back(
+ std::move(local_device_manager_.owned_object));
+ }
+ local_device_manager_.Reset(device_mgr);
+ }
host_cpu_device_ = local_device_manager_.Get()->HostCPU();
+ if (reuse_rendezvous_for_functions_) {
+ // If reuse_rendezvous_for_functions_ is true, CreateRendezvous is
+ // idempotent and ignores its step_id argument. Create a rendezvous now to
+ // replace the old one, preventing the old one from getting used.
+ if (rendezvous_ != nullptr) rendezvous_->Unref();
+ rendezvous_ = CreateRendezvous(/*step_id=*/-1);
+ return errors::Aborted("Cannot create a valid rendezvous.");
+ }
+
InitPrioritizedDeviceTypeList();
ClearCachesAndThreadExecutors();
default_executor_.ClearError();
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 659d296..c837f1d 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -516,6 +516,22 @@
if (on_same_task) {
continue;
}
+ // Compare with default_device if it has a narrower scope matching
+ // requested device.
+ int colocated_on_default_device = 0;
+ for (int i = 0; i < matching_devices.size(); ++i) {
+ if (DeviceNameUtils::IsSameAddressSpace(
+ default_device->parsed_name(),
+ matching_devices.at(i)->parsed_name())) {
+ colocated_on_default_device++;
+ }
+ }
+ // Continue to raise error if multiple colocated devices are
+ // found.
+ if (colocated_on_default_device == 1) {
+ continue;
+ }
+
// Convert a vector of devices to a string.
// Using absl::StrJoin did not work in Android builds.
string devices = "[";
diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc
index ed6d68a..5120c17 100644
--- a/tensorflow/core/data/service/data_service.cc
+++ b/tensorflow/core/data/service/data_service.cc
@@ -233,8 +233,31 @@
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
+ args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
stub_ = DispatcherService::NewStub(channel);
+ GetVersionRequest req;
+ GetVersionResponse resp;
+ TF_RETURN_IF_ERROR(grpc_util::Retry(
+ [&] {
+ grpc::ClientContext ctx;
+ grpc::Status s = stub_->GetVersion(&ctx, req, &resp);
+ if (!s.ok()) {
+ return grpc_util::WrapError("Failed to get dispatcher version", s);
+ }
+ return Status::OK();
+ },
+ "checking service version",
+ /*deadline_micros=*/kint64max));
+ if (resp.version() != kDataServiceVersion) {
+ return errors::FailedPrecondition(
+ "Version mismatch with tf.data service server. The server is running "
+ "version ",
+ resp.version(), ", while the client is running version ",
+ kDataServiceVersion,
+ ". Please ensure that the client and server side are running the "
+ "same version of TensorFlow.");
+ }
return Status::OK();
}
diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h
index 150750c..a101972 100644
--- a/tensorflow/core/data/service/data_service.h
+++ b/tensorflow/core/data/service/data_service.h
@@ -28,6 +28,10 @@
namespace tensorflow {
namespace data {
+// Increment this when making backwards-incompatible changes to communication
+// between tf.data servers.
+constexpr int kDataServiceVersion = 1;
+
// Modes for how a tf.data service job should process a dataset.
enum class ProcessingMode : int64 {
UNSET = 0,
diff --git a/tensorflow/core/data/service/dispatcher.proto b/tensorflow/core/data/service/dispatcher.proto
index 91f3c5c..efcddee 100644
--- a/tensorflow/core/data/service/dispatcher.proto
+++ b/tensorflow/core/data/service/dispatcher.proto
@@ -48,6 +48,12 @@
bool end_of_splits = 2;
}
+message GetVersionRequest {}
+
+message GetVersionResponse {
+ int64 version = 1;
+}
+
message GetOrRegisterDatasetRequest {
// The dataset to register.
DatasetDef dataset = 1;
@@ -146,6 +152,9 @@
// Gets the next split for a given job.
rpc GetSplit(GetSplitRequest) returns (GetSplitResponse);
+ // Returns the API version of the server.
+ rpc GetVersion(GetVersionRequest) returns (GetVersionResponse);
+
// Registers a dataset with the server, or returns its id if it is already
// registered.
//
diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc
index 61d6c1e..b2dbe51 100644
--- a/tensorflow/core/data/service/dispatcher_impl.cc
+++ b/tensorflow/core/data/service/dispatcher_impl.cc
@@ -361,6 +361,12 @@
return Status::OK();
}
+Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request,
+ GetVersionResponse* response) {
+ response->set_version(kDataServiceVersion);
+ return Status::OK();
+}
+
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
const GetOrRegisterDatasetRequest* request,
GetOrRegisterDatasetResponse* response) {
diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h
index 79e8b3a..8dae65e 100644
--- a/tensorflow/core/data/service/dispatcher_impl.h
+++ b/tensorflow/core/data/service/dispatcher_impl.h
@@ -135,6 +135,8 @@
Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response);
/// Client-facing API.
+ Status GetVersion(const GetVersionRequest* request,
+ GetVersionResponse* response);
Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request,
GetOrRegisterDatasetResponse* response);
Status GetOrCreateJob(const GetOrCreateJobRequest* request,
diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc
index 4426069..64edc44 100644
--- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc
+++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc
@@ -44,6 +44,7 @@
HANDLER(WorkerUpdate);
HANDLER(GetDatasetDef);
HANDLER(GetSplit);
+HANDLER(GetVersion);
HANDLER(GetOrRegisterDataset);
HANDLER(ReleaseJobClient);
HANDLER(GetOrCreateJob);
diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h
index 0feeb23..294898a 100644
--- a/tensorflow/core/data/service/grpc_dispatcher_impl.h
+++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h
@@ -43,6 +43,7 @@
HANDLER(WorkerUpdate);
HANDLER(GetDatasetDef);
HANDLER(GetSplit);
+ HANDLER(GetVersion);
HANDLER(GetOrRegisterDataset);
HANDLER(ReleaseJobClient);
HANDLER(GetOrCreateJob);
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 8053b61..6772212 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -1638,13 +1638,38 @@
void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
int64 ram_budget, double model_input_time) {
+ std::shared_ptr<Node> snapshot;
+ {
+ tf_shared_lock lock(mu_);
+ snapshot = output_->Snapshot();
+ }
+ OptimizationParams optimization_params;
+ optimization_params.set_algorithm(algorithm);
+ optimization_params.set_cpu_budget(cpu_budget);
+ optimization_params.set_ram_budget(ram_budget);
+ optimization_params.set_model_input_time(model_input_time);
switch (algorithm) {
case AutotuneAlgorithm::HILL_CLIMB:
- OptimizeHillClimb(cpu_budget, ram_budget, model_input_time);
+ OptimizeHillClimb(snapshot, optimization_params);
break;
case AutotuneAlgorithm::GRADIENT_DESCENT:
- OptimizeGradientDescent(cpu_budget, ram_budget, model_input_time);
+ OptimizeGradientDescent(snapshot, optimization_params);
break;
+ default:
+ VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
+ "optimization.";
+ return;
+ }
+ if (!save_dir_.empty()) {
+ mutex_lock lock(mu_);
+ Status status = EnsureSaveLoopThreadStarted();
+ if (status.ok() && save_buffer_.size() < kMaxNumBufferedOptimizeArgs) {
+ save_buffer_.push_back(std::make_pair(snapshot, optimization_params));
+ save_cond_var_.notify_all();
+ } else if (save_buffer_.size() >= kMaxNumBufferedOptimizeArgs) {
+ VLOG(3) << "Saved snapshots buffer is full. Current snapshot and "
+ "optimization parameters will not be saved.";
+ }
}
}
@@ -1707,7 +1732,7 @@
cancellation_manager,
[this]() {
mutex_lock l(mu_);
- cond_var_.notify_all();
+ optimize_cond_var_.notify_all();
},
/*deregister_fn=*/&unused));
@@ -1721,7 +1746,7 @@
auto wait_ms =
last_optimization_ms + optimization_period_ms_ - current_time_ms;
VLOG(2) << "Waiting for " << wait_ms << " ms.";
- cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
+ optimize_cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
}
if (cancellation_manager->IsCancelled()) {
@@ -1747,13 +1772,9 @@
}
}
-void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget,
- double model_input_time) {
- std::shared_ptr<Node> snapshot;
- {
- tf_shared_lock lock(mu_);
- snapshot = output_->Snapshot();
- }
+void Model::OptimizeGradientDescent(
+ std::shared_ptr<Node> snapshot,
+ const OptimizationParams& optimization_params) {
VLOG(2) << "Starting optimization of tunable parameters with Gradient "
"Descent.";
auto parameters = CollectTunableParameters(snapshot);
@@ -1788,13 +1809,15 @@
// and we only increase the buffer size parameters.
bool cpu_budget_reached = false;
- for (int i = 0;
- i < kMaxIterations &&
- !ShouldStop(cpu_budget, ram_budget, parameters, parallelism_parameters,
- buffer_size_parameters, snapshot, &cpu_budget_reached);
+ for (int i = 0; i < kMaxIterations &&
+ !ShouldStop(optimization_params.cpu_budget(),
+ optimization_params.ram_budget(), parameters,
+ parallelism_parameters, buffer_size_parameters,
+ snapshot, &cpu_budget_reached);
++i) {
absl::flat_hash_map<string, double> gradients;
- new_output_time = OutputTime(snapshot, model_input_time, &gradients);
+ new_output_time = OutputTime(
+ snapshot, optimization_params.model_input_time(), &gradients);
// We also terminate once the improvement of the output latency is too
// small.
if (std::abs(output_time - new_output_time) < kOptimizationPrecision) {
@@ -1812,13 +1835,8 @@
UpdateStateValues(¶meters);
}
-void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget,
- double model_input_time) {
- std::shared_ptr<Node> snapshot;
- {
- tf_shared_lock lock(mu_);
- snapshot = output_->Snapshot();
- }
+void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
+ const OptimizationParams& optimization_params) {
VLOG(2) << "Starting optimization of tunable parameters with Hill Climb.";
const double processing_time = TotalProcessingTime(snapshot);
auto parameters = CollectTunableParameters(snapshot);
@@ -1838,7 +1856,8 @@
}
while (true) {
const double output_time =
- OutputTime(snapshot, model_input_time, /*gradients=*/nullptr);
+ OutputTime(snapshot, optimization_params.model_input_time(),
+ /*gradients=*/nullptr);
bool all_max = true;
for (auto& pair : parameters) {
if (pair.second->value < pair.second->max) {
@@ -1846,8 +1865,10 @@
break;
}
}
- if (output_time < processing_time / cpu_budget || all_max ||
- TotalMaximumBufferedBytes(snapshot) > ram_budget) {
+ if (output_time < processing_time / optimization_params.cpu_budget() ||
+ all_max ||
+ TotalMaximumBufferedBytes(snapshot) >
+ optimization_params.ram_budget()) {
break;
}
double best_delta = -1.0L;
@@ -1858,7 +1879,8 @@
}
pair.second->value++;
double new_output_time =
- OutputTime(snapshot, model_input_time, /*gradients=*/nullptr);
+ OutputTime(snapshot, optimization_params.model_input_time(),
+ /*gradients=*/nullptr);
double delta = output_time - new_output_time;
if (delta > best_delta &&
(delta > kBufferSizeMinDelta || pair.second->name != kBufferSize)) {
@@ -1930,6 +1952,72 @@
return Status::OK();
}
+Status Model::Save(const string& fname, std::shared_ptr<Node> snapshot,
+ const OptimizationParams& optimization_params) {
+ ModelProto model_proto;
+ std::unique_ptr<Model> model_snapshot = std::make_unique<Model>();
+ {
+ mutex_lock lock(model_snapshot->mu_);
+ model_snapshot->output_ = std::move(snapshot);
+ model_snapshot->id_counter_ = id_counter_;
+ model_snapshot->collect_resource_usage_.store(collect_resource_usage_);
+ }
+ TF_RETURN_IF_ERROR(model_snapshot->ToProto(&model_proto));
+ OptimizationParams* saved_optimization_params =
+ model_proto.mutable_optimization_params();
+ *saved_optimization_params = optimization_params;
+ return WriteBinaryProto(Env::Default(), fname, model_proto);
+}
+
+Status Model::Load(const string& fname, std::unique_ptr<Model>* model,
+ OptimizationParams* optimization_params) {
+ ModelProto model_proto;
+ TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), fname, &model_proto));
+ TF_RETURN_IF_ERROR(FromProto(model_proto, model));
+ const OptimizationParams restored_optimization_params =
+ model_proto.optimization_params();
+ *optimization_params = restored_optimization_params;
+ return Status::OK();
+}
+
+Status Model::EnsureSaveLoopThreadStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!save_thread_) {
+ save_thread_ = absl::WrapUnique(
+ Env::Default()->StartThread({}, "tf_data_model_save", [this]() {
+ Status status = SaveLoop();
+ if (!status.ok()) {
+ VLOG(2) << "Model save loop failed: " << status.ToString();
+ }
+ }));
+ }
+ return Status::OK();
+}
+
+Status Model::SaveLoop() {
+ TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(save_dir_));
+ while (true) {
+ std::pair<std::shared_ptr<Node>, OptimizationParams> to_save;
+ {
+ mutex_lock l(mu_);
+ while (!save_thread_cancelled_ && save_buffer_.empty()) {
+ save_cond_var_.wait(l);
+ }
+ if (save_thread_cancelled_) {
+ return Status::OK();
+ }
+ to_save = save_buffer_.front();
+ save_buffer_.pop_front();
+ }
+ string model_name =
+ absl::StrCat("autotune_model_",
+ Hash64Combine(static_cast<uint64>(EnvTime::NowMicros()),
+ reinterpret_cast<uint64>(this)));
+ string fname = io::JoinPath(save_dir_, model_name);
+ TF_RETURN_IF_ERROR(Save(fname, to_save.first, to_save.second));
+ VLOG(2) << "Model was saved as " << fname;
+ }
+}
+
} // namespace model
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index d3a9cf5..fdd975a 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -35,6 +35,7 @@
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/path.h"
namespace tensorflow {
namespace data {
@@ -48,11 +49,6 @@
// A key used to identify the input time of the model.
constexpr char kModelInputTimeKey[] = "model_input_time";
-enum class AutotuneAlgorithm {
- HILL_CLIMB = 0,
- GRADIENT_DESCENT = 1,
-};
-
enum class TraversalOrder {
BFS = 0,
REVERSE_BFS = 1,
@@ -641,10 +637,24 @@
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
class Model {
public:
+ using OptimizationParams = ModelProto::OptimizationParams;
+
// Creates a new model.
Model()
: collect_resource_usage_(false),
- optimization_period_ms_(kOptimizationPeriodMinMs) {}
+ optimization_period_ms_(kOptimizationPeriodMinMs) {
+ const char* save_dir = std::getenv("TF_DATA_AUTOTUNE_DEBUG_DIR");
+ if (save_dir) {
+ save_dir_ = string(save_dir);
+ }
+ }
+
+ ~Model() {
+ if (!save_dir_.empty()) {
+ save_thread_cancelled_ = true;
+ save_cond_var_.notify_all();
+ }
+ }
// Indicates whether to collect resource usage.
bool collect_resource_usage() const { return collect_resource_usage_; }
@@ -664,7 +674,7 @@
// autotuning optimization.
//
// To terminate the execution of the optimization loop, the caller needs to
- // to invoke `cancellation_mgr->StartCancel()`.
+ // invoke `cancellation_mgr->StartCancel()`.
Status OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget,
int64 ram_budget, CancellationManager* cancellation_mgr);
@@ -683,11 +693,24 @@
static Status FromProto(ModelProto model_proto,
std::unique_ptr<Model>* model);
+ // Saves this model with a given snapshot and its optimization parameters to a
+ // file. Note that the file directory must already exist.
+ Status Save(const string& fname, std::shared_ptr<Node> snapshot,
+ const OptimizationParams& optimization_params);
+
+ // Loads a model and its optimization parameters from a file with the given
+ // name.
+ static Status Load(const string& fname, std::unique_ptr<Model>* model,
+ OptimizationParams* optimization_params);
+
private:
static constexpr int64 kOptimizationPeriodMinMs = 10;
static constexpr int64 kOptimizationPeriodMaxMs =
60 * EnvTime::kSecondsToMillis;
+ // Maximum number of optimization snapshots kept in a buffer for saving.
+ static constexpr int64 kMaxNumBufferedOptimizeArgs = 100;
+
// Collects tunable parameters in the tree rooted in the given node, returning
// a mapping from a (unique) node name to a tunable parameter.
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
@@ -702,8 +725,8 @@
// This process is repeated until all parameters reach their maximum values or
// the projected output time is less than or equal to the processing time
// needed to produce an element divided by CPU budget.
- void OptimizeHillClimb(int64 cpu_budget, int64 ram_budget,
- double model_input_time);
+ void OptimizeHillClimb(std::shared_ptr<Node> snapshot,
+ const OptimizationParams& optimization_params);
// This optimization algorithm starts by setting all tunable parallelism
// parameters to the minimum value. It then improves current parameters by
@@ -712,8 +735,8 @@
// repeated until either the output time improvement is smaller than threshold
// value or the output time is less than the processing time needed to produce
// an element divided by CPU budget.
- void OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget,
- double model_input_time);
+ void OptimizeGradientDescent(std::shared_ptr<Node> snapshot,
+ const OptimizationParams& optimization_params);
// Collects the output time and if `gradients` is not `nullptr`, the output
// time gradient w.r.t. tunable parameters of the subtree rooted in the given
@@ -746,12 +769,21 @@
// buffers were full.
double TotalMaximumBufferedBytes(std::shared_ptr<Node> node);
+ // Starts a model saving thread if it hasn't started yet.
+ Status EnsureSaveLoopThreadStarted();
+
+ // Periodically saves the state of optimization that is kept in
+ // `save_buffer_`.
+ //
+ // The saving loop is terminated when the model is destroyed.
+ Status SaveLoop();
+
// Used for coordination between different input pipeline threads. Exclusive
// access is required only when adding or removing nodes. Concurrent access to
// existing nodes is protected by a node mutex.
mutex mu_;
// Used for coordinating the optimization loop and model modifications.
- condition_variable cond_var_;
+ condition_variable optimize_cond_var_;
int64 id_counter_ TF_GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_);
@@ -766,6 +798,25 @@
// Determines the time the optimization loop should wait between
// running optimizations.
int64 optimization_period_ms_ TF_GUARDED_BY(mu_);
+
+ // Thread that runs the model saving loop.
+ std::unique_ptr<Thread> save_thread_ TF_GUARDED_BY(mu_);
+
+ // Used for coordinating the saving loop and model optimization.
+ condition_variable save_cond_var_;
+
+ // Indicates whether the save thread is cancelled.
+ bool save_thread_cancelled_ = false;
+
+ // Contains path to the model saving directory if saving is enabled, empty
+ // otherwise.
+ string save_dir_;
+
+ // Contains pairs of model snapshots and optimization parameters to be saved
+ // if model saving is enabled, empty otherwise. Buffer elements are pushed by
+ // `OptimizeLoop` and popped by `SaveLoop`.
+ std::deque<std::pair<std::shared_ptr<Node>, OptimizationParams>> save_buffer_
+ TF_GUARDED_BY(mu_);
};
} // namespace model
diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto
index a0e2565..ba74d7a 100644
--- a/tensorflow/core/framework/model.proto
+++ b/tensorflow/core/framework/model.proto
@@ -14,6 +14,12 @@
UNKNOWN_RATIO = 5;
}
+// Algorithm used for model autotuning optimization.
+enum AutotuneAlgorithm {
+ HILL_CLIMB = 0;
+ GRADIENT_DESCENT = 1;
+}
+
// Protocol buffer representing the data used by the autotuning modeling
// framework.
message ModelProto {
@@ -103,4 +109,22 @@
// Indicates whether the modeling framework should collect resource usage,
// e.g. CPU, memory.
bool collect_resource_usage = 3;
+
+ // Contains parameters of the model autotuning optimization.
+ message OptimizationParams {
+ // Algorithm used for autotuning optimization.
+ AutotuneAlgorithm algorithm = 1;
+
+ // Number of available logical threads.
+ int64 cpu_budget = 2;
+
+ // Amount of available memory in bytes.
+ int64 ram_budget = 3;
+
+ // Time between two consecutive `GetNext` calls to the iterator represented
+ // by the output node.
+ double model_input_time = 4;
+ }
+
+ OptimizationParams optimization_params = 4;
}
diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc
index 2408b4d..7c0b3cb 100644
--- a/tensorflow/core/framework/model_test.cc
+++ b/tensorflow/core/framework/model_test.cc
@@ -885,7 +885,7 @@
}
}
-TEST(SerializeModelTest, Model) {
+TEST(SaveModelTest, Model) {
model::Model model;
std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
@@ -941,13 +941,29 @@
current = input;
}
- // Make ToProto->FromProto roundtrip.
- ModelProto model_proto;
- Status status = model.ToProto(&model_proto);
- TF_ASSERT_OK(status);
+ // Make Save->Load roundtrip.
+ ModelProto::OptimizationParams optimization_params;
+ optimization_params.set_algorithm(AutotuneAlgorithm::GRADIENT_DESCENT);
+ optimization_params.set_cpu_budget(64);
+ optimization_params.set_ram_budget(1024);
+ optimization_params.set_model_input_time(43653.34534);
+ TF_ASSERT_OK(model.Save("/tmp/autotune_model_test",
+ model.output()->Snapshot(), optimization_params));
+
std::unique_ptr<model::Model> restored_model;
- status = model::Model::FromProto(model_proto, &restored_model);
- TF_ASSERT_OK(status);
+ ModelProto::OptimizationParams restored_optimization_params;
+ TF_ASSERT_OK(model.Load("/tmp/autotune_model_test", &restored_model,
+ &restored_optimization_params));
+
+ // Check optimization parameters.
+ EXPECT_EQ(optimization_params.algorithm(),
+ restored_optimization_params.algorithm());
+ EXPECT_EQ(optimization_params.cpu_budget(),
+ restored_optimization_params.cpu_budget());
+ EXPECT_EQ(optimization_params.ram_budget(),
+ restored_optimization_params.ram_budget());
+ EXPECT_EQ(optimization_params.model_input_time(),
+ restored_optimization_params.model_input_time());
// Check that original and restored models hold the same data.
EXPECT_EQ(model.collect_resource_usage(),
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 683dac0..e5eb512 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -650,6 +650,12 @@
RefIfNonNull(buf);
}
+Tensor::Tensor(DataType type, const TensorShape& shape,
+ core::RefCountPtr<TensorBuffer> buf)
+ : shape_(shape), buf_(buf.release()) {
+ set_dtype(type);
+}
+
bool Tensor::IsInitialized() const {
return (buf_ != nullptr && buf_->data() != nullptr) ||
shape_.num_elements() == 0;
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 062e541..33a240d 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -157,6 +157,12 @@
/// Acquires a ref on buf that belongs to this Tensor.
Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
+ /// \brief Creates a tensor with the input datatype, shape and buf.
+ ///
+ /// Takes an ownership of the bufffer from the reference counted pointer.
+ Tensor(DataType type, const TensorShape& shape,
+ core::RefCountPtr<TensorBuffer> buf);
+
/// \brief Creates an empty Tensor of the given data type.
///
/// Like Tensor(), returns a 1-dimensional, 0-element Tensor with
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index aa3bdea..c63f1a3 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -26,6 +26,26 @@
namespace tensorflow {
+const char* VariantUnaryOpToString(VariantUnaryOp op) {
+ switch (op) {
+ case INVALID_VARIANT_UNARY_OP:
+ return "INVALID";
+ case ZEROS_LIKE_VARIANT_UNARY_OP:
+ return "ZEROS_LIKE";
+ case CONJ_VARIANT_UNARY_OP:
+ return "CONJ";
+ }
+}
+
+const char* VariantBinaryOpToString(VariantBinaryOp op) {
+ switch (op) {
+ case INVALID_VARIANT_BINARY_OP:
+ return "INVALID";
+ case ADD_VARIANT_BINARY_OP:
+ return "ADD";
+ }
+}
+
std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
static std::unordered_set<string>* string_storage =
new std::unordered_set<string>();
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index edfb9c5..6095407 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -44,11 +44,15 @@
CONJ_VARIANT_UNARY_OP = 2,
};
+const char* VariantUnaryOpToString(VariantUnaryOp op);
+
enum VariantBinaryOp {
INVALID_VARIANT_BINARY_OP = 0,
ADD_VARIANT_BINARY_OP = 1,
};
+const char* VariantBinaryOpToString(VariantBinaryOp op);
+
enum VariantDeviceCopyDirection {
INVALID_DEVICE_COPY_DIRECTION = 0,
HOST_TO_DEVICE = 1,
@@ -311,9 +315,10 @@
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
if (unary_op_fn == nullptr) {
- return errors::Internal(
- "No unary variant unary_op function found for unary variant op enum: ",
- op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
+ return errors::Internal("No unary variant unary_op function found for op ",
+ VariantUnaryOpToString(op),
+ " Variant type_name: ", v.TypeName(),
+ " for device type: ", device);
}
return (*unary_op_fn)(ctx, v, v_out);
}
@@ -340,11 +345,10 @@
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
if (binary_op_fn == nullptr) {
- return errors::Internal(
- "No unary variant binary_op function found for binary variant op "
- "enum: ",
- op, " Variant type_name: '", a.TypeName(), "' for device type: ",
- device);
+ return errors::Internal("No unary variant binary_op function found for op ",
+ VariantBinaryOpToString(op),
+ " Variant type_name: '", a.TypeName(),
+ "' for device type: ", device);
}
return (*binary_op_fn)(ctx, a, b, out);
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
index 368c957..2ba0ff0 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -32,7 +32,7 @@
namespace {
constexpr char kMapDataset[] = "MapDataset";
-constexpr char kParallelMapDataset[] = "ParallelMapDataset";
+constexpr char kParallelMapDataset[] = "ParallelMapDatasetV2";
NodeDef MakeParallelMap(const string& name, MutableGraphView* graph) {
// The inputs of the node to be parallelized could be changed by the
@@ -45,8 +45,9 @@
¶llel_map);
parallel_map.set_op(kParallelMapDataset);
auto* num_parallel_calls = graph_utils::AddScalarConstNode(
- static_cast<int32>(data::model::kAutotune), graph);
+ static_cast<int64>(data::model::kAutotune), graph);
parallel_map.add_input(num_parallel_calls->name());
+ AddNodeAttr("deterministic", "true", ¶llel_map);
return parallel_map;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
index 0d59cfc..8449873 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -68,7 +68,7 @@
GraphDef output;
TF_ASSERT_OK(OptimizeWithMapParallelization(item, &output, autotune));
- EXPECT_EQ(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output),
+ EXPECT_EQ(graph_utils::ContainsNodeWithOp("ParallelMapDatasetV2", output),
autotune);
EXPECT_EQ(graph_utils::ContainsGraphNodeWithName("map", output), !autotune);
}
@@ -99,7 +99,7 @@
GraphDef output;
TF_ASSERT_OK(OptimizeWithMapParallelization(item, &output, true));
- EXPECT_EQ(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output),
+ EXPECT_EQ(graph_utils::ContainsNodeWithOp("ParallelMapDatasetV2", output),
!from_function_def);
EXPECT_EQ(graph_utils::ContainsGraphNodeWithName("map", output),
from_function_def);
@@ -131,7 +131,7 @@
GraphDef output;
TF_ASSERT_OK(OptimizeWithMapParallelization(item, &output, true));
- EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDatasetV2", output));
EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a3037c2..198b3ee 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3328,8 +3328,10 @@
tf_kernel_library(
name = "cast_op",
+ copts = if_mlir_generated_gpu_kernels_enabled(if_true = ["-DMLIR_GENERATED_GPU_KERNELS_ENABLED=1"]) +
+ if_mlir_experimental_kernels_enabled(if_true = ["-DMLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED=1"]),
prefix = "cast_op",
- deps = MATH_DEPS,
+ deps = MATH_DEPS + if_mlir_experimental_kernels_enabled(if_true = ["//tensorflow/core/kernels/mlir_generated:cast_op"]),
)
tf_kernel_library(
@@ -6225,6 +6227,7 @@
"stateless_random_gamma_op.cc",
"stateless_random_ops.cc",
"stateless_random_ops_v2.cc",
+ "string_format_op.cc",
"string_join_op.cc",
"string_length_op.cc",
"string_lower_op.cc",
diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc
index 58004d1..654446b 100644
--- a/tensorflow/core/kernels/avgpooling_op.cc
+++ b/tensorflow/core/kernels/avgpooling_op.cc
@@ -77,6 +77,11 @@
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
+
+ for (int i = 0; i < ksize_.size(); ++i) {
+ OP_REQUIRES(context, ksize_[i] != 0,
+ errors::InvalidArgument("ksize cannot be zero"));
+ }
}
void Compute(OpKernelContext* context) override {
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index b52b4ab..5c0e6cd 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -160,9 +160,12 @@
opts.cancellation_manager = last_task_context->cancellation_manager();
opts.collective_executor = last_task_context->collective_executor();
opts.stats_collector = last_task_context->stats_collector();
- opts.rendezvous = last_task_context->rendezvous();
opts.runner = last_task_context->runner();
opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
+ // We do not set 'opts.rendezvous', since if the function is run multiple
+ // times in parallel with the same rendezvous, a _Send node from one run
+ // might be matched with a _Recv node of a different run. Not setting the
+ // rendezvous causes a new rendezvous to be used for each run.
Notification done_notif;
flib_->Run(opts, fhandle_, inputs, combined_outputs,
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 5f32291..0bc5131 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -230,11 +230,9 @@
.Device(DEVICE_GPU), \
GpuCastOp)
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+ !defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
CURRY_TYPES2(REGISTER_CAST_GPU, bool);
-CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
-CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
-CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
-CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
CURRY_TYPES2(REGISTER_CAST_GPU, int8);
CURRY_TYPES2(REGISTER_CAST_GPU, int16);
CURRY_TYPES2(REGISTER_CAST_GPU, int32);
@@ -242,6 +240,33 @@
CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
CURRY_TYPES2(REGISTER_CAST_GPU, float);
CURRY_TYPES2(REGISTER_CAST_GPU, double);
+#else
+
+#define CURRY_SUBSET_OF_TYPES(FN, arg0) \
+ FN(arg0, uint8); \
+ FN(arg0, uint16); \
+ FN(arg0, uint32); \
+ FN(arg0, uint64); \
+ FN(arg0, std::complex<float>); \
+ FN(arg0, std::complex<double>)
+
+CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, bool);
+CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int8);
+CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int16);
+CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int32);
+CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, int64);
+CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, Eigen::half);
+CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, float);
+CURRY_SUBSET_OF_TYPES(REGISTER_CAST_GPU, double);
+
+#undef CURRY_SUBSET_OF_TYPES
+
+#endif
+
+CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
REGISTER_CAST_GPU(float, bfloat16);
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc
index 9714894..0be09d8 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc
@@ -116,12 +116,15 @@
#define BM_Conv2DBwdFilter(T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING, \
type) \
static void BM_NAME(BM_Conv2DBackpropFilter, type, T, FMT, N, H, W, C, FH, \
- FW, FC, SH, SW, PADDING)(int iters) { \
- testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * \
- (C)); \
- test::Benchmark(#type, Conv2DBackpropFilter<T>(N, H, W, C, FH, FW, FC, SH, \
- SW, PADDING, FORMAT_##FMT)) \
- .Run(iters); \
+ FW, FC, SH, SW, \
+ PADDING)(::testing::benchmark::State & state) { \
+ test::Benchmark(#type, \
+ Conv2DBackpropFilter<T>(N, H, W, C, FH, FW, FC, SH, SW, \
+ PADDING, FORMAT_##FMT), \
+ /*old_benchmark_api*/ false) \
+ .Run(state); \
+ state.SetItemsProcessed(static_cast<int64>(state.iterations()) * (N) * \
+ (H) * (W) * (C)); \
} \
BENCHMARK(BM_NAME(BM_Conv2DBackpropFilter, type, T, FMT, N, H, W, C, FH, FW, \
FC, SH, SW, PADDING));
diff --git a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc
index 713c935..575e914 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc
@@ -84,9 +84,9 @@
.Input(backprop)
.Attr("T", DataTypeToEnum<T>::value)
.Attr("strides", {1, stride_h, stride_w, 1})
- .Attr("padding", padding == Padding::SAME
- ? "SAME"
- : padding == Padding::VALID ? "VALID" : "N/A")
+ .Attr("padding", padding == Padding::SAME ? "SAME"
+ : padding == Padding::VALID ? "VALID"
+ : "N/A")
.Attr("data_format", ToString(data_format))
.Finalize(graph, &conv2d));
@@ -115,12 +115,14 @@
#define BM_Conv2DBwdInput(T, FMT, N, H, W, C, FW, FH, FC, SH, SW, PADDING, \
type) \
static void BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, \
- FW, FC, SH, SW, PADDING)(int iters) { \
- testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * \
- (C)); \
- test::Benchmark(#type, Conv2DBackpropInput<T>(N, H, W, C, FH, FW, FC, SH, \
- SW, PADDING, FORMAT_##FMT)) \
- .Run(iters); \
+ FW, FC, SH, SW, \
+ PADDING)(::testing::benchmark::State & state) { \
+ test::Benchmark(#type, \
+ Conv2DBackpropInput<T>(N, H, W, C, FH, FW, FC, SH, SW, \
+ PADDING, FORMAT_##FMT), \
+ /*old_benchmark_api*/ false) \
+ .Run(state); \
+ state.SetItemsProcessed(state.iterations() * (N) * (H) * (W) * (C)); \
} \
BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, FW, \
FC, SH, SW, PADDING));
diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc
index f3a7592..5426efc 100644
--- a/tensorflow/core/kernels/cwise_op_abs.cc
+++ b/tensorflow/core/kernels/cwise_op_abs.cc
@@ -17,7 +17,8 @@
namespace tensorflow {
-#if !defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+ !defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
REGISTER8(UnaryOp, CPU, "Abs", functor::abs, Eigen::half, bfloat16, float,
double, int8, int16, int32, int64);
#else
diff --git a/tensorflow/core/kernels/cwise_op_add_1.cc b/tensorflow/core/kernels/cwise_op_add_1.cc
index 4c740c4..d10da05 100644
--- a/tensorflow/core/kernels/cwise_op_add_1.cc
+++ b/tensorflow/core/kernels/cwise_op_add_1.cc
@@ -19,7 +19,8 @@
REGISTER6(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
int64, bfloat16);
-#if !defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+ !defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
REGISTER6(BinaryOp, CPU, "AddV2", functor::add, float, Eigen::half, double,
int32, int64, bfloat16);
#else
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc
index d3c5e4b..1f46522 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc
@@ -598,10 +598,13 @@
}
// Rotate the vector such that the first dataset contains the next element
- // to be produced.
- std::rotate(datasets.begin(),
- datasets.begin() + (start_index % shard_dirs.size()),
- datasets.end());
+ // to be produced, but not if there are no shards at all (then we just
+ // construct an empty dataset).
+ if (!shard_dirs.empty()) {
+ std::rotate(datasets.begin(),
+ datasets.begin() + (start_index % shard_dirs.size()),
+ datasets.end());
+ }
*output = new NestedDataset(
datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset",
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util_test.cc b/tensorflow/core/kernels/data/experimental/snapshot_util_test.cc
index e253014..83a5b40 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_util_test.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_util_test.cc
@@ -91,10 +91,8 @@
SnapshotRoundTrip(io::compression::kSnappy, 2);
}
-void SnapshotReaderBenchmarkLoop(int iters, std::string compression_type,
- int version) {
- tensorflow::testing::StopTiming();
-
+void SnapshotReaderBenchmarkLoop(::testing::benchmark::State& state,
+ std::string compression_type, int version) {
tensorflow::DataTypeVector dtypes;
std::vector<Tensor> tensors;
GenerateTensorVector(dtypes, tensors);
@@ -106,7 +104,7 @@
TF_ASSERT_OK(Writer::Create(tensorflow::Env::Default(), filename,
compression_type, version, dtypes, &writer));
- for (int i = 0; i < iters; ++i) {
+ for (auto s : state) {
writer->WriteTensors(tensors).IgnoreError();
}
TF_ASSERT_OK(writer->Close());
@@ -115,34 +113,32 @@
TF_ASSERT_OK(Reader::Create(Env::Default(), filename, compression_type,
version, dtypes, &reader));
- tensorflow::testing::StartTiming();
- for (int i = 0; i < iters; ++i) {
+ for (auto s : state) {
std::vector<Tensor> read_tensors;
reader->ReadTensors(&read_tensors).IgnoreError();
}
- tensorflow::testing::StopTiming();
TF_ASSERT_OK(Env::Default()->DeleteFile(filename));
}
-void SnapshotCustomReaderNoneBenchmark(int iters) {
- SnapshotReaderBenchmarkLoop(iters, io::compression::kNone, 1);
+void SnapshotCustomReaderNoneBenchmark(::testing::benchmark::State& state) {
+ SnapshotReaderBenchmarkLoop(state, io::compression::kNone, 1);
}
-void SnapshotCustomReaderGzipBenchmark(int iters) {
- SnapshotReaderBenchmarkLoop(iters, io::compression::kGzip, 1);
+void SnapshotCustomReaderGzipBenchmark(::testing::benchmark::State& state) {
+ SnapshotReaderBenchmarkLoop(state, io::compression::kGzip, 1);
}
-void SnapshotCustomReaderSnappyBenchmark(int iters) {
- SnapshotReaderBenchmarkLoop(iters, io::compression::kSnappy, 1);
+void SnapshotCustomReaderSnappyBenchmark(::testing::benchmark::State& state) {
+ SnapshotReaderBenchmarkLoop(state, io::compression::kSnappy, 1);
}
-void SnapshotTFRecordReaderNoneBenchmark(int iters) {
- SnapshotReaderBenchmarkLoop(iters, io::compression::kNone, 2);
+void SnapshotTFRecordReaderNoneBenchmark(::testing::benchmark::State& state) {
+ SnapshotReaderBenchmarkLoop(state, io::compression::kNone, 2);
}
-void SnapshotTFRecordReaderGzipBenchmark(int iters) {
- SnapshotReaderBenchmarkLoop(iters, io::compression::kGzip, 2);
+void SnapshotTFRecordReaderGzipBenchmark(::testing::benchmark::State& state) {
+ SnapshotReaderBenchmarkLoop(state, io::compression::kGzip, 2);
}
BENCHMARK(SnapshotCustomReaderNoneBenchmark);
@@ -151,10 +147,8 @@
BENCHMARK(SnapshotTFRecordReaderNoneBenchmark);
BENCHMARK(SnapshotTFRecordReaderGzipBenchmark);
-void SnapshotWriterBenchmarkLoop(int iters, std::string compression_type,
- int version) {
- tensorflow::testing::StopTiming();
-
+void SnapshotWriterBenchmarkLoop(::testing::benchmark::State& state,
+ std::string compression_type, int version) {
tensorflow::DataTypeVector dtypes;
std::vector<Tensor> tensors;
GenerateTensorVector(dtypes, tensors);
@@ -166,38 +160,36 @@
TF_ASSERT_OK(Writer::Create(tensorflow::Env::Default(), filename,
compression_type, version, dtypes, &writer));
- tensorflow::testing::StartTiming();
- for (int i = 0; i < iters; ++i) {
+ for (auto s : state) {
writer->WriteTensors(tensors).IgnoreError();
}
writer->Close().IgnoreError();
- tensorflow::testing::StopTiming();
TF_ASSERT_OK(Env::Default()->DeleteFile(filename));
}
-void SnapshotCustomWriterNoneBenchmark(int iters) {
- SnapshotWriterBenchmarkLoop(iters, io::compression::kNone, 1);
+void SnapshotCustomWriterNoneBenchmark(::testing::benchmark::State& state) {
+ SnapshotWriterBenchmarkLoop(state, io::compression::kNone, 1);
}
-void SnapshotCustomWriterGzipBenchmark(int iters) {
- SnapshotWriterBenchmarkLoop(iters, io::compression::kGzip, 1);
+void SnapshotCustomWriterGzipBenchmark(::testing::benchmark::State& state) {
+ SnapshotWriterBenchmarkLoop(state, io::compression::kGzip, 1);
}
-void SnapshotCustomWriterSnappyBenchmark(int iters) {
- SnapshotWriterBenchmarkLoop(iters, io::compression::kSnappy, 1);
+void SnapshotCustomWriterSnappyBenchmark(::testing::benchmark::State& state) {
+ SnapshotWriterBenchmarkLoop(state, io::compression::kSnappy, 1);
}
-void SnapshotTFRecordWriterNoneBenchmark(int iters) {
- SnapshotWriterBenchmarkLoop(iters, io::compression::kNone, 2);
+void SnapshotTFRecordWriterNoneBenchmark(::testing::benchmark::State& state) {
+ SnapshotWriterBenchmarkLoop(state, io::compression::kNone, 2);
}
-void SnapshotTFRecordWriterGzipBenchmark(int iters) {
- SnapshotWriterBenchmarkLoop(iters, io::compression::kGzip, 2);
+void SnapshotTFRecordWriterGzipBenchmark(::testing::benchmark::State& state) {
+ SnapshotWriterBenchmarkLoop(state, io::compression::kGzip, 2);
}
-void SnapshotTFRecordWriterSnappyBenchmark(int iters) {
- SnapshotWriterBenchmarkLoop(iters, io::compression::kSnappy, 2);
+void SnapshotTFRecordWriterSnappyBenchmark(::testing::benchmark::State& state) {
+ SnapshotWriterBenchmarkLoop(state, io::compression::kSnappy, 2);
}
BENCHMARK(SnapshotCustomWriterNoneBenchmark);
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index f79d8e7..59c9114 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -967,7 +967,7 @@
Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
profiler::TraceMe traceme(
[&] {
- int64 mem_bw = port::GetMemoryInfo().bw_used;
+ int64 mem_bw = port::GetMemoryBandwidthInfo().bw_used;
if (mem_bw != INT64_MAX) {
return profiler::TraceMeEncode(
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index f3f67bc..9562fd8 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -225,6 +225,7 @@
elem.end_of_sequence = true;
} else {
buffer_[shard_num].callbacks.push_back(std::move(callback));
+ buffer_[shard_num].cond_var.notify_all();
callback = nullptr;
}
}
@@ -297,7 +298,8 @@
{
mutex_lock l(mu_);
while (!cancelled_ &&
- buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
+ buffer_[shard_to_fetch].data.size() >= max_buffer_size_ &&
+ buffer_[shard_to_fetch].callbacks.empty()) {
buffer_[shard_to_fetch].cond_var.wait(l);
}
diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
index 2e15107..3acd86e 100644
--- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc
+++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
@@ -35,6 +35,8 @@
namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+
namespace {
using sparse::SparseTensor;
@@ -204,9 +206,9 @@
target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
}
- ReshapeSparseTensor(context, output.indices(), input_shape, target_shape,
- 0 /* output indices index */,
- 2 /* output shape index */);
+ ReshapeSparseTensor<CPUDevice>(context, output.indices(), input_shape,
+ target_shape, 0 /* output indices index */,
+ 2 /* output shape index */);
context->set_output(1, output.values());
}
diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h
index 87e41b8..8b35bfd 100644
--- a/tensorflow/core/kernels/eigen_benchmark.h
+++ b/tensorflow/core/kernels/eigen_benchmark.h
@@ -35,8 +35,9 @@
using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
- SpatialConvolutionBenchmarksSuite(int iters, Device& device)
- : iters_(iters), device_(device) {}
+ SpatialConvolutionBenchmarksSuite(::testing::benchmark::State& state,
+ Device& device)
+ : state_(state), device_(device) {}
Eigen::Index BufferSize(const Dimensions& dims) {
return dims.TotalSize() * sizeof(Scalar);
@@ -62,12 +63,10 @@
Filter filter(filter_data, filter_dims);
Output output(output_data, output_dims);
- ::tensorflow::testing::StartTiming();
- for (int i = 0; i < iters_; ++i) {
+ for (auto s : state_) {
output.device(device_) = Eigen::SpatialConvolution(input, filter);
tensorflow::testing::DoNotOptimize(output);
}
- ::tensorflow::testing::StopTiming();
device_.deallocate(input_data);
device_.deallocate(filter_data);
@@ -102,13 +101,11 @@
OutputBackward output_backward(output_backward_data, output_dims);
InputBackward input_backward(input_backward_data, input_dims);
- ::tensorflow::testing::StartTiming();
- for (int i = 0; i < iters_; ++i) {
+ for (auto s : state_) {
input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput(
filter, output_backward, input_rows, input_cols);
tensorflow::testing::DoNotOptimize(input_backward);
}
- ::tensorflow::testing::StopTiming();
device_.deallocate(filter_data);
device_.deallocate(output_backward_data);
@@ -143,13 +140,11 @@
OutputBackward output_backward(output_backward_data, input_dims);
FilterBackward filter_backward(filter_backward_data, filter_dims);
- ::tensorflow::testing::StartTiming();
- for (int i = 0; i < iters_; ++i) {
+ for (auto s : state_) {
filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
input, output_backward, filter_rows, filter_cols);
tensorflow::testing::DoNotOptimize(filter_backward);
}
- ::tensorflow::testing::StopTiming();
device_.deallocate(input_data);
device_.deallocate(output_backward_data);
@@ -157,7 +152,8 @@
}
private:
- int iters_;
+ ::testing::benchmark::State& state_;
+
Device& device_;
};
@@ -170,8 +166,9 @@
using Dimensions = Eigen::DSizes<Eigen::Index, 5>;
- CuboidConvolutionBenchmarksSuite(int iters, Device& device)
- : iters_(iters), device_(device) {}
+ CuboidConvolutionBenchmarksSuite(::testing::benchmark::State& state,
+ Device& device)
+ : state_(state), device_(device) {}
Eigen::Index BufferSize(const Dimensions& dims) {
return dims.TotalSize() * sizeof(Scalar);
@@ -198,12 +195,10 @@
Filter filter(filter_data, filter_dims);
Output output(output_data, output_dims);
- ::tensorflow::testing::StartTiming();
- for (int i = 0; i < iters_; ++i) {
+ for (auto s : state_) {
output.device(device_) = Eigen::CuboidConvolution(input, filter);
tensorflow::testing::DoNotOptimize(output);
}
- ::tensorflow::testing::StopTiming();
device_.deallocate(input_data);
device_.deallocate(filter_data);
@@ -240,13 +235,11 @@
OutputBackward output_backward(output_backward_data, output_dims);
InputBackward input_backward(input_backward_data, input_dims);
- ::tensorflow::testing::StartTiming();
- for (int i = 0; i < iters_; ++i) {
+ for (auto s : state_) {
input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput(
filter, output_backward, input_planes, input_rows, input_cols);
tensorflow::testing::DoNotOptimize(input_backward);
}
- ::tensorflow::testing::StopTiming();
device_.deallocate(filter_data);
device_.deallocate(output_backward_data);
@@ -283,13 +276,11 @@
OutputBackward output_backward(output_backward_data, output_dims);
FilterBackward filter_backward(filter_backward_data, filter_dims);
- ::tensorflow::testing::StartTiming();
- for (int i = 0; i < iters_; ++i) {
+ for (auto s : state_) {
filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
input, output_backward, filter_planes, filter_rows, filter_cols);
tensorflow::testing::DoNotOptimize(filter_backward);
}
- ::tensorflow::testing::StopTiming();
device_.deallocate(input_data);
device_.deallocate(output_backward_data);
@@ -297,7 +288,7 @@
}
private:
- int iters_;
+ ::testing::benchmark::State& state_;
Device& device_;
};
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
index 12fa7f3..2abc2e9 100644
--- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -27,19 +27,17 @@
// Spatial Convolutions //
// -------------------------------------------------------------------------- //
-void SpatialConvolution(int iters, int num_threads,
+void SpatialConvolution(::testing::benchmark::State& state, int num_threads,
/* Input dimensions: */
int input_batches, int input_height, int input_width,
int input_depth,
/* Filter (kernel) dimensions: */
int filter_count, int filter_height, int filter_width) {
- ::tensorflow::testing::StopTiming();
-
CREATE_THREAD_POOL(num_threads);
using Benchmark =
SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
- auto benchmark = Benchmark(iters, device);
+ auto benchmark = Benchmark(state, device);
typename Benchmark::Dimensions input_dims(input_batches, input_height,
input_width, input_depth);
@@ -52,23 +50,22 @@
(input_dims.TotalSize() / input_depth) * filter_count;
auto flops =
num_computed_elements * (input_depth * filter_height * filter_width);
- ::tensorflow::testing::ItemsProcessed(flops * iters);
+ state.SetItemsProcessed(flops * state.iterations());
}
-void SpatialConvolutionBackwardInput(int iters, int num_threads,
+void SpatialConvolutionBackwardInput(::testing::benchmark::State& state,
+ int num_threads,
/* Input dimensions: */
int input_batches, int input_height,
int input_width, int input_depth,
/* Filter (kernel) dimensions: */
int filter_count, int filter_height,
int filter_width) {
- ::tensorflow::testing::StopTiming();
-
CREATE_THREAD_POOL(num_threads);
using Benchmark =
SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
- auto benchmark = Benchmark(iters, device);
+ auto benchmark = Benchmark(state, device);
typename Benchmark::Dimensions input_dims(input_batches, input_height,
input_width, input_depth);
@@ -80,23 +77,22 @@
auto num_computed_elements = input_dims.TotalSize();
auto flops =
num_computed_elements * (input_depth * filter_height * filter_width);
- ::tensorflow::testing::ItemsProcessed(flops * iters);
+ state.SetItemsProcessed(flops * state.iterations());
}
-void SpatialConvolutionBackwardKernel(int iters, int num_threads,
+void SpatialConvolutionBackwardKernel(::testing::benchmark::State& state,
+ int num_threads,
/* Input dimensions: */
int input_batches, int input_height,
int input_width, int input_depth,
/* Filter (kernel) dimensions: */
int filter_count, int filter_height,
int filter_width) {
- ::tensorflow::testing::StopTiming();
-
CREATE_THREAD_POOL(num_threads);
using Benchmark =
SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
- auto benchmark = Benchmark(iters, device);
+ auto benchmark = Benchmark(state, device);
typename Benchmark::Dimensions input_dims(input_batches, input_height,
input_width, input_depth);
@@ -108,7 +104,7 @@
auto num_computed_elements = filter_dims.TotalSize();
auto flops =
num_computed_elements * (input_batches * input_height * input_width);
- ::tensorflow::testing::ItemsProcessed(flops * iters);
+ state.SetItemsProcessed(flops * state.iterations());
}
// Macro arguments names: --------------------------------------------------- //
@@ -126,26 +122,26 @@
#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \
static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
- FW)(int iters) { \
- ::tensorflow::testing::SetLabel(LABEL); \
- SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \
+ FW)(::testing::benchmark::State & state) { \
+ state.SetLabel(LABEL); \
+ SpatialConvolution(state, NT, N, H, W, C, FC, FH, FW); \
} \
BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \
static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \
- FH, FW)(int iters) { \
- ::tensorflow::testing::SetLabel(LABEL); \
- SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \
+ FH, FW)(::testing::benchmark::State & state) { \
+ state.SetLabel(LABEL); \
+ SpatialConvolutionBackwardInput(state, NT, N, H, W, C, FC, FH, FW); \
} \
BENCHMARK( \
BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, FW))
#define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \
static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
- FH, FW)(int iters) { \
- ::tensorflow::testing::SetLabel(LABEL); \
- SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \
+ FH, FW)(::testing::benchmark::State & state) { \
+ state.SetLabel(LABEL); \
+ SpatialConvolutionBackwardKernel(state, NT, N, H, W, C, FC, FH, FW); \
} \
BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
FH, FW))
@@ -248,20 +244,18 @@
// Cuboid Convolutions //
// -------------------------------------------------------------------------- //
-void CuboidConvolution(int iters, int num_threads,
+void CuboidConvolution(::testing::benchmark::State& state, int num_threads,
/* Input dimensions: */
int input_batches, int input_height, int input_width,
int input_planes, int input_depth,
/* Filter (kernel) dimensions: */
int filter_count, int filter_height, int filter_width,
int filter_planes) {
- ::tensorflow::testing::StopTiming();
-
CREATE_THREAD_POOL(num_threads);
using Benchmark =
CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
- auto benchmark = Benchmark(iters, device);
+ auto benchmark = Benchmark(state, device);
typename Benchmark::Dimensions input_dims(
input_batches, input_height, input_width, input_planes, input_depth);
@@ -274,10 +268,11 @@
(input_dims.TotalSize() / input_depth) * filter_count;
auto flops = num_computed_elements *
(input_depth * filter_height * filter_width * filter_planes);
- ::tensorflow::testing::ItemsProcessed(flops * iters);
+ state.SetItemsProcessed(flops * state.iterations());
}
-void CuboidConvolutionBackwardInput(int iters, int num_threads,
+void CuboidConvolutionBackwardInput(::testing::benchmark::State& state,
+ int num_threads,
/* Input dimensions: */
int input_batches, int input_height,
int input_width, int input_planes,
@@ -285,13 +280,11 @@
/* Filter (kernel) dimensions: */
int filter_count, int filter_height,
int filter_width, int filter_planes) {
- ::tensorflow::testing::StopTiming();
-
CREATE_THREAD_POOL(num_threads);
using Benchmark =
CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
- auto benchmark = Benchmark(iters, device);
+ auto benchmark = Benchmark(state, device);
typename Benchmark::Dimensions input_dims(
input_batches, input_height, input_width, input_planes, input_depth);
@@ -303,10 +296,11 @@
auto num_computed_elements = input_dims.TotalSize();
auto flops = num_computed_elements *
(input_depth * filter_height * filter_width * filter_planes);
- ::tensorflow::testing::ItemsProcessed(flops * iters);
+ state.SetItemsProcessed(flops * state.iterations());
}
-void CuboidConvolutionBackwardKernel(int iters, int num_threads,
+void CuboidConvolutionBackwardKernel(::testing::benchmark::State& state,
+ int num_threads,
/* Input dimensions: */
int input_batches, int input_height,
int input_width, int input_planes,
@@ -314,13 +308,11 @@
/* Filter (kernel) dimensions: */
int filter_count, int filter_height,
int filter_width, int filter_planes) {
- ::tensorflow::testing::StopTiming();
-
CREATE_THREAD_POOL(num_threads);
using Benchmark =
CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
- auto benchmark = Benchmark(iters, device);
+ auto benchmark = Benchmark(state, device);
typename Benchmark::Dimensions input_dims(
input_batches, input_height, input_width, input_planes, input_depth);
@@ -332,9 +324,16 @@
auto num_computed_elements = filter_dims.TotalSize();
auto flops = num_computed_elements *
(input_batches * input_height * input_width * input_planes);
- ::tensorflow::testing::ItemsProcessed(flops * iters);
+ state.SetItemsProcessed(flops * state.iterations());
}
+// The multiple #'s in the function names + the `::testing::benchmark::State&`
+// as parameters apparently confuses clang if they are not on the same line. So
+// we need to turn off LINT and clang-format for this block.
+//
+// clang-format off
+// NOLINTBEGIN
+
// Macro arguments names: --------------------------------------------------- //
// NT: num threads
// N: batch size
@@ -354,33 +353,33 @@
_f_##FC##_##FH##_##FW##_##FP)
#define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
- static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \
- FP)(int iters) { \
- ::tensorflow::testing::SetLabel(LABEL); \
- CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, FP)(::testing::benchmark::State & state) { \
+ state.SetLabel(LABEL); \
+ CuboidConvolution(state, NT, N, H, W, P, C, FC, FH, FW, FP); \
} \
BENCHMARK( \
BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, FP))
#define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
- static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
- FH, FW, FP)(int iters) { \
- ::tensorflow::testing::SetLabel(LABEL); \
- CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, FH, FW, FP)(::testing::benchmark::State & state) { \
+ state.SetLabel(LABEL); \
+ CuboidConvolutionBackwardInput(state, NT, N, H, W, P, C, FC, FH, FW, FP); \
} \
BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
FH, FW, FP))
#define BM_CuboidConvolutionBwdKernel(NT, N, H, W, P, C, FC, FH, FW, FP, \
LABEL) \
- static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \
- FC, FH, FW, FP)(int iters) { \
- ::tensorflow::testing::SetLabel(LABEL); \
- CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, FH, FW, FP)(::testing::benchmark::State & state) { \
+ state.SetLabel(LABEL); \
+ CuboidConvolutionBackwardKernel(state, NT, N, H, W, P, C, FC, FH, FW, FP); \
} \
BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \
FH, FW, FP))
+// NOLINTEND
+// clang-format on
+
#define BM_CuboidConvolutions(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
BM_CuboidConvolution(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
BM_CuboidConvolution(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
diff --git a/tensorflow/core/kernels/fused_batch_norm_op_test.cc b/tensorflow/core/kernels/fused_batch_norm_op_test.cc
index 734fb29..989fbc2 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op_test.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op_test.cc
@@ -283,18 +283,23 @@
// -------------------------------------------------------------------------- //
// FusedBatchNorm inference
// -------------------------------------------------------------------------- //
+// clang-format off
+// NOLINTBEGIN
+#define BM_FusedBatchNorm(N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE) \
+ static void BM_NAME(FusedBatchNorm, N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE)(::testing::benchmark::State & state) { \
+ test::Benchmark( \
+ #DEVICE, \
+ FusedBatchNormInference<T>(N, H, W, C, IS_TRAINING, FORMAT_##FORMAT), \
+ /*old_benchmark_api*/ false) \
+ .Run(state); \
+ state.SetItemsProcessed(state.iterations() * N * H * W * C); \
+ } \
+ BENCHMARK( \
+ BM_NAME(FusedBatchNorm, N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE)) \
+ ->UseRealTime();
-#define BM_FusedBatchNorm(N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE) \
- static void BM_NAME(FusedBatchNorm, N, H, W, C, T, IS_TRAINING, FORMAT, \
- DEVICE)(int iters) { \
- testing::UseRealTime(); \
- testing::ItemsProcessed(static_cast<int64>(iters) * N * H * W * C); \
- test::Benchmark(#DEVICE, FusedBatchNormInference<T>( \
- N, H, W, C, IS_TRAINING, FORMAT_##FORMAT)) \
- .Run(iters); \
- } \
- BENCHMARK( \
- BM_NAME(FusedBatchNorm, N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE));
+// NOLINTEND
+// clang-format on
BM_FusedBatchNorm(64, 14, 14, 256, fp32, false, NHWC, cpu);
BM_FusedBatchNorm(64, 14, 14, 256, fp16, false, NHWC, cpu);
@@ -320,17 +325,19 @@
// FusedBatchNorm gradient
// -------------------------------------------------------------------------- //
-#define BM_FusedBatchNormGrad(N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE) \
- static void BM_NAME(FusedBatchNormGrad, N, H, W, C, T, IS_TRAINING, FORMAT, \
- DEVICE)(int iters) { \
- testing::UseRealTime(); \
- testing::ItemsProcessed(static_cast<int64>(iters) * N * H * W * C); \
- test::Benchmark(#DEVICE, FusedBatchNormGrad<T>(N, H, W, C, IS_TRAINING, \
- FORMAT_##FORMAT)) \
- .Run(iters); \
- } \
- BENCHMARK(BM_NAME(FusedBatchNormGrad, N, H, W, C, T, IS_TRAINING, FORMAT, \
- DEVICE));
+#define BM_FusedBatchNormGrad(N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE) \
+ static void BM_NAME(FusedBatchNormGrad, N, H, W, C, T, IS_TRAINING, FORMAT, \
+ DEVICE)(::testing::benchmark::State & state) { \
+ test::Benchmark( \
+ #DEVICE, \
+ FusedBatchNormGrad<T>(N, H, W, C, IS_TRAINING, FORMAT_##FORMAT), \
+ /*old_benchmark_api*/ false) \
+ .Run(state); \
+ state.SetItemsProcessed(state.iterations() * N * H * W * C); \
+ } \
+ BENCHMARK( \
+ BM_NAME(FusedBatchNormGrad, N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE)) \
+ ->UseRealTime();
#define BM_FusedBatchNormGradResnetShapes(T, IS_TRAINING, FORMAT, DEVICE) \
BM_FusedBatchNormGrad(64, 56, 56, 64, T, IS_TRAINING, FORMAT, DEVICE); \
diff --git a/tensorflow/core/kernels/generate_vocab_remapping_op.cc b/tensorflow/core/kernels/generate_vocab_remapping_op.cc
index d4cf838..e60abc4 100644
--- a/tensorflow/core/kernels/generate_vocab_remapping_op.cc
+++ b/tensorflow/core/kernels/generate_vocab_remapping_op.cc
@@ -72,6 +72,7 @@
kUnusedLookupDelim,
-1, // key_index, use the line number.
-2, // value_index, use the whole line/token.
+ 0, // No offset.
context->env(), new_vocab_table));
OP_REQUIRES(context,
new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
@@ -101,6 +102,7 @@
old_vocab_filename, old_vocab_size_, kUnusedLookupDelim,
-2, // key_index, use the whole line/token.
-1, // value_index, use the line number.
+ 0, // No offset.
context->env(), old_vocab_table));
// Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
diff --git a/tensorflow/core/kernels/image/resize_bilinear_op_test.cc b/tensorflow/core/kernels/image/resize_bilinear_op_test.cc
index fe0d4d1..bb9ce96 100644
--- a/tensorflow/core/kernels/image/resize_bilinear_op_test.cc
+++ b/tensorflow/core/kernels/image/resize_bilinear_op_test.cc
@@ -144,7 +144,7 @@
TensorShape({batch_size, output_width, output_height, channels})));
ResizeBilinearBaseline(input->tensor<float, 4>(),
expected->tensor<float, 4>());
- test::ExpectClose(*expected, *GetOutput(0), /*atol=*/3e-5);
+ test::ExpectClose(*expected, *GetOutput(0), /*atol=*/4e-5);
}
void RunManyRandomTests(int channels) {
diff --git a/tensorflow/core/kernels/linalg/banded_triangular_solve_op_test.cc b/tensorflow/core/kernels/linalg/banded_triangular_solve_op_test.cc
index 7c20b88..f4b54fb 100644
--- a/tensorflow/core/kernels/linalg/banded_triangular_solve_op_test.cc
+++ b/tensorflow/core/kernels/linalg/banded_triangular_solve_op_test.cc
@@ -98,14 +98,16 @@
// BS: boolean indicating whether to use the banded solver
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
-#define BM_BandedTriangularSolveDev(K, N, M, BS, T, TT, D) \
- static void BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT( \
- int iters) { \
- testing::UseRealTime(); \
- testing::ItemsProcessed(static_cast<int64>(iters) * K * N + N * M); \
- test::Benchmark(#D, BandedTriangularSolve<T>(K, N, M, BS, TT)).Run(iters); \
- } \
- BENCHMARK(BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT);
+#define BM_BandedTriangularSolveDev(K, N, M, BS, T, TT, D) \
+ static void BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT( \
+ ::testing::benchmark::State& state) { \
+ test::Benchmark(#D, BandedTriangularSolve<T>(K, N, M, BS, TT), \
+ /*old_benchmark_api*/ false) \
+ .Run(state); \
+ state.SetItemsProcessed(state.iterations() * K * N + N * M); \
+ } \
+ BENCHMARK(BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT) \
+ ->UseRealTime();
#define BM_BandedTriangularSolve(K, N, M, BS, D) \
BM_BandedTriangularSolveDev(K, N, M, BS, float, DT_FLOAT, D); \
diff --git a/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_test.cc b/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_test.cc
index 7bb71ae..e03f293 100644
--- a/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_test.cc
+++ b/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_test.cc
@@ -101,18 +101,18 @@
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
-#define BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, T, TT, D) \
- static void \
- BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D( \
- int iters) { \
- testing::UseRealTime(); \
- testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
- M * N * 2); \
- test::Benchmark( \
- #D, MatrixTriangularSolveWithBroadcast<T>(B1, B2, M, N, MB, TT)) \
- .Run(iters); \
- } \
- BENCHMARK( \
+#define BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, T, TT, D) \
+ static void \
+ BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D( \
+ ::testing::benchmark::State& state) { \
+ state.SetItemsProcessed(state.iterations() * std::max(B1, B2) * M * M * \
+ N * 2); \
+ test::Benchmark( \
+ #D, MatrixTriangularSolveWithBroadcast<T>(B1, B2, M, N, MB, TT), \
+ /*old_benchmark_api*/ false) \
+ .Run(state); \
+ } \
+ BENCHMARK( \
BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
index cb757ac..d21ac54 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.cc
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -105,6 +105,7 @@
OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_));
string delimiter;
OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
OP_REQUIRES(ctx, delimiter.size() == 1,
@@ -141,7 +142,7 @@
}
OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile(
vocab_filename, vocab_size_, delimiter_, key_index_,
- value_index_, ctx->env(), table));
+ value_index_, offset_, ctx->env(), table));
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
memory_used_before);
@@ -154,6 +155,7 @@
char delimiter_;
int64 key_index_;
int64 value_index_;
+ int64 offset_;
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
};
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc
index d07b525..aa39063 100644
--- a/tensorflow/core/kernels/lookup_util.cc
+++ b/tensorflow/core/kernels/lookup_util.cc
@@ -77,7 +77,7 @@
// delimiter.
Status Init(const string& filename, int64 vocab_size, char delimiter,
DataType key_dtype, int64 key_index, DataType value_dtype,
- int64 value_index, Env* env) {
+ int64 value_index, int64 offset, Env* env) {
filename_ = filename;
vocab_size_ = vocab_size;
delimiter_ = delimiter;
@@ -93,6 +93,7 @@
input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
valid_ = true;
next_id_ = 0;
+ offset_ = offset;
ignore_split_ = std::max(key_index_, value_index_) < 0;
Next();
return status_;
@@ -143,6 +144,7 @@
return;
}
}
+
status_ = SetValue(line, tokens, key_index_, &key_);
if (!status_.ok()) {
valid_ = false;
@@ -186,6 +188,7 @@
int64 value_index_;
Env* env_;
int64 next_id_;
+ int64 offset_;
int64 vocab_size_;
string filename_;
char delimiter_;
@@ -199,7 +202,7 @@
Status SetValue(const string& line, const std::vector<string>& tokens,
int64 index, Tensor* tensor) {
if (index == kLineNumber) {
- tensor->flat<int64>()(0) = next_id_;
+ tensor->flat<int64>()(0) = next_id_ + offset_;
return Status::OK();
}
const string& token = (index == kWholeLine) ? line : tokens[index];
@@ -212,7 +215,7 @@
return errors::InvalidArgument("Field ", token, " in line ", next_id_,
" is not a valid int32.");
}
- tensor->flat<int32>()(0) = value;
+ tensor->flat<int32>()(0) = value + offset_;
} break;
case DT_INT64: {
int64 value;
@@ -352,7 +355,7 @@
// Helper function to initialize an InitializableLookupTable from a text file.
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
char delimiter, int32 key_index,
- int32 value_index, Env* env,
+ int32 value_index, int64 offset, Env* env,
InitializableLookupTable* table) {
if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
return errors::InvalidArgument(
@@ -380,7 +383,8 @@
TextFileLineIterator iter;
TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
- key_index, value_dtype, value_index, env));
+ key_index, value_dtype, value_index, offset,
+ env));
// For initialization from files, ignore if the table is already
// initialized. The table shared name should contain the filename to
// avoid trying to initialize the same table from the same file at the same
diff --git a/tensorflow/core/kernels/lookup_util.h b/tensorflow/core/kernels/lookup_util.h
index 7e53ed5..26974ab 100644
--- a/tensorflow/core/kernels/lookup_util.h
+++ b/tensorflow/core/kernels/lookup_util.h
@@ -53,7 +53,7 @@
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
char delimiter, int32 key_index,
- int32 value_index, Env* env,
+ int32 value_index, int64 offset, Env* env,
InitializableLookupTable* table);
// Initializes `table` from `dataset` by iterating over it. Caller retains
diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD
index 04d9241..5166f4b 100644
--- a/tensorflow/core/kernels/mlir_generated/BUILD
+++ b/tensorflow/core/kernels/mlir_generated/BUILD
@@ -49,7 +49,6 @@
"gpu_op_atan.cc",
"gpu_op_atanh.cc",
"gpu_op_ceil.cc",
- "gpu_op_complex.cc",
"gpu_op_complex_abs.cc",
"gpu_op_conj.cc",
"gpu_op_cos.cc",
@@ -100,6 +99,55 @@
compatible_with = get_compatible_with_cloud(),
)
+filegroup(
+ name = "enabled_binary_gpu_kernel_srcs",
+ srcs = [
+ "gpu_op_complex.cc",
+ ],
+ compatible_with = get_compatible_with_cloud(),
+)
+
+filegroup(
+ name = "experimental_binary_gpu_kernel_srcs",
+ srcs = [
+ "gpu_op_add.cc",
+ "gpu_op_atan2.cc",
+ "gpu_op_bitwise_and.cc",
+ "gpu_op_bitwise_or.cc",
+ "gpu_op_bitwise_xor.cc",
+ "gpu_op_div.cc",
+ "gpu_op_equal.cc",
+ "gpu_op_floor_div.cc",
+ "gpu_op_greater.cc",
+ "gpu_op_greater_equal.cc",
+ "gpu_op_left_shift.cc",
+ "gpu_op_less.cc",
+ "gpu_op_less_equal.cc",
+ "gpu_op_logical_and.cc",
+ "gpu_op_logical_or.cc",
+ "gpu_op_maximum.cc",
+ "gpu_op_minimum.cc",
+ "gpu_op_mul.cc",
+ "gpu_op_not_equal.cc",
+ "gpu_op_pow.cc",
+ "gpu_op_right_shift.cc",
+ "gpu_op_squared_difference.cc",
+ "gpu_op_sub.cc",
+ "gpu_op_zeta.cc",
+ ],
+ compatible_with = get_compatible_with_cloud(),
+)
+
+filegroup(
+ name = "binary_gpu_kernel_srcs",
+ srcs = [
+ ":enabled_binary_gpu_kernel_srcs",
+ ] + if_mlir_experimental_kernels_enabled(
+ if_true = [":experimental_binary_gpu_kernel_srcs"],
+ ),
+ compatible_with = get_compatible_with_cloud(),
+)
+
cc_library(
name = "base_op",
srcs = ["base_op.cc"],
@@ -150,7 +198,6 @@
":gpu_atanh_kernels",
":gpu_ceil_kernels",
":gpu_complex_abs_kernels",
- ":gpu_complex_kernels",
":gpu_conj_kernels",
":gpu_cos_kernels",
":gpu_cosh_kernels",
@@ -201,35 +248,15 @@
tf_kernel_library(
name = "gpu_cwise_binary_op",
- srcs = [
- "gpu_op_add.cc",
- "gpu_op_atan2.cc",
- "gpu_op_bitwise_and.cc",
- "gpu_op_bitwise_or.cc",
- "gpu_op_bitwise_xor.cc",
- "gpu_op_div.cc",
- "gpu_op_equal.cc",
- "gpu_op_floor_div.cc",
- "gpu_op_greater.cc",
- "gpu_op_greater_equal.cc",
- "gpu_op_left_shift.cc",
- "gpu_op_less.cc",
- "gpu_op_less_equal.cc",
- "gpu_op_logical_and.cc",
- "gpu_op_logical_or.cc",
- "gpu_op_maximum.cc",
- "gpu_op_minimum.cc",
- "gpu_op_mul.cc",
- "gpu_op_not_equal.cc",
- "gpu_op_pow.cc",
- "gpu_op_right_shift.cc",
- "gpu_op_squared_difference.cc",
- "gpu_op_sub.cc",
- "gpu_op_zeta.cc",
- ],
+ srcs = [":binary_gpu_kernel_srcs"],
tags = [
"manual",
],
+ # Technically we only need to depend on the kernel libraries for the
+ # kernels which are enabled by default. But this would make our BUILD
+ # target structure uglier. We already need to make sure that those
+ # targets can be built, so it should not hurt to link them in even if
+ # they are currently not needed yet.
deps = [
":base_gpu_op",
":gpu_add_v2_kernels",
@@ -237,6 +264,7 @@
":gpu_bitwise_and_kernels",
":gpu_bitwise_or_kernels",
":gpu_bitwise_xor_kernels",
+ ":gpu_complex_kernels",
":gpu_div_kernels",
":gpu_equal_kernels",
":gpu_floor_div_kernels",
@@ -281,6 +309,7 @@
# but we want to avoid building them if they are not needed.
deps = if_cuda_or_rocm([
":gpu_cwise_unary_op",
+ ":gpu_cwise_binary_op",
]) + if_mlir_experimental_kernels_enabled([":experimental_cwise_op"]),
)
@@ -288,9 +317,22 @@
name = "experimental_cwise_op",
srcs = [],
deps = [
- ":cpu_cwise_unary_op",
":cpu_cwise_binary_op",
- ] + if_cuda_or_rocm([":gpu_cwise_binary_op"]),
+ ":cpu_cwise_unary_op",
+ ],
+)
+
+tf_kernel_library(
+ name = "cast_op",
+ srcs = ["gpu_op_cast.cc"],
+ tags = [
+ "manual",
+ ],
+ deps = [
+ ":base_gpu_op",
+ ":gpu_cast_kernels",
+ "//third_party/eigen3",
+ ],
)
cc_library(
@@ -453,7 +495,7 @@
"f32",
"f64",
],
- unroll_factors = "4",
+ # Cannot vectorize.
)
gpu_kernel_library(
@@ -464,7 +506,8 @@
"f32",
"f64",
],
- unroll_factors = "4",
+ # May be compute-bound.
+ # unroll_factors = "4",
)
gpu_kernel_library(
@@ -490,7 +533,7 @@
"f32",
"f64",
],
- unroll_factors = "4",
+ # Cannot vectorize.
)
gpu_kernel_library(
@@ -501,7 +544,7 @@
"f32",
"f64",
],
- unroll_factors = "4",
+ # Cannot vectorize.
)
gpu_kernel_library(
@@ -545,7 +588,8 @@
"f32",
"f64",
],
- unroll_factors = "4",
+ # May be compute-bound.
+ # unroll_factors = "4",
)
gpu_kernel_library(
@@ -643,7 +687,8 @@
"f32",
"f64",
],
- unroll_factors = "4",
+ # May be compute-bound.
+ # unroll_factors = "4",
)
[
@@ -707,7 +752,6 @@
"i16",
"i64",
],
- unroll_factors = "4",
)
gpu_kernel_library(
@@ -988,7 +1032,6 @@
"f64",
"i64",
],
- unroll_factors = "4",
)
gpu_kernel_library(
diff --git a/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h b/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h
index 54019c6..a6b0535 100644
--- a/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h
+++ b/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h
@@ -43,16 +43,16 @@
void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape,
const absl::InlinedVector<T, 10>& lhs_input,
const TensorShape& rhs_shape,
- const absl::InlinedVector<T, 10>& rhs_input, bool add_t,
- bool add_tout) {
+ const absl::InlinedVector<T, 10>& rhs_input,
+ const test::OpsTestConfig& config) {
auto builder = NodeDefBuilder("some_name", op_name)
.Input(FakeInput(DataTypeToEnum<T>::v()))
.Input(FakeInput(DataTypeToEnum<T>::v()));
- if (add_t) {
- builder.Attr("T", DataTypeToEnum<T>::v());
+ if (config.add_t) {
+ builder.Attr(config.input_attribute, DataTypeToEnum<T>::v());
}
- if (add_tout) {
- builder.Attr("Tout", DataTypeToEnum<OutT>::v());
+ if (config.add_tout) {
+ builder.Attr(config.output_attribute, DataTypeToEnum<OutT>::v());
}
TF_ASSERT_OK(builder.Finalize(node_def()));
@@ -73,7 +73,7 @@
const absl::InlinedVector<OutT, 10>& expected_output,
const test::OpsTestConfig& config) {
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
- config.add_t, config.add_tout);
+ config);
TF_ASSERT_OK(RunOpKernel());
// Compare output to expectation.
@@ -96,7 +96,7 @@
const absl::InlinedVector<T, 10>& rhs_input,
const test::OpsTestConfig& config) {
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
- config.add_t, config.add_tout);
+ config);
auto status = RunOpKernel();
EXPECT_FALSE(status.ok());
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@@ -188,6 +188,42 @@
template <typename T, typename BaselineT, typename OutT,
typename BaselineOutT>
+ void TestOneEffectiveScalar(const std::string& op_name, T scalar_input,
+ const TensorShape& other_shape,
+ const absl::InlinedVector<T, 10>& other_input,
+ BaselineOutT (*baseline_callback)(BaselineT,
+ BaselineT),
+ const test::OpsTestConfig& config) {
+ // Prepare inputs.
+ TensorShape effective_scalar_shape{1, 1, 1, 1, 1, 1, 1};
+ CHECK(other_input.size() <= other_shape.num_elements() &&
+ "expect other input shape to hold all input values");
+ auto repeated_other_input =
+ test::RepeatInputToMatchShape(other_input, other_shape.num_elements());
+
+ // Compute expected results.
+ absl::InlinedVector<OutT, 10> expected_output;
+ for (auto it = repeated_other_input.begin(),
+ end = repeated_other_input.end();
+ it != end; ++it) {
+ auto scalar = static_cast<BaselineT>(scalar_input);
+ auto other_value = static_cast<BaselineT>(*it);
+ auto result = static_cast<OutT>(baseline_callback(scalar, other_value));
+ expected_output.push_back(result);
+ }
+
+ auto scalar_input_vector = test::InputAsVector<T>({scalar_input});
+ TensorShape expected_shape = other_shape;
+ while (expected_shape.dims() < effective_scalar_shape.dims()) {
+ expected_shape.InsertDim(0, 1);
+ }
+ RunAndExpectResult<T, OutT>(
+ op_name, effective_scalar_shape, scalar_input_vector, other_shape,
+ repeated_other_input, expected_shape, expected_output, config);
+ }
+
+ template <typename T, typename BaselineT, typename OutT,
+ typename BaselineOutT>
void TestBroadcastingExpand(const std::string& op_name,
const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input,
@@ -330,6 +366,13 @@
baseline_callback, config); \
} \
\
+ TEST_F(BinaryOpsTest, op_name##TestOneEffectiveScalar##test_name) { \
+ TestOneEffectiveScalar<T, BaselineT, OutT, BaselineOutT>( \
+ #op_name, /*scalar_input=*/lhs_input.front(), \
+ /*other_shape=*/test::DefaultInputShape(), /*other_input=*/rhs_input, \
+ baseline_callback, config); \
+ } \
+ \
TEST_F(BinaryOpsTest, op_name##IncompatibleShapes##test_name) { \
TestIncompatibleShapes<T, OutT>(#op_name, lhs_input, rhs_input, config); \
} \
diff --git a/tensorflow/core/kernels/mlir_generated/base_cpu_op.h b/tensorflow/core/kernels/mlir_generated/base_cpu_op.h
index d680b2b..8eaf227 100644
--- a/tensorflow/core/kernels/mlir_generated/base_cpu_op.h
+++ b/tensorflow/core/kernels/mlir_generated/base_cpu_op.h
@@ -20,57 +20,39 @@
namespace tensorflow {
-#define GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(tf_op, mlir_type, tf_data_type, \
- data_type) \
- GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, CPU, mlir_type, tf_data_type, \
- data_type)
+#define GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(tf_op, input_type) \
+ GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, CPU, input_type)
-#define GENERATE_UNARY_CPU_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
- GENERATE_UNARY_KERNEL(tf_op, CPU, mlir_type, tf_data_type, data_type)
+#define GENERATE_UNARY_CPU_KERNEL(tf_op, input_type) \
+ GENERATE_UNARY_KERNEL(tf_op, CPU, input_type)
-#define GENERATE_UNARY_CPU_KERNEL2(tf_op, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, \
- input_data_type) \
- GENERATE_UNARY_KERNEL2(tf_op, CPU, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, input_data_type)
+#define GENERATE_UNARY_CPU_KERNEL2(tf_op, input_type, output_type) \
+ GENERATE_UNARY_KERNEL2(tf_op, CPU, input_type, output_type)
-#define REGISTER_ALIASED_CPU_KERNEL(tf_op, mlir_op, mlir_type, \
- mlir_output_type, data_type) \
- REGISTER_ALIASED_KERNEL(tf_op, mlir_op, CPU, mlir_type, mlir_output_type, \
- data_type)
+#define REGISTER_ALIASED_CPU_KERNEL(tf_op, mlir_op, input_type, output_type) \
+ REGISTER_ALIASED_KERNEL(tf_op, mlir_op, CPU, input_type, output_type)
-#define REGISTER_CPU_KERNEL(tf_op, mlir_type, mlir_output_type, data_type) \
- REGISTER_KERNEL(tf_op, CPU, mlir_type, mlir_output_type, data_type)
+#define REGISTER_CPU_KERNEL(tf_op, input_type, output_type) \
+ REGISTER_KERNEL(tf_op, CPU, input_type, output_type)
-#define REGISTER_COMPLEX_CPU_KERNEL(tf_op, mlir_type, mlir_output_type, \
- data_type, input_data_type) \
- REGISTER_COMPLEX_KERNEL(tf_op, CPU, mlir_type, mlir_output_type, data_type, \
- input_data_type)
+#define REGISTER_COMPLEX_CPU_KERNEL(tf_op, input_type, output_type) \
+ REGISTER_COMPLEX_KERNEL(tf_op, CPU, input_type, output_type)
-#define REGISTER_CPU_KERNEL_NO_TYPE_CONSTRAINT(tf_op, mlir_type, \
- mlir_output_type) \
- REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, CPU, mlir_type, mlir_output_type)
+#define REGISTER_CPU_KERNEL_NO_TYPE_CONSTRAINT(tf_op, input_type) \
+ REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, CPU, input_type)
-#define GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(tf_op, mlir_type, \
- tf_data_type, data_type) \
- GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, CPU, mlir_type, tf_data_type, \
- data_type)
+#define GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(tf_op, input_type) \
+ GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, CPU, input_type)
-#define GENERATE_AND_REGISTER_BINARY_CPU_KERNEL2( \
- tf_op, mlir_type, mlir_output_type, tf_data_type, result_data_type, \
- input_data_type) \
- GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, CPU, mlir_type, \
- mlir_output_type, tf_data_type, \
- result_data_type, input_data_type)
+#define GENERATE_AND_REGISTER_BINARY_CPU_KERNEL2(tf_op, input_type, \
+ output_type) \
+ GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, CPU, input_type, output_type)
-#define GENERATE_BINARY_CPU_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
- GENERATE_BINARY_KERNEL(tf_op, CPU, mlir_type, tf_data_type, data_type)
+#define GENERATE_BINARY_CPU_KERNEL(tf_op, input_type) \
+ GENERATE_BINARY_KERNEL(tf_op, CPU, input_type)
-#define GENERATE_BINARY_CPU_KERNEL2(tf_op, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, \
- input_data_type) \
- GENERATE_BINARY_KERNEL2(tf_op, CPU, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, input_data_type)
+#define GENERATE_BINARY_CPU_KERNEL2(tf_op, input_type, output_type) \
+ GENERATE_BINARY_KERNEL2(tf_op, CPU, input_type, output_type)
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/base_gpu_op.h b/tensorflow/core/kernels/mlir_generated/base_gpu_op.h
index 941a029..ea67c83 100644
--- a/tensorflow/core/kernels/mlir_generated/base_gpu_op.h
+++ b/tensorflow/core/kernels/mlir_generated/base_gpu_op.h
@@ -20,57 +20,39 @@
namespace tensorflow {
-#define GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, \
- data_type) \
- GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, \
- data_type)
+#define GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(tf_op, input_type) \
+ GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, GPU, input_type)
-#define GENERATE_UNARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
- GENERATE_UNARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, data_type)
+#define GENERATE_UNARY_GPU_KERNEL(tf_op, input_type) \
+ GENERATE_UNARY_KERNEL(tf_op, GPU, input_type)
-#define GENERATE_UNARY_GPU_KERNEL2(tf_op, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, \
- input_data_type) \
- GENERATE_UNARY_KERNEL2(tf_op, GPU, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, input_data_type)
+#define GENERATE_UNARY_GPU_KERNEL2(tf_op, input_type, output_type) \
+ GENERATE_UNARY_KERNEL2(tf_op, GPU, input_type, output_type)
-#define REGISTER_ALIASED_GPU_KERNEL(tf_op, mlir_op, mlir_type, \
- mlir_output_type, data_type) \
- REGISTER_ALIASED_KERNEL(tf_op, mlir_op, GPU, mlir_type, mlir_output_type, \
- data_type)
+#define REGISTER_ALIASED_GPU_KERNEL(tf_op, mlir_op, input_type, output_type) \
+ REGISTER_ALIASED_KERNEL(tf_op, mlir_op, GPU, input_type, output_type)
-#define REGISTER_GPU_KERNEL(tf_op, mlir_type, mlir_output_type, data_type) \
- REGISTER_KERNEL(tf_op, GPU, mlir_type, mlir_output_type, data_type)
+#define REGISTER_GPU_KERNEL(tf_op, input_type, output_type) \
+ REGISTER_KERNEL(tf_op, GPU, input_type, output_type)
-#define REGISTER_COMPLEX_GPU_KERNEL(tf_op, mlir_type, mlir_output_type, \
- data_type, input_data_type) \
- REGISTER_COMPLEX_KERNEL(tf_op, GPU, mlir_type, mlir_output_type, data_type, \
- input_data_type)
+#define REGISTER_COMPLEX_GPU_KERNEL(tf_op, input_type, output_type) \
+ REGISTER_COMPLEX_KERNEL(tf_op, GPU, input_type, output_type)
-#define REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(tf_op, mlir_type, \
- mlir_output_type) \
- REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, GPU, mlir_type, mlir_output_type)
+#define REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(tf_op, input_type) \
+ REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, GPU, input_type)
-#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(tf_op, mlir_type, \
- tf_data_type, data_type) \
- GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, \
- data_type)
+#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(tf_op, input_type) \
+ GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, GPU, input_type)
-#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2( \
- tf_op, mlir_type, mlir_output_type, tf_data_type, result_data_type, \
- input_data_type) \
- GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, GPU, mlir_type, \
- mlir_output_type, tf_data_type, \
- result_data_type, input_data_type)
+#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(tf_op, input_type, \
+ output_type) \
+ GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, GPU, input_type, output_type)
-#define GENERATE_BINARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
- GENERATE_BINARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, data_type)
+#define GENERATE_BINARY_GPU_KERNEL(tf_op, input_type) \
+ GENERATE_BINARY_KERNEL(tf_op, GPU, input_type)
-#define GENERATE_BINARY_GPU_KERNEL2(tf_op, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, \
- input_data_type) \
- GENERATE_BINARY_KERNEL2(tf_op, GPU, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, input_data_type)
+#define GENERATE_BINARY_GPU_KERNEL2(tf_op, input_type, output_type) \
+ GENERATE_BINARY_KERNEL2(tf_op, GPU, input_type, output_type)
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/base_op.h b/tensorflow/core/kernels/mlir_generated/base_op.h
index bdd651b..6e7e9cb 100644
--- a/tensorflow/core/kernels/mlir_generated/base_op.h
+++ b/tensorflow/core/kernels/mlir_generated/base_op.h
@@ -107,6 +107,7 @@
input_descs.push_back(
std::move(ConvertTensorToDescriptor<InputDataType>(ctx->input(i))));
}
+ VLOG(4) << ctx->op_kernel().TraceString(*ctx, true);
auto result_desc = Kernel::Invoke(ctx, input_descs);
for (const auto& input_desc : input_descs) {
free(input_desc.descriptor);
@@ -142,116 +143,112 @@
}
};
-#define MLIR_FUNCTION(tf_op, platform, mlir_type, mlir_output_type) \
- _mlir_ciface_##tf_op##_##platform##_##mlir_type##_##mlir_output_type
+#define MLIR_FUNCTION(tf_op, platform, input_type, output_type) \
+ _mlir_ciface_##tf_op##_##platform##_##input_type##_##output_type
-#define REGISTER_ALIASED_KERNEL(tf_op, mlir_op, platform, mlir_type, \
- mlir_output_type, data_type) \
- REGISTER_KERNEL_BUILDER( \
- Name(#tf_op).Device(DEVICE_##platform).TypeConstraint<data_type>("T"), \
- Mlir##mlir_op##platform##mlir_type##mlir_output_type##Op);
+#define MLIR_OP(tf_op, platform, input_type, output_type) \
+ Mlir##tf_op##platform##input_type##output_type##Op
-#define REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_output_type, \
- data_type) \
- REGISTER_ALIASED_KERNEL(tf_op, tf_op, platform, mlir_type, mlir_output_type, \
- data_type)
+#define REGISTER_ALIASED_KERNEL(tf_op, mlir_op, platform, input_type, \
+ output_type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(#tf_op) \
+ .Device(DEVICE_##platform) \
+ .TypeConstraint<typename EnumToDataType<input_type>::Type>("T"), \
+ MLIR_OP(mlir_op, platform, input_type, output_type));
-#define REGISTER_COMPLEX_KERNEL(tf_op, platform, mlir_type, mlir_output_type, \
- data_type, input_data_type) \
- REGISTER_KERNEL_BUILDER( \
- Name(#tf_op) \
- .Device(DEVICE_##platform) \
- .TypeConstraint<input_data_type>("T") \
- .TypeConstraint<data_type>("Tout"), \
- Mlir##tf_op##platform##mlir_type##mlir_output_type##Op);
+#define REGISTER_KERNEL(tf_op, platform, input_type, output_type) \
+ REGISTER_ALIASED_KERNEL(tf_op, tf_op, platform, input_type, output_type)
-#define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, platform, mlir_type, \
- mlir_output_type) \
- REGISTER_KERNEL_BUILDER( \
- Name(#tf_op).Device(DEVICE_##platform), \
- Mlir##tf_op##platform##mlir_type##mlir_output_type##Op);
+#define REGISTER_COMPLEX_KERNEL(tf_op, platform, input_type, output_type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(#tf_op) \
+ .Device(DEVICE_##platform) \
+ .TypeConstraint<typename EnumToDataType<input_type>::Type>("T") \
+ .TypeConstraint<typename EnumToDataType<output_type>::Type>("Tout"), \
+ MLIR_OP(tf_op, platform, input_type, output_type));
+
+#define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, platform, input_type) \
+ REGISTER_KERNEL_BUILDER(Name(#tf_op).Device(DEVICE_##platform), \
+ MLIR_OP(tf_op, platform, input_type, input_type));
// OpKernel with Compute function that converts input tensors to unranked
// memref descriptors and calls mlir-generated unranked kernel. The outputs
// are converted back to tensors using MlirTensorBuffer to take ownership of
// pre-allocated memory.
-#define GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, platform, mlir_type, \
- tf_data_type, data_type) \
- GENERATE_BINARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, data_type) \
- REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_type, data_type)
+#define GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, platform, input_type) \
+ GENERATE_BINARY_KERNEL(tf_op, platform, input_type) \
+ REGISTER_KERNEL(tf_op, platform, input_type, input_type)
-#define GENERATE_AND_REGISTER_BINARY_KERNEL2( \
- tf_op, platform, mlir_type, mlir_output_type, tf_data_type, \
- result_data_type, input_data_type) \
- GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, input_data_type) \
- REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_output_type, input_data_type)
+#define GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, platform, input_type, \
+ output_type) \
+ GENERATE_BINARY_KERNEL2(tf_op, platform, input_type, output_type) \
+ REGISTER_KERNEL(tf_op, platform, input_type, output_type)
-#define GENERATE_BINARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, \
- data_type) \
- GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_type, tf_data_type, \
- data_type, data_type)
+#define GENERATE_BINARY_KERNEL(tf_op, platform, input_type) \
+ GENERATE_BINARY_KERNEL2(tf_op, platform, input_type, input_type)
-#define GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, \
- input_data_type) \
- extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION( \
- tf_op, platform, mlir_type, mlir_output_type)( \
- tensorflow::OpKernelContext * ctx, \
- const ::UnrankedMemRefType<input_data_type>* arg1, \
- const ::UnrankedMemRefType<input_data_type>* arg2); \
- \
- namespace { \
- class Mlir##tf_op##platform##mlir_type##mlir_output_type##Op \
- : public MlirOp<tf_data_type, result_data_type, \
- Mlir##tf_op##platform##mlir_type##mlir_output_type##Op, \
- input_data_type> { \
- public: \
- using MlirOp::MlirOp; \
- \
- static ::UnrankedMemRefType<result_data_type> Invoke( \
- OpKernelContext* ctx, \
- llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
- return ConvertToTyped<result_data_type>( \
- MLIR_FUNCTION(tf_op, platform, mlir_type, mlir_output_type)( \
- ctx, &args[0], &args[1])); \
- } \
- }; \
+#define GENERATE_BINARY_KERNEL2(tf_op, platform, input_type, output_type) \
+ extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, platform, \
+ input_type, output_type)( \
+ tensorflow::OpKernelContext * ctx, \
+ const ::UnrankedMemRefType<typename EnumToDataType<input_type>::Type>* \
+ arg1, \
+ const ::UnrankedMemRefType<typename EnumToDataType<input_type>::Type>* \
+ arg2); \
+ \
+ namespace { \
+ class MLIR_OP(tf_op, platform, input_type, output_type) \
+ : public MlirOp<output_type, typename EnumToDataType<output_type>::Type, \
+ MLIR_OP(tf_op, platform, input_type, output_type), \
+ typename EnumToDataType<input_type>::Type> { \
+ public: \
+ using MlirOp::MlirOp; \
+ using ResultDataType = EnumToDataType<output_type>::Type; \
+ \
+ static ::UnrankedMemRefType<ResultDataType> Invoke( \
+ OpKernelContext* ctx, \
+ llvm::ArrayRef< \
+ ::UnrankedMemRefType<typename EnumToDataType<input_type>::Type>> \
+ args) { \
+ return ConvertToTyped<ResultDataType>(MLIR_FUNCTION( \
+ tf_op, platform, input_type, output_type)(ctx, &args[0], &args[1])); \
+ } \
+ }; \
}
-#define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, platform, mlir_type, \
- tf_data_type, data_type) \
- GENERATE_UNARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, data_type) \
- REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_type, data_type)
+#define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, platform, input_type) \
+ GENERATE_UNARY_KERNEL(tf_op, platform, input_type) \
+ REGISTER_KERNEL(tf_op, platform, input_type, input_type)
-#define GENERATE_UNARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, \
- data_type) \
- GENERATE_UNARY_KERNEL2(tf_op, platform, mlir_type, mlir_type, tf_data_type, \
- data_type, data_type)
+#define GENERATE_UNARY_KERNEL(tf_op, platform, input_type) \
+ GENERATE_UNARY_KERNEL2(tf_op, platform, input_type, input_type)
-#define GENERATE_UNARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
- tf_data_type, result_data_type, \
- input_data_type) \
- extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION( \
- tf_op, platform, mlir_type, mlir_output_type)( \
- tensorflow::OpKernelContext * ctx, \
- const ::UnrankedMemRefType<input_data_type>* arg); \
- \
- namespace { \
- class Mlir##tf_op##platform##mlir_type##mlir_output_type##Op \
- : public MlirOp<tf_data_type, result_data_type, \
- Mlir##tf_op##platform##mlir_type##mlir_output_type##Op, \
- input_data_type> { \
- public: \
- using MlirOp::MlirOp; \
- \
- static ::UnrankedMemRefType<result_data_type> Invoke( \
- OpKernelContext* ctx, \
- llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
- return ConvertToTyped<result_data_type>(MLIR_FUNCTION( \
- tf_op, platform, mlir_type, mlir_output_type)(ctx, &args[0])); \
- } \
- }; \
+#define GENERATE_UNARY_KERNEL2(tf_op, platform, input_type, output_type) \
+ extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, platform, \
+ input_type, output_type)( \
+ tensorflow::OpKernelContext * ctx, \
+ const ::UnrankedMemRefType<typename EnumToDataType<input_type>::Type>* \
+ arg); \
+ \
+ namespace { \
+ class MLIR_OP(tf_op, platform, input_type, output_type) \
+ : public MlirOp<output_type, typename EnumToDataType<output_type>::Type, \
+ MLIR_OP(tf_op, platform, input_type, output_type), \
+ typename EnumToDataType<input_type>::Type> { \
+ public: \
+ using MlirOp::MlirOp; \
+ using ResultDataType = EnumToDataType<output_type>::Type; \
+ \
+ static ::UnrankedMemRefType<ResultDataType> Invoke( \
+ OpKernelContext* ctx, \
+ llvm::ArrayRef< \
+ ::UnrankedMemRefType<typename EnumToDataType<input_type>::Type>> \
+ args) { \
+ return ConvertToTyped<ResultDataType>(MLIR_FUNCTION( \
+ tf_op, platform, input_type, output_type)(ctx, &args[0])); \
+ } \
+ }; \
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/base_ops_test.h b/tensorflow/core/kernels/mlir_generated/base_ops_test.h
index a728d64..a048f74 100644
--- a/tensorflow/core/kernels/mlir_generated/base_ops_test.h
+++ b/tensorflow/core/kernels/mlir_generated/base_ops_test.h
@@ -63,6 +63,8 @@
// Negative atol/rtol will make ExpectClose use the default.
double atol = -1;
double rtol = -1;
+ std::string input_attribute = "T";
+ std::string output_attribute = "Tout";
OpsTestConfig ExpectStrictlyEqual() {
OpsTestConfig config = *this;
config.expect_strictly_equal = true;
@@ -93,6 +95,16 @@
config.atol = new_atol;
return config;
}
+ OpsTestConfig InputAttribute(const std::string& attr) {
+ OpsTestConfig config = *this;
+ config.input_attribute = attr;
+ return config;
+ }
+ OpsTestConfig OutputAttribute(const std::string& attr) {
+ OpsTestConfig config = *this;
+ config.output_attribute = attr;
+ return config;
+ }
};
/// Helper functions to get more specific input data.
diff --git a/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h b/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h
index 91c7148..fa996eb 100644
--- a/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h
+++ b/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h
@@ -37,15 +37,15 @@
template <typename T, typename OutT>
void SetOpKernel(const std::string& op_name, const TensorShape& shape,
- const absl::InlinedVector<T, 10>& input, bool add_t,
- bool add_tout) {
+ const absl::InlinedVector<T, 10>& input,
+ const test::OpsTestConfig& config) {
NodeDefBuilder builder("some_name", op_name);
builder.Input(FakeInput(DataTypeToEnum<T>::v()));
- if (add_t) {
- builder.Attr("T", DataTypeToEnum<T>::v());
+ if (config.add_t) {
+ builder.Attr(config.input_attribute, DataTypeToEnum<T>::v());
}
- if (add_tout) {
- builder.Attr("Tout", DataTypeToEnum<OutT>::v());
+ if (config.add_tout) {
+ builder.Attr(config.output_attribute, DataTypeToEnum<OutT>::v());
}
TF_ASSERT_OK(builder.Finalize(node_def()));
@@ -58,7 +58,7 @@
const absl::InlinedVector<T, 10>& input,
const absl::InlinedVector<OutT, 10>& expected_output,
const test::OpsTestConfig& config) {
- SetOpKernel<T, OutT>(op_name, shape, input, config.add_t, config.add_tout);
+ SetOpKernel<T, OutT>(op_name, shape, input, config);
TF_ASSERT_OK(RunOpKernel());
// Assert buffer reuse if expected.
@@ -147,7 +147,7 @@
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
op_name, InT, BaselineT, OutT, BaselineOutT, input_values, \
baseline_callback, config) \
- TEST_F(UnaryOpsTest, op_name##InT) { \
+ TEST_F(UnaryOpsTest, op_name##InT##OutT) { \
using NativeT = EnumToDataType<InT>::Type; \
using NativeBaselineT = EnumToDataType<BaselineT>::Type; \
using NativeOutT = EnumToDataType<OutT>::Type; \
@@ -156,7 +156,7 @@
#op_name, test::DefaultInputShape(), input_values, baseline_callback, \
config); \
} \
- TEST_F(UnaryOpsTest, op_name##InT##EmptyShape) { \
+ TEST_F(UnaryOpsTest, op_name##InT##OutT##EmptyShape) { \
using NativeT = EnumToDataType<InT>::Type; \
using NativeOutT = EnumToDataType<OutT>::Type; \
TestEmptyShape<NativeT, NativeOutT>(#op_name, config); \
diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
index a096525..951f329 100644
--- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl
+++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
@@ -35,6 +35,19 @@
"c128": "complex<f64>",
}
+type_to_tf_dtype = {
+ "i1": "DT_BOOL",
+ "i8": "DT_INT8",
+ "i16": "DT_INT16",
+ "i32": "DT_INT32",
+ "i64": "DT_INT64",
+ "f16": "DT_HALF",
+ "f32": "DT_FLOAT",
+ "f64": "DT_DOUBLE",
+ "c64": "DT_COMPLEX64",
+ "c128": "DT_COMPLEX128",
+}
+
def _get_mlir_type(type):
"""Return the mlir type corresponding to 'type'"""
if type in type_to_mlir:
@@ -54,9 +67,9 @@
"sed 's/output_type/%s/g' > %s")) % (
ctx.file.template.path,
ctx.attr.platform.upper(),
- ctx.attr.type,
+ type_to_tf_dtype[ctx.attr.type],
mlir_type,
- ctx.attr.output_type,
+ type_to_tf_dtype[ctx.attr.output_type],
mlir_output_type,
ctx.outputs.out.path,
)
diff --git a/tensorflow/core/kernels/mlir_generated/cpu_op_abs.cc b/tensorflow/core/kernels/mlir_generated/cpu_op_abs.cc
index efe98dc..69ba49d 100644
--- a/tensorflow/core/kernels/mlir_generated/cpu_op_abs.cc
+++ b/tensorflow/core/kernels/mlir_generated/cpu_op_abs.cc
@@ -18,13 +18,13 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, f32, DT_FLOAT, float);
+GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, DT_DOUBLE);
+GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, DT_FLOAT);
-GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, i8, DT_INT8, int8);
-GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, i32, DT_INT32, int32);
-GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, DT_INT8);
+GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, DT_INT16);
+GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, DT_INT32);
+GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Abs, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/cpu_op_add.cc b/tensorflow/core/kernels/mlir_generated/cpu_op_add.cc
index c1d064e..7df2e59 100644
--- a/tensorflow/core/kernels/mlir_generated/cpu_op_add.cc
+++ b/tensorflow/core/kernels/mlir_generated/cpu_op_add.cc
@@ -17,10 +17,10 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, i32, DT_INT32, int32);
-GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, DT_INT32);
+GENERATE_AND_REGISTER_BINARY_CPU_KERNEL(AddV2, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc
index 8036ba1..9a75ce2 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc
@@ -18,10 +18,10 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, DT_DOUBLE);
// TODO(b/25387198): Add an int32 kernel.
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_acos.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_acos.cc
index defb02d..ef5ec94 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_acos.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_acos.cc
@@ -18,7 +18,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acos, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acos, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acos, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acos, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_acosh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_acosh.cc
index 29cf490..ed990b8 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_acosh.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_acosh.cc
@@ -18,7 +18,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acosh, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acosh, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acosh, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acosh, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc
index b372c1f..c946168 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc
@@ -17,15 +17,15 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, DT_INT64);
// Add is the same as AddV2 except for strings, which we do not support on gpu.
-REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f16, f16, Eigen::half);
-REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f32, f32, float);
-REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f64, f64, double);
-REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, i64, i64, int64);
+REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, DT_HALF, DT_HALF);
+REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, DT_FLOAT, DT_FLOAT);
+REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, DT_DOUBLE, DT_DOUBLE);
+REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, DT_INT64, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc
index cda99e0..55da943 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc
@@ -20,11 +20,9 @@
namespace tensorflow {
-GENERATE_UNARY_GPU_KERNEL2(Angle, c64, f32, DT_FLOAT, float,
- std::complex<float>);
-REGISTER_COMPLEX_GPU_KERNEL(Angle, c64, f32, float, std::complex<float>);
-GENERATE_UNARY_GPU_KERNEL2(Angle, c128, f64, DT_DOUBLE, double,
- std::complex<double>);
-REGISTER_COMPLEX_GPU_KERNEL(Angle, c128, f64, double, std::complex<double>);
+GENERATE_UNARY_GPU_KERNEL2(Angle, DT_COMPLEX64, DT_FLOAT);
+REGISTER_COMPLEX_GPU_KERNEL(Angle, DT_COMPLEX64, DT_FLOAT);
+GENERATE_UNARY_GPU_KERNEL2(Angle, DT_COMPLEX128, DT_DOUBLE);
+REGISTER_COMPLEX_GPU_KERNEL(Angle, DT_COMPLEX128, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc
index 95e1cc0..953c59d 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc
@@ -18,7 +18,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asin, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asin, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asin, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc
index a5512c1..3b9e7bb 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc
@@ -18,7 +18,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asinh, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asinh, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asinh, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asinh, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_atan.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_atan.cc
index a5a39ba..132d75a 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_atan.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_atan.cc
@@ -18,7 +18,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atan, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atan, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atan, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atan, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc
index 4b74f47..22414a8 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc
@@ -17,7 +17,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Atan2, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Atan2, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Atan2, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Atan2, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
index 9fbc79b..e0ea7c9 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
@@ -18,7 +18,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atanh, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atanh, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atanh, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atanh, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc
index 7ba4c16..71e6929 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc
@@ -17,15 +17,15 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i8, DT_INT8, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i32, DT_INT32, int32);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, DT_INT8);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, DT_INT16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, DT_INT32);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, DT_INT64);
// TODO(b/172804967): Enable once fixed.
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui8, DT_UINT8, uint8);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui16, DT_UINT16, uint16);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui32, DT_UINT32, uint32);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui64, DT_UINT64, uint64);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, DT_UINT8);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, DT_UINT16);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, DT_UINT32);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, DT_UINT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc
index e237b5c..791896e 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc
@@ -17,15 +17,15 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i8, DT_INT8, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i32, DT_INT32, int32);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, DT_INT8);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, DT_INT16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, DT_INT32);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, DT_INT64);
// TODO(b/172804967): Enable once fixed.
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui8, DT_UINT8, uint8);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui16, DT_UINT16, uint16);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui32, DT_UINT32, uint32);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui64, DT_UINT64, uint64);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, DT_UINT8);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, DT_UINT16);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, DT_UINT32);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, DT_UINT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc
index d836d48..593b539 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc
@@ -17,15 +17,15 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i8, DT_INT8, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i32, DT_INT32, int32);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, DT_INT8);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, DT_INT16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, DT_INT32);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, DT_INT64);
// TODO(b/172804967): Enable once fixed.
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui8, DT_UINT8, uint8);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui16, DT_UINT16, uint16);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui32, DT_UINT32, uint32);
-// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui64, DT_UINT64, uint64);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, DT_UINT8);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, DT_UINT16);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, DT_UINT32);
+// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, DT_UINT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_cast.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_cast.cc
new file mode 100644
index 0000000..d410224
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_cast.cc
@@ -0,0 +1,52 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
+
+namespace tensorflow {
+
+#define CURRY_TYPES(FN, arg0) \
+ FN(arg0, DT_BOOL); \
+ FN(arg0, DT_INT8); \
+ FN(arg0, DT_INT16); \
+ FN(arg0, DT_INT32); \
+ FN(arg0, DT_INT64); \
+ FN(arg0, DT_HALF); \
+ FN(arg0, DT_FLOAT); \
+ FN(arg0, DT_DOUBLE)
+
+#define GENERATE_AND_REGISTER_CAST_GPU(input_type, output_type) \
+ GENERATE_UNARY_GPU_KERNEL2(Cast, input_type, output_type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cast") \
+ .TypeConstraint<typename EnumToDataType<input_type>::Type>("SrcT") \
+ .TypeConstraint<typename EnumToDataType<output_type>::Type>("DstT") \
+ .Device(DEVICE_GPU), \
+ MLIR_OP(Cast, GPU, input_type, output_type))
+
+CURRY_TYPES(GENERATE_AND_REGISTER_CAST_GPU, DT_BOOL)
+CURRY_TYPES(GENERATE_AND_REGISTER_CAST_GPU, DT_INT8)
+CURRY_TYPES(GENERATE_AND_REGISTER_CAST_GPU, DT_INT16)
+CURRY_TYPES(GENERATE_AND_REGISTER_CAST_GPU, DT_INT32)
+CURRY_TYPES(GENERATE_AND_REGISTER_CAST_GPU, DT_INT64)
+CURRY_TYPES(GENERATE_AND_REGISTER_CAST_GPU, DT_HALF)
+CURRY_TYPES(GENERATE_AND_REGISTER_CAST_GPU, DT_FLOAT)
+CURRY_TYPES(GENERATE_AND_REGISTER_CAST_GPU, DT_DOUBLE)
+
+#undef REGISTER_CAST_GPU
+#undef CURRY_TYPES
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc
index 3e2767e..fd223cb 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc
index af94664..36526a3 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc
@@ -20,11 +20,9 @@
namespace tensorflow {
-GENERATE_BINARY_GPU_KERNEL2(Complex, f32, c64, DT_COMPLEX64,
- std::complex<float>, float);
-REGISTER_COMPLEX_GPU_KERNEL(Complex, f32, c64, std::complex<float>, float);
-GENERATE_BINARY_GPU_KERNEL2(Complex, f64, c128, DT_COMPLEX128,
- std::complex<double>, double);
-REGISTER_COMPLEX_GPU_KERNEL(Complex, f64, c128, std::complex<double>, double);
+GENERATE_BINARY_GPU_KERNEL2(Complex, DT_FLOAT, DT_COMPLEX64);
+REGISTER_COMPLEX_GPU_KERNEL(Complex, DT_FLOAT, DT_COMPLEX64);
+GENERATE_BINARY_GPU_KERNEL2(Complex, DT_DOUBLE, DT_COMPLEX128);
+REGISTER_COMPLEX_GPU_KERNEL(Complex, DT_DOUBLE, DT_COMPLEX128);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_complex_abs.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_complex_abs.cc
index 4c69dfc..69f45d5 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_complex_abs.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_complex_abs.cc
@@ -20,12 +20,9 @@
namespace tensorflow {
-GENERATE_UNARY_GPU_KERNEL2(ComplexAbs, c64, f32, DT_FLOAT, float,
- std::complex<float>);
-REGISTER_COMPLEX_GPU_KERNEL(ComplexAbs, c64, f32, float, std::complex<float>);
-GENERATE_UNARY_GPU_KERNEL2(ComplexAbs, c128, f64, DT_DOUBLE, double,
- std::complex<double>);
-REGISTER_COMPLEX_GPU_KERNEL(ComplexAbs, c128, f64, double,
- std::complex<double>);
+GENERATE_UNARY_GPU_KERNEL2(ComplexAbs, DT_COMPLEX64, DT_FLOAT);
+REGISTER_COMPLEX_GPU_KERNEL(ComplexAbs, DT_COMPLEX64, DT_FLOAT);
+GENERATE_UNARY_GPU_KERNEL2(ComplexAbs, DT_COMPLEX128, DT_DOUBLE);
+REGISTER_COMPLEX_GPU_KERNEL(ComplexAbs, DT_COMPLEX128, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc
index 375a575..755cc49 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc
@@ -20,9 +20,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Conj, c64, DT_COMPLEX64,
- std::complex<float>);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Conj, c128, DT_COMPLEX128,
- std::complex<double>);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Conj, DT_COMPLEX64);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Conj, DT_COMPLEX128);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc
index a9270dd..1a716a7 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc
index 86a8f7e..3f94cc6 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc
@@ -18,7 +18,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cosh, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cosh, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cosh, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cosh, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_digamma.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_digamma.cc
index 3a90e8e..bcba969 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_digamma.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_digamma.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_div.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_div.cc
index c5c03fe..abd38d7 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_div.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_div.cc
@@ -18,17 +18,17 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, DT_INT16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, DT_INT64);
-REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f16, f16, Eigen::half)
-REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f32, f32, float)
-REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f64, f64, double)
+REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, DT_HALF, DT_HALF);
+REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, DT_FLOAT, DT_FLOAT);
+REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, DT_DOUBLE, DT_DOUBLE);
-REGISTER_ALIASED_GPU_KERNEL(TruncateDiv, Div, i16, i16, int16)
-REGISTER_ALIASED_GPU_KERNEL(TruncateDiv, Div, i64, i64, int64)
+REGISTER_ALIASED_GPU_KERNEL(TruncateDiv, Div, DT_INT16, DT_INT16);
+REGISTER_ALIASED_GPU_KERNEL(TruncateDiv, Div, DT_INT64, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc
index f408a04..88f1d50 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc
@@ -20,14 +20,13 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f16, i1, DT_BOOL, bool,
- Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f32, i1, DT_BOOL, bool, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f64, i1, DT_BOOL, bool, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i1, i1, DT_BOOL, bool, bool);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i8, i1, DT_BOOL, bool, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i16, i1, DT_BOOL, bool, int16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, DT_HALF, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, DT_FLOAT, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, DT_DOUBLE, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, DT_BOOL, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, DT_INT8, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, DT_INT16, DT_BOOL);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i64, i1, DT_BOOL, bool, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, DT_INT64, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc
index 61cc5ee..9c2cb44 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_erfc.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_erfc.cc
index ebe6bba..2d5059c 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_erfc.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_erfc.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f16, DT_HALF, Eigen::half);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, DT_DOUBLE);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, DT_HALF);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc
index 24d04e9..ab40a8d 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc
index 7695c20..ce05e0f 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc
index 804b806..4e60987 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc
index 2f8c23d..db0df1f 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc
@@ -17,8 +17,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc
index bbf8278..83f3347 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc
@@ -20,17 +20,12 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f16, i1, DT_BOOL, bool,
- Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f32, i1, DT_BOOL, bool,
- float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f64, i1, DT_BOOL, bool,
- double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i8, i1, DT_BOOL, bool, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i16, i1, DT_BOOL, bool,
- int16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, DT_HALF, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, DT_FLOAT, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, DT_DOUBLE, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, DT_INT8, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, DT_INT16, DT_BOOL);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i64, i1, DT_BOOL, bool,
- int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, DT_INT64, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc
index 801eb43..cd082b5 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc
@@ -20,18 +20,12 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f16, i1, DT_BOOL, bool,
- Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f32, i1, DT_BOOL, bool,
- float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f64, i1, DT_BOOL, bool,
- double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i8, i1, DT_BOOL, bool,
- int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i16, i1, DT_BOOL, bool,
- int16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, DT_HALF, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, DT_FLOAT, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, DT_DOUBLE, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, DT_INT8, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, DT_INT16, DT_BOOL);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i64, i1, DT_BOOL, bool,
- int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, DT_INT64, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc
index f2fc6be..1c0f210 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc
@@ -20,11 +20,9 @@
namespace tensorflow {
-GENERATE_UNARY_GPU_KERNEL2(Imag, c64, f32, DT_FLOAT, float,
- std::complex<float>);
-REGISTER_COMPLEX_GPU_KERNEL(Imag, c64, f32, float, std::complex<float>);
-GENERATE_UNARY_GPU_KERNEL2(Imag, c128, f64, DT_DOUBLE, double,
- std::complex<double>);
-REGISTER_COMPLEX_GPU_KERNEL(Imag, c128, f64, double, std::complex<double>);
+GENERATE_UNARY_GPU_KERNEL2(Imag, DT_COMPLEX64, DT_FLOAT);
+REGISTER_COMPLEX_GPU_KERNEL(Imag, DT_COMPLEX64, DT_FLOAT);
+GENERATE_UNARY_GPU_KERNEL2(Imag, DT_COMPLEX128, DT_DOUBLE);
+REGISTER_COMPLEX_GPU_KERNEL(Imag, DT_COMPLEX128, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_invert.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_invert.cc
index 5b7c37f..3c4f4ac 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_invert.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_invert.cc
@@ -20,9 +20,9 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i8, DT_INT8, int8);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i32, DT_INT32, int32);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, DT_INT8);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, DT_INT16);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, DT_INT32);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_is_finite.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_is_finite.cc
index 1eb46c8..7f4c4e1 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_is_finite.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_is_finite.cc
@@ -20,11 +20,11 @@
namespace tensorflow {
-GENERATE_UNARY_GPU_KERNEL2(IsFinite, f16, i1, DT_BOOL, bool, Eigen::half);
-REGISTER_GPU_KERNEL(IsFinite, f16, i1, Eigen::half);
-GENERATE_UNARY_GPU_KERNEL2(IsFinite, f32, i1, DT_BOOL, bool, float);
-REGISTER_GPU_KERNEL(IsFinite, f32, i1, float);
-GENERATE_UNARY_GPU_KERNEL2(IsFinite, f64, i1, DT_BOOL, bool, double);
-REGISTER_GPU_KERNEL(IsFinite, f64, i1, double);
+GENERATE_UNARY_GPU_KERNEL2(IsFinite, DT_HALF, DT_BOOL);
+REGISTER_GPU_KERNEL(IsFinite, DT_HALF, DT_BOOL);
+GENERATE_UNARY_GPU_KERNEL2(IsFinite, DT_FLOAT, DT_BOOL);
+REGISTER_GPU_KERNEL(IsFinite, DT_FLOAT, DT_BOOL);
+GENERATE_UNARY_GPU_KERNEL2(IsFinite, DT_DOUBLE, DT_BOOL);
+REGISTER_GPU_KERNEL(IsFinite, DT_DOUBLE, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc
index 07286bb..d5ebdf9 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc
@@ -20,11 +20,11 @@
namespace tensorflow {
-GENERATE_UNARY_GPU_KERNEL2(IsInf, f16, i1, DT_BOOL, bool, Eigen::half);
-REGISTER_GPU_KERNEL(IsInf, f16, i1, Eigen::half);
-GENERATE_UNARY_GPU_KERNEL2(IsInf, f32, i1, DT_BOOL, bool, float);
-REGISTER_GPU_KERNEL(IsInf, f32, i1, float);
-GENERATE_UNARY_GPU_KERNEL2(IsInf, f64, i1, DT_BOOL, bool, double);
-REGISTER_GPU_KERNEL(IsInf, f64, i1, double);
+GENERATE_UNARY_GPU_KERNEL2(IsInf, DT_HALF, DT_BOOL);
+REGISTER_GPU_KERNEL(IsInf, DT_HALF, DT_BOOL);
+GENERATE_UNARY_GPU_KERNEL2(IsInf, DT_FLOAT, DT_BOOL);
+REGISTER_GPU_KERNEL(IsInf, DT_FLOAT, DT_BOOL);
+GENERATE_UNARY_GPU_KERNEL2(IsInf, DT_DOUBLE, DT_BOOL);
+REGISTER_GPU_KERNEL(IsInf, DT_DOUBLE, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_is_nan.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_is_nan.cc
index 819158c..8da39e8 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_is_nan.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_is_nan.cc
@@ -20,11 +20,11 @@
namespace tensorflow {
-GENERATE_UNARY_GPU_KERNEL2(IsNan, f16, i1, DT_BOOL, bool, Eigen::half);
-REGISTER_GPU_KERNEL(IsNan, f16, i1, Eigen::half);
-GENERATE_UNARY_GPU_KERNEL2(IsNan, f32, i1, DT_BOOL, bool, float);
-REGISTER_GPU_KERNEL(IsNan, f32, i1, float);
-GENERATE_UNARY_GPU_KERNEL2(IsNan, f64, i1, DT_BOOL, bool, double);
-REGISTER_GPU_KERNEL(IsNan, f64, i1, double);
+GENERATE_UNARY_GPU_KERNEL2(IsNan, DT_HALF, DT_BOOL);
+REGISTER_GPU_KERNEL(IsNan, DT_HALF, DT_BOOL);
+GENERATE_UNARY_GPU_KERNEL2(IsNan, DT_FLOAT, DT_BOOL);
+REGISTER_GPU_KERNEL(IsNan, DT_FLOAT, DT_BOOL);
+GENERATE_UNARY_GPU_KERNEL2(IsNan, DT_DOUBLE, DT_BOOL);
+REGISTER_GPU_KERNEL(IsNan, DT_DOUBLE, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc
index ea8e748..4a0f99a 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc
@@ -17,9 +17,9 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i8, DT_INT8, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i32, DT_INT32, int32);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, DT_INT8);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, DT_INT16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, DT_INT32);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc
index f144afe..c4b72b1 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc
@@ -20,13 +20,12 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f16, i1, DT_BOOL, bool,
- Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f32, i1, DT_BOOL, bool, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f64, i1, DT_BOOL, bool, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i8, i1, DT_BOOL, bool, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i16, i1, DT_BOOL, bool, int16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, DT_HALF, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, DT_FLOAT, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, DT_DOUBLE, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, DT_INT8, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, DT_INT16, DT_BOOL);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i64, i1, DT_BOOL, bool, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, DT_INT64, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc
index cacf2ae..676225f 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc
@@ -20,18 +20,12 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f16, i1, DT_BOOL, bool,
- Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f32, i1, DT_BOOL, bool,
- float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f64, i1, DT_BOOL, bool,
- double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i8, i1, DT_BOOL, bool,
- int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i16, i1, DT_BOOL, bool,
- int16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, DT_HALF, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, DT_FLOAT, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, DT_DOUBLE, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, DT_INT8, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, DT_INT16, DT_BOOL);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i64, i1, DT_BOOL, bool,
- int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, DT_INT64, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_lgamma.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_lgamma.cc
index 9f5e6e9..855a450 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_lgamma.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_lgamma.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc
index b7b7aa4..c99ee78 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_log1p.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_log1p.cc
index 62d5cc83..dfe42aa 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_log1p.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_log1p.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc
index 35b93b9..d113015 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc
@@ -17,9 +17,9 @@
namespace tensorflow {
-GENERATE_BINARY_GPU_KERNEL(LogicalAnd, i1, DT_BOOL, bool);
+GENERATE_BINARY_GPU_KERNEL(LogicalAnd, DT_BOOL);
// LogicalAnd does not have a "T" attribute because it only works with type
// bool. So we need to register it without TypeConstraint<bool>("T").
-REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalAnd, i1, i1);
+REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalAnd, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc
index fcb3ad3..fdce45d 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc
@@ -17,9 +17,9 @@
namespace tensorflow {
-GENERATE_UNARY_GPU_KERNEL(LogicalNot, i1, DT_BOOL, bool);
+GENERATE_UNARY_GPU_KERNEL(LogicalNot, DT_BOOL);
// LogicalNot does not have a "T" attribute because it only works with type
// bool. So we need to register it without TypeConstraint<bool>("T").
-REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalNot, i1, i1);
+REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalNot, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc
index e0eb451..04f5c7a 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc
@@ -17,9 +17,9 @@
namespace tensorflow {
-GENERATE_BINARY_GPU_KERNEL(LogicalOr, i1, DT_BOOL, bool);
+GENERATE_BINARY_GPU_KERNEL(LogicalOr, DT_BOOL);
// LogicalOr does not have a "T" attribute because it only works with type
// bool. So we need to register it without TypeConstraint<bool>("T").
-REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalOr, i1, i1);
+REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalOr, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_maximum.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_maximum.cc
index d4b60d1..3c45821 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_maximum.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_maximum.cc
@@ -18,11 +18,11 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, DT_INT16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_minimum.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_minimum.cc
index 4d243cf..4b57487 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_minimum.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_minimum.cc
@@ -18,11 +18,11 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, DT_INT16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc
index 8230f42..1cb25ed 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc
@@ -18,12 +18,12 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i8, DT_INT8, int8);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, DT_INT8);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, DT_INT16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc
index 3c50067..86cea9d 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc
@@ -18,12 +18,12 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i8, DT_INT8, int8);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i16, DT_INT16, int16);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, DT_DOUBLE);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, DT_INT8);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, DT_INT16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc
index e500a52..7175162 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc
@@ -20,18 +20,13 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f16, i1, DT_BOOL, bool,
- Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f32, i1, DT_BOOL, bool,
- float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f64, i1, DT_BOOL, bool,
- double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i1, i1, DT_BOOL, bool, bool);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i8, i1, DT_BOOL, bool, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i16, i1, DT_BOOL, bool,
- int16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, DT_HALF, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, DT_FLOAT, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, DT_DOUBLE, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, DT_BOOL, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, DT_INT8, DT_BOOL);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, DT_INT16, DT_BOOL);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i64, i1, DT_BOOL, bool,
- int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, DT_INT64, DT_BOOL);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc
index 401bcb3..456f81f 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc
@@ -17,9 +17,9 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc
index e892e9e..e46f73f 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc
@@ -20,11 +20,9 @@
namespace tensorflow {
-GENERATE_UNARY_GPU_KERNEL2(Real, c64, f32, DT_FLOAT, float,
- std::complex<float>);
-REGISTER_COMPLEX_GPU_KERNEL(Real, c64, f32, float, std::complex<float>);
-GENERATE_UNARY_GPU_KERNEL2(Real, c128, f64, DT_DOUBLE, double,
- std::complex<double>);
-REGISTER_COMPLEX_GPU_KERNEL(Real, c128, f64, double, std::complex<double>);
+GENERATE_UNARY_GPU_KERNEL2(Real, DT_COMPLEX64, DT_FLOAT);
+REGISTER_COMPLEX_GPU_KERNEL(Real, DT_COMPLEX64, DT_FLOAT);
+GENERATE_UNARY_GPU_KERNEL2(Real, DT_COMPLEX128, DT_DOUBLE);
+REGISTER_COMPLEX_GPU_KERNEL(Real, DT_COMPLEX128, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc
index 757b8e7..659ed57 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc
@@ -17,9 +17,9 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i8, DT_INT8, int8);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i16, DT_INT16, int16);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i32, DT_INT32, int32);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, DT_INT8);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, DT_INT16);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, DT_INT32);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc
index cbc1b23..078a8f2 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc
index 609e107..d224214 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc
@@ -18,11 +18,11 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, DT_DOUBLE);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, DT_INT64);
// TODO(b/162577610): Register the kernel for complex types and bfloat.
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc
index 420e00e..b80c076 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sinh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sinh.cc
index 6040b9b..55c074c 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_sinh.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sinh.cc
@@ -18,7 +18,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sinh, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sinh, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sinh, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sinh, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc
index c83753d..506729f 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sqrt, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sqrt, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sqrt, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sqrt, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sqrt, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sqrt, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_square.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_square.cc
index 1373129..39938fc 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_square.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_square.cc
@@ -18,9 +18,9 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Square, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Square, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Square, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Square, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Square, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Square, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Square, DT_DOUBLE);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Square, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_squared_difference.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_squared_difference.cc
index 65af65b..cbb1a93 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_squared_difference.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_squared_difference.cc
@@ -17,13 +17,9 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(SquaredDifference, f16, DT_HALF,
- Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(SquaredDifference, f32, DT_FLOAT,
- float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(SquaredDifference, f64, DT_DOUBLE,
- double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(SquaredDifference, i64, DT_INT64,
- int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(SquaredDifference, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(SquaredDifference, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(SquaredDifference, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(SquaredDifference, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
index 702dd4c..098b129 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
@@ -18,9 +18,9 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, f64, DT_DOUBLE, double);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, i64, DT_INT64, int64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_HALF);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_DOUBLE);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_INT64);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_tan.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_tan.cc
index 354016a..6643745 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_tan.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_tan.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tan, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tan, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tan, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tan, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tan, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tan, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc
index b062273..0839242 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tanh, f16, DT_HALF, Eigen::half);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tanh, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tanh, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tanh, DT_HALF);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tanh, DT_FLOAT);
+GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Tanh, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_zeta.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_zeta.cc
index e17a14c..104d50b 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_zeta.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_zeta.cc
@@ -17,7 +17,7 @@
namespace tensorflow {
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Zeta, f32, DT_FLOAT, float);
-GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Zeta, f64, DT_DOUBLE, double);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Zeta, DT_FLOAT);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Zeta, DT_DOUBLE);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
index 2496dcc..f031be0 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
@@ -124,6 +124,43 @@
Atanh, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
std::atanh, test::OpsTestConfig())
+/// Test `tf.Cast`.
+
+template <typename SrcT, typename DstT>
+DstT baseline_cast(SrcT x) {
+ return static_cast<DstT>(x);
+}
+
+#define TEST_CAST_FROM_TO(from_type, to_type) \
+ GENERATE_DEFAULT_TEST(Cast, from_type, to_type, baseline_cast, \
+ test::OpsTestConfig() \
+ .AddTout() \
+ .NoBufferReuse() \
+ .ExpectStrictlyEqual() \
+ .InputAttribute("SrcT") \
+ .OutputAttribute("DstT"))
+
+#define TEST_CAST_TO(from_type) \
+ TEST_CAST_FROM_TO(from_type, DT_BOOL) \
+ TEST_CAST_FROM_TO(from_type, DT_INT8) \
+ TEST_CAST_FROM_TO(from_type, DT_INT16) \
+ TEST_CAST_FROM_TO(from_type, DT_INT32) \
+ TEST_CAST_FROM_TO(from_type, DT_INT64) \
+ TEST_CAST_FROM_TO(from_type, DT_FLOAT) \
+ TEST_CAST_FROM_TO(from_type, DT_DOUBLE)
+
+TEST_CAST_TO(DT_BOOL)
+TEST_CAST_TO(DT_INT8)
+TEST_CAST_TO(DT_INT16)
+TEST_CAST_TO(DT_INT32)
+TEST_CAST_TO(DT_INT64)
+TEST_CAST_TO(DT_HALF)
+TEST_CAST_TO(DT_FLOAT)
+TEST_CAST_TO(DT_DOUBLE)
+
+#undef TEST_CAST_FROM_TO
+#undef TEST_CAST_TO
+
/// Test `tf.Ceil`.
GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index 913d5f7..f83252c 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -32,7 +32,8 @@
// activations: same shape as "features".
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor activations) {
- activations.device(d) = features.cwiseMax(static_cast<T>(0));
+ activations.device(d) =
+ features.template cwiseMax<Eigen::PropagateNaN>(static_cast<T>(0));
}
};
@@ -66,7 +67,8 @@
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor activations) {
activations.device(d) =
- features.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(6));
+ features.template cwiseMax<Eigen::PropagateNaN>(static_cast<T>(0))
+ .template cwiseMin<Eigen::PropagateNaN>(static_cast<T>(6));
}
};
diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc
index 1fce80f..72fd0eb 100644
--- a/tensorflow/core/kernels/reshape_util.cc
+++ b/tensorflow/core/kernels/reshape_util.cc
@@ -31,6 +31,53 @@
namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+namespace functor {
+
+template <>
+struct ReshapeSparseTensorFunctor<CPUDevice> {
+ Status operator()(const TensorShape &input_shape,
+ const TensorShape &output_shape,
+ typename TTypes<int64>::ConstMatrix input_indices,
+ typename TTypes<int64>::Matrix output_indices) const {
+ const int64 input_rank = input_shape.dims();
+ const int64 output_rank = output_shape.dims();
+ const int64 nnz = input_indices.dimension(0);
+ gtl::InlinedVector<int64, 8> input_strides(input_rank);
+ if (input_rank > 0) {
+ input_strides[input_rank - 1] = 1;
+ for (int d = input_rank - 2; d >= 0; --d) {
+ input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
+ }
+ }
+
+ gtl::InlinedVector<int64, 8> output_strides(output_rank);
+ if (output_rank > 0) {
+ output_strides[output_rank - 1] = 1;
+ for (int d = output_rank - 2; d >= 0; --d) {
+ output_strides[d] =
+ output_strides[d + 1] * output_shape.dim_size(d + 1);
+ }
+ }
+
+ for (int i = 0; i < nnz; ++i) {
+ int64 id = 0;
+ for (int j = 0; j < input_rank; ++j) {
+ id += input_indices(i, j) * input_strides[j];
+ }
+ for (int j = 0; j < output_rank; ++j) {
+ output_indices(i, j) = id / output_strides[j];
+ id %= output_strides[j];
+ }
+ }
+ return Status::OK();
+ }
+};
+
+} // namespace functor
+
+template <typename Device>
void ReshapeSparseTensor(OpKernelContext *context,
const Tensor &input_indices_in,
const Tensor &input_shape_in,
@@ -49,7 +96,6 @@
"Target shape should be a vector but received shape ",
target_shape_in.shape().DebugString()));
- const int64 input_rank = input_shape_in.NumElements();
const int64 output_rank = target_shape_in.NumElements();
const TensorShape input_shape(input_shape_in.vec<int64>());
const int64 dense_size = input_shape.num_elements();
@@ -111,40 +157,6 @@
return;
}
- gtl::InlinedVector<int64, 8> input_strides(input_rank);
- if (input_rank > 0) {
- input_strides[input_rank - 1] = 1;
- for (int d = input_rank - 2; d >= 0; --d) {
- input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
- }
- }
-
- gtl::InlinedVector<int64, 8> output_strides(output_rank);
- if (output_rank > 0) {
- output_strides[output_rank - 1] = 1;
- for (int d = output_rank - 2; d >= 0; --d) {
- output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
- }
- }
-
- Tensor *result_indices = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(output_indices_idx,
- TensorShape({nnz, output_rank}),
- &result_indices));
- auto input_ind = input_indices_in.matrix<int64>();
- auto output_ind = result_indices->matrix<int64>();
- for (int i = 0; i < nnz; ++i) {
- int64 id = 0;
- for (int j = 0; j < input_rank; ++j) {
- id += input_ind(i, j) * input_strides[j];
- }
- for (int j = 0; j < output_rank; ++j) {
- output_ind(i, j) = id / output_strides[j];
- id %= output_strides[j];
- }
- }
-
Tensor *result_shape = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
TensorShape({output_rank}),
@@ -153,6 +165,26 @@
for (int j = 0; j < output_shape.dims(); ++j) {
output_shape_vec(j) = output_shape.dim_size(j);
}
+
+ Tensor *result_indices = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(output_indices_idx,
+ TensorShape({nnz, output_rank}),
+ &result_indices));
+ if (nnz > 0) {
+ OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
+ input_shape, output_shape,
+ input_indices_in.matrix<int64>(),
+ result_indices->matrix<int64>()));
+ }
}
+#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \
+ template void ReshapeSparseTensor<Device>( \
+ OpKernelContext * context, const Tensor &input_indices_in, \
+ const Tensor &input_shape_in, const Tensor &target_shape_in, \
+ int output_indices_idx, int output_shape_idx)
+EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
+#undef EXPLICITLY_INSTANTIATE_FUNCTION
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h
index 7e1809e..b3a3565 100644
--- a/tensorflow/core/kernels/reshape_util.h
+++ b/tensorflow/core/kernels/reshape_util.h
@@ -16,18 +16,36 @@
#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/status.h"
+
namespace tensorflow {
class OpKernelContext;
class Tensor;
// Reshapes the input indices and input shape to the target shape.
+// Note: This template is explicitly instantiated for CPU device only.
+template <typename Device>
void ReshapeSparseTensor(OpKernelContext *context,
const Tensor &input_indices_in,
const Tensor &input_shape_in,
const Tensor &target_shape_in, int output_indices_idx,
int output_shape_idx);
+namespace functor {
+
+template <typename Device>
+struct ReshapeSparseTensorFunctor {
+ Status operator()(const TensorShape &input_shape,
+ const TensorShape &output_shape,
+ typename TTypes<int64>::ConstMatrix input_indices,
+ typename TTypes<int64>::Matrix output_indices) const;
+};
+
+} // namespace functor
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 90463d2..37d22ea 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -221,7 +221,6 @@
OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
- PartialTensorShape shape;
OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
diff --git a/tensorflow/core/kernels/spacetobatch_benchmark_test.cc b/tensorflow/core/kernels/spacetobatch_benchmark_test.cc
index 92ddf8e..a321f47 100644
--- a/tensorflow/core/kernels/spacetobatch_benchmark_test.cc
+++ b/tensorflow/core/kernels/spacetobatch_benchmark_test.cc
@@ -56,20 +56,25 @@
// The BM_Expand macro is needed for this to build with VC++.
#define BM_Expand(x) x
+// Macro is already longer than 80 chars.
+// NOLINTBEGIN
#define BM_SpaceToBatchDev(OP, DEVICE, DTYPE, B, H, W, D, BS, P00, P01, P10, \
P11) \
static void \
BM_##OP##_##DEVICE##_##DTYPE##_##B##_##H##_##W##_##D##_bs##BS##_pad##P00##_##P01##_##P10##_##P11( \
- int iters) { \
- testing::ItemsProcessed(static_cast<int64>(iters) * B * (H + P00 + P01) * \
+ ::testing::benchmark::State& state) { \
+ test::Benchmark( \
+ #DEVICE, \
+ ConstructSpaceToBatchGraph(#OP, TensorShape({B, H, W, D}), BS, DTYPE, \
+ {{P00, P01}, {P10, P11}}), \
+ /*old_benchmark_api*/ false) \
+ .Run(state); \
+ state.SetItemsProcessed(state.iterations() * B * (H + P00 + P01) * \
(W + P10 + P11) * D); \
- test::Benchmark(#DEVICE, ConstructSpaceToBatchGraph( \
- #OP, TensorShape({B, H, W, D}), BS, DTYPE, \
- {{P00, P01}, {P10, P11}})) \
- .Run(iters); \
} \
BENCHMARK( \
BM_##OP##_##DEVICE##_##DTYPE##_##B##_##H##_##W##_##D##_bs##BS##_pad##P00##_##P01##_##P10##_##P11);
+// NOLINTEND
#define BM_SpaceToBatch(OP, ...) \
BM_Expand(BM_SpaceToBatchDev(OP, cpu, DT_FLOAT, __VA_ARGS__)); \
BM_Expand(BM_SpaceToBatchDev(OP, gpu, DT_FLOAT, __VA_ARGS__)); \
diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc
index 1dc51cd..a0f07d4 100644
--- a/tensorflow/core/kernels/sparse_matmul_op_test.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc
@@ -107,36 +107,30 @@
#define BM_SPARSE(M, K, N, S1, S2, TRA, TRB, TA, TB) \
static void \
BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB( \
- int iters) { \
- testing::StopTiming(); \
- testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
+ ::testing::benchmark::State& state) { \
auto label = strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", \
TRA, TRB, S1 / 100.0, S2 / 100.0); \
- testing::SetLabel(label); \
- testing::UseRealTime(); \
+ state.SetLabel(label); \
auto g = SparseMatMul<TA, TB>(M, N, K, S1 / 100.0, S2 / 100.0, TRA, TRB); \
- testing::StartTiming(); \
- test::Benchmark("cpu", g).Run(iters); \
+ test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); \
} \
BENCHMARK( \
- BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB);
+ BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB) \
+ ->UseRealTime();
#define BM_SPARSE_REPLICATED(M, K, N, S1, S2, Copies) \
static void BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies( \
- int iters) { \
- testing::StopTiming(); \
- testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * Copies * \
- 2); \
+ ::testing::benchmark::State& state) { \
auto label = strings::Printf("copies: %d sp_a: %0.2f sp_b: %0.2f", \
(Copies), S1 / 100.0, S2 / 100.0); \
- testing::SetLabel(label); \
- testing::UseRealTime(); \
+ state.SetLabel(label); \
auto g = \
ReplicatedSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, (Copies)); \
- testing::StartTiming(); \
- test::Benchmark("cpu", g).Run(iters); \
+ test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); \
+ state.SetItemsProcessed(state.iterations() * M * K * N * Copies * 2); \
} \
- BENCHMARK(BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies);
+ BENCHMARK(BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies) \
+ ->UseRealTime();
#define BM_SPARSE_FLOAT(M, K, N, S1, S2, TRA, TRB) \
BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, float)
@@ -219,22 +213,21 @@
return g;
}
-#define BM_SPARSE_MULTI(M, K, N, S1, S2, Copies) \
- static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies( \
- int iters) { \
- testing::StopTiming(); \
- testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2 * 2 * \
- Copies); \
- auto label = strings::Printf("%d_%d_%d_%d_%0.2f_%0.2f", M, K, N, Copies, \
- S1 / 100.0, S2 / 100.0); \
- testing::SetLabel(label); \
- testing::UseRealTime(); \
- auto g = MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, Copies); \
- testing::StartTiming(); \
- test::Benchmark("cpu", g).Run(iters); \
- } \
- BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies);
-
+// clang-format off
+// NOLINTBEGIN
+#define BM_SPARSE_MULTI(M, K, N, S1, S2, Copies) \
+ static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies(::testing::benchmark::State& state) { \
+ auto label = strings::Printf("%d_%d_%d_%d_%0.2f_%0.2f", M, K, N, Copies, \
+ S1 / 100.0, S2 / 100.0); \
+ state.SetLabel(label); \
+ auto g = MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, Copies); \
+ test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); \
+ state.SetItemsProcessed(state.iterations() * M * K * N * 2 * 2 * Copies); \
+ } \
+ BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies) \
+ ->UseRealTime();
+// NOLINTEND
+// clang-format on
BM_SPARSE_MULTI(1024, 2140, 4096, 0, 82, 1);
BM_SPARSE_MULTI(1024, 4096, 2048, 83, 83, 1);
BM_SPARSE_MULTI(400, 800, 2560, 85, 85, 1);
diff --git a/tensorflow/core/kernels/sparse_reshape_op.cc b/tensorflow/core/kernels/sparse_reshape_op.cc
index 6eb5f0a..472a7a2 100644
--- a/tensorflow/core/kernels/sparse_reshape_op.cc
+++ b/tensorflow/core/kernels/sparse_reshape_op.cc
@@ -29,17 +29,21 @@
namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+template <typename Device>
class SparseReshapeOp : public OpKernel {
public:
explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
- ReshapeSparseTensor(context, context->input(0), context->input(1),
- context->input(2), 0 /* output indices index */,
- 1 /* output shape index */);
+ ReshapeSparseTensor<Device>(context, context->input(0), context->input(1),
+ context->input(2), 0 /* output indices index */,
+ 1 /* output shape index */);
}
};
REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU),
- SparseReshapeOp)
+ SparseReshapeOp<CPUDevice>)
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc
index 249ddbe..b06f72d4 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc
@@ -68,19 +68,22 @@
return g;
}
+// NOLINTBEGIN
#define BM_SparseTensorDenseMatmulDev(NNZ, M, K, N, TA, TB, DEVICE) \
static void \
BM_SparseTensorDenseMatmul##_##NNZ##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE( \
- int iters) { \
+ ::testing::benchmark::State& state) { \
int64 items_per_iter = (static_cast<int64>(NNZ) * (TB ? K : N)); \
- testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \
- testing::BytesProcessed(static_cast<int64>(iters) * items_per_iter * \
+ test::Benchmark(#DEVICE, SparseTensorDenseMatmul(NNZ, M, K, N, TA, TB), \
+ /*old_benchmark_api*/ false) \
+ .Run(state); \
+ state.SetItemsProcessed(state.iterations() * items_per_iter); \
+ state.SetBytesProcessed(state.iterations() * items_per_iter * \
sizeof(float)); \
- test::Benchmark(#DEVICE, SparseTensorDenseMatmul(NNZ, M, K, N, TA, TB)) \
- .Run(iters); \
} \
BENCHMARK( \
BM_SparseTensorDenseMatmul##_##NNZ##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE);
+// NOLINTEND
#define BM_SparseTensorDenseMatmul(NNZ, M, K, N, TA, TB) \
BM_SparseTensorDenseMatmulDev(NNZ, M, K, N, TA, TB, cpu); \
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index ab83efd..8ca14c4 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -151,15 +151,6 @@
auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
- // Allocate temporary buffer for broadcasted input tensor
- Tensor input_buffer;
- OP_REQUIRES_OK(context, context->allocate_temp(
- DT_STRING, output_shape, &input_buffer));
- TTypes<tstring, 1>::Tensor input_bcast =
- input_buffer.shaped<tstring, 1>(bcast.result_shape());
- input_bcast =
- input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast()));
-
// Allocate temporary buffer for broadcasted position tensor
Tensor pos_buffer;
OP_REQUIRES_OK(context,
@@ -182,7 +173,7 @@
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
- StringPiece in(input_bcast(i));
+ StringPiece in(input(input.dimension(0) > 1 ? i : 0));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
T byte_pos = pos;
@@ -197,8 +188,7 @@
case CharUnit::BYTE:
byte_pos = AdjustedPosIndex(byte_pos, in);
OP_REQUIRES(
- context,
- FastBoundsCheck(byte_pos, input_bcast(i).size() + 1),
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for ",
"string b'", in, "' at index ", i));
}
@@ -214,15 +204,6 @@
auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
- // Allocate temporary buffer for broadcasted input tensor
- Tensor input_buffer;
- OP_REQUIRES_OK(context, context->allocate_temp(
- DT_STRING, output_shape, &input_buffer));
- TTypes<tstring, 2>::Tensor input_bcast =
- input_buffer.shaped<tstring, 2>(bcast.result_shape());
- input_bcast =
- input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast()));
-
// Allocate temporary buffer for broadcasted position tensor
Tensor pos_buffer;
OP_REQUIRES_OK(context,
@@ -246,7 +227,8 @@
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
for (int j = 0; j < output_shape.dim_size(1); ++j) {
- StringPiece in(input_bcast(i, j));
+ StringPiece in(input(input.dimension(0) > 1 ? i : 0,
+ input.dimension(1) > 1 ? j : 0));
const T pos =
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index ad673d3..8b2ded5 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -965,7 +965,7 @@
.Input("dims: bool")
.Output("output: T")
.Attr(
- "T: {uint8, int8, uint16, int16, int32, int64, bool, half, "
+ "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
"float, double, complex64, complex128, string}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
diff --git a/tensorflow/core/ops/compat/ops_history_v2/InitializeTableFromTextFile.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/InitializeTableFromTextFile.pbtxt
index c4de3da..77be4ca 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/InitializeTableFromTextFile.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/InitializeTableFromTextFile.pbtxt
@@ -38,3 +38,50 @@
}
}
}
+op {
+ name: "InitializeTableFromTextFile"
+ input_arg {
+ name: "table_handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ input_arg {
+ name: "filename"
+ type: DT_STRING
+ }
+ attr {
+ name: "key_index"
+ type: "int"
+ has_minimum: true
+ minimum: -2
+ }
+ attr {
+ name: "value_index"
+ type: "int"
+ has_minimum: true
+ minimum: -2
+ }
+ attr {
+ name: "vocab_size"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ has_minimum: true
+ minimum: -1
+ }
+ attr {
+ name: "delimiter"
+ type: "string"
+ default_value {
+ s: "\t"
+ }
+ }
+ attr {
+ name: "offset"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/InitializeTableFromTextFileV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/InitializeTableFromTextFileV2.pbtxt
index 0096e94..6593434 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/InitializeTableFromTextFileV2.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/InitializeTableFromTextFileV2.pbtxt
@@ -38,3 +38,50 @@
}
is_stateful: true
}
+op {
+ name: "InitializeTableFromTextFileV2"
+ input_arg {
+ name: "table_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "filename"
+ type: DT_STRING
+ }
+ attr {
+ name: "key_index"
+ type: "int"
+ has_minimum: true
+ minimum: -2
+ }
+ attr {
+ name: "value_index"
+ type: "int"
+ has_minimum: true
+ minimum: -2
+ }
+ attr {
+ name: "vocab_size"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ has_minimum: true
+ minimum: -1
+ }
+ attr {
+ name: "delimiter"
+ type: "string"
+ default_value {
+ s: "\t"
+ }
+ }
+ attr {
+ name: "offset"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/Reverse.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/Reverse.pbtxt
index 99b3f2e..7a267b4 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/Reverse.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/Reverse.pbtxt
@@ -101,3 +101,40 @@
}
}
}
+op {
+ name: "Reverse"
+ input_arg {
+ name: "tensor"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dims"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_BOOL
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_STRING
+ }
+ }
+ }
+}
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index 05aa229..f9aea52 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -480,6 +480,7 @@
.Attr("value_index: int >= -2")
.Attr("vocab_size: int >= -1 = -1")
.Attr("delimiter: string = '\t'")
+ .Attr("offset: int = 0")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
@@ -497,6 +498,7 @@
.Attr("value_index: int >= -2")
.Attr("vocab_size: int >= -1 = -1")
.Attr("delimiter: string = '\t'")
+ .Attr("offset: int = 0")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index e3ad43c..d892e75 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -19797,6 +19797,13 @@
s: "\t"
}
}
+ attr {
+ name: "offset"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
}
op {
name: "InitializeTableFromTextFileV2"
@@ -19836,6 +19843,13 @@
s: "\t"
}
}
+ attr {
+ name: "offset"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
is_stateful: true
}
op {
@@ -41595,6 +41609,7 @@
type: DT_INT32
type: DT_INT64
type: DT_BOOL
+ type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
diff --git a/tensorflow/core/platform/default/port.cc b/tensorflow/core/platform/default/port.cc
index 6e82c67..8c8c864 100644
--- a/tensorflow/core/platform/default/port.cc
+++ b/tensorflow/core/platform/default/port.cc
@@ -357,7 +357,7 @@
}
MemoryInfo GetMemoryInfo() {
- MemoryInfo mem_info = {INT64_MAX, INT64_MAX, INT64_MAX};
+ MemoryInfo mem_info = {INT64_MAX, INT64_MAX};
#if defined(__linux__) && !defined(__ANDROID__)
struct sysinfo info;
int err = sysinfo(&info);
@@ -369,5 +369,10 @@
return mem_info;
}
+MemoryBandwidthInfo GetMemoryBandwidthInfo() {
+ MemoryBandwidthInfo membw_info = {INT64_MAX};
+ return membw_info;
+}
+
} // namespace port
} // namespace tensorflow
diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h
index e01d495..36954c7 100644
--- a/tensorflow/core/platform/mem.h
+++ b/tensorflow/core/platform/mem.h
@@ -62,6 +62,9 @@
struct MemoryInfo {
int64 total = 0;
int64 free = 0;
+};
+
+struct MemoryBandwidthInfo {
int64 bw_used = 0; // memory bandwidth used across all CPU (in MBs/second)
};
@@ -70,6 +73,10 @@
// available.
MemoryInfo GetMemoryInfo();
+// Retrieves the host memory bandwidth information. If any field in the returned
+// structure is INT64_MAX, it means such information is not available.
+MemoryBandwidthInfo GetMemoryBandwidthInfo();
+
// Returns the amount of RAM available in bytes, or INT64_MAX if unknown.
static inline int64 AvailableRam() { return GetMemoryInfo().free; }
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 8d74ea6..51ff982 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -1195,6 +1195,6 @@
return Status::OK();
}
-REGISTER_FILE_SYSTEM("s3", RetryingS3FileSystem);
+REGISTER_LEGACY_FILE_SYSTEM("s3", RetryingS3FileSystem);
} // namespace tensorflow
diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc
index 256f525..250e271 100644
--- a/tensorflow/core/platform/windows/port.cc
+++ b/tensorflow/core/platform/windows/port.cc
@@ -192,7 +192,7 @@
}
MemoryInfo GetMemoryInfo() {
- MemoryInfo mem_info = {INT64_MAX, INT64_MAX, INT64_MAX};
+ MemoryInfo mem_info = {INT64_MAX, INT64_MAX};
MEMORYSTATUSEX statex;
statex.dwLength = sizeof(statex);
if (GlobalMemoryStatusEx(&statex)) {
@@ -202,6 +202,11 @@
return mem_info;
}
+MemoryBandwidthInfo GetMemoryBandwidthInfo() {
+ MemoryBandwidthInfo membw_info = {INT64_MAX};
+ return membw_info;
+}
+
int NumHyperthreadsPerCore() {
static const int ht_per_core = tensorflow::port::CPUIDNumSMT();
return (ht_per_core > 0) ? ht_per_core : 1;
diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
index 1663d63..fdfefe2 100644
--- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
@@ -51,29 +51,6 @@
// MemoryActivityMetadata proto it contains.
using IndexMetaPair = std::pair<int64 /*index*/, const MemoryActivityMetadata*>;
-// Aggregated memory stats from an allocator. Temporary container to fill
-// MemoryAggregationStats.
-struct AggregationStats {
- int64 bytes_reserved = 0;
- int64 bytes_allocated = 0;
- int64 bytes_available = 0;
- double fragmentation = 0;
- int64 peak_bytes_in_use = 0;
-};
-
-// Metadata associated with each memory allocation/deallocation activity.
-// Temporary container to fill MemoryActivityMetadata.
-struct ActivityMetadata {
- int64 requested_bytes = 0;
- int64 allocation_bytes = 0;
- uint64 address = 0;
- absl::string_view tf_op_name;
- int64 step_id = kInvalidStepId;
- absl::string_view region_type;
- int64 data_type = 0;
- absl::string_view tensor_shape;
-};
-
bool IsMemoryAllocation(int64 event_type) {
return event_type == HostEventType::kMemoryAllocation;
}
@@ -82,51 +59,22 @@
return event_type == HostEventType::kMemoryDeallocation;
}
-void FillAggregationStats(const AggregationStats& src,
- MemoryAggregationStats* dst) {
- dst->set_stack_reserved_bytes(src.bytes_reserved);
- dst->set_heap_allocated_bytes(src.bytes_allocated);
- dst->set_free_memory_bytes(src.bytes_available);
- dst->set_fragmentation(src.fragmentation);
- dst->set_peak_bytes_in_use(src.peak_bytes_in_use);
-}
-
-void FillActivityMetadata(int64 event_type, const ActivityMetadata& src,
- MemoryActivityMetadata* dst) {
- if (IsMemoryAllocation(event_type)) {
- dst->set_memory_activity(ALLOCATION);
- } else if (IsMemoryDeallocation(event_type)) {
- dst->set_memory_activity(DEALLOCATION);
- }
- dst->set_requested_bytes(src.requested_bytes);
- dst->set_allocation_bytes(src.allocation_bytes);
- dst->set_address(src.address);
- dst->set_tf_op_name(std::string(src.tf_op_name));
- dst->set_step_id(src.step_id);
- dst->set_region_type(std::string(src.region_type));
- dst->set_data_type(tensorflow::DataTypeString(
- static_cast<tensorflow::DataType>(src.data_type)));
- dst->set_tensor_shape(std::string(src.tensor_shape));
-}
-
-void UpdateProfileSummary(const AggregationStats& stats, int64 time_offset_ps,
- MemoryProfileSummary* summary) {
+void UpdateProfileSummary(const MemoryAggregationStats& stats,
+ int64 time_offset_ps, MemoryProfileSummary* summary) {
// Update the peak memory usage over allocator's lifetime.
- summary->set_peak_bytes_usage_lifetime(stats.peak_bytes_in_use);
+ summary->set_peak_bytes_usage_lifetime(stats.peak_bytes_in_use());
MemoryAggregationStats* peak_stats = summary->mutable_peak_stats();
// If we reach (or stay at) peak memory usage within the profiling window,
// update memory profile summary.
- if (stats.bytes_reserved + stats.bytes_allocated >=
+ if (stats.stack_reserved_bytes() + stats.heap_allocated_bytes() >=
peak_stats->peak_bytes_in_use()) {
- peak_stats->set_peak_bytes_in_use(stats.bytes_reserved +
- stats.bytes_allocated);
- peak_stats->set_stack_reserved_bytes(stats.bytes_reserved);
- peak_stats->set_heap_allocated_bytes(stats.bytes_allocated);
- peak_stats->set_free_memory_bytes(stats.bytes_available);
- peak_stats->set_fragmentation(stats.fragmentation);
+ *peak_stats = stats;
+ peak_stats->set_peak_bytes_in_use(stats.stack_reserved_bytes() +
+ stats.heap_allocated_bytes());
summary->set_peak_stats_time_ps(time_offset_ps);
- summary->set_memory_capacity(stats.bytes_reserved + stats.bytes_allocated +
- stats.bytes_available);
+ summary->set_memory_capacity(stats.stack_reserved_bytes() +
+ stats.heap_allocated_bytes() +
+ stats.free_memory_bytes());
}
}
@@ -145,8 +93,15 @@
return;
}
- AggregationStats stats;
- ActivityMetadata metadata;
+ MemoryAggregationStats stats;
+ MemoryActivityMetadata metadata;
+ if (IsMemoryAllocation(event_type)) {
+ metadata.set_memory_activity(ALLOCATION);
+ } else if (IsMemoryDeallocation(event_type)) {
+ metadata.set_memory_activity(DEALLOCATION);
+ }
+ metadata.set_step_id(kInvalidStepId);
+
std::string memory_id;
event.ForEachStat([&](const XStatVisitor& stat) {
if (!stat.Type().has_value()) return;
@@ -159,59 +114,59 @@
memory_id = std::string(stat.StrOrRefValue());
break;
case StatType::kBytesReserved:
- stats.bytes_reserved = stat.IntValue();
+ stats.set_stack_reserved_bytes(stat.IntValue());
break;
case StatType::kBytesAllocated:
- stats.bytes_allocated = stat.IntValue();
+ stats.set_heap_allocated_bytes(stat.IntValue());
break;
case StatType::kBytesAvailable:
- stats.bytes_available = stat.IntValue();
+ stats.set_free_memory_bytes(stat.IntValue());
break;
case StatType::kFragmentation:
- stats.fragmentation = stat.DoubleValue();
+ stats.set_fragmentation(stat.DoubleValue());
break;
case StatType::kPeakBytesInUse:
- stats.peak_bytes_in_use = stat.IntValue();
+ stats.set_peak_bytes_in_use(stat.IntValue());
break;
case StatType::kRequestedBytes:
- metadata.requested_bytes = stat.IntValue();
+ metadata.set_requested_bytes(stat.IntValue());
break;
case StatType::kAllocationBytes:
- metadata.allocation_bytes = stat.IntValue();
+ metadata.set_allocation_bytes(stat.IntValue());
break;
case StatType::kAddress:
- metadata.address = stat.IntValue();
+ metadata.set_address(stat.IntValue());
break;
case StatType::kTfOp:
- metadata.tf_op_name = stat.StrOrRefValue();
+ metadata.set_tf_op_name(std::string(stat.StrOrRefValue()));
break;
case StatType::kGroupId:
- metadata.step_id = stat.IntValue();
+ metadata.set_step_id(stat.IntValue());
break;
case StatType::kRegionType:
- metadata.region_type = stat.StrOrRefValue();
+ metadata.set_region_type(std::string(stat.StrOrRefValue()));
break;
case StatType::kDataType:
- metadata.data_type = stat.IntValue();
+ metadata.set_data_type(tensorflow::DataTypeString(
+ static_cast<tensorflow::DataType>(stat.IntValue())));
break;
case StatType::kTensorShapes:
- metadata.tensor_shape = stat.StrOrRefValue();
+ metadata.set_tensor_shape(std::string(stat.StrOrRefValue()));
break;
}
});
- MemoryProfileSnapshot* snapshot =
- (*memory_profile.mutable_memory_profile_per_allocator())[memory_id]
- .add_memory_profile_snapshots();
- snapshot->set_time_offset_ps(event.OffsetPs());
- FillAggregationStats(stats, snapshot->mutable_aggregation_stats());
- FillActivityMetadata(event_type, metadata,
- snapshot->mutable_activity_metadata());
-
MemoryProfileSummary* summary =
(*memory_profile.mutable_memory_profile_per_allocator())[memory_id]
.mutable_profile_summary();
UpdateProfileSummary(stats, event.OffsetPs(), summary);
+
+ MemoryProfileSnapshot* snapshot =
+ (*memory_profile.mutable_memory_profile_per_allocator())[memory_id]
+ .add_memory_profile_snapshots();
+ snapshot->set_time_offset_ps(event.OffsetPs());
+ *snapshot->mutable_aggregation_stats() = std::move(stats);
+ *snapshot->mutable_activity_metadata() = std::move(metadata);
});
});
return memory_profile;
@@ -320,11 +275,16 @@
if (unmapped_allocation_bytes > 0) {
MemoryActivityMetadata* special_allocation =
memory_profile->add_special_allocations();
- FillActivityMetadata(HostEventType::kMemoryAllocation,
- {unmapped_allocation_bytes, unmapped_allocation_bytes,
- 0, "unused preallocated device memory", step_id,
- "persist/dynamic", 0, "unknown"},
- special_allocation);
+ special_allocation->set_memory_activity(ALLOCATION);
+ special_allocation->set_requested_bytes(unmapped_allocation_bytes);
+ special_allocation->set_allocation_bytes(unmapped_allocation_bytes);
+ special_allocation->set_address(0);
+ special_allocation->set_tf_op_name("unused preallocated device memory");
+ special_allocation->set_step_id(step_id);
+ special_allocation->set_region_type("persist/dynamic");
+ special_allocation->set_data_type(
+ tensorflow::DataTypeString(static_cast<tensorflow::DataType>(0)));
+ special_allocation->set_tensor_shape("unknown");
active_allocs->push_back({--index, special_allocation});
}
int64 stack_bytes =
@@ -332,10 +292,16 @@
if (stack_bytes > 0) {
MemoryActivityMetadata* special_allocation =
memory_profile->add_special_allocations();
- FillActivityMetadata(
- HostEventType::kMemoryAllocation,
- {stack_bytes, stack_bytes, 0, "stack", step_id, "stack", 0, "unknown"},
- special_allocation);
+ special_allocation->set_memory_activity(ALLOCATION);
+ special_allocation->set_requested_bytes(stack_bytes);
+ special_allocation->set_allocation_bytes(stack_bytes);
+ special_allocation->set_address(0);
+ special_allocation->set_tf_op_name("stack");
+ special_allocation->set_step_id(step_id);
+ special_allocation->set_region_type("stack");
+ special_allocation->set_data_type(
+ tensorflow::DataTypeString(static_cast<tensorflow::DataType>(0)));
+ special_allocation->set_tensor_shape("unknown");
active_allocs->push_back({--index, special_allocation});
}
}
diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
index 00f5cac..cc1fb8d 100644
--- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc
+++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
@@ -52,13 +52,13 @@
private:
bool recording_ = false;
const PythonHooksOptions options_;
+ std::unique_ptr<tensorflow::profiler::PythonHookContext> context_;
TF_DISALLOW_COPY_AND_ASSIGN(PythonTracer);
};
PythonTracer::~PythonTracer() {
Stop().IgnoreError();
- PythonHooks::GetSingleton()->Finalize(nullptr);
}
Status PythonTracer::Start() {
@@ -76,7 +76,7 @@
return errors::Internal("TraceMeRecorder not started");
}
VLOG(1) << __FUNCTION__;
- PythonHooks::GetSingleton()->Stop();
+ context_ = PythonHooks::GetSingleton()->Stop();
recording_ = false;
return Status::OK();
}
@@ -88,13 +88,16 @@
// We had assumed HostTracer::Stop is called when ProfilerSession try to
// serialize PythonTracer.
VLOG(2) << "Collecting data to RunMetaData from PythonTracer.";
- PythonHooks::GetSingleton()->Finalize(nullptr);
+ context_.reset();
return Status::OK();
}
Status PythonTracer::CollectData(XSpace* space) {
VLOG(2) << "Collecting data to XSpace from PythonTracer.";
- PythonHooks::GetSingleton()->Finalize(space);
+ if (context_) {
+ context_->Finalize(space);
+ context_.reset();
+ }
return Status::OK();
}
diff --git a/tensorflow/core/profiler/internal/tpu/BUILD b/tensorflow/core/profiler/internal/tpu/BUILD
index 67c5c34..0817946 100644
--- a/tensorflow/core/profiler/internal/tpu/BUILD
+++ b/tensorflow/core/profiler/internal/tpu/BUILD
@@ -23,6 +23,7 @@
"//tensorflow/core/profiler/utils:xplane_utils",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_api_dlsym_initializer",
+ "//tensorflow/core/tpu:tpu_initializer_helper",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor/tpu:status_helper",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
index 528432f..b69b5b7 100644
--- a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
+++ b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
@@ -29,6 +29,7 @@
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/core/tpu/tpu_initializer_helper.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
@@ -134,7 +135,9 @@
}
auto register_tpu_tracer_factory = [] {
- RegisterProfilerFactory(&CreateTpuTracer);
+ if (tensorflow::tpu::TryAcquireTpuLock()) {
+ RegisterProfilerFactory(&CreateTpuTracer);
+ }
return 0;
}();
diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD
index 538a3c8..9294464 100644
--- a/tensorflow/core/profiler/lib/BUILD
+++ b/tensorflow/core/profiler/lib/BUILD
@@ -227,6 +227,10 @@
hdrs = ["profiler_lock.h"],
copts = tf_profiler_copts(),
visibility = ["//tensorflow/core/profiler:internal"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
)
filegroup(
diff --git a/tensorflow/core/profiler/lib/profiler_lock.cc b/tensorflow/core/profiler/lib/profiler_lock.cc
index b276b00..a5b1ead 100644
--- a/tensorflow/core/profiler/lib/profiler_lock.cc
+++ b/tensorflow/core/profiler/lib/profiler_lock.cc
@@ -16,6 +16,9 @@
#include <atomic>
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/env_var.h"
+
namespace tensorflow {
namespace profiler {
@@ -23,7 +26,22 @@
// Prevents another profiler session from creating ProfilerInterface(s).
std::atomic<bool> session_active = ATOMIC_VAR_INIT(false);
-bool AcquireProfilerLock() { return !session_active.exchange(true); }
+bool AcquireProfilerLock() {
+ // Use environment variable to permanently lock the profiler.
+ // This allows running TensorFlow under an external profiling tool with all
+ // built-in profiling disabled.
+ static bool tf_profiler_disabled = [] {
+ bool disabled = false;
+ ReadBoolFromEnvVar("TF_DISABLE_PROFILING", false, &disabled).IgnoreError();
+ return disabled;
+ }();
+ if (TF_PREDICT_FALSE(tf_profiler_disabled)) {
+ LOG(WARNING) << "TensorFlow Profiler is permanently disabled by env var "
+ "TF_DISABLE_PROFILING.";
+ return false;
+ }
+ return !session_active.exchange(true);
+}
void ReleaseProfilerLock() { session_active.store(false); }
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 4bf996da9..57d2388 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -108,7 +108,7 @@
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 673 // Updated: 2021/2/10
+#define TF_GRAPH_DEF_VERSION 680 // Updated: 2021/2/17
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 8753f9e..a9faa18 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -11,6 +11,7 @@
"//tensorflow/compiler/mlir/tensorflow:__subpackages__",
"//tensorflow/compiler/tf2xla/kernels:__subpackages__",
"//tensorflow/compiler/xrt:__subpackages__",
+ "//tensorflow/core/profiler/internal/tpu:__subpackages__",
"//tensorflow/core/tpu:__subpackages__",
"//tensorflow/stream_executor/tpu:__subpackages__",
],
@@ -105,7 +106,11 @@
name = "tpu_initializer_helper",
srcs = ["tpu_initializer_helper.cc"],
hdrs = ["tpu_initializer_helper.h"],
- deps = ["@com_google_absl//absl/strings"],
+ deps = [
+ "//tensorflow/core/platform:logging",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ ],
)
cc_library(
@@ -314,3 +319,23 @@
"//tensorflow/core/lib/core:status",
],
)
+
+cc_library(
+ name = "tpu_model_server_initializer",
+ srcs = ["tpu_model_server_initializer.cc"],
+ hdrs = ["tpu_model_server_initializer.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":libtftpu_header",
+ ":tpu_api",
+ ":tpu_api_dlsym_set_fn",
+ ":tpu_executor_init_fns",
+ ":tpu_initializer_helper",
+ ":tpu_library_init_fns",
+ ":tpu_ops_c_api_hdrs",
+ "//tensorflow/core:lib",
+ "//tensorflow/stream_executor/tpu:tpu_executor",
+ "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
+ ],
+ alwayslink = True,
+)
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
index aed0add..a183c3d 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
@@ -3934,6 +3934,7 @@
TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
device_set, tpu_compilation_device,
&num_tpus_per_task, &tpu_devices));
+ *num_tasks = tpu_devices.size();
string topology;
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
index 4c67d59..eb32962 100644
--- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
@@ -65,6 +65,8 @@
}
bool FindAndLoadTpuLibrary() {
+ if (!TryAcquireTpuLock()) return false;
+
void* library = dlopen("libtpu.so", RTLD_NOW);
if (library) {
InitializeTpuLibrary(library);
diff --git a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
index 8c2ae85..d3a70da 100644
--- a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
@@ -62,6 +62,8 @@
}
bool FindAndLoadTpuLibrary() {
+ if (!TryAcquireTpuLock()) return false;
+
void* library = dlopen("libtpu.so", RTLD_NOW);
if (library) {
InitializeTpuLibrary(library);
diff --git a/tensorflow/core/tpu/tpu_initializer_helper.cc b/tensorflow/core/tpu/tpu_initializer_helper.cc
index c97a09b..518c961 100644
--- a/tensorflow/core/tpu/tpu_initializer_helper.cc
+++ b/tensorflow/core/tpu/tpu_initializer_helper.cc
@@ -15,13 +15,56 @@
#include "tensorflow/core/tpu/tpu_initializer_helper.h"
+#if defined(LIBTPU_ON_GCE)
+#include <fcntl.h>
#include <stdlib.h>
+#include <unistd.h>
+#endif // LIBTPU_ON_GCE
#include "absl/strings/str_split.h"
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace tpu {
+bool TryAcquireTpuLock() {
+#if defined(LIBTPU_ON_GCE)
+ static absl::Mutex* mu = new absl::Mutex();
+ absl::MutexLock l(mu);
+
+ static bool attempted_file_open = false;
+ static bool should_load_library = false;
+
+ if (!attempted_file_open) {
+ should_load_library = true;
+
+ // if the TPU_HOST_BOUNDS env var is set, that means we are loading each
+ // chip in a different process and thus multiple libtpu loads are OK.
+ if (getenv("TPU_HOST_BOUNDS") == nullptr) {
+ int fd = open("/tmp/libtpu_lockfile", O_CREAT | O_RDWR, 0644);
+
+ // This lock is held until the process exits intentionally. The underlying
+ // TPU device will be held on until it quits.
+ if (lockf(fd, F_TLOCK, 0) != 0) {
+ LOG(WARNING) << "libtpu.so already in used by another process. Not "
+ "attempting to load libtpu.so in this process.";
+ should_load_library = false;
+ } else {
+ should_load_library = true;
+ }
+ } else {
+ LOG(INFO) << "TPU_HOST_BOUNDS is set, allowing multiple libtpu.so loads.";
+ should_load_library = true;
+ }
+ }
+
+ return should_load_library;
+#else // LIBTPU_ON_GCE
+ return false;
+#endif
+}
+
std::pair<std::vector<std::string>, std::vector<const char*>>
GetLibTpuInitArguments() {
// We make copies of the arguments returned by getenv because the memory
diff --git a/tensorflow/core/tpu/tpu_initializer_helper.h b/tensorflow/core/tpu/tpu_initializer_helper.h
index cd9b419..3abad8b 100644
--- a/tensorflow/core/tpu/tpu_initializer_helper.h
+++ b/tensorflow/core/tpu/tpu_initializer_helper.h
@@ -22,6 +22,11 @@
namespace tensorflow {
namespace tpu {
+// This will acquire a system-wide lock on behalf of the whole process. Follow
+// up calls to this function will return true if the lock has been acquired and
+// false if we failed to acquire the lock.
+bool TryAcquireTpuLock();
+
// Returns arguments (e.g. flags) set in the LIBTPU_INIT_ARGS environment
// variable. The first return value is the arguments, the second return value is
// pointers to the arguments suitable for passing into the C API.
diff --git a/tensorflow/core/tpu/tpu_model_server_initializer.cc b/tensorflow/core/tpu/tpu_model_server_initializer.cc
new file mode 100644
index 0000000..554d115
--- /dev/null
+++ b/tensorflow/core/tpu/tpu_model_server_initializer.cc
@@ -0,0 +1,78 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/tpu_model_server_initializer.h"
+
+#include <dlfcn.h>
+
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tpu/tpu_api_dlsym_set_fn.h"
+
+#if !defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/core/tpu/tpu_initializer_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+#endif
+
+namespace tensorflow {
+namespace tpu {
+
+
+#if defined(PLATFORM_GOOGLE)
+Status InitializeTpuModelServer(void* library_handle) {
+ return errors::Unimplemented("You must statically link in a TPU library.");
+}
+#else // PLATFORM_GOOGLE
+#include "tensorflow/core/tpu/tpu_library_init_fns.inc"
+
+Status InitializeTpuModelServer(void* library_handle) {
+ Status s = InitializeTpuStructFns(library_handle);
+
+ // Retrieve arguments from environment if applicable
+ std::pair<std::vector<std::string>, std::vector<const char*> > args =
+ GetLibTpuInitArguments();
+
+ // TPU platform registration must only be performed after the library is
+ // loaded. We do not want to register a TPU platform in XLA without the
+ // supporting library providing the necessary APIs.
+ if (s.ok()) {
+ void (*initialize_fn)(bool init_library, int num_args, const char** args);
+ initialize_fn = reinterpret_cast<decltype(initialize_fn)>(
+ dlsym(library_handle, "TfTpu_Initialize"));
+ (*initialize_fn)(/*init_library=*/true, args.second.size(),
+ args.second.data());
+
+ RegisterTpuPlatform();
+ }
+
+ OpsApiFn()->TfTpu_InitializeTpuModelServerFn();
+ return s;
+}
+
+bool FindAndLoadTpuModelServer() {
+ if (!TryAcquireTpuLock()) return false;
+ void* library = dlopen("libtpu.so", RTLD_NOW);
+ if (library) {
+ InitializeTpuModelServer(library);
+ }
+ return true;
+}
+
+static bool tpu_library_finder = FindAndLoadTpuModelServer();
+#endif // PLATFORM_GOOGLE
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/tpu_model_server_initializer.h b/tensorflow/core/tpu/tpu_model_server_initializer.h
new file mode 100644
index 0000000..d85786c
--- /dev/null
+++ b/tensorflow/core/tpu/tpu_model_server_initializer.h
@@ -0,0 +1,33 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_TPU_MODEL_SERVER_INITIALIZER_H_
+#define TENSORFLOW_CORE_TPU_TPU_MODEL_SERVER_INITIALIZER_H_
+
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tpu/libtftpu.h"
+#include "tensorflow/core/tpu/tpu_ops_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+
+namespace tensorflow {
+namespace tpu {
+
+Status InitializeTpuModelServer(void* library_handle);
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_TPU_MODEL_SERVER_INITIALIZER_H_
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 72f8fb6..2dc55ef 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -1827,33 +1827,33 @@
// For example:
//
// ```
-// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
-// y, idx, count = unique_with_counts(x)
+// x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
+// y, idx, count = UniqueWithCountsV2(x, axis = [0])
// y ==> [1, 2, 4, 7, 8]
// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
// count ==> [2, 1, 3, 1, 2]
// ```
//
-// For an `2-D` tensor `x` with `axis = 0`:
+// For a `2-D` tensor `x` with `axis = 0`:
//
// ```
-// # tensor 'x' is [[1, 0, 0],
-// # [1, 0, 0],
-// # [2, 0, 0]]
-// y, idx, count = unique_with_counts(x, axis=0)
+// x = tf.constant([[1, 0, 0],
+// [1, 0, 0],
+// [2, 0, 0]])
+// y, idx, count = UniqueWithCountsV2(x, axis=[0])
// y ==> [[1, 0, 0],
// [2, 0, 0]]
// idx ==> [0, 0, 1]
// count ==> [2, 1]
// ```
//
-// For an `2-D` tensor `x` with `axis = 1`:
+// For a `2-D` tensor `x` with `axis = 1`:
//
// ```
-// # tensor 'x' is [[1, 0, 0],
-// # [1, 0, 0],
-// # [2, 0, 0]]
-// y, idx, count = unique_with_counts(x, axis=1)
+// x = tf.constant([[1, 0, 0],
+// [1, 0, 0],
+// [2, 0, 0]])
+// y, idx, count = UniqueWithCountsV2(x, axis=[1])
// y ==> [[1, 0],
// [1, 0],
// [2, 0]]
@@ -32571,8 +32571,8 @@
//
// See: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
// Example usage:
-// >>> tf.nn.relu([-2., 0., -0., 3.]).numpy()
-// array([ 0., 0., -0., 3.], dtype=float32)
+// >>> tf.nn.relu([-2., 0., 3.]).numpy()
+// array([0., 0., 3.], dtype=float32)
func Relu(scope *Scope, features tf.Output) (activations tf.Output) {
if scope.Err() != nil {
return
@@ -41793,6 +41793,14 @@
}
}
+// InitializeTableFromTextFileV2Offset sets the optional offset attribute to value.
+// If not specified, defaults to 0
+func InitializeTableFromTextFileV2Offset(value int64) InitializeTableFromTextFileV2Attr {
+ return func(m optionalAttr) {
+ m["offset"] = value
+ }
+}
+
// Initializes a table from a text file.
//
// It inserts one key-value pair into the table for each line of the file.
diff --git a/tensorflow/lite/c/c_api_types.h b/tensorflow/lite/c/c_api_types.h
index d6dc514..0128477 100644
--- a/tensorflow/lite/c/c_api_types.h
+++ b/tensorflow/lite/c/c_api_types.h
@@ -75,6 +75,7 @@
kTfLiteUInt64 = 13,
kTfLiteResource = 14,
kTfLiteVariant = 15,
+ kTfLiteUInt32 = 16,
} TfLiteType;
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c
index d47ec4e..aaa98a9 100644
--- a/tensorflow/lite/c/common.c
+++ b/tensorflow/lite/c/common.c
@@ -199,6 +199,8 @@
return "INT16";
case kTfLiteInt32:
return "INT32";
+ case kTfLiteUInt32:
+ return "UINT32";
case kTfLiteUInt8:
return "UINT8";
case kTfLiteInt8:
diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h
index 59ad977..e7d97ed 100644
--- a/tensorflow/lite/c/common.h
+++ b/tensorflow/lite/c/common.h
@@ -296,6 +296,7 @@
* GetTensorData<TYPE>(tensor) instead, otherwise only access .data, as other
* members are deprecated. */
int32_t* i32;
+ uint32_t* u32;
int64_t* i64;
uint64_t* u64;
float* f;
diff --git a/tensorflow/lite/c/common_test.cc b/tensorflow/lite/c/common_test.cc
index e8425f1..7a45db1 100644
--- a/tensorflow/lite/c/common_test.cc
+++ b/tensorflow/lite/c/common_test.cc
@@ -83,6 +83,7 @@
EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16");
EXPECT_EQ(type_name(kTfLiteInt16), "INT16");
EXPECT_EQ(type_name(kTfLiteInt32), "INT32");
+ EXPECT_EQ(type_name(kTfLiteUInt32), "UINT32");
EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8");
EXPECT_EQ(type_name(kTfLiteUInt64), "UINT64");
EXPECT_EQ(type_name(kTfLiteInt8), "INT8");
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index 73e841f..40b0c9b 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -851,6 +851,9 @@
case TensorType_INT32:
*type = kTfLiteInt32;
return kTfLiteOk;
+ case TensorType_UINT32:
+ *type = kTfLiteUInt32;
+ return kTfLiteOk;
case TensorType_UINT8:
*type = kTfLiteUInt8;
return kTfLiteOk;
@@ -1990,6 +1993,14 @@
void**) {
return kTfLiteOk;
}
+//
+// We have this parse function instead of directly returning kTfLiteOk from the
+// switch-case in ParseOpData because this function is used as part of the
+// selective registration for the OpResolver implementation in micro.
+TfLiteStatus ParseTranspose(const Operator*, ErrorReporter*,
+ BuiltinDataAllocator*, void**) {
+ return kTfLiteOk;
+}
TfLiteStatus ParseTransposeConv(const Operator* op,
ErrorReporter* error_reporter,
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h
index 57276de..a1f5856 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.h
@@ -324,6 +324,10 @@
TfLiteStatus ParseTanh(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
+TfLiteStatus ParseTranspose(const Operator* op, ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator,
+ void** builtin_data);
+
TfLiteStatus ParseTransposeConv(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
diff --git a/tensorflow/lite/core/shims/BUILD b/tensorflow/lite/core/shims/BUILD
index 12bf5a9..2ad8356 100644
--- a/tensorflow/lite/core/shims/BUILD
+++ b/tensorflow/lite/core/shims/BUILD
@@ -223,5 +223,21 @@
)
#------------------------------------------------------------------------------
+# JNI bindings (Java API and Java Tasks library)
+
+# Contains code to initialize TFLite through JNI in the internal version.
+cc_library(
+ name = "jni_initialization",
+ srcs = [],
+ # Prevent automated tools from removing this target as a dependency due to
+ # it being empty.
+ tags = ["keep_dep"],
+ visibility = [
+ "//tensorflow/lite:__subpackages__",
+ "//tensorflow_lite_support:__subpackages__",
+ ],
+)
+
+#------------------------------------------------------------------------------
tflite_portable_test_suite()
diff --git a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
index 95af889..e53e9f7 100644
--- a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
+++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
@@ -579,6 +579,7 @@
"StridedSlice",
"StridedSliceAssign",
"StridedSliceGrad",
+ "StringFormat",
"StringJoin",
"StringLength",
"StringLower",
diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc
index 2ba9161..ffb5bc2 100644
--- a/tensorflow/lite/delegates/flex/util.cc
+++ b/tensorflow/lite/delegates/flex/util.cc
@@ -68,6 +68,8 @@
return TF_INT16;
case kTfLiteInt32:
return TF_INT32;
+ case kTfLiteUInt32:
+ return TF_UINT32;
case kTfLiteUInt8:
return TF_UINT8;
case kTfLiteInt8:
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
index 15fb046..17803ca 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
@@ -106,7 +106,7 @@
":cl_test",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
- "//tensorflow/lite/delegates/gpu/common/tasks:conv_powervr",
+ "//tensorflow/lite/delegates/gpu/common/tasks:conv_powervr_test_util",
"@com_google_googletest//:gtest_main",
],
)
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr_test.cc
index 0b2792d..0a97589 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr_test.cc
@@ -13,8 +13,6 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h"
-
#include <vector>
#include <gmock/gmock.h>
@@ -22,164 +20,32 @@
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
-
-using ::testing::FloatNear;
-using ::testing::Pointwise;
+#include "tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.h"
namespace tflite {
namespace gpu {
namespace cl {
-namespace {
TEST_F(OpenCLOperationTest, ConvPowerVR1x1SimpleWeights) {
- TensorFloat32 src_tensor;
- src_tensor.shape = BHWC(1, 2, 2, 2);
- src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
-
- Convolution2DAttributes attr;
- attr.padding.prepended = HW(0, 0);
- attr.padding.appended = HW(0, 0);
- attr.strides = HW(1, 1);
- attr.dilations = HW(1, 1);
- attr.weights.shape = OHWI(2, 1, 1, 2);
- attr.weights.data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
- attr.bias.shape = Linear(1);
- attr.bias.data = {0.0f};
-
- for (auto storage : env_.GetSupportedStorages()) {
- for (auto precision : env_.GetSupportedPrecisions()) {
- const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
- OperationDef op_def;
- op_def.precision = precision;
- auto data_type = DeduceDataTypeFromPrecision(precision);
- op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
- op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
- TensorFloat32 dst_tensor;
- ConvPowerVR operation =
- CreateConvPowerVR(creation_context_.GetGpuInfo(), op_def, attr);
- ASSERT_OK(ExecuteGPUOperation(
- src_tensor, creation_context_,
- absl::make_unique<ConvPowerVR>(std::move(operation)),
- BHWC(1, 2, 2, 2), &dst_tensor));
- EXPECT_THAT(dst_tensor.data,
- Pointwise(FloatNear(eps), {1.0f, 1.0f, 5.0f, 5.0f, 9.0f, 9.0f,
- 13.0f, 13.0f}));
- }
- }
+ const auto status = ConvPowerVR1x1SimpleWeightsTest(&exec_env_);
+ ASSERT_TRUE(status.ok()) << status.error_message();
}
TEST_F(OpenCLOperationTest, ConvPowerVR1x1) {
- TensorFloat32 src_tensor;
- src_tensor.shape = BHWC(1, 2, 2, 2);
- src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
-
- Convolution2DAttributes attr;
- attr.padding.prepended = HW(0, 0);
- attr.padding.appended = HW(0, 0);
- attr.strides = HW(1, 1);
- attr.dilations = HW(1, 1);
- attr.weights.shape = OHWI(2, 1, 1, 2);
- attr.weights.data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
- attr.bias.shape = Linear(2);
- attr.bias.data = {0.5f, -0.5f};
-
- for (auto storage : env_.GetSupportedStorages()) {
- for (auto precision : env_.GetSupportedPrecisions()) {
- const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
- OperationDef op_def;
- op_def.precision = precision;
- auto data_type = DeduceDataTypeFromPrecision(precision);
- op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
- op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
- TensorFloat32 dst_tensor;
- ConvPowerVR operation =
- CreateConvPowerVR(creation_context_.GetGpuInfo(), op_def, attr);
- ASSERT_OK(ExecuteGPUOperation(
- src_tensor, creation_context_,
- absl::make_unique<ConvPowerVR>(std::move(operation)),
- BHWC(1, 2, 2, 2), &dst_tensor));
- EXPECT_THAT(dst_tensor.data,
- Pointwise(FloatNear(eps), {2.5f, 3.5f, 8.5f, 17.5f, 14.5f,
- 31.5f, 20.5f, 45.5f}));
- }
- }
+ const auto status = ConvPowerVR1x1Test(&exec_env_);
+ ASSERT_TRUE(status.ok()) << status.error_message();
}
TEST_F(OpenCLOperationTest, ConvPowerVRSimpleWeights) {
- TensorFloat32 src_tensor;
- src_tensor.shape = BHWC(1, 2, 2, 2);
- src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
-
- Convolution2DAttributes attr;
- attr.padding.prepended = HW(0, 0);
- attr.padding.appended = HW(1, 1);
- attr.strides = HW(1, 1);
- attr.dilations = HW(1, 1);
- attr.weights.shape = OHWI(1, 2, 2, 2);
- attr.weights.data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
- attr.bias.shape = Linear(1);
- attr.bias.data = {0.0f};
-
- for (auto storage : env_.GetSupportedStorages()) {
- for (auto precision : env_.GetSupportedPrecisions()) {
- const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
- OperationDef op_def;
- op_def.precision = precision;
- auto data_type = DeduceDataTypeFromPrecision(precision);
- op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
- op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
- TensorFloat32 dst_tensor;
- ConvPowerVR operation =
- CreateConvPowerVR(creation_context_.GetGpuInfo(), op_def, attr);
- ASSERT_OK(ExecuteGPUOperation(
- src_tensor, creation_context_,
- absl::make_unique<ConvPowerVR>(std::move(operation)),
- BHWC(1, 2, 2, 1), &dst_tensor));
- EXPECT_THAT(dst_tensor.data,
- Pointwise(FloatNear(eps), {28.0f, 18.0f, 22.0f, 13.0f}));
- }
- }
+ const auto status = ConvPowerVRSimpleWeightsTest(&exec_env_);
+ ASSERT_TRUE(status.ok()) << status.error_message();
}
TEST_F(OpenCLOperationTest, ConvPowerVR) {
- TensorFloat32 src_tensor;
- src_tensor.shape = BHWC(1, 2, 2, 2);
- src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
-
- Convolution2DAttributes attr;
- attr.padding.prepended = HW(0, 0);
- attr.padding.appended = HW(1, 1);
- attr.strides = HW(1, 1);
- attr.dilations = HW(1, 1);
- attr.weights.shape = OHWI(2, 2, 2, 2);
- attr.weights.data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
- 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f};
- attr.bias.shape = Linear(2);
- attr.bias.data = {0.5f, -0.5f};
-
- for (auto storage : env_.GetSupportedStorages()) {
- for (auto precision : env_.GetSupportedPrecisions()) {
- const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
- OperationDef op_def;
- op_def.precision = precision;
- auto data_type = DeduceDataTypeFromPrecision(precision);
- op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
- op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
- TensorFloat32 dst_tensor;
- ConvPowerVR operation =
- CreateConvPowerVR(creation_context_.GetGpuInfo(), op_def, attr);
- ASSERT_OK(ExecuteGPUOperation(
- src_tensor, creation_context_,
- absl::make_unique<ConvPowerVR>(std::move(operation)),
- BHWC(1, 2, 2, 2), &dst_tensor));
- EXPECT_THAT(dst_tensor.data,
- Pointwise(FloatNear(eps), {168.5f, 391.5f, 80.5f, 223.5f,
- 60.5f, 235.5f, 20.5f, 123.5f}));
- }
- }
+ const auto status = ConvPowerVRTest(&exec_env_);
+ ASSERT_TRUE(status.ok()) << status.error_message();
}
-} // namespace
} // namespace cl
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index dad780f..6351f8d 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -1819,6 +1819,40 @@
}
};
+class SplitVOperationParser : public TFLiteOperationParser {
+ public:
+ absl::Status IsSupported(const TfLiteContext* context,
+ const TfLiteNode* tflite_node,
+ const TfLiteRegistration* registration) final {
+ const TfLiteSplitVParams* split_params;
+ RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
+ if (split_params->num_splits == 1) {
+ return absl::InvalidArgumentError(
+ "SplitV with num_splits = 1 is a no-op.");
+ }
+ return absl::OkStatus();
+ }
+
+ absl::Status Parse(const TfLiteNode* tflite_node,
+ const TfLiteRegistration* registration,
+ GraphFloat32* graph, ObjectReader* reader) final {
+ const TfLiteTensor* input = reader->GetInputTensor(0);
+ const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
+ SplitAttributes attr;
+ RETURN_IF_ERROR(
+ ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
+
+ Node* node = graph->NewNode();
+ node->operation.type = ToString(OperationType::SPLIT);
+ node->operation.attributes = attr;
+ RETURN_IF_ERROR(reader->AddInput(node, 0));
+ for (int i = 0; i < tflite_node->outputs->size; ++i) {
+ RETURN_IF_ERROR(reader->AddOutput(node, i));
+ }
+ return absl::OkStatus();
+ }
+};
+
class StridedSliceOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@@ -2382,6 +2416,8 @@
return std::make_unique<SoftmaxOperationParser>();
case kTfLiteBuiltinSpaceToDepth:
return std::make_unique<SpaceToDepthOperationParser>();
+ case kTfLiteBuiltinSplitV:
+ return std::make_unique<SplitVOperationParser>();
case kTfLiteBuiltinSqrt:
return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
case kTfLiteBuiltinSquare:
diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc
index 958df8d..1e94754 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.cc
+++ b/tensorflow/lite/delegates/gpu/common/operations.cc
@@ -176,6 +176,8 @@
return "space_to_batch";
case OperationType::SPACE_TO_DEPTH:
return "space_to_depth";
+ case OperationType::SPLIT:
+ return "split";
case OperationType::SQRT:
return "sqrt";
case OperationType::SQUARE:
@@ -246,6 +248,7 @@
{"slice", OperationType::SLICE},
{"softmax", OperationType::SOFTMAX},
{"space_to_depth", OperationType::SPACE_TO_DEPTH},
+ {"split", OperationType::SPLIT},
{"sqrt", OperationType::SQRT},
{"square", OperationType::SQUARE},
{"squared_diff", OperationType::SQUARED_DIFF},
diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h
index 47787f9..90312cb 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.h
+++ b/tensorflow/lite/delegates/gpu/common/operations.h
@@ -85,6 +85,7 @@
SOFTMAX,
SPACE_TO_BATCH,
SPACE_TO_DEPTH,
+ SPLIT,
SQRT,
SQUARE,
SQUARED_DIFF,
@@ -547,6 +548,11 @@
int block_size;
};
+struct SplitAttributes {
+ // Defines axis by which to split.
+ Axis axis = Axis::UNKNOWN;
+};
+
// These help perform a combination of Quantize & Dequantize to adjust float
// values like quantized inference would.
struct QuantizeAndDequantizeAttributes {
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/BUILD
index 9304c13..e21bea2 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/selectors/BUILD
@@ -136,6 +136,7 @@
"//tensorflow/lite/delegates/gpu/common/tasks:softmax",
"//tensorflow/lite/delegates/gpu/common/tasks:softmax1x1",
"//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth",
+ "//tensorflow/lite/delegates/gpu/common/tasks:split",
"//tensorflow/lite/delegates/gpu/common/tasks:strided_slice",
"//tensorflow/lite/delegates/gpu/common/tasks:transpose",
"//tensorflow/lite/delegates/gpu/common/tasks:winograd",
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/default/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/default/BUILD
index 46515d0..aa52f3e 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/default/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/BUILD
@@ -18,6 +18,7 @@
"//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
"//tensorflow/lite/delegates/gpu/common/tasks:conv_buffer_1x1",
"//tensorflow/lite/delegates/gpu/common/tasks:conv_constants",
+ "//tensorflow/lite/delegates/gpu/common/tasks:conv_metal",
"//tensorflow/lite/delegates/gpu/common/tasks:conv_powervr",
"//tensorflow/lite/delegates/gpu/common/tasks:conv_weights_converter",
"@com_google_absl//absl/memory",
@@ -79,6 +80,7 @@
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
"//tensorflow/lite/delegates/gpu/common/tasks:conv_buffer_1x1",
+ "//tensorflow/lite/delegates/gpu/common/tasks:conv_metal",
"//tensorflow/lite/delegates/gpu/common/tasks:conv_powervr",
"//tensorflow/lite/delegates/gpu/common/tasks:fully_connected",
"@com_google_absl//absl/memory",
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc
index 9f0fdb5..be2a12a 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc
@@ -24,6 +24,7 @@
#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/conv_buffer_1x1.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/conv_constants.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/conv_weights_converter.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
@@ -137,10 +138,14 @@
const Convolution2DAttributes& attr, const BHWC& dst_shape,
const GpuInfo& gpu_info, const OperationDef& op_def,
ModelHints hints) {
- if (gpu_info.IsAdreno()) {
+ if (gpu_info.IsApiMetal() && IsConvolutionMetalSupported(op_def)) {
+ ConvolutionMetal conv =
+ CreateConvolutionMetal(op_def, dst_shape, attr, gpu_info);
+ return absl::make_unique<ConvolutionMetal>(std::move(conv));
+ } else if (gpu_info.IsAdreno()) {
return SelectConvolutionAdreno(attr, dst_shape, gpu_info, op_def, hints);
- } else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() ||
- gpu_info.IsIntel()) {
+ } else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() || gpu_info.IsIntel() ||
+ gpu_info.IsApple()) {
return SelectConvolutionPowerVR(attr, gpu_info, op_def);
} else if (gpu_info.IsNvidia()) {
return SelectConvolutionNVidia(attr, dst_shape, gpu_info, op_def);
@@ -155,11 +160,15 @@
const Convolution2DAttributes& attr, const BHWC& dst_shape,
const GpuInfo& gpu_info, const OperationDef& op_def,
ModelHints hints) {
- if (gpu_info.IsAdreno()) {
+ if (gpu_info.IsApiMetal() && IsConvolutionMetalSupported(op_def)) {
+ ConvolutionMetal conv =
+ CreateConvolutionMetalWino4x4To6x6(op_def, dst_shape, attr, gpu_info);
+ return absl::make_unique<ConvolutionMetal>(std::move(conv));
+ } else if (gpu_info.IsAdreno()) {
return SelectConvolutionWinogradAdreno(attr, dst_shape, gpu_info, op_def,
hints);
- } else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() ||
- gpu_info.IsNvidia() || gpu_info.IsIntel()) {
+ } else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() || gpu_info.IsNvidia() ||
+ gpu_info.IsIntel() || gpu_info.IsApple()) {
ConvPowerVR conv =
CreateConvPowerVRWino4x4To6x6(gpu_info, op_def, attr, &dst_shape);
return absl::make_unique<ConvPowerVR>(std::move(conv));
@@ -176,7 +185,14 @@
const BHWC& dst_shape, const GpuInfo& gpu_info,
const OperationDef& op_def, ModelHints hints,
WeightsDescription* weights_desc) {
- if (gpu_info.IsAdreno()) {
+ if (gpu_info.IsApiMetal() && IsConvolutionMetalSupported(op_def)) {
+ Convolution2DAttributes attr_copy = attr;
+ attr_copy.weights.shape = OHWI(weights_shape.b, weights_shape.h,
+ weights_shape.w, weights_shape.c);
+ ConvolutionMetal conv =
+ CreateConvolutionMetal(op_def, dst_shape, attr_copy, gpu_info);
+ return absl::make_unique<ConvolutionMetal>(std::move(conv));
+ } else if (gpu_info.IsAdreno()) {
return SelectConvolutionDynamicWeightsAdreno(attr, weights_shape, dst_shape,
gpu_info, op_def, hints,
weights_desc);
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc
index a2fbe37..c03556d 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc
@@ -17,6 +17,7 @@
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/conv_buffer_1x1.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/fully_connected.h"
@@ -83,7 +84,24 @@
std::unique_ptr<GPUOperation> SelectFullyConnected(
const FullyConnectedAttributes& attr, const GpuInfo& gpu_info,
const OperationDef& op_def, int batch_size) {
- if (gpu_info.IsAdreno()) {
+ if (gpu_info.IsApiMetal()) {
+ if (op_def.IsBatchSupported() && IsConvolutionMetalSupported(op_def)) {
+ BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o);
+ Convolution2DAttributes conv_attr;
+ conv_attr.padding.prepended = HW(0, 0);
+ conv_attr.padding.appended = HW(0, 0);
+ conv_attr.strides = HW(1, 1);
+ conv_attr.dilations = HW(1, 1);
+ conv_attr.weights = attr.weights;
+ conv_attr.bias = attr.bias;
+ ConvolutionMetal conv =
+ CreateConvolutionMetal(op_def, dst_shape, conv_attr, gpu_info);
+ return absl::make_unique<ConvolutionMetal>(std::move(conv));
+ } else {
+ FullyConnected fc = CreateFullyConnected(gpu_info, op_def, attr);
+ return absl::make_unique<FullyConnected>(std::move(fc));
+ }
+ } else if (gpu_info.IsAdreno()) {
return SelectFullyConnectedAdreno(attr, gpu_info, op_def, batch_size);
} else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() || gpu_info.IsNvidia() ||
gpu_info.IsIntel() || gpu_info.IsApple()) {
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
index 959855c..201d7c9 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
@@ -48,9 +48,21 @@
const int total_tiles = tiles_x * tiles_y;
const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
- // Mali among other devices has smaller SIMD line size
- int min_depth = gpu_info.IsMali() ? 16 : 32;
- const int min_tiles = gpu_info.IsMali() ? 32 : 128;
+ int min_depth = 16;
+ if (gpu_info.IsAdreno() || gpu_info.IsAMD()) {
+ min_depth = 32;
+ }
+ int min_tiles = 32;
+ if (gpu_info.IsAdreno()) {
+ if (gpu_info.adreno_info.IsAdreno6xx()) {
+ min_tiles = 128;
+ } else {
+ min_tiles = 64;
+ }
+ }
+ if (gpu_info.IsAMD()) {
+ min_tiles = 64;
+ }
if (total_tiles >= min_tiles * 8) {
min_depth /= 4;
min_depth = std::max(min_depth, 8);
@@ -59,7 +71,7 @@
min_depth = std::max(min_depth, 8);
}
const bool recommended_channels =
- dst_depth % 4 == 0 && src_depth >= min_depth && dst_depth >= min_depth;
+ src_depth >= min_depth && dst_depth >= min_depth;
const bool recommended_hw = total_tiles >= min_tiles;
return recommended_channels && recommended_hw;
}
@@ -479,6 +491,11 @@
SelectSpaceToDepth(attr, op_def, gpu_op);
return absl::OkStatus();
}
+ case OperationType::SPLIT: {
+ auto attr = absl::any_cast<SplitAttributes>(node.operation.attributes);
+ RETURN_IF_ERROR(SelectSplit(attr, op_def, gpu_op));
+ return absl::OkStatus();
+ }
case OperationType::TRANSPOSE: {
auto attr =
absl::any_cast<TransposeAttributes>(node.operation.attributes);
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
index 3834fa3..237d8b6 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
@@ -38,6 +38,7 @@
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/space_to_depth.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/strided_slice.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/transpose.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/winograd.h"
@@ -134,6 +135,17 @@
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
}
+absl::Status SelectSplit(const SplitAttributes& attr,
+ const OperationDef& op_def,
+ std::unique_ptr<GPUOperation>* ptr) {
+ if (attr.axis != Axis::CHANNELS) {
+ return absl::UnimplementedError("No split for this axis.");
+ }
+ Split operation = CreateSplit(op_def, attr);
+ *ptr = absl::make_unique<Split>(std::move(operation));
+ return absl::OkStatus();
+}
+
void SelectPadding(const PadAttributes& attr, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
GPUOperation operation = CreatePadding(op_def, attr);
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
index 4f757a8..42d8045 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
@@ -82,6 +82,10 @@
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
+absl::Status SelectSplit(const SplitAttributes& attr,
+ const OperationDef& op_def,
+ std::unique_ptr<GPUOperation>* ptr);
+
void SelectTranspose(const TransposeAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/BUILD b/tensorflow/lite/delegates/gpu/common/tasks/BUILD
index c086937..12ef26f 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/tasks/BUILD
@@ -112,6 +112,26 @@
)
cc_library(
+ name = "conv_metal",
+ srcs = ["conv_metal.cc"],
+ hdrs = ["conv_metal.h"],
+ deps = [
+ "//tensorflow/lite/delegates/gpu/common:data_type",
+ "//tensorflow/lite/delegates/gpu/common:gpu_info",
+ "//tensorflow/lite/delegates/gpu/common:operations",
+ "//tensorflow/lite/delegates/gpu/common:shape",
+ "//tensorflow/lite/delegates/gpu/common:types",
+ "//tensorflow/lite/delegates/gpu/common:util",
+ "//tensorflow/lite/delegates/gpu/common:winograd_util",
+ "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+ "//tensorflow/lite/delegates/gpu/common/task:util",
+ "//tensorflow/lite/delegates/gpu/common/task:weights_conversion",
+ "//tensorflow/lite/delegates/gpu/common/task:weights_layout",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
name = "conv_powervr",
srcs = ["conv_powervr.cc"],
hdrs = ["conv_powervr.h"],
@@ -137,6 +157,19 @@
)
cc_library(
+ name = "conv_powervr_test_util",
+ testonly = 1,
+ srcs = ["conv_powervr_test_util.cc"],
+ hdrs = ["conv_powervr_test_util.h"],
+ deps = [
+ ":conv_powervr",
+ "//tensorflow/lite/delegates/gpu/common:operations",
+ "//tensorflow/lite/delegates/gpu/common:status",
+ "//tensorflow/lite/delegates/gpu/common/task:testing_util",
+ ],
+)
+
+cc_library(
name = "conv_weights_converter",
srcs = ["conv_weights_converter.cc"],
hdrs = ["conv_weights_converter.h"],
@@ -829,6 +862,19 @@
)
cc_library(
+ name = "split",
+ srcs = ["split.cc"],
+ hdrs = ["split.h"],
+ deps = [
+ "//tensorflow/lite/delegates/gpu/common:operations",
+ "//tensorflow/lite/delegates/gpu/common:status",
+ "//tensorflow/lite/delegates/gpu/common:types",
+ "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
+ "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
+ ],
+)
+
+cc_library(
name = "strided_slice",
srcs = ["strided_slice.cc"],
hdrs = ["strided_slice.h"],
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_metal.cc
similarity index 70%
rename from tensorflow/lite/delegates/gpu/metal/kernels/conv.cc
rename to tensorflow/lite/delegates/gpu/common/tasks/conv_metal.cc
index e736424..bd4e63a 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_metal.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h"
#include <cmath>
#include <cstdint>
@@ -28,14 +28,15 @@
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
namespace tflite {
namespace gpu {
-namespace metal {
namespace {
@@ -154,8 +155,9 @@
return c;
}
-std::string GenerateConvolution(const ConvolutionGeneric::ConvParams& params,
- const OperationDef& definition) {
+std::string GenerateConvolution(const ConvolutionMetal::ConvParams& params,
+ const OperationDef& definition,
+ bool stride_correction) {
GlobalIdsParams ids_params;
ids_params.group_ids = {"group_id.x", "group_id.y", "group_id.z"};
ids_params.global_ids = {"ugid.x", "ugid.y", "ugid.z"};
@@ -170,33 +172,31 @@
std::string addr_space =
params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::CONSTANT_MEM
+ ConvolutionMetal::WeightsUploadType::CONSTANT_MEM
? "constant"
: "device";
const bool use_local_mem =
params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS;
+ ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
const int local_mem_size =
params.block_size.z * 4 * params.src_depth_loop_size;
const bool use_simd_broadcast =
params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
+ ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
+ ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST;
+ ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST;
int simd_size = 1;
if (params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
+ ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
simd_size = 8;
- } else if (params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::
- PRIVATE_MEM_SIMD16_BROADCAST) {
+ } else if (params.weights_upload_type == ConvolutionMetal::WeightsUploadType::
+ PRIVATE_MEM_SIMD16_BROADCAST) {
simd_size = 16;
- } else if (params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::
- PRIVATE_MEM_SIMD32_BROADCAST) {
+ } else if (params.weights_upload_type == ConvolutionMetal::WeightsUploadType::
+ PRIVATE_MEM_SIMD32_BROADCAST) {
simd_size = 32;
}
@@ -204,6 +204,15 @@
!params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
params.y_kernel_is_1;
+ const auto src_storage_type = definition.src_tensors[0].storage_type;
+ const auto dst_storage_type = definition.dst_tensors[0].storage_type;
+ const bool src_is_linear =
+ src_storage_type == TensorStorageType::BUFFER ||
+ src_storage_type == TensorStorageType::IMAGE_BUFFER;
+ const bool dst_is_linear =
+ dst_storage_type == TensorStorageType::BUFFER ||
+ dst_storage_type == TensorStorageType::IMAGE_BUFFER;
+
std::string channels[4] = {"x", "y", "z", "w"};
std::string c;
c.reserve(16 * 1024); // Reserve large enough buffer.
@@ -274,8 +283,15 @@
if (!params.x_kernel_is_1) {
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_x = std::to_string(x);
- c += " int x" + s_x + " = (X + " + s_x +
- ") * args.stride_x + args.padding_x;\n";
+ if (stride_correction) {
+ c += " int x" + s_x + " = " +
+ GetXStrideCorrected("(X + " + s_x + ")", "args.src_tensor.Batch()",
+ "args.stride_x", "args.padding_x") +
+ ";\n";
+ } else {
+ c += " int x" + s_x + " = (X + " + s_x +
+ ") * args.stride_x + args.padding_x;\n";
+ }
}
}
if (!params.y_kernel_is_1) {
@@ -295,10 +311,12 @@
for (int y = 0; y < params.block_size.y; ++y) {
const std::string s_y = std::to_string(y);
c += " int c_y" + s_y + " = y * args.dilation_y + y" + s_y + ";\n";
- c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y +
- " >= args.src_tensor.Height();\n";
- c += " c_y" + s_y + " = clamp(c_y" + s_y +
- ", 0, args.src_tensor.Height() - 1);\n";
+ if (src_is_linear) {
+ c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y +
+ " >= args.src_tensor.Height();\n";
+ c += " c_y" + s_y + " = clamp(c_y" + s_y +
+ ", 0, args.src_tensor.Height() - 1);\n";
+ }
}
} else {
for (int y = 0; y < params.block_size.y; ++y) {
@@ -313,10 +331,12 @@
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_x = std::to_string(x);
c += " int c_x" + s_x + " = x * args.dilation_x + x" + s_x + ";\n";
- c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x +
- " >= args.src_tensor.Width();\n";
- c += " c_x" + s_x + " = clamp(c_x" + s_x +
- ", 0, args.src_tensor.Width() - 1);\n";
+ if (src_is_linear) {
+ c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x +
+ " >= args.src_tensor.Width();\n";
+ c += " c_x" + s_x + " = clamp(c_x" + s_x +
+ ", 0, args.src_tensor.Width() - 1);\n";
+ }
}
} else {
for (int x = 0; x < params.block_size.x; ++x) {
@@ -325,34 +345,38 @@
", 0, args.src_tensor.Width() - 1);\n";
}
}
- for (int y = 0; y < params.block_size.y; ++y) {
- const std::string s_y = std::to_string(y);
- for (int x = 0; x < params.block_size.x; ++x) {
- const std::string s_x = std::to_string(x);
- const std::string s_yx = s_y + s_x;
- if (!params.y_kernel_is_1 && !params.x_kernel_is_1) {
- c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x + "_out);\n";
- } else if (!params.y_kernel_is_1) {
- c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n";
- } else if (!params.x_kernel_is_1) {
- c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n";
+ if (src_is_linear) {
+ for (int y = 0; y < params.block_size.y; ++y) {
+ const std::string s_y = std::to_string(y);
+ for (int x = 0; x < params.block_size.x; ++x) {
+ const std::string s_x = std::to_string(x);
+ const std::string s_yx = s_y + s_x;
+ if (!params.y_kernel_is_1 && !params.x_kernel_is_1) {
+ c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x +
+ "_out);\n";
+ } else if (!params.y_kernel_is_1) {
+ c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n";
+ } else if (!params.x_kernel_is_1) {
+ c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n";
+ }
}
}
- }
- for (int y = 0; y < params.block_size.y; ++y) {
- const std::string s_y = std::to_string(y);
- for (int x = 0; x < params.block_size.x; ++x) {
- const std::string s_x = std::to_string(x);
- const std::string s_yx = s_y + s_x;
- if (definition.src_tensors[0].storage_type == TensorStorageType::BUFFER) {
- c +=
- " device FLT4* src_loc_" + s_yx +
- " = args.src_tensor.GetHandle() + args.src_tensor.GetWHOffset(c_x" +
- s_x + ", c_y" + s_y + ");\n";
- } else if (definition.src_tensors[0].storage_type ==
- TensorStorageType::IMAGE_BUFFER) {
- c += " int src_loc_" + s_yx + " = args.src_tensor.GetWHOffset(c_x" +
- s_x + ", c_y" + s_y + ");\n";
+ for (int y = 0; y < params.block_size.y; ++y) {
+ const std::string s_y = std::to_string(y);
+ for (int x = 0; x < params.block_size.x; ++x) {
+ const std::string s_x = std::to_string(x);
+ const std::string s_yx = s_y + s_x;
+ if (definition.src_tensors[0].storage_type ==
+ TensorStorageType::BUFFER) {
+ c += " device FLT4* src_loc_" + s_yx +
+ " = args.src_tensor.GetHandle() + "
+ "args.src_tensor.GetWHOffset(c_x" +
+ s_x + ", c_y" + s_y + ");\n";
+ } else if (definition.src_tensors[0].storage_type ==
+ TensorStorageType::IMAGE_BUFFER) {
+ c += " int src_loc_" + s_yx + " = args.src_tensor.GetWHOffset(c_x" +
+ s_x + ", c_y" + s_y + ");\n";
+ }
}
}
}
@@ -396,30 +420,37 @@
for (int y = 0; y < params.block_size.y; ++y) {
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_yx = std::to_string(y) + std::to_string(x);
- if (definition.src_tensors[0].storage_type ==
- TensorStorageType::BUFFER) {
- if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
- c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx +
- ";\n";
- } else {
- c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n";
+ if (src_is_linear) {
+ if (definition.src_tensors[0].storage_type ==
+ TensorStorageType::BUFFER) {
+ if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
+ c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx +
+ ";\n";
+ } else {
+ c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n";
+ }
+ } else if (definition.src_tensors[0].storage_type ==
+ TensorStorageType::IMAGE_BUFFER) {
+ if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
+ c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" +
+ s_yx + ") * m" + s_yx + ";\n";
+ } else {
+ c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" +
+ s_yx + ");\n";
+ }
}
- } else if (definition.src_tensors[0].storage_type ==
- TensorStorageType::IMAGE_BUFFER) {
- if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
- c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" + s_yx +
- ") * m" + s_yx + ";\n";
- } else {
- c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" + s_yx +
- ");\n";
- }
+ } else {
+ c += " src" + s_yx + " = args.src_tensor.Read(c_x" +
+ std::to_string(x) + ", c_y" + std::to_string(y) + ", s);\n";
}
}
}
- for (int y = 0; y < params.block_size.y; ++y) {
- for (int x = 0; x < params.block_size.x; ++x) {
- const std::string s_yx = std::to_string(y) + std::to_string(x);
- c += " src_loc_" + s_yx + " += args.src_tensor.SliceStride();\n";
+ if (src_is_linear) {
+ for (int y = 0; y < params.block_size.y; ++y) {
+ for (int x = 0; x < params.block_size.x; ++x) {
+ const std::string s_yx = std::to_string(y) + std::to_string(x);
+ c += " src_loc_" + s_yx + " += args.src_tensor.SliceStride();\n";
+ }
}
}
};
@@ -445,8 +476,7 @@
}
std::string s_val = "src" + s_id;
std::string r_val = "r" + r_id;
- if (params.weight_layout ==
- ConvolutionGeneric::WeightsInnerBlockLayout::O4I4) {
+ if (params.weights_layout == WeightsLayout::kOHWIOGroupO4I4) {
c += " " + r_val + "." + channels[ch] + " += dot(" + f_val +
", " + s_val + ");\n";
} else { // WeightsInnerBlockLayout::I404
@@ -492,11 +522,13 @@
"return;\n";
}
- for_every_yx([](const std::string& s_yx, const std::string& s_x,
- const std::string& s_y, int x, int y) {
- return " args.dst_tensor.GetAddress(offset_" + s_yx + ", X + " + s_x +
- ", Y + " + s_y + ", Z);";
- });
+ if (dst_is_linear) {
+ for_every_yx([](const std::string& s_yx, const std::string& s_x,
+ const std::string& s_y, int x, int y) {
+ return " args.dst_tensor.GetAddress(offset_" + s_yx + ", X + " + s_x +
+ ", Y + " + s_y + ", Z);";
+ });
+ }
std::string bias_name = "args.biases.GetPtr()";
if (params.need_dst_loop) {
@@ -538,11 +570,16 @@
c += " {\n";
}
c += " FLT4 value = FLT4(r" + s_zyx + ");\n";
- c += " int linear_index = offset_" + s_yx +
- " + args.dst_tensor.SliceStride() * " + s_z + ";\n";
- c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
- s_y + ", Z + " + s_z + ");\n";
- c += " args.dst_tensor.WriteLinear(value, linear_index);\n";
+ if (dst_is_linear) {
+ c += " int linear_index = offset_" + s_yx +
+ " + args.dst_tensor.SliceStride() * " + s_z + ";\n";
+ c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
+ s_y + ", Z + " + s_z + ");\n";
+ c += " args.dst_tensor.WriteLinear(value, linear_index);\n";
+ } else {
+ c += " args.dst_tensor.Write(value, X + " + s_x + ", Y + " +
+ s_y + ", Z + " + s_z + ");\n";
+ }
c += " }\n";
}
}
@@ -552,50 +589,32 @@
return c;
}
-std::vector<float> ReorderWeightsForConv(
+std::vector<uint8_t> ReorderWeightsForConv(
const tflite::gpu::Tensor<OHWI, DataType::FLOAT32>& weights,
- const ConvolutionGeneric::ConvParams& params) {
- const int dst_depth = DivideRoundUp(weights.shape.o, 4);
- const int src_depth = DivideRoundUp(weights.shape.i, 4);
- std::vector<float> weights_reordered(
- weights.shape.w * weights.shape.h *
- AlignByN(dst_depth, params.block_size.z) * 4 * src_depth * 4);
+ const WeightsDescription& weights_desc, const DataType& weights_type) {
+ const int flt_count =
+ GetTotalElementsCountForLayout(weights_desc, weights.shape);
+ std::vector<uint8_t> result(flt_count * SizeOf(weights_type));
+ RearrangeWeights(weights, weights_desc, weights_type, absl::MakeSpan(result));
+ return result;
+}
- bool isO4I4 =
- params.weight_layout == ConvolutionGeneric::WeightsInnerBlockLayout::O4I4;
-
- int counter = 0;
- for (int d = 0; d < DivideRoundUp(dst_depth, params.block_size.z); ++d) {
- for (int y = 0; y < weights.shape.h; ++y) {
- for (int x = 0; x < weights.shape.w; ++x) {
- for (int s = 0; s < src_depth; ++s) {
- for (int k = 0; k < params.block_size.z; ++k) {
- for (int j = 0; j < 4; ++j) {
- for (int i = 0; i < 4; ++i) {
- int src_ch;
- int dst_ch;
- if (isO4I4) {
- src_ch = s * 4 + i;
- dst_ch = (d * params.block_size.z + k) * 4 + j;
- } else {
- src_ch = s * 4 + j;
- dst_ch = (d * params.block_size.z + k) * 4 + i;
- }
- if (src_ch >= weights.shape.i || dst_ch >= weights.shape.o) {
- weights_reordered[counter++] = 0.0f;
- } else {
- const size_t f_index =
- weights.shape.LinearIndex({dst_ch, y, x, src_ch});
- weights_reordered[counter++] = weights.data[f_index];
- }
- }
- }
- }
- }
- }
+std::vector<uint8_t> ReorderBiasesForConv(
+ const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
+ const DataType& biases_type, int output_size) {
+ std::vector<uint8_t> result(output_size * SizeOf(biases_type));
+ if (biases_type == DataType::FLOAT32) {
+ float* gpu_data = reinterpret_cast<float*>(result.data());
+ for (int i = 0; i < output_size; ++i) {
+ gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
+ }
+ } else {
+ half* gpu_data = reinterpret_cast<half*>(result.data());
+ for (int i = 0; i < output_size; ++i) {
+ gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
}
}
- return weights_reordered;
+ return result;
}
int GetGroupsCount(const BHWC& dst_shape, const int3& wg_size,
@@ -669,15 +688,15 @@
}
}
-ConvolutionGeneric::ConvParams GetConvParamsForA7A8(
+ConvolutionMetal::ConvParams GetConvParamsForA7A8(
const AppleInfo& apple_info, const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
const int dst_slices = DivideRoundUp(dst_shape.c, 4);
const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
- ConvolutionGeneric::ConvParams params;
+ ConvolutionMetal::ConvParams params;
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS;
+ ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
params.x_kernel_is_1 = IsKernelXIs1(attr);
params.y_kernel_is_1 = IsKernelYIs1(attr);
params.src_depth_loop_size = 1;
@@ -685,7 +704,7 @@
params.linear_wh = false;
params.linear_whs = false;
params.work_group_launch_order = int3(0, 1, 2);
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::O4I4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
@@ -729,7 +748,7 @@
params.linear_whs = true;
params.work_group_size = int3(32, 1, 1);
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::GLOBAL_MEM;
+ ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
}
if (params.src_depth_loop_size == src_slices) {
@@ -743,13 +762,13 @@
params.y_kernel_is_1;
if (use_filters_constants) {
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::CONSTANT_MEM;
+ ConvolutionMetal::WeightsUploadType::CONSTANT_MEM;
}
return params;
}
-ConvolutionGeneric::ConvParams GetConvParamsForA9AndHigher(
+ConvolutionMetal::ConvParams GetConvParamsForA9AndHigher(
const AppleInfo& apple_info, const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
const int dst_slices = DivideRoundUp(dst_shape.c, 4);
@@ -776,9 +795,8 @@
blk_total_size /= 4;
}
- ConvolutionGeneric::ConvParams params;
- params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::GLOBAL_MEM;
+ ConvolutionMetal::ConvParams params;
+ params.weights_upload_type = ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
params.x_kernel_is_1 = IsKernelXIs1(attr);
params.y_kernel_is_1 = IsKernelYIs1(attr);
params.src_depth_loop_size = 1;
@@ -787,7 +805,7 @@
params.linear_whs = false;
params.work_group_size = int3(8, 4, 1);
params.work_group_launch_order = int3(2, 0, 1);
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::O4I4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
int g1 = GetGroupsCount(dst_shape, {8, 4, 1}, block_size);
int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, block_size);
int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, block_size);
@@ -827,20 +845,20 @@
params.y_kernel_is_1;
if (use_filters_constants) {
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::CONSTANT_MEM;
+ ConvolutionMetal::WeightsUploadType::CONSTANT_MEM;
}
return params;
}
-ConvolutionGeneric::ConvParams GetConvParamsForIntel(
+ConvolutionMetal::ConvParams GetConvParamsForIntel(
const Convolution2DAttributes& attr, CalculationsPrecision precision,
const BHWC& dst_shape) {
const int dst_slices = DivideRoundUp(dst_shape.c, 4);
const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
- ConvolutionGeneric::ConvParams params;
+ ConvolutionMetal::ConvParams params;
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST;
+ ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST;
params.x_kernel_is_1 = IsKernelXIs1(attr);
params.y_kernel_is_1 = IsKernelYIs1(attr);
params.src_depth_loop_size = 1;
@@ -855,9 +873,9 @@
}
params.work_group_size = int3(8, 2, 1);
if (precision == CalculationsPrecision::F32_F16) {
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::O4I4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
} else {
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::I4O4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
}
if (src_slices % 2 == 0) {
@@ -876,10 +894,10 @@
return params;
}
-ConvolutionGeneric::ConvParams GetConvParamsForAMD(
+ConvolutionMetal::ConvParams GetConvParamsForAMD(
const Convolution2DAttributes& attr, CalculationsPrecision precision,
const BHWC& dst_shape) {
- ConvolutionGeneric::ConvParams params;
+ ConvolutionMetal::ConvParams params;
params.block_size = int3(1, 1, 4);
params.work_group_size = int3(8, 4, 1);
params.work_group_launch_order = int3(2, 0, 1);
@@ -888,22 +906,22 @@
params.need_dst_loop = true;
params.linear_wh = false;
params.linear_whs = false;
- params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::GLOBAL_MEM;
+ params.weights_upload_type = ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
params.different_weights_for_height = false;
params.x_kernel_is_1 = IsKernelXIs1(attr);
params.y_kernel_is_1 = IsKernelYIs1(attr);
if (precision == CalculationsPrecision::F32_F16) {
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::O4I4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
} else {
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::I4O4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
}
return params;
}
-ConvolutionGeneric::ConvParams GetConvParams(
- const GpuInfo& gpu_info, const Convolution2DAttributes& attr,
- CalculationsPrecision precision, const BHWC& dst_shape) {
+ConvolutionMetal::ConvParams GetConvParams(const GpuInfo& gpu_info,
+ const Convolution2DAttributes& attr,
+ CalculationsPrecision precision,
+ const BHWC& dst_shape) {
if (gpu_info.IsApple()) {
if (gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
return GetConvParamsForA7A8(gpu_info.apple_info, attr, dst_shape);
@@ -915,7 +933,7 @@
} else if (gpu_info.IsAMD()) {
return GetConvParamsForAMD(attr, precision, dst_shape);
} else {
- ConvolutionGeneric::ConvParams params;
+ ConvolutionMetal::ConvParams params;
params.block_size = int3(1, 1, 4);
params.work_group_size = int3(8, 4, 1);
params.work_group_launch_order = int3(2, 0, 1);
@@ -925,27 +943,31 @@
params.linear_wh = false;
params.linear_whs = false;
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::GLOBAL_MEM;
+ ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
params.different_weights_for_height = false;
params.x_kernel_is_1 = IsKernelXIs1(attr);
params.y_kernel_is_1 = IsKernelYIs1(attr);
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::O4I4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
return params;
}
}
} // namespace
-absl::Status ConvolutionGeneric::BindArguments(ArgumentsBinder* args) {
- const int grid_x = DivideRoundUp(dst_[0]->Width(), params_.block_size.x);
+absl::Status ConvolutionMetal::BindArguments(ArgumentsBinder* args) {
+ RETURN_IF_ERROR(args->SetInt("padding_x", padding_.x * src_[0]->Batch()));
+ RETURN_IF_ERROR(args->SetInt("dilation_x", dilation_.x * src_[0]->Batch()));
+ const int grid_x =
+ DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), params_.block_size.x);
const int grid_y = DivideRoundUp(dst_[0]->Height(), params_.block_size.y);
RETURN_IF_ERROR(args->SetInt("task_size_x", grid_x));
RETURN_IF_ERROR(args->SetInt("task_size_y", grid_x * grid_y));
return absl::OkStatus();
}
-int3 ConvolutionGeneric::GetGridSize() const {
- int grid_x = DivideRoundUp(dst_[0]->Width(), params_.block_size.x);
+int3 ConvolutionMetal::GetGridSize() const {
+ int grid_x =
+ DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), params_.block_size.x);
int grid_y = DivideRoundUp(dst_[0]->Height(), params_.block_size.y);
int grid_z = DivideRoundUp(dst_[0]->Slices(), params_.block_size.z);
@@ -961,18 +983,30 @@
}
}
-ConvolutionGeneric CreateConvolutionGeneric(const OperationDef& definition,
- const BHWC& dst_shape,
- const Convolution2DAttributes& attr,
- const GpuInfo& gpu_info) {
- ConvolutionGeneric::ConvParams params =
- GetConvParams(gpu_info, attr, definition.precision, dst_shape);
+ConvolutionMetal CreateConvolutionMetal(const OperationDef& definition,
+ const BHWC& dst_shape,
+ const Convolution2DAttributes& attr,
+ const GpuInfo& gpu_info) {
+ BHWC new_shape = BHWC(1, dst_shape.h, dst_shape.w * dst_shape.b, dst_shape.c);
+ ConvolutionMetal::ConvParams params =
+ GetConvParams(gpu_info, attr, definition.precision, new_shape);
- ConvolutionGeneric desc(definition);
+ ConvolutionMetal desc(definition);
desc.params_ = params;
- desc.code_ = GenerateConvolution(params, definition);
- desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
- desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+ const bool stride_correction =
+ definition.IsBatchSupported() && attr.strides.w != 1;
+ desc.code_ = GenerateConvolution(params, definition, stride_correction);
+
+ auto src_desc = definition.src_tensors[0];
+ if (definition.IsBatchSupported()) {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ desc.AddSrcTensor("src_tensor", src_desc);
+ auto dst_desc = definition.dst_tensors[0];
+ if (definition.IsBatchSupported()) {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ desc.AddDstTensor("dst_tensor", dst_desc);
desc.args_.AddInt("kernel_size_x", attr.weights.shape.w);
desc.args_.AddInt("kernel_size_y", attr.weights.shape.h);
@@ -982,32 +1016,43 @@
desc.args_.AddInt("stride_y", attr.strides.h);
desc.args_.AddInt("padding_x", -attr.padding.prepended.w);
desc.args_.AddInt("padding_y", -attr.padding.prepended.h);
+ desc.padding_ = int2(-attr.padding.prepended.w, -attr.padding.prepended.h);
+ desc.dilation_ = int2(attr.dilations.w, attr.dilations.h);
- auto weights_reordered = ReorderWeightsForConv(attr.weights, params);
- auto data_type = DeduceDataTypeFromPrecision(definition.precision);
- const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+ auto weights_type = DeduceDataTypeFromPrecision(definition.precision);
MemoryType mem_type =
params.weights_upload_type ==
- ConvolutionGeneric::WeightsUploadType::CONSTANT_MEM
+ ConvolutionMetal::WeightsUploadType::CONSTANT_MEM
? MemoryType::CONSTANT
: MemoryType::GLOBAL;
- BufferDescriptor weights_desc;
- weights_desc.element_type = data_type;
- weights_desc.element_size = 4;
- weights_desc.memory_type = mem_type;
- weights_desc.data = GetByteBufferConverted(weights_reordered, data_type);
- weights_desc.size = weights_desc.data.size();
- desc.args_.AddObject(
- "weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
+ if (definition.src_tensors.size() == 2) {
+ // dynamic weights
+ BufferDescriptor weights_desc;
+ weights_desc.element_type = definition.src_tensors[1].data_type;
+ weights_desc.element_size = 4;
+ weights_desc.memory_type = mem_type;
+ desc.AddSrcBuffer("weights", weights_desc);
+ } else {
+ BufferDescriptor weights_desc;
+ weights_desc.element_type = weights_type;
+ weights_desc.element_size = 4;
+ weights_desc.memory_type = mem_type;
+ weights_desc.data = ReorderWeightsForConv(
+ attr.weights, desc.GetWeightsDescription(), weights_type);
+ weights_desc.size = weights_desc.data.size();
+ desc.args_.AddObject("weights", absl::make_unique<BufferDescriptor>(
+ std::move(weights_desc)));
+ }
BufferDescriptor bias_desc;
- bias_desc.element_type = data_type;
+ bias_desc.element_type = weights_type;
bias_desc.element_size = 4;
bias_desc.memory_type = mem_type;
- bias_desc.data = GetByteBufferConvertedResized(
- attr.bias.data, data_type, AlignByN(dst_depth, params.block_size.z) * 4);
+ bias_desc.data = ReorderBiasesForConv(
+ attr.bias, weights_type,
+ AlignByN(attr.weights.shape.o, params.block_size.z * 4));
bias_desc.size = bias_desc.data.size();
desc.args_.AddObject(
"biases", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
@@ -1028,10 +1073,10 @@
return desc;
}
-ConvolutionGeneric CreateConvolutionWino4x4To6x6(
+ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
const OperationDef& definition, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const GpuInfo& gpu_info) {
- ConvolutionGeneric::ConvParams params;
+ ConvolutionMetal::ConvParams params;
params.work_group_launch_order = int3(2, 0, 1);
params.src_depth_loop_size = 1;
params.need_src_loop = true;
@@ -1042,43 +1087,51 @@
params.x_kernel_is_1 = true;
params.y_kernel_is_1 = true;
if (gpu_info.IsApple()) {
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::O4I4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
if (gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS;
+ ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
params.work_group_size = int3(32, 1, 1);
params.block_size = int3(4, 1, 4);
} else {
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::GLOBAL_MEM;
+ ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
params.work_group_size = int3(8, 4, 1);
params.block_size = int3(4, 1, 4);
}
} else if (gpu_info.IsIntel()) {
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::I4O4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST;
+ ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST;
params.work_group_size = int3(16, 1, 1);
params.block_size = int3(1, 1, 4);
} else if (gpu_info.IsAMD()) {
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::I4O4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::GLOBAL_MEM;
+ ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
params.work_group_size = int3(32, 1, 1);
params.block_size = int3(2, 1, 4);
} else {
- params.weight_layout = ConvolutionGeneric::WeightsInnerBlockLayout::I4O4;
+ params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
params.weights_upload_type =
- ConvolutionGeneric::WeightsUploadType::GLOBAL_MEM;
+ ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
params.work_group_size = int3(32, 1, 1);
params.block_size = int3(2, 1, 4);
}
- ConvolutionGeneric desc(definition);
+ ConvolutionMetal desc(definition);
desc.params_ = params;
- desc.code_ = GenerateConvolution(params, definition);
- desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
- desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+ desc.code_ = GenerateConvolution(params, definition, false);
+ auto src_desc = definition.src_tensors[0];
+ if (definition.IsBatchSupported()) {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ desc.AddSrcTensor("src_tensor", src_desc);
+ auto dst_desc = definition.dst_tensors[0];
+ if (definition.IsBatchSupported()) {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ desc.AddDstTensor("dst_tensor", dst_desc);
desc.args_.AddInt("kernel_size_x", 1);
desc.args_.AddInt("kernel_size_y", 1);
@@ -1088,28 +1141,32 @@
desc.args_.AddInt("stride_y", 1);
desc.args_.AddInt("padding_x", 0);
desc.args_.AddInt("padding_y", 0);
+ desc.padding_ = int2(0, 0);
+ desc.dilation_ = int2(1, 1);
- ::tflite::gpu::Tensor<OHWI, DataType::FLOAT32> wino_weights;
+ auto weights_type = DeduceDataTypeFromPrecision(definition.precision);
+
+ tflite::gpu::Tensor<OHWI, DataType::FLOAT32> wino_weights;
+ tflite::gpu::Tensor<Linear, DataType::FLOAT32> wino_biases;
RearrangeWeightsToWinograd4x4To6x6Weights(attr.weights, &wino_weights);
- auto weights_reordered = ReorderWeightsForConv(wino_weights, params);
- const int dst_slices = DivideRoundUp(attr.weights.shape.o, 4);
- std::vector<float> dummy_biases(AlignByN(dst_slices, params.block_size.z) * 4,
- 0.0f);
-
- auto data_type = DeduceDataTypeFromPrecision(definition.precision);
+ wino_biases.shape = Linear(attr.weights.shape.o);
+ wino_biases.data.resize(attr.weights.shape.o, 0.0f);
BufferDescriptor weights_desc;
- weights_desc.element_type = data_type;
+ weights_desc.element_type = weights_type;
weights_desc.element_size = 4;
- weights_desc.data = GetByteBufferConverted(weights_reordered, data_type);
+ weights_desc.data = ReorderWeightsForConv(
+ wino_weights, desc.GetWeightsDescription(), weights_type);
weights_desc.size = weights_desc.data.size();
desc.args_.AddObject(
"weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
BufferDescriptor bias_desc;
- bias_desc.element_type = data_type;
+ bias_desc.element_type = weights_type;
bias_desc.element_size = 4;
- bias_desc.data = GetByteBufferConverted(dummy_biases, data_type);
+ bias_desc.data = ReorderBiasesForConv(
+ wino_biases, weights_type,
+ AlignByN(attr.weights.shape.o, params.block_size.z * 4));
bias_desc.size = bias_desc.data.size();
desc.args_.AddObject(
"biases", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
@@ -1130,6 +1187,9 @@
return desc;
}
-} // namespace metal
+bool IsConvolutionMetalSupported(const OperationDef& definition) {
+ return !definition.src_tensors[0].HasAxis(Axis::DEPTH);
+}
+
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h b/tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h
similarity index 61%
rename from tensorflow/lite/delegates/gpu/metal/kernels/conv.h
rename to tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h
index a44a664..7fb440f 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h
@@ -13,20 +13,20 @@
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONV_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONV_H_
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_METAL_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_METAL_H_
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
namespace tflite {
namespace gpu {
-namespace metal {
-class ConvolutionGeneric : public GPUOperation {
+class ConvolutionMetal : public GPUOperation {
public:
enum class WeightsUploadType {
PRIVATE_MEM_SIMD8_BROADCAST,
@@ -37,11 +37,6 @@
CONSTANT_MEM,
};
- enum class WeightsInnerBlockLayout {
- O4I4,
- I4O4,
- };
-
struct ConvParams {
int3 block_size;
int3 work_group_size;
@@ -52,13 +47,13 @@
bool linear_wh;
bool linear_whs;
WeightsUploadType weights_upload_type;
- WeightsInnerBlockLayout weight_layout;
+ WeightsLayout weights_layout;
bool different_weights_for_height = false;
bool x_kernel_is_1;
bool y_kernel_is_1;
};
- ConvolutionGeneric() = default;
+ ConvolutionMetal() = default;
void GetPossibleKernelWorkGroups(
TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info,
@@ -69,36 +64,46 @@
absl::Status BindArguments(ArgumentsBinder* args) override;
// Move only
- ConvolutionGeneric(ConvolutionGeneric&& kernel) = default;
- ConvolutionGeneric& operator=(ConvolutionGeneric&& kernel) = default;
- ConvolutionGeneric(const ConvolutionGeneric&) = delete;
- ConvolutionGeneric& operator=(const ConvolutionGeneric&) = delete;
+ ConvolutionMetal(ConvolutionMetal&& kernel) = default;
+ ConvolutionMetal& operator=(ConvolutionMetal&& kernel) = default;
+ ConvolutionMetal(const ConvolutionMetal&) = delete;
+ ConvolutionMetal& operator=(const ConvolutionMetal&) = delete;
+
+ WeightsDescription GetWeightsDescription() const {
+ WeightsDescription desc;
+ desc.layout = params_.weights_layout;
+ desc.output_group_size = params_.block_size.z;
+ return desc;
+ }
private:
- explicit ConvolutionGeneric(const OperationDef& definition)
+ explicit ConvolutionMetal(const OperationDef& definition)
: GPUOperation(definition) {}
- friend ConvolutionGeneric CreateConvolutionGeneric(
+ friend ConvolutionMetal CreateConvolutionMetal(
const OperationDef& definition, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const GpuInfo& gpu_info);
- friend ConvolutionGeneric CreateConvolutionWino4x4To6x6(
+ friend ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
const OperationDef& definition, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const GpuInfo& gpu_info);
+ int2 padding_;
+ int2 dilation_;
ConvParams params_;
};
-ConvolutionGeneric CreateConvolutionGeneric(const OperationDef& definition,
- const BHWC& dst_shape,
- const Convolution2DAttributes& attr,
- const GpuInfo& gpu_info);
+ConvolutionMetal CreateConvolutionMetal(const OperationDef& definition,
+ const BHWC& dst_shape,
+ const Convolution2DAttributes& attr,
+ const GpuInfo& gpu_info);
-ConvolutionGeneric CreateConvolutionWino4x4To6x6(
+ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
const OperationDef& definition, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const GpuInfo& gpu_info);
-} // namespace metal
+bool IsConvolutionMetalSupported(const OperationDef& definition);
+
} // namespace gpu
} // namespace tflite
-#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONV_H_
+#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_METAL_H_
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.cc
index d48101c..7ff683e 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.cc
@@ -72,13 +72,33 @@
std::string GenerateBlockCoords(const int4& block_size,
const int3& work_group_launch_order,
- bool linear_spatial, bool need_depth) {
+ bool linear_spatial, bool linear_all,
+ bool need_depth) {
std::string c;
int3 launch_remap;
launch_remap[work_group_launch_order.x] = 0;
launch_remap[work_group_launch_order.y] = 1;
launch_remap[work_group_launch_order.z] = 2;
- if (linear_spatial) {
+ if (linear_all) {
+ c += " int linear_id = GLOBAL_ID_0;\n";
+ c += " int DST_S = (linear_id / args.task_size_spatial) * " +
+ std::to_string(block_size.w) + ";\n";
+ c += " int linear_spatial = linear_id % args.task_size_spatial;\n";
+ if (need_depth) {
+ c += " int DST_X = (linear_spatial % args.task_size_x) * " +
+ std::to_string(block_size.x) + ";\n";
+ c += " linear_spatial = linear_spatial / args.task_size_x;\n";
+ c += " int DST_Y = (linear_spatial % args.task_size_y) * " +
+ std::to_string(block_size.y) + ";\n";
+ c += " int DST_Z = (linear_spatial / args.task_size_y) * " +
+ std::to_string(block_size.z) + ";\n";
+ } else {
+ c += " int DST_Y = (linear_spatial / args.task_size_x) * " +
+ std::to_string(block_size.y) + ";\n";
+ c += " int DST_X = (linear_spatial % args.task_size_x) * " +
+ std::to_string(block_size.x) + ";\n";
+ }
+ } else if (linear_spatial) {
if (work_group_launch_order[0] == 0) {
c += " int linear_spatial = GLOBAL_ID_0;\n";
} else {
@@ -219,7 +239,9 @@
}
void ConvPowerVR::GenerateCode(const GpuInfo& gpu_info) {
- if (conv_params_.linear_spatial) {
+ if (conv_params_.linear_all) {
+ grid_dimension_ = 1;
+ } else if (conv_params_.linear_spatial) {
grid_dimension_ = 2;
}
const bool stride_correction =
@@ -264,16 +286,16 @@
RETURN_IF_ERROR(args->SetInt("kernel_size_z", kernel_size_.z));
RETURN_IF_ERROR(args->SetInt("dilation_z", dilation_.z));
}
- if (conv_params_.linear_spatial) {
- const int grid_x = DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(),
- conv_params_.block_size.x);
- RETURN_IF_ERROR(args->SetInt("task_size_x", grid_x));
- }
- if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
- const int task_size_y =
- DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
- RETURN_IF_ERROR(args->SetInt("task_size_y", task_size_y));
- }
+ const int task_size_x = DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(),
+ conv_params_.block_size.x);
+ const int task_size_y =
+ DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
+ const int task_size_z =
+ DivideRoundUp(dst_[0]->Depth(), conv_params_.block_size.z);
+ RETURN_IF_ERROR(args->SetInt("task_size_x", task_size_x));
+ RETURN_IF_ERROR(args->SetInt("task_size_y", task_size_y));
+ const int task_size_spatial = task_size_x * task_size_y * task_size_z;
+ RETURN_IF_ERROR(args->SetInt("task_size_spatial", task_size_spatial));
return absl::OkStatus();
}
@@ -288,18 +310,12 @@
DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w);
int3 wg;
- if (conv_params_.linear_spatial) {
- int grid_x = task_size_x * task_size_y;
- if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
- grid_x *= task_size_z;
- }
- return int3(grid_x, task_size_s, 1);
+ if (conv_params_.linear_all) {
+ return int3(task_size_x * task_size_y * task_size_z * task_size_s, 1, 1);
+ } else if (conv_params_.linear_spatial) {
+ return int3(task_size_x * task_size_y * task_size_z, task_size_s, 1);
} else {
- int grid_y = task_size_y;
- if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
- grid_y *= task_size_z;
- }
- return int3(task_size_x, grid_y, task_size_s);
+ return int3(task_size_x, task_size_y * task_size_z, task_size_s);
}
}
@@ -409,12 +425,9 @@
args_.AddInt("kernel_size_z");
args_.AddInt("dilation_z");
}
- if (conv_params_.linear_spatial) {
- args_.AddInt("task_size_x");
- }
- if (src_def.HasAxis(Axis::DEPTH)) {
- args_.AddInt("task_size_y");
- }
+ args_.AddInt("task_size_x");
+ args_.AddInt("task_size_y");
+ args_.AddInt("task_size_spatial");
const int wg_total_size =
work_group_size_.x * work_group_size_.y * work_group_size_.z;
@@ -470,7 +483,9 @@
}
std::string dst_oob_check;
if (src_def.HasAxis(Axis::DEPTH)) {
- if (conv_params.linear_spatial) {
+ if (conv_params.linear_all) {
+ dst_oob_check = "DST_S >= args.dst_tensor.Slices()";
+ } else if (conv_params.linear_spatial) {
dst_oob_check =
"DST_Z >= args.dst_tensor.Depth() || DST_S >= "
"args.dst_tensor.Slices()";
@@ -480,7 +495,9 @@
"args.dst_tensor.Depth() || DST_S >= args.dst_tensor.Slices()";
}
} else {
- if (conv_params.linear_spatial) {
+ if (conv_params.linear_all) {
+ dst_oob_check = "DST_S >= args.dst_tensor.Slices()";
+ } else if (conv_params.linear_spatial) {
dst_oob_check =
"DST_Y >= args.dst_tensor.Height() || DST_S >= "
"args.dst_tensor.Slices()";
@@ -492,7 +509,7 @@
}
c += "MAIN_FUNCTION($0) {\n";
c += GenerateBlockCoords(conv_params.block_size, work_group_launch_order_,
- conv_params.linear_spatial,
+ conv_params.linear_spatial, conv_params.linear_all,
src_def.HasAxis(Axis::DEPTH));
if (!late_oob_check) {
c += " if (" + dst_oob_check + ") {\n";
@@ -804,14 +821,25 @@
std::string w_val_w = "SUB_GROUP_BROADCAST(simd_w" +
std::to_string(simd_id) + ".w, " +
std::to_string(thread_id) + "u)";
- c += " " + R + ".x += " + w_val_x + " * " + S + "." +
- channels[ch] + ";\n";
- c += " " + R + ".y += " + w_val_y + " * " + S + "." +
- channels[ch] + ";\n";
- c += " " + R + ".z += " + w_val_z + " * " + S + "." +
- channels[ch] + ";\n";
- c += " " + R + ".w += " + w_val_w + " * " + S + "." +
- channels[ch] + ";\n";
+ if (GetWeightsDescription().IsI4O4()) {
+ c += " " + R + ".x += " + w_val_x + " * " + S + "." +
+ channels[ch] + ";\n";
+ c += " " + R + ".y += " + w_val_y + " * " + S + "." +
+ channels[ch] + ";\n";
+ c += " " + R + ".z += " + w_val_z + " * " + S + "." +
+ channels[ch] + ";\n";
+ c += " " + R + ".w += " + w_val_w + " * " + S + "." +
+ channels[ch] + ";\n";
+ } else {
+ c += " " + R + "." + channels[ch] + " += " + w_val_x +
+ " * " + S + ".x;\n";
+ c += " " + R + "." + channels[ch] + " += " + w_val_y +
+ " * " + S + ".y;\n";
+ c += " " + R + "." + channels[ch] + " += " + w_val_z +
+ " * " + S + ".z;\n";
+ c += " " + R + "." + channels[ch] + " += " + w_val_w +
+ " * " + S + ".w;\n";
+ }
} else {
const std::string weight_id =
std::to_string(s * 4 + ch + shared_offset);
@@ -821,8 +849,13 @@
} else {
w_val = "f" + weight_id;
}
- c += " " + R + " += " + w_val + " * " + S + "." +
- channels[ch] + ";\n";
+ if (GetWeightsDescription().IsI4O4()) {
+ c += " " + R + " += " + w_val + " * " + S + "." +
+ channels[ch] + ";\n";
+ } else {
+ c += " " + R + "." + channels[ch] + " += dot(" + w_val +
+ ", " + S + ");\n";
+ }
}
}
}
@@ -847,9 +880,16 @@
F[i] = "f" + weight_id;
}
}
- c += " " + R + " += TO_ACCUM_TYPE(" + S + ".x * " + F[0] +
- " + " + S + ".y * " + F[1] + " + " + S + ".z * " + F[2] +
- " + " + S + ".w * " + F[3] + ");\n";
+ if (GetWeightsDescription().IsI4O4()) {
+ c += " " + R + " += TO_ACCUM_TYPE(" + S + ".x * " + F[0] +
+ " + " + S + ".y * " + F[1] + " + " + S + ".z * " + F[2] +
+ " + " + S + ".w * " + F[3] + ");\n";
+ } else {
+ c += " " + R + ".x += dot(" + S + ", " + F[0] + ");\n";
+ c += " " + R + ".y += dot(" + S + ", " + F[1] + ");\n";
+ c += " " + R + ".z += dot(" + S + ", " + F[2] + ");\n";
+ c += " " + R + ".w += dot(" + S + ", " + F[3] + ");\n";
+ }
}
}
}
@@ -1021,6 +1061,8 @@
bool different_weights_for_height, const BHWC* dst_shape) {
ConvParams conv_params;
conv_params.linear_spatial = false;
+ conv_params.linear_all = false;
+ conv_params.block_size = int4(1, 1, 1, 1);
conv_params.weights_data_type =
DeduceDataTypeFromPrecision(definition.precision);
conv_params.x_kernel_is_1 = x_kernel_is_1;
@@ -1249,6 +1291,13 @@
if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
conv_params.src_depth_loop_size = 4;
}
+ } else if (gpu_info.IsApple()) {
+ conv_params.block_size = int4(2, 2, 1, 2);
+ work_group_size_ = int3(8, 4, 1);
+ work_group_launch_order_ = int3(0, 1, 2);
+ conv_params.fixed_work_group_size = true;
+ conv_params.src_depth_loop_size = 1;
+ conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
} else {
conv_params.block_size = int4(1, 1, 1, 4);
work_group_size_ = int3(8, 2, 1);
@@ -1270,6 +1319,19 @@
conv_params.src_depth_loop_size = 4;
}
}
+ if (conv_params.AreWeightsBuffer()) {
+ if (gpu_info.IsApple()) {
+ conv_params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
+ } else {
+ conv_params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
+ }
+ } else {
+ if (gpu_info.IsApple()) {
+ conv_params.weights_layout = WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4;
+ } else {
+ conv_params.weights_layout = WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4;
+ }
+ }
return conv_params;
}
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h
index 1f74369..ee3f08a 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h
@@ -49,7 +49,7 @@
WeightsDescription GetWeightsDescription() const {
WeightsDescription desc;
- desc.layout = WeightsLayout::kOHWIOGroupI4O4;
+ desc.layout = conv_params_.weights_layout;
desc.output_group_size = conv_params_.block_size.w;
return desc;
}
@@ -82,12 +82,17 @@
int4 block_size; // WHDS
bool fixed_work_group_size;
bool linear_spatial; // spatial dimensions are Width/Height/Depth
+ bool linear_all; // linear_spatial & linear_all can not be used together,
+ // linear_all can not be used with WeightsUploadTypes
+ // that use workgroups(subgroups) for
+ // uploading(LOCAL_MEM_BY_THREADS for example).
bool different_weights_for_height;
int src_depth_loop_size;
WeightsUploadType weights_upload_type;
bool x_kernel_is_1;
bool y_kernel_is_1;
bool z_kernel_is_1;
+ WeightsLayout weights_layout;
// used only with PRIVATE_MEM_SIMD_BROADCAST
int simd_size = 1;
@@ -248,59 +253,41 @@
template <DataType T>
void ConvPowerVR::UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights) {
- const int dst_slices =
- AlignByN(DivideRoundUp(weights.shape.o, 4), conv_params_.block_size.w);
- const int src_slices = DivideRoundUp(weights.shape.i, 4);
+ const int flt_count =
+ GetTotalElementsCountForLayout(GetWeightsDescription(), weights.shape);
+ DataType weights_type = conv_params_.weights_data_type;
- const bool f32_weights = conv_params_.weights_data_type == DataType::FLOAT32;
- const int float4_size = f32_weights ? sizeof(float4) : sizeof(half4);
+ std::vector<uint8_t> weights_data(flt_count * SizeOf(weights_type));
+ RearrangeWeights(weights, GetWeightsDescription(), weights_type,
+ absl::MakeSpan(weights_data));
- const int elements_count =
- weights.shape.h * weights.shape.w * src_slices * dst_slices * 4;
-
- std::vector<uint8_t> data(float4_size * elements_count);
-
- if (f32_weights) {
- float4* ptr = reinterpret_cast<float4*>(data.data());
- if (conv_params_.AreWeightsBuffer()) {
- RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.w,
- absl::MakeSpan(ptr, elements_count));
- } else {
- RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.w,
- absl::MakeSpan(ptr, elements_count));
- }
- } else {
- half4* ptr = reinterpret_cast<half4*>(data.data());
- if (conv_params_.AreWeightsBuffer()) {
- RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.w,
- absl::MakeSpan(ptr, elements_count));
- } else {
- RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.w,
- absl::MakeSpan(ptr, elements_count));
- }
- }
if (conv_params_.AreWeightsBuffer()) {
BufferDescriptor desc;
- desc.element_type = conv_params_.weights_data_type;
+ desc.element_type = weights_type;
desc.element_size = 4;
desc.memory_type = conv_params_.weights_upload_type ==
ConvPowerVR::WeightsUploadType::CONSTANT_MEM
? MemoryType::CONSTANT
: MemoryType::GLOBAL;
- desc.size = float4_size * elements_count;
- desc.data = std::move(data);
+ desc.size = weights_data.size();
+ desc.data = std::move(weights_data);
args_.AddObject("weights",
absl::make_unique<BufferDescriptor>(std::move(desc)));
} else {
- const int texture_width = dst_slices;
- const int texture_height = src_slices * weights.shape.h * weights.shape.w;
- const int sub_size = float4_size * texture_width * texture_height;
+ const int dst_depth =
+ AlignByN(DivideRoundUp(weights.shape.o, 4), conv_params_.block_size.w);
+ const int src_depth = DivideRoundUp(weights.shape.i, 4);
+ const int kernel_x = weights.shape.w;
+ const int kernel_y = weights.shape.h;
+ int texture_width = dst_depth;
+ int texture_height = src_depth * kernel_x * kernel_y;
+ int sub_size = SizeOf(weights_type) * 4 * texture_width * texture_height;
for (int i = 0; i < 4; ++i) {
Texture2DDescriptor desc;
- desc.element_type = conv_params_.weights_data_type;
+ desc.element_type = weights_type;
desc.size = int2(texture_width, texture_height);
desc.data.resize(sub_size);
- std::memcpy(desc.data.data(), data.data() + sub_size * i, sub_size);
+ memcpy(desc.data.data(), weights_data.data() + sub_size * i, sub_size);
const std::string name = "weights" + std::to_string(i);
args_.AddObject(name,
absl::make_unique<Texture2DDescriptor>(std::move(desc)));
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.cc
new file mode 100644
index 0000000..f0c0084
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.cc
@@ -0,0 +1,177 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.h"
+
+#include <vector>
+
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.h"
+
+namespace tflite {
+namespace gpu {
+
+absl::Status ConvPowerVR1x1SimpleWeightsTest(TestExecutionEnvironment* env) {
+ TensorFloat32 src_tensor;
+ src_tensor.shape = BHWC(1, 2, 2, 2);
+ src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
+
+ Convolution2DAttributes attr;
+ attr.padding.prepended = HW(0, 0);
+ attr.padding.appended = HW(0, 0);
+ attr.strides = HW(1, 1);
+ attr.dilations = HW(1, 1);
+ attr.weights.shape = OHWI(2, 1, 1, 2);
+ attr.weights.data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
+ attr.bias.shape = Linear(1);
+ attr.bias.data = {0.0f};
+
+ for (auto storage : env->GetSupportedStorages()) {
+ for (auto precision : env->GetSupportedPrecisions()) {
+ const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+ OperationDef op_def;
+ op_def.precision = precision;
+ auto data_type = DeduceDataTypeFromPrecision(precision);
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+ TensorFloat32 dst_tensor;
+ ConvPowerVR operation =
+ CreateConvPowerVR(env->GetGpuInfo(), op_def, attr);
+ RETURN_IF_ERROR(env->ExecuteGPUOperation(
+ src_tensor, absl::make_unique<ConvPowerVR>(std::move(operation)),
+ BHWC(1, 2, 2, 2), &dst_tensor));
+ RETURN_IF_ERROR(
+ PointWiseNear({1.0f, 1.0f, 5.0f, 5.0f, 9.0f, 9.0f, 13.0f, 13.0f},
+ dst_tensor.data, eps));
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ConvPowerVR1x1Test(TestExecutionEnvironment* env) {
+ TensorFloat32 src_tensor;
+ src_tensor.shape = BHWC(1, 2, 2, 2);
+ src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
+
+ Convolution2DAttributes attr;
+ attr.padding.prepended = HW(0, 0);
+ attr.padding.appended = HW(0, 0);
+ attr.strides = HW(1, 1);
+ attr.dilations = HW(1, 1);
+ attr.weights.shape = OHWI(2, 1, 1, 2);
+ attr.weights.data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
+ attr.bias.shape = Linear(2);
+ attr.bias.data = {0.5f, -0.5f};
+
+ for (auto storage : env->GetSupportedStorages()) {
+ for (auto precision : env->GetSupportedPrecisions()) {
+ const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+ OperationDef op_def;
+ op_def.precision = precision;
+ auto data_type = DeduceDataTypeFromPrecision(precision);
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+ TensorFloat32 dst_tensor;
+ ConvPowerVR operation =
+ CreateConvPowerVR(env->GetGpuInfo(), op_def, attr);
+ RETURN_IF_ERROR(env->ExecuteGPUOperation(
+ src_tensor, absl::make_unique<ConvPowerVR>(std::move(operation)),
+ BHWC(1, 2, 2, 2), &dst_tensor));
+ RETURN_IF_ERROR(
+ PointWiseNear({2.5f, 3.5f, 8.5f, 17.5f, 14.5f, 31.5f, 20.5f, 45.5f},
+ dst_tensor.data, eps));
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ConvPowerVRSimpleWeightsTest(TestExecutionEnvironment* env) {
+ TensorFloat32 src_tensor;
+ src_tensor.shape = BHWC(1, 2, 2, 2);
+ src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
+
+ Convolution2DAttributes attr;
+ attr.padding.prepended = HW(0, 0);
+ attr.padding.appended = HW(1, 1);
+ attr.strides = HW(1, 1);
+ attr.dilations = HW(1, 1);
+ attr.weights.shape = OHWI(1, 2, 2, 2);
+ attr.weights.data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
+ attr.bias.shape = Linear(1);
+ attr.bias.data = {0.0f};
+
+ for (auto storage : env->GetSupportedStorages()) {
+ for (auto precision : env->GetSupportedPrecisions()) {
+ const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+ OperationDef op_def;
+ op_def.precision = precision;
+ auto data_type = DeduceDataTypeFromPrecision(precision);
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+ TensorFloat32 dst_tensor;
+ ConvPowerVR operation =
+ CreateConvPowerVR(env->GetGpuInfo(), op_def, attr);
+ RETURN_IF_ERROR(env->ExecuteGPUOperation(
+ src_tensor, absl::make_unique<ConvPowerVR>(std::move(operation)),
+ BHWC(1, 2, 2, 1), &dst_tensor));
+ RETURN_IF_ERROR(
+ PointWiseNear({28.0f, 18.0f, 22.0f, 13.0f}, dst_tensor.data, eps));
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ConvPowerVRTest(TestExecutionEnvironment* env) {
+ TensorFloat32 src_tensor;
+ src_tensor.shape = BHWC(1, 2, 2, 2);
+ src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
+
+ Convolution2DAttributes attr;
+ attr.padding.prepended = HW(0, 0);
+ attr.padding.appended = HW(1, 1);
+ attr.strides = HW(1, 1);
+ attr.dilations = HW(1, 1);
+ attr.weights.shape = OHWI(2, 2, 2, 2);
+ attr.weights.data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f};
+ attr.bias.shape = Linear(2);
+ attr.bias.data = {0.5f, -0.5f};
+
+ for (auto storage : env->GetSupportedStorages()) {
+ for (auto precision : env->GetSupportedPrecisions()) {
+ const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+ OperationDef op_def;
+ op_def.precision = precision;
+ auto data_type = DeduceDataTypeFromPrecision(precision);
+ op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+ op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+ TensorFloat32 dst_tensor;
+ ConvPowerVR operation =
+ CreateConvPowerVR(env->GetGpuInfo(), op_def, attr);
+ RETURN_IF_ERROR(env->ExecuteGPUOperation(
+ src_tensor, absl::make_unique<ConvPowerVR>(std::move(operation)),
+ BHWC(1, 2, 2, 2), &dst_tensor));
+ RETURN_IF_ERROR(PointWiseNear(
+ {168.5f, 391.5f, 80.5f, 223.5f, 60.5f, 235.5f, 20.5f, 123.5f},
+ dst_tensor.data, eps));
+ }
+ }
+ return absl::OkStatus();
+}
+
+} // namespace gpu
+} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.h b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.h
new file mode 100644
index 0000000..c47845d
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.h
@@ -0,0 +1,33 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_POWERVR_TEST_UTIL_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_POWERVR_TEST_UTIL_H_
+
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
+
+namespace tflite {
+namespace gpu {
+
+absl::Status ConvPowerVR1x1SimpleWeightsTest(TestExecutionEnvironment* env);
+absl::Status ConvPowerVR1x1Test(TestExecutionEnvironment* env);
+absl::Status ConvPowerVRSimpleWeightsTest(TestExecutionEnvironment* env);
+absl::Status ConvPowerVRTest(TestExecutionEnvironment* env);
+
+} // namespace gpu
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_CONV_POWERVR_TEST_UTIL_H_
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc
index 32424a9..ac79d73 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc
@@ -126,18 +126,17 @@
result->args_.AddInt("dilation_y", dw_attr.dilations.h);
std::string c;
- c += "__kernel void main_function(\n";
- c += "$0) {\n";
+ c += "MAIN_FUNCTION($0) {\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
- c += " int linear_id = get_global_id(0);\n";
+ c += " int linear_id = GLOBAL_ID_0;\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
c += " args.src_tensor.SetBatchRef(B);\n";
} else {
- c += " int X = get_global_id(0);\n";
+ c += " int X = GLOBAL_ID_0;\n";
}
- c += " int Y = get_global_id(1);\n";
+ c += " int Y = GLOBAL_ID_1;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) { "
"\n";
c += " return; \n";
@@ -194,7 +193,8 @@
for (int d = 0; d < intermediate_depth; ++d) {
const int src_ch_count = std::min(4, dw_attr.weights.shape.i - d * 4);
const std::string s_postfix = postfixes[src_ch_count - 1];
- std::string multiplier = check.empty() ? "" : " * (FLT)(" + check + ")";
+ std::string multiplier =
+ check.empty() ? "" : " * INIT_FLT(" + check + ")";
c += " src" + s_postfix + " = args.src_tensor.Read(x_c, y_c, " +
std::to_string(d) + ")" + s_postfix + multiplier + ";\n";
c += " dw_res_" + std::to_string(d) + s_postfix + " += src" +
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc
index a632dff..4b23232 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc
@@ -92,20 +92,22 @@
c += "#define WG_X " + std::to_string(work_group_size_.x) + "\n";
c += "#define WG_Y " + std::to_string(work_group_size_.y) + "\n";
- c += R"(__kernel void main_function($0) {
+ c += R"(MAIN_FUNCTION($0) {
int gid = get_global_id(0);
- int2 tid = (int2)(get_local_id(0), get_local_id(1));
- ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);
+ int2 tid;
+ tid.x = LOCAL_ID_0;
+ tid.y = LOCAL_ID_1;
+ ACCUM_FLT4 s = INIT_ACCUM_FLT4(0.0f);
if (gid < args.dst_tensor.Slices()) {
for (int c = tid.y; c < args.src_tensor_0.Slices(); c += WG_Y) {
FLT4 v = args.src_tensor_0.Read(0, 0, c);
)";
if (weights_are_buffer) {
c += R"(FLT16 w = args.weights0.Read(c * args.dst_tensor.Slices() + gid);
- FLT4 partial = v.s0 * w.s0123;
- partial = mad(v.s1, w.s4567, partial);
- partial = mad(v.s2, w.s89ab, partial);
- partial = mad(v.s3, w.scdef, partial);
+ FLT4 partial = v.x * FLT16_0123(w);
+ partial += v.y * FLT16_4567(w);
+ partial += v.z * FLT16_89ab(w);
+ partial += v.w * FLT16_cdef(w);
s += TO_ACCUM_TYPE(partial);
)";
} else {
@@ -113,10 +115,10 @@
FLT4 w1 = args.weights0.Read(c * 4 + 1, gid);
FLT4 w2 = args.weights0.Read(c * 4 + 2, gid);
FLT4 w3 = args.weights0.Read(c * 4 + 3, gid);
- FLT4 partial = v.s0 * w0;
- partial = mad(v.s1, w1, partial);
- partial = mad(v.s2, w2, partial);
- partial = mad(v.s3, w3, partial);
+ FLT4 partial = v.x * w0;
+ partial += v.y * w1;
+ partial += v.z * w2;
+ partial += v.w * w3;
s += TO_ACCUM_TYPE(partial);
)";
}
@@ -126,10 +128,10 @@
)";
if (weights_are_buffer) {
c += R"(FLT16 w = args.weights1.Read(c * args.dst_tensor.Slices() + gid);
- FLT4 partial = v.s0 * w.s0123;
- partial = mad(v.s1, w.s4567, partial);
- partial = mad(v.s2, w.s89ab, partial);
- partial = mad(v.s3, w.scdef, partial);
+ FLT4 partial = v.x * FLT16_0123(w);
+ partial += v.y * FLT16_4567(w);
+ partial += v.z * FLT16_89ab(w);
+ partial += v.w * FLT16_cdef(w);
s += TO_ACCUM_TYPE(partial);
)";
} else {
@@ -137,10 +139,10 @@
FLT4 w1 = args.weights1.Read(c * 4 + 1, gid);
FLT4 w2 = args.weights1.Read(c * 4 + 2, gid);
FLT4 w3 = args.weights1.Read(c * 4 + 3, gid);
- FLT4 partial = v.s0 * w0;
- partial = mad(v.s1, w1, partial);
- partial = mad(v.s2, w2, partial);
- partial = mad(v.s3, w3, partial);
+ FLT4 partial = v.x * w0;
+ partial += v.y * w1;
+ partial += v.z * w2;
+ partial += v.w * w3;
s += TO_ACCUM_TYPE(partial);
)";
}
@@ -148,7 +150,7 @@
}
__local ACCUM_FLT4 temp[WG_X][WG_Y];
temp[tid.x][tid.y] = s;
- barrier(CLK_LOCAL_MEM_FENCE);
+ LOCAL_MEM_BARRIER;
if (gid >= args.dst_tensor.Slices()) {
return;
}
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/split.cc b/tensorflow/lite/delegates/gpu/common/tasks/split.cc
new file mode 100644
index 0000000..398f220
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/tasks/split.cc
@@ -0,0 +1,156 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
+
+#include <string>
+
+#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
+
+namespace tflite {
+namespace gpu {
+
+Split::Split(const OperationDef& definition, const SplitAttributes& attr)
+ : GPUOperation(definition), attr_(attr) {
+ work_group_size_ = int3(8, 4, 1);
+ code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode() : GetSplitCode();
+}
+
+std::string Split::GetSplitCode() {
+ AddSrcTensor("src_tensor", definition_.src_tensors[0]);
+ for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
+ AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
+ }
+ const std::string task_width =
+ attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
+ const std::string task_height =
+ attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
+ const std::string task_depth =
+ attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
+ const std::string task_batch =
+ attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
+ const std::string task_slices =
+ attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
+
+ std::string c;
+ c += "MAIN_FUNCTION($0) {\n";
+ c += " int task_width = "
+ ";\n";
+ if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
+ c += " int linear_id = GLOBAL_ID_0;\n";
+ c += " int X = linear_id / " + task_batch + ";\n";
+ c += " int B = linear_id % " + task_batch + ";\n";
+ } else {
+ c += " int X = GLOBAL_ID_0;\n";
+ }
+ if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
+ c += " int linear_id = GLOBAL_ID_1;\n";
+ c += " int Y = linear_id % " + task_height + ";\n";
+ c += " int B = linear_id / " + task_height + ";\n";
+ } else {
+ c += " int Y = GLOBAL_ID_1;\n";
+ }
+ c += " int S = GLOBAL_ID_2;\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "S >= args.dst_tensor.Slices()) { \n";
+ c += " return; \n";
+ c += " } \n";
+ c += " int src_counter = 0;\n";
+ for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
+ const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
+ c += " for (int i = 0; i < " + dst_name +
+ ".Slices(); ++i, src_counter++) {\n";
+ c += " FLT4 result = args.src_tensor.Read(s_x, s_y, src_counter);\n";
+ c += " " + dst_name + ".Write(result, X, Y, i);\n";
+ c += " }\n";
+ }
+ c += "}\n";
+ return c;
+}
+
+std::string Split::GetSplitChannelsCode() {
+ AddSrcTensor("src_tensor", definition_.src_tensors[0]);
+ for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
+ AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
+ }
+
+ const std::string batch_coord =
+ definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
+ std::string coords = "X, Y";
+ std::string c;
+ c += "MAIN_FUNCTION($0) {\n";
+ if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
+ c += " int linear_id = GLOBAL_ID_0;\n";
+ c += " int X = linear_id / args.src_tensor.Batch();\n";
+ c += " int B = linear_id % args.src_tensor.Batch();\n";
+ c += " if (X >= args.src_tensor.Width()) return;\n";
+ } else {
+ c += " int X = GLOBAL_ID_0;\n";
+ c += " if (X >= args.src_tensor.Width()) return;\n";
+ }
+ if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
+ c += " int linear_id = GLOBAL_ID_1;\n";
+ c += " int Y = linear_id % args.src_tensor.Height();\n";
+ c += " int Z = linear_id / args.src_tensor.Height();\n";
+ c += " if (Z >= args.src_tensor.Depth()) return;\n";
+ coords += ", Z";
+ } else {
+ c += " int Y = GLOBAL_ID_1;\n";
+ c += " if (Y >= args.src_tensor.Height()) return;\n";
+ }
+ c += " int src_channel = 0;\n";
+ const std::string postfixes[] = {"x", "y", "z", "w"};
+ for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
+ const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
+ c += " for (int i = 0; i < " + dst_name + ".Slices(); ++i) {\n";
+ c += " FLT4 result = INIT_FLT4(0.0f);\n";
+ for (int j = 0; j < 4; ++j) {
+ c += " if (i * 4 + " + std::to_string(j) + " < " + dst_name +
+ ".Channels()) {\n";
+ c += " int src_slice = src_channel >> 2;\n";
+ c += " int src_sub_ch = src_channel & 3;\n";
+ c += " FLT4 t = args.src_tensor.Read(" + coords + ", src_slice" +
+ batch_coord + ");\n";
+ c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
+ c += " result." + postfixes[j] + " = t_ar[src_sub_ch];\n";
+ c += " src_channel++;\n";
+ c += " }\n";
+ }
+ c += " " + dst_name + ".Write(result, " + coords + ", i" + batch_coord +
+ ");\n";
+ c += " }\n";
+ }
+ c += "}\n";
+ return c;
+}
+
+int3 Split::GetGridSize() const {
+ const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
+ const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
+ const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
+ const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
+ const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
+ const int grid_x = width * batch;
+ const int grid_y = height * depth;
+ const int grid_z = slices;
+ return int3(grid_x, grid_y, grid_z);
+}
+
+Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr) {
+ return Split(definition, attr);
+}
+
+} // namespace gpu
+} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/split.h b/tensorflow/lite/delegates/gpu/common/tasks/split.h
new file mode 100644
index 0000000..c1249b2
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/tasks/split.h
@@ -0,0 +1,49 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_
+
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
+
+namespace tflite {
+namespace gpu {
+
+class Split : public GPUOperation {
+ public:
+ Split(const OperationDef& definition, const SplitAttributes& attr);
+ int3 GetGridSize() const override;
+
+ // Move only
+ Split(Split&& operation) = default;
+ Split& operator=(Split&& operation) = default;
+ Split(const Split&) = delete;
+ Split& operator=(const Split&) = delete;
+
+ private:
+ std::string GetSplitCode();
+ std::string GetSplitChannelsCode();
+
+ SplitAttributes attr_;
+};
+
+Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr);
+
+} // namespace gpu
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_STRIDED_SPLIT_H_
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index 65463b9..63512cc 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -159,6 +159,8 @@
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:util",
+ "//tensorflow/lite/delegates/gpu/common/selectors:operation_selector",
+ "//tensorflow/lite/delegates/gpu/common/selectors:special_selector",
"//tensorflow/lite/delegates/gpu/common/selectors:subgraph",
"//tensorflow/lite/delegates/gpu/common/task:profiling_info",
"//tensorflow/lite/delegates/gpu/common/task:storage_type_util",
@@ -166,7 +168,6 @@
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
"//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op",
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
- "//tensorflow/lite/delegates/gpu/metal/selectors:operation_selector",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/time",
],
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
index 83e2c3a..c605054 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
@@ -26,6 +26,8 @@
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/precision.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h"
+#include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@@ -36,13 +38,18 @@
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
-#include "tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
+// returns true if actual memory for this storage type is buffer
+bool IsBufferBased(const TensorStorageType& type) {
+ return type == TensorStorageType::BUFFER ||
+ type == TensorStorageType::IMAGE_BUFFER;
+}
+
bool HasIntersection(const std::vector<ValueId>& vec_ids,
const std::set<ValueId>& ids) {
for (ValueId id : vec_ids) {
@@ -127,24 +134,27 @@
absl::Status InferenceContext::InitFromGraph(
const CreateInferenceInfo& create_info, const GraphFloat32& graph,
id<MTLDevice> device_id) {
+ std::set<ValueId> preallocated_ids;
const auto inputs = graph.inputs();
for (const auto& input : inputs) {
input_ids_.push_back(input->id);
+ preallocated_ids.insert(input->id);
}
const auto outputs = graph.outputs();
for (const auto& output : outputs) {
output_ids_.push_back(output->id);
+ preallocated_ids.insert(output->id);
}
precision_ = create_info.precision;
MetalDevice metal_device(device_id);
- ReserveGraphTensors(create_info, metal_device.GetInfo(), graph);
- RETURN_IF_ERROR(
- Compile(graph, metal_device.GetInfo(), create_info.precision));
+ ReserveGraphTensors(create_info, metal_device.GetInfo(), graph,
+ preallocated_ids);
+ RETURN_IF_ERROR(Compile(graph, metal_device.GetInfo(), create_info.hints));
RETURN_IF_ERROR(Merge());
RETURN_IF_ERROR(CompileOperations(&metal_device));
- RETURN_IF_ERROR(AllocateTensors(&metal_device));
+ RETURN_IF_ERROR(AllocateTensors(&metal_device, preallocated_ids));
BindTensorsToOperations();
RETURN_IF_ERROR(UpdateParams(metal_device.GetInfo()));
RETURN_IF_ERROR(Tune(TuningType::kFast, &metal_device));
@@ -153,12 +163,15 @@
void InferenceContext::ReserveGraphTensors(
const CreateInferenceInfo& create_info, const GpuInfo& gpu_info,
- const GraphFloat32& graph) {
+ const GraphFloat32& graph, const std::set<ValueId>& preallocated_ids) {
ValueId max_id = 0;
auto tensors = graph.values();
auto data_type = DeduceDataTypeFromPrecision(create_info.precision);
for (auto& t : tensors) {
TensorStorageType storage_type = create_info.storage_type;
+ if (preallocated_ids.find(t->id) != preallocated_ids.end()) {
+ storage_type = TensorStorageType::BUFFER;
+ }
const auto shape = graph.GetValue(t->id)->tensor.shape;
Layout layout = shape.b == 1 ? Layout::HWC : Layout::BHWC;
// Temporary disabled because no support of SINGLE_TEXTURE_2D in Metal
@@ -184,11 +197,17 @@
absl::Status InferenceContext::Compile(const GraphFloat32& graph,
const GpuInfo& gpu_info,
- CalculationsPrecision precision) {
+ ModelHints hints) {
if (!IsBatchMatchesForAllValues(graph)) {
return absl::InvalidArgumentError(
"Only identical batch dimension is supported");
}
+ std::map<ValueId, TensorDescriptor> tensor_descriptors;
+ const auto values = graph.values();
+ for (auto value : values) {
+ tensor_descriptors[value->id] = tensor_reserver_.Get(value->id).descriptor;
+ }
+ std::set<NodeId> consumed_nodes;
std::map<ValueId, int>
tensor_usages; // keeps latest index of operation that updated tensor
for (const auto& input_id : input_ids_) {
@@ -198,40 +217,62 @@
std::vector<Node*> graph_nodes = graph.nodes();
for (int i = 0; i < graph_nodes.size(); ++i) {
const Node& node = *graph_nodes[i];
- auto inputs = graph.FindInputs(node.id);
- auto outputs = graph.FindOutputs(node.id);
- // Reordering of input ids and updating of temporary tensors_usage struct.
- // This stage is necessary because we are building OperationDef that rely
- // on order of input ids. But we also should have input id on first
- // position that potentially can be "linking" tensor and as result
- // eliminated(unused) We apply it only for ADD operation, because of ADD
- // associativity and ADD can be linked. In current approach "linking"
- // tensor can be only latest written tensor(during linear order of
- // execution) among input tensors.
- if (IsGenericAdd(node, inputs, outputs)) {
- int latest_written_tensor_index = 0;
- int last_usage = tensor_usages[inputs[0]->id];
- for (int j = 1; j < inputs.size(); ++j) {
- if (tensor_usages[inputs[j]->id] > last_usage) {
- last_usage = tensor_usages[inputs[j]->id];
- latest_written_tensor_index = j;
- }
- }
- std::swap(inputs[0], inputs[latest_written_tensor_index]);
+ auto op_type = OperationTypeFromString(node.operation.type);
+ if (op_type == OperationType::CONSTANT) {
+ auto attr =
+ absl::any_cast<ConstTensorAttributes>(node.operation.attributes);
+ auto outputs = graph.FindOutputs(node.id);
+ const_tensors_descs_[outputs[0]->id] =
+ tensor_reserver_.Get(outputs[0]->id).descriptor;
+ const_tensors_descs_[outputs[0]->id].UploadData(attr.tensor);
+ continue;
}
- OperationDef op_def;
- op_def.precision = precision;
- for (int j = 0; j < inputs.size(); ++j) {
- op_def.src_tensors.push_back(
- tensor_reserver_.Get(inputs[j]->id).descriptor);
- }
- for (int j = 0; j < outputs.size(); ++j) {
- op_def.dst_tensors.push_back(
- tensor_reserver_.Get(outputs[j]->id).descriptor);
- }
+ std::string op_name = node.operation.type + " " + std::to_string(node.id);
GPUOperationsSubgraph gpu_subgraph;
- RETURN_IF_ERROR(GPUOperationFromNode(gpu_info, op_def, inputs, outputs,
- node, &gpu_subgraph));
+ if (hints.Check(ModelHints::kAllowSpecialKernels) &&
+ GPUSubgraphFromGraph(gpu_info, precision_, graph, node.id,
+ tensor_descriptors, &consumed_nodes, &gpu_subgraph,
+ &op_name)
+ .ok()) {
+ // Mapping of subgraph (set of nodes) to GPU operations. Should happen
+ // before straigtforward mapping.
+ } else {
+ // Straigtforward mapping of one graph node to GPU operations.
+ auto inputs = graph.FindInputs(node.id);
+ auto outputs = graph.FindOutputs(node.id);
+ // Reordering of input ids and updating of temporary tensors_usage struct.
+ // This stage is necessary because we are building OperationDef that rely
+ // on order of input ids. But we also should have input id on first
+ // position that potentially can be "linking" tensor and as result
+ // eliminated(unused) We apply it only for ADD operation, because of ADD
+ // associativity and ADD can be linked. In current approach "linking"
+ // tensor can be only latest written tensor(during linear order of
+ // execution) among input tensors.
+ if (IsGenericAdd(node, inputs, outputs)) {
+ int latest_written_tensor_index = 0;
+ int last_usage = tensor_usages[inputs[0]->id];
+ for (int j = 1; j < inputs.size(); ++j) {
+ if (tensor_usages[inputs[j]->id] > last_usage) {
+ last_usage = tensor_usages[inputs[j]->id];
+ latest_written_tensor_index = j;
+ }
+ }
+ std::swap(inputs[0], inputs[latest_written_tensor_index]);
+ }
+ consumed_nodes.insert(node.id);
+ OperationDef op_def;
+ op_def.precision = precision_;
+ for (int j = 0; j < inputs.size(); ++j) {
+ op_def.src_tensors.push_back(
+ tensor_reserver_.Get(inputs[j]->id).descriptor);
+ }
+ for (int j = 0; j < outputs.size(); ++j) {
+ op_def.dst_tensors.push_back(
+ tensor_reserver_.Get(outputs[j]->id).descriptor);
+ }
+ RETURN_IF_ERROR(GPUOperationFromNode(gpu_info, op_def, hints, inputs,
+ outputs, node, &gpu_subgraph));
+ }
std::map<int, ValueId> mapping_to_global_ids;
for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
const auto& t = gpu_subgraph.new_tensors[j];
@@ -260,7 +301,7 @@
metal_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
}
}
- metal_node.name = node.operation.type + " " + std::to_string(node.id);
+ metal_node.name = op_name;
nodes_.push_back(std::move(metal_node));
}
}
@@ -318,14 +359,8 @@
return absl::OkStatus();
}
-absl::Status InferenceContext::AllocateTensors(MetalDevice* device) {
- std::set<ValueId> preallocated_ids;
- for (auto tensor_id : input_ids_) {
- preallocated_ids.insert(tensor_id);
- }
- for (const auto& outputId : output_ids_) {
- preallocated_ids.insert(outputId);
- }
+absl::Status InferenceContext::AllocateTensors(
+ MetalDevice* device, const std::set<ValueId>& preallocated_ids) {
for (int i = 0; i < nodes_.size(); ++i) {
auto& node = nodes_[i];
if (HasIntersection(node.inputs, preallocated_ids) ||
@@ -334,24 +369,31 @@
}
}
- const bool f32_storage = precision_ == CalculationsPrecision::F32;
for (auto& tensor_id : preallocated_ids) {
const auto& t = tensor_reserver_.Get(tensor_id);
RETURN_IF_ERROR(CreateSharedBufferTensor(
nil, t.shape, t.descriptor, &preallocated_tensors_[tensor_id]));
}
+ RETURN_IF_ERROR(AllocateMemoryForConstTensors(device));
RETURN_IF_ERROR(AllocateMemoryForBuffers(device));
+ RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device));
return absl::OkStatus();
}
MetalSpatialTensor* InferenceContext::GetTensor(ValueId tensor_id) {
if (preallocated_tensors_.find(tensor_id) != preallocated_tensors_.end()) {
return &preallocated_tensors_[tensor_id];
+ } else if (const_tensors_.find(tensor_id) != const_tensors_.end()) {
+ return &const_tensors_[tensor_id];
} else if (graph_ids_to_shared_buffer_tensors_.find(tensor_id) !=
graph_ids_to_shared_buffer_tensors_.end()) {
return &shared_buffer_tensors_
[graph_ids_to_shared_buffer_tensors_[tensor_id]];
+ } else if (graph_ids_to_strong_shape_tensors_.find(tensor_id) !=
+ graph_ids_to_strong_shape_tensors_.end()) {
+ return &strong_shape_tensors_
+ [graph_ids_to_strong_shape_tensors_[tensor_id]];
}
return nullptr;
}
@@ -384,36 +426,62 @@
return absl::OkStatus();
}
-void InferenceContext::GetUsages(std::map<ValueId, int2>* usages) {
+InferenceContext::TensorMemoryType InferenceContext::GetTensorMemoryType(
+ ValueId id) {
+ if (preallocated_tensors_.find(id) != preallocated_tensors_.end()) {
+ return TensorMemoryType::kPreallocated;
+ } else if (const_tensors_.find(id) != const_tensors_.end()) {
+ return TensorMemoryType::kConst;
+ } else if (IsBufferBased(tensor_reserver_.Get(id).descriptor.storage_type)) {
+ return TensorMemoryType::kBuffer;
+ } else {
+ return TensorMemoryType::kStrongShape;
+ }
+}
+
+void InferenceContext::GetUsages(const std::function<bool(ValueId)>& functor,
+ std::map<ValueId, int2>* usages) {
for (ValueId in_id : input_ids_) {
- if (preallocated_tensors_.find(in_id) == preallocated_tensors_.end()) {
+ if (functor(in_id)) {
AddUsage(in_id, 0, usages);
}
}
for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
for (auto& tensor_id : nodes_[op_index].inputs) {
- if (preallocated_tensors_.find(tensor_id) ==
- preallocated_tensors_.end()) {
+ if (functor(tensor_id)) {
AddUsage(tensor_id, op_index, usages);
}
}
for (auto& tensor_id : nodes_[op_index].outputs) {
- if (preallocated_tensors_.find(tensor_id) ==
- preallocated_tensors_.end()) {
+ if (functor(tensor_id)) {
AddUsage(tensor_id, op_index, usages);
}
}
}
for (ValueId out_id : output_ids_) {
- if (preallocated_tensors_.find(out_id) == preallocated_tensors_.end()) {
+ if (functor(out_id)) {
AddUsage(out_id, nodes_.size(), usages);
}
}
}
+absl::Status InferenceContext::AllocateMemoryForConstTensors(
+ MetalDevice* device) {
+ for (auto& description : const_tensors_descs_) {
+ RETURN_IF_ERROR(const_tensors_[description.first].CreateFromDescriptor(
+ description.second, device->device()));
+ }
+ const_tensors_descs_.clear();
+ return absl::OkStatus();
+}
+
absl::Status InferenceContext::AllocateMemoryForBuffers(MetalDevice* device) {
std::map<ValueId, int2> buffer_usages;
- GetUsages(&buffer_usages);
+ GetUsages(
+ [this](ValueId id) {
+ return GetTensorMemoryType(id) == TensorMemoryType::kBuffer;
+ },
+ &buffer_usages);
std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
for (auto& usage : buffer_usages) {
@@ -473,6 +541,49 @@
return absl::OkStatus();
}
+absl::Status InferenceContext::AllocateMemoryForStrongShapes(
+ MetalDevice* device) {
+ std::map<ValueId, int2> usages;
+ GetUsages(
+ [this](ValueId id) {
+ return GetTensorMemoryType(id) == TensorMemoryType::kStrongShape;
+ },
+ &usages);
+
+ std::vector<TensorUsageRecord<DummyTensor>> usage_records;
+ std::map<ValueId, ValueId> remap_from_graph_ids;
+ for (auto& usage : usages) {
+ remap_from_graph_ids[usage.first] = usage_records.size();
+ usage_records.push_back({tensor_reserver_.Get(usage.first),
+ static_cast<TaskId>(usage.second.x),
+ static_cast<TaskId>(usage.second.y)});
+ }
+
+ ObjectsAssignment<DummyTensor> assignment;
+ RETURN_IF_ERROR(AssignObjectsToTensors(
+ usage_records, MemoryStrategy::EQUALITY, &assignment));
+
+ for (auto& node : nodes_) {
+ std::vector<ValueId> all_ids = node.inputs;
+ all_ids.insert(all_ids.end(), node.outputs.begin(), node.outputs.end());
+ for (auto& tensor_id : all_ids) {
+ const auto& tensor_dummy = tensor_reserver_.Get(tensor_id);
+ if (GetTensorMemoryType(tensor_id) != TensorMemoryType::kStrongShape) {
+ continue;
+ }
+ const auto id = assignment.object_ids[remap_from_graph_ids[tensor_id]];
+ graph_ids_to_strong_shape_tensors_[tensor_id] = id;
+ const auto& it = strong_shape_tensors_.find(id);
+ if (it == strong_shape_tensors_.end()) {
+ RETURN_IF_ERROR(CreateTensor(device->device(), tensor_dummy.shape,
+ tensor_dummy.descriptor,
+ &strong_shape_tensors_[id]));
+ }
+ }
+ }
+ return absl::OkStatus();
+}
+
absl::Status InferenceContext::Tune(TuningType tuning_type,
MetalDevice* device) {
for (auto& node : nodes_) {
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h
index f2264d8..afdad2f 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.h
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h
@@ -118,21 +118,34 @@
void Profile(id<MTLDevice> device, ProfilingInfo* result);
private:
+ enum class TensorMemoryType {
+ kStrongShape,
+ kBuffer,
+ kVariable,
+ kConst,
+ kPreallocated
+ };
absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info,
- CalculationsPrecision precision);
+ ModelHints hints);
void ReserveGraphTensors(const CreateInferenceInfo& create_info,
- const GpuInfo& gpu_info, const GraphFloat32& graph);
+ const GpuInfo& gpu_info, const GraphFloat32& graph,
+ const std::set<ValueId>& preallocated_ids);
absl::Status CompileOperations(MetalDevice* device);
absl::Status Merge();
- absl::Status AllocateTensors(MetalDevice* device);
+ absl::Status AllocateTensors(MetalDevice* device,
+ const std::set<ValueId>& preallocated_ids);
+ absl::Status AllocateMemoryForConstTensors(MetalDevice* device);
absl::Status AllocateMemoryForBuffers(MetalDevice* device);
+ absl::Status AllocateMemoryForStrongShapes(MetalDevice* device);
void BindTensorsToOperations();
absl::Status UpdateParams(const GpuInfo& gpu_info);
MetalSpatialTensor* GetTensor(ValueId tensor_id);
- void GetUsages(std::map<ValueId, int2>* usages);
+ void GetUsages(const std::function<bool(ValueId)>& functor,
+ std::map<ValueId, int2>* usages);
+ TensorMemoryType GetTensorMemoryType(ValueId id);
absl::Status Tune(TuningType tuning_type, MetalDevice* device);
struct DummyTensor {
@@ -197,11 +210,17 @@
CalculationsPrecision precision_;
std::map<ValueId, MetalSpatialTensor> preallocated_tensors_;
+ std::map<ValueId, TensorDescriptor> const_tensors_descs_;
+ std::map<ValueId, MetalSpatialTensor> const_tensors_;
+
std::map<ValueId, int> graph_ids_to_shared_buffer_tensors_;
std::vector<id<MTLBuffer>> shared_buffers_;
std::vector<MetalSpatialTensor>
shared_buffer_tensors_; // use references to memory
// from _sharedBuffers
+
+ std::map<ValueId, MetalSpatialTensor> strong_shape_tensors_;
+ std::map<ValueId, ValueId> graph_ids_to_strong_shape_tensors_;
};
// Runs specific transforms for the graph.
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index a518817..5bbf8f8 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -15,13 +15,6 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "kernels",
- deps = [
- ":conv",
- ],
-)
-
objc_library(
name = "add_test_lib",
testonly = 1,
@@ -68,32 +61,15 @@
deps = [":concat_test_lib"],
)
-cc_library(
- name = "conv",
- srcs = ["conv.cc"],
- hdrs = ["conv.h"],
- deps = [
- ":util",
- "//tensorflow/lite/delegates/gpu/common:data_type",
- "//tensorflow/lite/delegates/gpu/common:gpu_info",
- "//tensorflow/lite/delegates/gpu/common:operations",
- "//tensorflow/lite/delegates/gpu/common:shape",
- "//tensorflow/lite/delegates/gpu/common:types",
- "//tensorflow/lite/delegates/gpu/common:util",
- "//tensorflow/lite/delegates/gpu/common:winograd_util",
- "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
- "@com_google_absl//absl/strings",
- ],
-)
-
objc_library(
name = "conv_test_lib",
testonly = 1,
srcs = ["conv_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
- ":conv",
":test_util",
+ "//tensorflow/lite/delegates/gpu/common/tasks:conv_metal",
+ "//tensorflow/lite/delegates/gpu/common/tasks:conv_powervr_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:winograd",
],
)
@@ -613,16 +589,6 @@
],
)
-cc_library(
- name = "util",
- srcs = ["util.cc"],
- hdrs = ["util.h"],
- deps = [
- "//tensorflow/lite/delegates/gpu/common:data_type",
- "//tensorflow/lite/delegates/gpu/common:types",
- ],
-)
-
objc_library(
name = "winograd_test_lib",
testonly = 1,
@@ -680,7 +646,6 @@
],
sdk_frameworks = ["XCTest"],
deps = [
- ":conv",
":test_util",
"//tensorflow/lite/delegates/gpu/common:gpu_info",
"//tensorflow/lite/delegates/gpu/common:precision",
@@ -689,6 +654,8 @@
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/tasks:add_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:concat_test_util",
+ "//tensorflow/lite/delegates/gpu/common/tasks:conv_metal",
+ "//tensorflow/lite/delegates/gpu/common/tasks:conv_powervr_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:conv_weights_converter_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_4x4_test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_test_util",
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
index 95c10d1..693b12f 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h"
#import <XCTest/XCTest.h>
#include "tensorflow/lite/delegates/gpu/common/tasks/winograd.h"
@@ -24,6 +24,7 @@
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tasks/conv_powervr_test_util.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
@@ -54,7 +55,7 @@
attr.padding.appended = HW(1, 0);
attr.strides = HW(1, 1);
- for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
+ for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@@ -63,10 +64,10 @@
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
- ConvolutionGeneric operation =
- CreateConvolutionGeneric(op_def, BHWC(1, 2, 2, 2), attr, env->GetGpuInfo());
+ ConvolutionMetal operation =
+ CreateConvolutionMetal(op_def, BHWC(1, 2, 2, 2), attr, env->GetGpuInfo());
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- src_tensor, absl::make_unique<ConvolutionGeneric>(std::move(operation)), BHWC(1, 2, 2, 2),
+ src_tensor, absl::make_unique<ConvolutionMetal>(std::move(operation)), BHWC(1, 2, 2, 2),
&dst_tensor));
RETURN_IF_ERROR(PointWiseNear({4, 8, 4, 8, 2, 4, 2, 4}, dst_tensor.data, eps))
<< "Failed using precision " << ToString(precision);
@@ -90,7 +91,7 @@
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
- for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
+ for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@@ -99,10 +100,10 @@
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
- ConvolutionGeneric operation =
- CreateConvolutionGeneric(op_def, BHWC(1, 1, 1, 1), attr, env->GetGpuInfo());
+ ConvolutionMetal operation =
+ CreateConvolutionMetal(op_def, BHWC(1, 1, 1, 1), attr, env->GetGpuInfo());
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- src_tensor, absl::make_unique<ConvolutionGeneric>(std::move(operation)), BHWC(1, 1, 1, 1),
+ src_tensor, absl::make_unique<ConvolutionMetal>(std::move(operation)), BHWC(1, 1, 1, 1),
&dst_tensor));
RETURN_IF_ERROR(PointWiseNear({10}, dst_tensor.data, eps))
<< "Failed using precision " << ToString(precision);
@@ -126,7 +127,7 @@
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
- for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
+ for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@@ -135,10 +136,10 @@
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
- ConvolutionGeneric operation =
- CreateConvolutionGeneric(op_def, BHWC(1, 1, 1, 1), attr, env->GetGpuInfo());
+ ConvolutionMetal operation =
+ CreateConvolutionMetal(op_def, BHWC(1, 1, 1, 1), attr, env->GetGpuInfo());
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- src_tensor, absl::make_unique<ConvolutionGeneric>(std::move(operation)), BHWC(1, 1, 1, 1),
+ src_tensor, absl::make_unique<ConvolutionMetal>(std::move(operation)), BHWC(1, 1, 1, 1),
&dst_tensor));
RETURN_IF_ERROR(PointWiseNear({11}, dst_tensor.data, eps))
<< "Failed using precision " << ToString(precision);
@@ -162,7 +163,7 @@
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
- for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
+ for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@@ -171,10 +172,10 @@
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
- ConvolutionGeneric operation =
- CreateConvolutionGeneric(op_def, BHWC(1, 2, 1, 2), attr, env->GetGpuInfo());
+ ConvolutionMetal operation =
+ CreateConvolutionMetal(op_def, BHWC(1, 2, 1, 2), attr, env->GetGpuInfo());
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- src_tensor, absl::make_unique<ConvolutionGeneric>(std::move(operation)), BHWC(1, 2, 1, 2),
+ src_tensor, absl::make_unique<ConvolutionMetal>(std::move(operation)), BHWC(1, 2, 1, 2),
&dst_tensor));
RETURN_IF_ERROR(PointWiseNear({4, 8, 4, 8}, dst_tensor.data, eps))
<< "Failed using precision " << ToString(precision);
@@ -198,7 +199,7 @@
attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2);
- for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
+ for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@@ -207,10 +208,10 @@
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
- ConvolutionGeneric operation =
- CreateConvolutionGeneric(op_def, BHWC(1, 2, 2, 1), attr, env->GetGpuInfo());
+ ConvolutionMetal operation =
+ CreateConvolutionMetal(op_def, BHWC(1, 2, 2, 1), attr, env->GetGpuInfo());
RETURN_IF_ERROR(env->ExecuteGPUOperation(
- src_tensor, absl::make_unique<ConvolutionGeneric>(std::move(operation)), BHWC(1, 2, 2, 1),
+ src_tensor, absl::make_unique<ConvolutionMetal>(std::move(operation)), BHWC(1, 2, 2, 1),
&dst_tensor));
RETURN_IF_ERROR(PointWiseNear({2, 4, 8, 16}, dst_tensor.data, eps))
<< "Failed using precision " << ToString(precision);
@@ -266,8 +267,8 @@
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 output0;
- auto gpu_op0 = CreateConvolutionGeneric(op_def, dst_shape, attr, env->GetGpuInfo());
- auto op0_ptr = absl::make_unique<ConvolutionGeneric>(std::move(gpu_op0));
+ auto gpu_op0 = CreateConvolutionMetal(op_def, dst_shape, attr, env->GetGpuInfo());
+ auto op0_ptr = absl::make_unique<ConvolutionMetal>(std::move(gpu_op0));
RETURN_IF_ERROR(
env->ExecuteGPUOperation(src_tensor, std::move(op0_ptr), dst_shape, &output0));
@@ -275,8 +276,9 @@
std::unique_ptr<GPUOperation> op1_ptr =
absl::make_unique<Winograd4x4To36>(std::move(gpu_op1));
- auto gpu_op2 = CreateConvolutionWino4x4To6x6(op_def, conv_shape, attr, env->GetGpuInfo());
- auto op2_ptr = absl::make_unique<ConvolutionGeneric>(std::move(gpu_op2));
+ auto gpu_op2 =
+ CreateConvolutionMetalWino4x4To6x6(op_def, conv_shape, attr, env->GetGpuInfo());
+ auto op2_ptr = absl::make_unique<ConvolutionMetal>(std::move(gpu_op2));
auto gpu_op3 = CreateWinograd36To4x4(op_def, attr.bias);
std::unique_ptr<GPUOperation> op3_ptr =
@@ -339,4 +341,24 @@
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
+- (void)testConvPowerVR1x1SimpleWeights {
+ const auto status = ConvPowerVR1x1SimpleWeightsTest(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testConvPowerVR1x1 {
+ const auto status = ConvPowerVR1x1Test(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testConvPowerVRSimpleWeights {
+ const auto status = ConvPowerVRSimpleWeightsTest(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testConvPowerVR {
+ const auto status = ConvPowerVRTest(&exec_env_);
+ XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
@end
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/util.cc b/tensorflow/lite/delegates/gpu/metal/kernels/util.cc
deleted file mode 100644
index 82529b9..0000000
--- a/tensorflow/lite/delegates/gpu/metal/kernels/util.cc
+++ /dev/null
@@ -1,57 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
-
-#include <vector>
-
-#include "tensorflow/lite/delegates/gpu/common/types.h"
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-
-/// Converts float to destination type (if needed) and stores as bytes array.
-std::vector<uint8_t> GetByteBufferConverted(
- const std::vector<float>& input_vector, DataType data_type) {
- if (data_type == DataType::FLOAT32) {
- return GetByteBuffer(input_vector);
- } else {
- std::vector<uint8_t> result;
- result.reserve(input_vector.size() * sizeof(half));
- for (const float value : input_vector) {
- const half converted = half(value);
- const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&converted);
- result.insert(result.end(), bytes, bytes + sizeof(half));
- }
- return result;
- }
-}
-
-/// Resizes, Converts float to destination type (if needed) and stores as bytes
-/// array.
-std::vector<uint8_t> GetByteBufferConvertedResized(
- const std::vector<float>& input_vector, DataType data_type,
- size_t elements_count) {
- auto result = GetByteBufferConverted(input_vector, data_type);
- const size_t type_size =
- data_type == DataType::FLOAT32 ? sizeof(float) : sizeof(half);
- result.resize(type_size * elements_count);
- return result;
-}
-
-} // namespace metal
-} // namespace gpu
-} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/util.h b/tensorflow/lite/delegates/gpu/metal/kernels/util.h
deleted file mode 100644
index 078eb51..0000000
--- a/tensorflow/lite/delegates/gpu/metal/kernels/util.h
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UTIL_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UTIL_H_
-
-#include <vector>
-
-#include "tensorflow/lite/delegates/gpu/common/data_type.h"
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-
-/// Helper function to convert buffer's content into stream of bytes
-template <typename T>
-std::vector<uint8_t> GetByteBuffer(const std::vector<T>& input_vector) {
- std::vector<uint8_t> result;
- result.insert(result.begin(),
- reinterpret_cast<const uint8_t*>(input_vector.data()),
- reinterpret_cast<const uint8_t*>(input_vector.data()) +
- input_vector.size() * sizeof(*input_vector.data()));
- return result;
-}
-
-/// Converts float to destination type (if needed) and stores as bytes array.
-/// supports DataType::FLOAT32 and DataType::FLOAT16
-std::vector<uint8_t> GetByteBufferConverted(
- const std::vector<float>& input_vector, DataType data_type);
-
-/// Resizes, Converts float to destination type (if needed) and stores as bytes
-/// array.
-/// supports DataType::FLOAT32 and DataType::FLOAT16
-std::vector<uint8_t> GetByteBufferConvertedResized(
- const std::vector<float>& input_vector, DataType data_type,
- size_t elements_count);
-
-} // namespace metal
-} // namespace gpu
-} // namespace tflite
-
-#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UTIL_H_
diff --git a/tensorflow/lite/delegates/gpu/metal/selectors/BUILD b/tensorflow/lite/delegates/gpu/metal/selectors/BUILD
deleted file mode 100644
index b0d1231..0000000
--- a/tensorflow/lite/delegates/gpu/metal/selectors/BUILD
+++ /dev/null
@@ -1,30 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "operation_selector",
- srcs = ["operation_selector.cc"],
- hdrs = ["operation_selector.h"],
- deps = [
- "//tensorflow/lite/delegates/gpu/common:gpu_info",
- "//tensorflow/lite/delegates/gpu/common:model",
- "//tensorflow/lite/delegates/gpu/common:model_hints",
- "//tensorflow/lite/delegates/gpu/common:operations",
- "//tensorflow/lite/delegates/gpu/common:precision",
- "//tensorflow/lite/delegates/gpu/common:shape",
- "//tensorflow/lite/delegates/gpu/common:status",
- "//tensorflow/lite/delegates/gpu/common:util",
- "//tensorflow/lite/delegates/gpu/common:winograd_util",
- "//tensorflow/lite/delegates/gpu/common/selectors:convolution_transposed_selector",
- "//tensorflow/lite/delegates/gpu/common/selectors:default_selector",
- "//tensorflow/lite/delegates/gpu/common/selectors:dw_convolution_selector",
- "//tensorflow/lite/delegates/gpu/common/selectors:fully_connected_selector",
- "//tensorflow/lite/delegates/gpu/common/selectors:simple_selectors",
- "//tensorflow/lite/delegates/gpu/common/selectors:subgraph",
- "//tensorflow/lite/delegates/gpu/common/tasks:elementwise",
- "//tensorflow/lite/delegates/gpu/common/tasks:mean_stddev_normalization",
- "//tensorflow/lite/delegates/gpu/metal/kernels",
- ],
-)
diff --git a/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc
deleted file mode 100644
index fc31cf6..0000000
--- a/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc
+++ /dev/null
@@ -1,377 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h"
-
-#include <vector>
-
-#include "absl/strings/substitute.h"
-#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
-#include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/model_hints.h"
-#include "tensorflow/lite/delegates/gpu/common/operations.h"
-#include "tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h"
-#include "tensorflow/lite/delegates/gpu/common/selectors/default_selector.h"
-#include "tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h"
-#include "tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h"
-#include "tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h"
-#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
-#include "tensorflow/lite/delegates/gpu/common/shape.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
-#include "tensorflow/lite/delegates/gpu/common/tasks/elementwise.h"
-#include "tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h"
-#include "tensorflow/lite/delegates/gpu/common/util.h"
-#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
-#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-namespace {
-bool IsRecommendedForWinograd4x4To6x6(const Convolution2DAttributes& attr,
- const GpuInfo& gpu_info,
- const BHWC& dst_shape) {
- const int tiles_x = DivideRoundUp(dst_shape.w, 4);
- const int tiles_y = DivideRoundUp(dst_shape.h, 4);
- const int total_tiles = tiles_x * tiles_y;
- const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
- const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
- int min_depth = 16;
- const int min_tiles = 32;
- if (total_tiles >= min_tiles * 8) {
- min_depth /= 4;
- min_depth = std::max(min_depth, 8);
- } else if (total_tiles >= min_tiles * 4) {
- min_depth /= 2;
- min_depth = std::max(min_depth, 8);
- }
- const bool recommended_channels =
- src_depth >= min_depth && dst_depth >= min_depth;
- const bool recommended_hw = total_tiles >= min_tiles;
- return recommended_channels && recommended_hw;
-}
-
-absl::Status WinogradFromNode(const GpuInfo& gpu_info,
- const std::vector<Value*>& inputs,
- const std::vector<Value*>& outputs,
- const OperationDef& op_def,
- const BHWC& input_shape, const BHWC& output_shape,
- const Convolution2DAttributes& attr,
- GPUOperationsSubgraph* gpu_subgraph) {
- if (!IsSuitableForWinograd4x4To6x6(attr)) {
- return absl::UnimplementedError("No implementation for this case.");
- }
- if (!IsRecommendedForWinograd4x4To6x6(attr, gpu_info, output_shape)) {
- return absl::UnimplementedError("Not recommended for this case.");
- }
-
- const int tiles_x = DivideRoundUp(output_shape.w, 4);
- const int tiles_y = DivideRoundUp(output_shape.h, 4);
- const BHWC shape_0{input_shape.b, 36, tiles_x * tiles_y, input_shape.c};
- const BHWC shape_1{input_shape.b, 36, tiles_x * tiles_y, output_shape.c};
- TensorDescriptor tensor_desc = op_def.src_tensors[0];
- gpu_subgraph->new_tensors = {{shape_0, tensor_desc}, {shape_1, tensor_desc}};
- gpu_subgraph->operations.clear();
- gpu_subgraph->operations.resize(3);
-
- OperationDef winograd_up_def;
- winograd_up_def.precision = op_def.precision;
- winograd_up_def.src_tensors.push_back(op_def.src_tensors[0]);
- winograd_up_def.dst_tensors.push_back(op_def.src_tensors[0]);
- auto& winograd_up = gpu_subgraph->operations[0];
- winograd_up.operation =
- SelectWinograd4x4To36(gpu_info, attr.padding, winograd_up_def);
- winograd_up.input_ids = {static_cast<int>(inputs[0]->id)};
- winograd_up.output_ids = {-1};
-
- OperationDef conv_def;
- conv_def.precision = op_def.precision;
- conv_def.src_tensors.push_back(op_def.src_tensors[0]);
- conv_def.dst_tensors.push_back(op_def.src_tensors[0]);
- auto& conv = gpu_subgraph->operations[1];
- conv.input_ids = {-1};
- conv.output_ids = {-2};
- auto gpu_op =
- CreateConvolutionWino4x4To6x6(conv_def, shape_1, attr, gpu_info);
- conv.operation = absl::make_unique<ConvolutionGeneric>(std::move(gpu_op));
- OperationDef winograd_down_def;
- winograd_down_def.precision = op_def.precision;
- winograd_down_def.src_tensors.push_back(op_def.src_tensors[0]);
- winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]);
- auto& winograd_down = gpu_subgraph->operations[2];
- winograd_down.input_ids = {-2};
- winograd_down.output_ids = {static_cast<int>(outputs[0]->id)};
- winograd_down.operation =
- SelectWinograd36To4x4(gpu_info, winograd_down_def, attr.bias);
- return absl::OkStatus();
-}
-
-} // namespace
-
-absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
- const OperationDef& op_def,
- const std::vector<Value*>& inputs,
- const std::vector<Value*>& outputs,
- const Node& node,
- GPUOperationsSubgraph* gpu_subgraph) {
- std::unique_ptr<GPUOperation>* gpu_op =
- InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
- auto op_type = OperationTypeFromString(node.operation.type);
- switch (op_type) {
- case OperationType::ADD: {
- if (inputs.size() == 2 &&
- (inputs[0]->tensor.shape.c == inputs[1]->tensor.shape.c ||
- inputs[1]->tensor.shape.c == 1)) {
- GPUOperation operation =
- CreateElementwiseTwoInput(op_def, op_type, inputs[1]->tensor.shape);
- *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
- return absl::OkStatus();
- } else if (inputs.size() >= 2) {
- auto output = outputs[0];
- std::vector<int> channels(inputs.size());
- for (int i = 0; i < inputs.size(); ++i) {
- channels[i] = inputs[i]->tensor.shape.c;
- }
- SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op);
- return absl::OkStatus();
- } else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
- auto attr =
- absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
- GPUOperation operation =
- CreateElementwise(gpu_info, op_def, op_type, attr);
- *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
- return absl::OkStatus();
- }
- return absl::UnimplementedError(absl::StrCat(
- "No support of ", node.operation.type, " with this parameters"));
- }
- case OperationType::CONCAT: {
- auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes);
- std::vector<int> channels(inputs.size());
- for (int i = 0; i < inputs.size(); ++i) {
- channels[i] = inputs[i]->tensor.shape.c;
- }
- return SelectConcat(attr, channels, op_def, gpu_info, gpu_op);
- }
- case OperationType::CONVOLUTION_2D: {
- if (inputs.size() != 1) {
- return absl::UnimplementedError(
- "Convolution does not support more than 1 runtime tensor");
- }
- auto attr =
- absl::any_cast<Convolution2DAttributes>(node.operation.attributes);
- auto input_shape = inputs[0]->tensor.shape;
- auto output_shape = outputs[0]->tensor.shape;
- if (WinogradFromNode(gpu_info, inputs, outputs, op_def, input_shape,
- output_shape, attr, gpu_subgraph)
- .ok()) {
- return absl::OkStatus();
- } else {
- auto conv_op =
- CreateConvolutionGeneric(op_def, output_shape, attr, gpu_info);
- *gpu_op = absl::make_unique<ConvolutionGeneric>(std::move(conv_op));
- }
- break;
- }
- case OperationType::CONVOLUTION_TRANSPOSED: {
- if (inputs.size() != 1) {
- return absl::UnimplementedError(
- "Convolution Transposed does not support more than 1 runtime "
- "tensor");
- }
- auto attr = absl::any_cast<ConvolutionTransposedAttributes>(
- node.operation.attributes);
- *gpu_op = SelectConvolutionTransposed(attr, gpu_info, op_def);
- return absl::OkStatus();
- }
- case OperationType::DEPTHWISE_CONVOLUTION: {
- auto attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
- node.operation.attributes);
- if (inputs.size() == 1) {
- *gpu_op = SelectDWConvolution(attr, gpu_info, op_def);
- } else {
- if (inputs[1]->tensor.shape.b != 1) {
- return absl::UnimplementedError(
- "No support of depthwise runtime weights with channel multiplier "
- "!= 1");
- }
- *gpu_op = SelectDWConvolutionDynamicWeights(attr, gpu_info, op_def);
- }
- return absl::OkStatus();
- }
- case OperationType::FULLY_CONNECTED: {
- auto attr =
- absl::any_cast<FullyConnectedAttributes>(node.operation.attributes);
- *gpu_op = SelectFullyConnected(attr, gpu_info, op_def,
- inputs[0]->tensor.shape.b);
- return absl::OkStatus();
- }
- case OperationType::LSTM: {
- *gpu_op = SelectLSTM(op_def, gpu_info);
- return absl::OkStatus();
- }
- case OperationType::MAX_UNPOOLING_2D: {
- auto attr =
- absl::any_cast<MaxUnpooling2DAttributes>(node.operation.attributes);
- *gpu_op = SelectMaxUnpooling(attr, op_def);
- return absl::OkStatus();
- }
- case OperationType::MEAN: {
- auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
- *gpu_op = SelectReduce(attr.dims, inputs[0]->tensor.shape, op_type,
- op_def, gpu_info);
- return absl::OkStatus();
- }
- case OperationType::MEAN_STDDEV_NORMALIZATION: {
- MeanStdDevNormalization operation = CreateMeanStdDevNormalization(
- op_def, gpu_info, (inputs[0]->tensor.shape.c + 3) / 4);
- *gpu_op =
- absl::make_unique<MeanStdDevNormalization>(std::move(operation));
- return absl::OkStatus();
- }
- case OperationType::PAD: {
- auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);
- SelectPadding(attr, op_def, gpu_op);
- return absl::OkStatus();
- }
- case OperationType::POOLING_2D: {
- auto attr =
- absl::any_cast<Pooling2DAttributes>(node.operation.attributes);
- *gpu_op = SelectPooling(attr, op_def);
- return absl::OkStatus();
- }
- case OperationType::PRELU: {
- auto attr = absl::any_cast<PReLUAttributes>(node.operation.attributes);
- *gpu_op = SelectPReLU(attr, gpu_info, op_def);
- return absl::OkStatus();
- }
- case OperationType::REDUCE_MAXIMUM:
- case OperationType::REDUCE_MINIMUM:
- case OperationType::REDUCE_PRODUCT:
- case OperationType::REDUCE_SUM: {
- auto attr = absl::any_cast<ReduceAttributes>(node.operation.attributes);
- *gpu_op = SelectReduce(attr.dims, inputs[0]->tensor.shape, op_type,
- op_def, gpu_info);
- return absl::OkStatus();
- }
- case OperationType::RELU: {
- auto attr = absl::any_cast<ReLUAttributes>(node.operation.attributes);
- *gpu_op = SelectReLU(attr, op_def);
- return absl::OkStatus();
- }
- case OperationType::QUANTIZE_AND_DEQUANTIZE: {
- auto attr = absl::any_cast<QuantizeAndDequantizeAttributes>(
- node.operation.attributes);
- *gpu_op = SelectQuantizeAndDequantize(attr, op_def);
- return absl::OkStatus();
- }
- case OperationType::RESHAPE: {
- const int src_channels = inputs[0]->tensor.shape.c;
- auto attr = absl::any_cast<ReshapeAttributes>(node.operation.attributes);
- SelectReshape(src_channels, attr.new_shape.c, op_def, gpu_op);
- return absl::OkStatus();
- }
- case OperationType::RESIZE: {
- auto attr = absl::any_cast<Resize2DAttributes>(node.operation.attributes);
- return SelectResize(attr, op_def, gpu_op);
- }
- case OperationType::SLICE: {
- auto attr = absl::any_cast<SliceAttributes>(node.operation.attributes);
- SelectStridedSlice(attr, op_def, gpu_op);
- return absl::OkStatus();
- }
- case OperationType::SOFTMAX: {
- SelectSoftmax(inputs[0]->tensor.shape, op_def, gpu_op);
- return absl::OkStatus();
- }
- case OperationType::SPACE_TO_DEPTH: {
- auto attr =
- absl::any_cast<SpaceToDepthAttributes>(node.operation.attributes);
- SelectSpaceToDepth(attr, op_def, gpu_op);
- return absl::OkStatus();
- }
- case OperationType::TRANSPOSE: {
- auto attr =
- absl::any_cast<TransposeAttributes>(node.operation.attributes);
- SelectTranspose(attr, op_def, gpu_op);
- return absl::OkStatus();
- }
- case OperationType::ABS:
- case OperationType::COPY:
- case OperationType::COS:
- case OperationType::ELU:
- case OperationType::EXP:
- case OperationType::HARD_SWISH:
- case OperationType::LOG:
- case OperationType::NEG:
- case OperationType::RSQRT:
- case OperationType::SIGMOID:
- case OperationType::SIN:
- case OperationType::SQRT:
- case OperationType::SQUARE:
- case OperationType::TANH: {
- GPUOperation operation =
- CreateElementwiseOneInput(gpu_info, op_def, op_type);
- *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
- return absl::OkStatus();
- }
- case OperationType::DIV:
- case OperationType::EQUAL:
- case OperationType::GREATER:
- case OperationType::GREATER_EQUAL:
- case OperationType::LESS:
- case OperationType::LESS_EQUAL:
- case OperationType::MAXIMUM:
- case OperationType::MINIMUM:
- case OperationType::MUL:
- case OperationType::NOT_EQUAL:
- case OperationType::POW:
- case OperationType::SQUARED_DIFF:
- case OperationType::SUB: {
- if (inputs.size() == 2) {
- GPUOperation operation =
- CreateElementwiseTwoInput(op_def, op_type, inputs[1]->tensor.shape);
- *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
- return absl::OkStatus();
- } else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
- auto attr =
- absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
- GPUOperation operation =
- CreateElementwise(gpu_info, op_def, op_type, attr);
- *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
- return absl::OkStatus();
- }
- return absl::UnimplementedError(absl::StrCat(
- "No support of ", node.operation.type, " with this parameters"));
- }
- case OperationType::BATCH_NORMALIZATION:
- case OperationType::BATCH_TO_SPACE:
- case OperationType::BATCHED_MATMUL:
- case OperationType::CONSTANT:
- case OperationType::SPACE_TO_BATCH:
- return absl::UnimplementedError("Unsupported op: " + node.operation.type);
- default: {
- ModelHints hints;
- return SelectDefault(gpu_info, op_def, hints, inputs, outputs, node,
- gpu_subgraph);
- }
- }
- return absl::OkStatus();
-}
-
-} // namespace metal
-} // namespace gpu
-} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h
deleted file mode 100644
index 5b73b46..0000000
--- a/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_OPERATION_SELECTOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_OPERATION_SELECTOR_H_
-
-#include <memory>
-
-#include "tensorflow/lite/delegates/gpu/common/model.h"
-#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-
-namespace tflite {
-namespace gpu {
-namespace metal {
-
-absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
- const OperationDef& op_def,
- const std::vector<Value*>& inputs,
- const std::vector<Value*>& outputs,
- const Node& node,
- GPUOperationsSubgraph* gpu_subgraph);
-} // namespace metal
-} // namespace gpu
-} // namespace tflite
-
-#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_SELECTORS_OPERATION_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD
index 86694d9..c41503f 100644
--- a/tensorflow/lite/delegates/nnapi/BUILD
+++ b/tensorflow/lite/delegates/nnapi/BUILD
@@ -39,6 +39,7 @@
"//tensorflow/lite/nnapi:nnapi_implementation",
"//tensorflow/lite/nnapi:nnapi_lib",
"//tensorflow/lite/nnapi:nnapi_util",
+ "@farmhash_archive//:farmhash",
],
)
@@ -76,6 +77,7 @@
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
+ "@farmhash_archive//:farmhash",
],
)
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index 33ac862..5e3acc3 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -57,6 +57,7 @@
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
#include "tensorflow/lite/nnapi/nnapi_util.h"
#include "tensorflow/lite/util.h"
+#include <farmhash.h>
namespace tflite {
namespace {
@@ -2056,7 +2057,11 @@
NNAPIValidationFailureType::kUnsupportedOperandRank,
"Input rank should be less than 4", &val_ctx);
- if (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8 &&
+ const auto& input_type = context->tensors[node->inputs->data[0]].type;
+ EXPECT_INPUT_TYPE_IN(input_type, kTfLiteFloat16, kTfLiteFloat32,
+ kTfLiteUInt8, kTfLiteInt8);
+
+ if (input_type == kTfLiteUInt8 &&
android_sdk_version < kMinSdkVersionForNNAPI12) {
auto first_param = context->tensors[node->inputs->data[0]].params;
for (int i = 1; i < node->inputs->size; i++) {
@@ -3657,9 +3662,10 @@
// TODO(b/133342794): use a generic token generator class.
uint64_t token_parts[4];
// Create bits from model_token.
- // TODO(b/172237993): should not use std::hash, as that is not
+ // Using farmhash fingerprint instead of std::hash, as the latter is not
// guaranteed to be stable across program invocations.
- token_parts[0] = std::hash<std::string>{}(model_token);
+ token_parts[0] =
+ ::util::Fingerprint64(model_token, std::strlen(model_token));
// Create bits from params->nodes_to_replace.
token_parts[1] = GetHash(params->nodes_to_replace);
// Create bits from params->input_tensors. These include the input tensor
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
index 4b12b0d..a7e6d52 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
@@ -130,16 +130,32 @@
// Uses default options.
StatefulNnApiDelegate();
+ // The ownership of the NnApi instance is left to the caller of the
+ // StatefulNnApiDelegate constructor; the caller must ensure that the lifetime
+ // of the NnApi instance exceeds the lifetime of the StatefulNnApiDelegate.
explicit StatefulNnApiDelegate(const NnApi* nnapi);
// The constructor that accepts options from user.
+ // This makes a copy of any data that it needs from Options, so
+ // the caller can safely deallocate any storage pointed to by
+ // the 'const char *' members of Options immediately after calling this.
explicit StatefulNnApiDelegate(Options options);
+ // Constructor that accepts both an NnApi instance and options.
+ // The ownership of the NnApi instance is left to the caller of the
+ // StatefulNnApiDelegate constructor; the caller must ensure that the lifetime
+ // of the NnApi instance exceeds the lifetime of the StatefulNnApiDelegate.
+ // This constructor makes a copy of any data that it needs from Options, so
+ // the caller can safely deallocate any storage pointed to by
+ // the 'const char *' members of Options immediately after calling this.
StatefulNnApiDelegate(const NnApi* nnapi, Options options);
~StatefulNnApiDelegate() = default;
// Returns the delegate options.
+ // The lifetime of the storage pointed to by the 'const char *' members of the
+ // returned Options object is the same as the lifetime of the supplied
+ // TfLiteDelegate instance.
static const Options GetOptions(TfLiteDelegate* delegate);
// Callback function which copies data from ANeuralNetworksMemory to host
diff --git a/tensorflow/lite/experimental/acceleration/configuration/BUILD b/tensorflow/lite/experimental/acceleration/configuration/BUILD
index ec8c664..30ebd3e 100644
--- a/tensorflow/lite/experimental/acceleration/configuration/BUILD
+++ b/tensorflow/lite/experimental/acceleration/configuration/BUILD
@@ -107,11 +107,24 @@
cc_library(
name = "nnapi_plugin",
+ deps = [
+ ":nnapi_plugin_impl",
+ ],
+)
+
+cc_library(
+ name = "nnapi_plugin_impl",
srcs = ["nnapi_plugin.cc"],
+ hdrs = ["nnapi_plugin.h"],
+ visibility = [
+ "//tensorflow/lite/experimental/acceleration/configuration/c:__pkg__",
+ ],
deps = [
":configuration_fbs",
":delegate_registry",
+ "//tensorflow/lite/c:common",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
+ "//tensorflow/lite/experimental/acceleration/configuration/c:delegate_plugin",
"@com_google_absl//absl/memory",
],
alwayslink = 1, # For registration to always run.
diff --git a/tensorflow/lite/experimental/acceleration/configuration/c/BUILD b/tensorflow/lite/experimental/acceleration/configuration/c/BUILD
new file mode 100644
index 0000000..5762f0a
--- /dev/null
+++ b/tensorflow/lite/experimental/acceleration/configuration/c/BUILD
@@ -0,0 +1,44 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# C API for delegate plugins.
+
+package(
+ default_visibility = ["//visibility:private"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "delegate_plugin",
+ hdrs = ["delegate_plugin.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/lite/c:common",
+ ],
+)
+
+cc_library(
+ name = "nnapi_plugin",
+ srcs = ["nnapi_plugin.cc"],
+ hdrs = ["nnapi_plugin.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":delegate_plugin",
+ "//tensorflow/lite/c:common",
+ "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
+ "//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
+ "//tensorflow/lite/experimental/acceleration/configuration:nnapi_plugin_impl",
+ ],
+)
diff --git a/tensorflow/lite/experimental/acceleration/configuration/c/delegate_plugin.h b/tensorflow/lite/experimental/acceleration/configuration/c/delegate_plugin.h
new file mode 100644
index 0000000..d5e6c3d
--- /dev/null
+++ b/tensorflow/lite/experimental/acceleration/configuration/c/delegate_plugin.h
@@ -0,0 +1,60 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_
+
+// C API types for TF Lite delegate plugins.
+
+#include "tensorflow/lite/c/common.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Type of function to allocate and construct a delegate.
+// The tflite_settings parameter should be a pointer to a FlatBuffer table
+// object of type tflite::TFLiteSettings. (We use 'void *' here since this
+// is a C API so we don't want to directly reference C++ types such
+// as tflite::TFLiteSettings.)
+typedef TfLiteDelegate *TfLiteDelegatePluginCreateFunc(
+ const void *tflite_settings);
+
+// Type of function to destroy and deallocate a delegate.
+// The delegate argument must have been created with the corresponding
+// create function from the same delegate plugin.
+typedef void TfLiteDelegatePluginDestroyFunc(TfLiteDelegate *);
+
+// Type of function to return an error code for the last delegate operation.
+// The delegate argument must have been created with the corresponding
+// create function from the same delegate plugin.
+typedef int TfLiteDelegatePluginGetDelegateErrnoFunc(TfLiteDelegate *);
+
+// Struct to hold all the methods for a delegate plugin.
+typedef struct TfLiteDelegatePlugin {
+ // Function to allocate and construct a delegate.
+ TfLiteDelegatePluginCreateFunc *create;
+
+ // Function to deallocate a delegate.
+ TfLiteDelegatePluginDestroyFunc *destroy;
+
+ // Function to return an error code for the last delegate operation.
+ TfLiteDelegatePluginGetDelegateErrnoFunc *get_delegate_errno;
+} TfLiteDelegatePlugin;
+
+#ifdef __cplusplus
+}; // extern "C"
+#endif
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_
diff --git a/tensorflow/lite/experimental/acceleration/configuration/c/nnapi_plugin.cc b/tensorflow/lite/experimental/acceleration/configuration/c/nnapi_plugin.cc
new file mode 100644
index 0000000..c9d94f5
--- /dev/null
+++ b/tensorflow/lite/experimental/acceleration/configuration/c/nnapi_plugin.cc
@@ -0,0 +1,57 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file implements the Delegate Plugin for the NNAPI Delegate.
+// It provides both
+
+#include "tensorflow/lite/experimental/acceleration/configuration/c/nnapi_plugin.h"
+
+#include <memory>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
+#include "tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.h"
+
+extern "C" {
+
+static TfLiteDelegate* CreateDelegate(const void* settings) {
+ const ::tflite::TFLiteSettings* tflite_settings =
+ static_cast<const ::tflite::TFLiteSettings*>(settings);
+ tflite::delegates::NnapiPlugin nnapi_plugin(*tflite_settings);
+ return new tflite::StatefulNnApiDelegate(nnapi_plugin.Options());
+}
+
+static void DestroyDelegate(TfLiteDelegate* delegate) {
+ delete static_cast<tflite::StatefulNnApiDelegate*>(delegate);
+}
+
+static int DelegateErrno(TfLiteDelegate* from_delegate) {
+ auto nnapi_delegate =
+ static_cast<tflite::StatefulNnApiDelegate*>(from_delegate);
+ return nnapi_delegate->GetNnApiErrno();
+}
+
+static constexpr TfLiteDelegatePlugin kPluginCApi{
+ CreateDelegate,
+ DestroyDelegate,
+ DelegateErrno,
+};
+
+const TfLiteDelegatePlugin* TfLiteNnapiDelegatePluginCApi() {
+ return &kPluginCApi;
+}
+
+} // extern "C"
diff --git a/tensorflow/lite/experimental/acceleration/configuration/c/nnapi_plugin.h b/tensorflow/lite/experimental/acceleration/configuration/c/nnapi_plugin.h
new file mode 100644
index 0000000..cef0b44
--- /dev/null
+++ b/tensorflow/lite/experimental/acceleration/configuration/c/nnapi_plugin.h
@@ -0,0 +1,43 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_C_NNAPI_PLUGIN_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_C_NNAPI_PLUGIN_H_
+
+// This header file is for the delegate plugin for NNAPI.
+//
+// For the C++ delegate plugin interface, the NNAPI delegate plugin is added to
+// the DelegatePluginRegistry by the side effect of a constructor for a static
+// object, so there's no public API needed for this plugin, other than the API
+// of tflite::delegates::DelegatePluginRegistry, which is declared in
+// delegate_registry.h.
+//
+// But to provide a C API to access the NNAPI delegate plugin, we do expose
+// some functions, which are declared below.
+
+#include "tensorflow/lite/experimental/acceleration/configuration/c/delegate_plugin.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// C API for the NNAPI delegate plugin.
+// Returns a pointer to a statically allocated table of function pointers.
+const TfLiteDelegatePlugin* TfLiteNnapiDelegatePluginCApi();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_C_NNAPI_PLUGIN_H_
diff --git a/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.cc b/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.cc
index 30dda0d..fdda69a 100644
--- a/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.cc
+++ b/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.cc
@@ -12,98 +12,14 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <memory>
-#include "absl/memory/memory.h"
-#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
-#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
-#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
+// This file implements the TFLite Delegate Plugin for the NNAPI Delegate.
+
+#include "tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.h"
namespace tflite {
namespace delegates {
-inline tflite::StatefulNnApiDelegate::Options::ExecutionPreference
-ConvertExecutionPrefence(
- NNAPIExecutionPreference from_compatibility_preference) {
- using TflitePreference =
- tflite::StatefulNnApiDelegate::Options::ExecutionPreference;
- switch (from_compatibility_preference) {
- case NNAPIExecutionPreference_NNAPI_LOW_POWER:
- return TflitePreference::kLowPower;
- case NNAPIExecutionPreference_NNAPI_FAST_SINGLE_ANSWER:
- return TflitePreference::kFastSingleAnswer;
- case NNAPIExecutionPreference_NNAPI_SUSTAINED_SPEED:
- return TflitePreference::kSustainedSpeed;
- default:
- return TflitePreference::kUndefined;
- }
-}
-
-inline int ConvertExecutionPriority(
- NNAPIExecutionPriority from_compatibility_priority) {
- switch (from_compatibility_priority) {
- case NNAPIExecutionPriority_NNAPI_PRIORITY_LOW:
- return ANEURALNETWORKS_PRIORITY_LOW;
- case NNAPIExecutionPriority_NNAPI_PRIORITY_MEDIUM:
- return ANEURALNETWORKS_PRIORITY_MEDIUM;
- case NNAPIExecutionPriority_NNAPI_PRIORITY_HIGH:
- return ANEURALNETWORKS_PRIORITY_HIGH;
- default:
- return ANEURALNETWORKS_PRIORITY_DEFAULT;
- }
-}
-
-class NnapiPlugin : public DelegatePluginInterface {
- public:
- TfLiteDelegatePtr Create() override {
- auto nnapi_delegate =
- absl::make_unique<tflite::StatefulNnApiDelegate>(options_);
- return TfLiteDelegatePtr(
- nnapi_delegate.release(), [](TfLiteDelegate* delegate) {
- delete reinterpret_cast<tflite::StatefulNnApiDelegate*>(delegate);
- });
- }
- int GetDelegateErrno(TfLiteDelegate* from_delegate) override {
- auto nnapi_delegate =
- reinterpret_cast<tflite::StatefulNnApiDelegate*>(from_delegate);
- return nnapi_delegate->GetNnApiErrno();
- }
- static std::unique_ptr<NnapiPlugin> New(
- const TFLiteSettings& tflite_settings) {
- return absl::make_unique<NnapiPlugin>(tflite_settings);
- }
- explicit NnapiPlugin(const TFLiteSettings& tflite_settings) {
- const NNAPISettings* nnapi_settings = tflite_settings.nnapi_settings();
- if (!nnapi_settings) return;
- if (nnapi_settings->accelerator_name() &&
- nnapi_settings->accelerator_name()->Length() != 0) {
- accelerator_ = nnapi_settings->accelerator_name()->str();
- options_.accelerator_name = accelerator_.c_str();
- }
- if (nnapi_settings->cache_directory() &&
- nnapi_settings->cache_directory()->Length() != 0) {
- cache_dir_ = nnapi_settings->cache_directory()->str();
- options_.cache_dir = cache_dir_.c_str();
- }
- if (nnapi_settings->model_token() &&
- nnapi_settings->model_token()->Length() != 0) {
- model_token_ = nnapi_settings->model_token()->str();
- options_.model_token = model_token_.c_str();
- }
- options_.execution_preference =
- ConvertExecutionPrefence(nnapi_settings->execution_preference());
- options_.disallow_nnapi_cpu =
- !nnapi_settings->allow_nnapi_cpu_on_android_10_plus();
- options_.execution_priority =
- ConvertExecutionPriority(nnapi_settings->execution_priority());
- options_.allow_fp16 = nnapi_settings->allow_fp16_precision_for_fp32();
- }
-
- private:
- std::string accelerator_, cache_dir_, model_token_;
- tflite::StatefulNnApiDelegate::Options options_;
-};
-
TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(NnapiPlugin, NnapiPlugin::New);
} // namespace delegates
diff --git a/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.h b/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.h
new file mode 100644
index 0000000..bf70a0e
--- /dev/null
+++ b/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.h
@@ -0,0 +1,120 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_NNAPI_PLUGIN_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_NNAPI_PLUGIN_H_
+
+// This file provides the NNApiPlugin class, which implements the
+// TFLite Delegate Plugin for the NNAPI Delegate.
+
+#include <memory>
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/experimental/acceleration/configuration/c/delegate_plugin.h"
+#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
+#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
+
+namespace tflite {
+namespace delegates {
+
+class NnapiPlugin : public DelegatePluginInterface {
+ public:
+ TfLiteDelegatePtr Create() override {
+ auto nnapi_delegate =
+ absl::make_unique<tflite::StatefulNnApiDelegate>(options_);
+ return TfLiteDelegatePtr(
+ nnapi_delegate.release(), [](TfLiteDelegate* delegate) {
+ delete static_cast<tflite::StatefulNnApiDelegate*>(delegate);
+ });
+ }
+ int GetDelegateErrno(TfLiteDelegate* from_delegate) override {
+ auto nnapi_delegate =
+ static_cast<tflite::StatefulNnApiDelegate*>(from_delegate);
+ return nnapi_delegate->GetNnApiErrno();
+ }
+ static std::unique_ptr<NnapiPlugin> New(
+ const TFLiteSettings& tflite_settings) {
+ return absl::make_unique<NnapiPlugin>(tflite_settings);
+ }
+ explicit NnapiPlugin(const TFLiteSettings& tflite_settings) {
+ const NNAPISettings* nnapi_settings = tflite_settings.nnapi_settings();
+ if (!nnapi_settings) return;
+ if (nnapi_settings->accelerator_name() &&
+ nnapi_settings->accelerator_name()->Length() != 0) {
+ accelerator_ = nnapi_settings->accelerator_name()->str();
+ options_.accelerator_name = accelerator_.c_str();
+ }
+ if (nnapi_settings->cache_directory() &&
+ nnapi_settings->cache_directory()->Length() != 0) {
+ cache_dir_ = nnapi_settings->cache_directory()->str();
+ options_.cache_dir = cache_dir_.c_str();
+ }
+ if (nnapi_settings->model_token() &&
+ nnapi_settings->model_token()->Length() != 0) {
+ model_token_ = nnapi_settings->model_token()->str();
+ options_.model_token = model_token_.c_str();
+ }
+ options_.execution_preference =
+ ConvertExecutionPrefence(nnapi_settings->execution_preference());
+ options_.disallow_nnapi_cpu =
+ !nnapi_settings->allow_nnapi_cpu_on_android_10_plus();
+ options_.execution_priority =
+ ConvertExecutionPriority(nnapi_settings->execution_priority());
+ options_.allow_fp16 = nnapi_settings->allow_fp16_precision_for_fp32();
+ }
+ const tflite::StatefulNnApiDelegate::Options& Options() { return options_; }
+
+ private:
+ static inline tflite::StatefulNnApiDelegate::Options::ExecutionPreference
+ ConvertExecutionPrefence(
+ NNAPIExecutionPreference from_compatibility_preference) {
+ using TflitePreference =
+ tflite::StatefulNnApiDelegate::Options::ExecutionPreference;
+ switch (from_compatibility_preference) {
+ case NNAPIExecutionPreference_NNAPI_LOW_POWER:
+ return TflitePreference::kLowPower;
+ case NNAPIExecutionPreference_NNAPI_FAST_SINGLE_ANSWER:
+ return TflitePreference::kFastSingleAnswer;
+ case NNAPIExecutionPreference_NNAPI_SUSTAINED_SPEED:
+ return TflitePreference::kSustainedSpeed;
+ default:
+ return TflitePreference::kUndefined;
+ }
+ }
+
+ static inline int ConvertExecutionPriority(
+ NNAPIExecutionPriority from_compatibility_priority) {
+ switch (from_compatibility_priority) {
+ case NNAPIExecutionPriority_NNAPI_PRIORITY_LOW:
+ return ANEURALNETWORKS_PRIORITY_LOW;
+ case NNAPIExecutionPriority_NNAPI_PRIORITY_MEDIUM:
+ return ANEURALNETWORKS_PRIORITY_MEDIUM;
+ case NNAPIExecutionPriority_NNAPI_PRIORITY_HIGH:
+ return ANEURALNETWORKS_PRIORITY_HIGH;
+ default:
+ return ANEURALNETWORKS_PRIORITY_DEFAULT;
+ }
+ }
+
+ std::string accelerator_, cache_dir_, model_token_;
+ tflite::StatefulNnApiDelegate::Options options_;
+};
+
+} // namespace delegates
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_NNAPI_PLUGIN_H_
diff --git a/tensorflow/lite/experimental/microfrontend/BUILD b/tensorflow/lite/experimental/microfrontend/BUILD
index 3f34902..9fc6cac 100644
--- a/tensorflow/lite/experimental/microfrontend/BUILD
+++ b/tensorflow/lite/experimental/microfrontend/BUILD
@@ -2,9 +2,11 @@
load(
"//tensorflow:tensorflow.bzl",
+ "tf_copts",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
+ "tf_opts_nortti_if_android",
"tf_py_test",
)
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -30,11 +32,31 @@
cc_library(
name = "audio_microfrontend_op_lib",
srcs = ["ops/audio_microfrontend_op.cc"],
+ copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [
+ "-Wno-narrowing",
+ "-Wno-sign-compare",
+ "-Wno-overloaded-virtual",
+ ] + select({
+ "//tensorflow:android": [
+ # Selective registration uses constexprs with recursive
+ # string comparisons; that can lead to compiler errors, so
+ # we increase the constexpr recursion depth.
+ "-fconstexpr-depth=1024",
+ "-Oz",
+ ],
+ "//conditions:default": [],
+ }),
deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
"//tensorflow/lite/experimental/microfrontend/lib:frontend",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:portable_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+ }),
alwayslink = 1,
)
diff --git a/tensorflow/lite/experimental/microfrontend/lib/BUILD b/tensorflow/lite/experimental/microfrontend/lib/BUILD
index 57f8055..1a75b40 100644
--- a/tensorflow/lite/experimental/microfrontend/lib/BUILD
+++ b/tensorflow/lite/experimental/microfrontend/lib/BUILD
@@ -1,10 +1,4 @@
# Library for generating feature vectors from audio data
-
-load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
@@ -123,7 +117,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "fft_test",
srcs = ["fft_test.cc"],
deps = [
@@ -132,7 +126,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "filterbank_test",
srcs = ["filterbank_test.cc"],
# Setting copts for experimental code to [], but this code should be fixed
@@ -144,7 +138,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "frontend_test",
srcs = ["frontend_test.cc"],
# Setting copts for experimental code to [], but this code should be fixed
@@ -156,7 +150,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "log_scale_test",
srcs = ["log_scale_test.cc"],
# Setting copts for experimental code to [], but this code should be fixed
@@ -168,7 +162,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "noise_reduction_test",
srcs = ["noise_reduction_test.cc"],
# Setting copts for experimental code to [], but this code should be fixed
@@ -180,7 +174,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "pcan_gain_control_test",
srcs = ["pcan_gain_control_test.cc"],
# Setting copts for experimental code to [], but this code should be fixed
@@ -192,7 +186,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "window_test",
srcs = ["window_test.cc"],
# Setting copts for experimental code to [], but this code should be fixed
diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml
index 815e9bb..faa1a53 100644
--- a/tensorflow/lite/g3doc/_book.yaml
+++ b/tensorflow/lite/g3doc/_book.yaml
@@ -187,7 +187,9 @@
section:
- title: "Cross compilation for ARM"
path: /lite/guide/build_cmake_arm
-
+ - title: "Build Python Wheel"
+ path: /lite/guide/build_cmake_pip
+ status: experimental
- title: "Reduce binary size"
path: /lite/guide/reduce_binary_size
status: experimental
diff --git a/tensorflow/lite/g3doc/guide/build_cmake_pip.md b/tensorflow/lite/g3doc/guide/build_cmake_pip.md
new file mode 100644
index 0000000..d61ac50
--- /dev/null
+++ b/tensorflow/lite/g3doc/guide/build_cmake_pip.md
@@ -0,0 +1,103 @@
+# Build TensorFlow Lite Python Wheel Package
+
+This page describes how to build the TensorFlow Lite `tflite_runtime` Python
+library for x86_64 and various ARM devices.
+
+The following instructions have been tested on Ubuntu 16.04.3 64-bit PC (AMD64)
+, TensorFlow devel Docker image
+[tensorflow/tensorflow:devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
+
+**Note:** This feature is currently experimental and available since version 2.4
+and may change.
+
+#### Prerequisites
+
+You need CMake installed and a copy of the TensorFlow source code. Please check
+[Build TensorFlow Lite with CMake](https://www.tensorflow.org/lite/guide/build_cmake)
+page for the details.
+
+To build the PIP package for your workstation, you can run the following
+commands.
+
+```sh
+PYTHON=python3 tensorflow/lite/tools/pip_package/build_pip_package_with_cmake.sh native
+```
+
+**Note:** If you have multiple Python interpreters available, specify the exact
+Python version with `PYTHON` variable. (Currently, it supports Python 3.5 or
+higher)
+
+## ARM cross compilation
+
+For ARM cross compilation, it's recommanded to use Docker since it makes easier
+to setup cross build environment. Also you needs a `target` option to figure out
+the target architecture.
+
+With the `container` name and the `target` name, you can run the build command
+as followings.
+
+```sh
+tensorflow/tools/ci_build/ci_build.sh <container> \
+ tensorflow/lite/tools/pip_package/build_pip_package_with_cmake.sh <target>
+```
+
+### Available Docker containers
+
+You need to select ARM cross build container for your target Python interpreter
+version. Here is the list of supported containers.
+
+Conainter | Supported Python version
+----------- | ------------------------
+PI | Python 3.5
+PI-PYTHON37 | Python 3.7
+PI-PYTHON38 | Python 3.8
+
+### Available target names
+
+`tensorflow/lite/tools/pip_package/build_pip_package_with_cmake.sh` script needs
+a target name to figure out target architecture. Here is the list of supported
+targets.
+
+Target | Target architecture | Comments
+--------- | -------------------- | --------
+armhf | ARMv7 VFP with Neon | Compatibile with Raspberry Pi 3 and 4
+rpi0 | ARMv6 | Compatibile with Raspberry Pi Zero
+aarch64 | aarch64 (ARM 64-bit) | [Coral Mendel Linux 4.0](https://coral.ai/) <br/> Raspberry Pi with [Ubuntu Server 20.04.01 LTS 64-bit](https://ubuntu.com/download/raspberry-pi)
+native | Your workstation | It builds with "-mnative" optimization
+<default> | Your workstation | Default target
+
+### Build examples
+
+Here are some example commands you can use.
+
+#### armhf target for Python 3.7
+
+```sh
+tensorflow/tools/ci_build/ci_build.sh PI-PYTHON37 \
+ tensorflow/lite/tools/pip_package/build_pip_package_with_cmake.sh armhf
+```
+
+#### aarch64 target for Python 3.8
+
+```sh
+tensorflow/tools/ci_build/ci_build.sh PI-PYTHON38 \
+ tensorflow/lite/tools/pip_package/build_pip_package_with_cmake.sh aarch64
+```
+
+#### How to use a custom toolchain?
+
+If the generated binaries are not compatibile with your target, you need to use
+your own toolchain or provide custom build flags. (Check
+[this](https://www.tensorflow.org/lite/guide/build_cmake_arm#check_your_target_environment)
+to understand your target environment) In that case, you need to modify
+`tensorflow/lite/tools/cmake/download_toolchains.sh` to use your own toolchain.
+The toolchain script defines the following two variables for the
+`build_pip_package_with_cmake.sh` script.
+
+Variable | Purpose | example
+------------ | ------------------------ | -------------------------------
+ARMCC_PREFIX | defines toolchain prefix | arm-linux-gnueabihf-
+ARMCC_FLAGS | compilation flags | -march=armv7-a -mfpu=neon-vfpv4
+
+**Note:** ARMCC_FLAGS might need to contain Python library include path. See the
+`download_toolchains.sh` for the reference.
diff --git a/tensorflow/lite/g3doc/guide/build_ios.md b/tensorflow/lite/g3doc/guide/build_ios.md
index efa70a9..4c21d24 100644
--- a/tensorflow/lite/g3doc/guide/build_ios.md
+++ b/tensorflow/lite/g3doc/guide/build_ios.md
@@ -10,8 +10,9 @@
In some cases, you might wish to use a local build of TensorFlow Lite, for
example when you want to make local changes to TensorFlow Lite and test those
-changes in your iOS app. To create a universal iOS framework for TensorFlow Lite
-locally, you need to build it using Bazel on a macOS machine.
+changes in your iOS app or you prefer using static framework to our provided
+dynamic one. To create a universal iOS framework for TensorFlow Lite locally,
+you need to build it using Bazel on a macOS machine.
### Install Xcode
@@ -42,7 +43,7 @@
answer "Yes" when the script asks if you wish to build TensorFlow with iOS
support.
-### Build TensorFlowLiteC framework
+### Build TensorFlowLiteC dynamic framework (recommended)
Note: This step is not necessary if (1) you are using Bazel for your app, or (2)
you only want to test local changes to the Swift or Objective-C APIs. In these
@@ -64,6 +65,22 @@
you specify `--config=ios_fat`, please refer to the iOS configs section in the
[`.bazelrc` file][bazelrc].
+### Build TensorFlowLiteC static framework
+
+By default, we only distribute the dynamic framework via Cocoapods. If you want
+to use the static framework instead, you can build the `TensorFlowLiteC` static
+framework with the following command:
+
+```
+bazel build --config=ios_fat -c opt \
+ //tensorflow/lite/ios:TensorFlowLiteC_static_framework
+```
+
+The command will generate a file named `TensorFlowLiteC_static_framework.zip`
+under `bazel-bin/tensorflow/lite/ios/` directory under your TensorFlow root
+directory. This static framework can be used in the exact same way as the
+dynamic one.
+
## Use in your own application
### CocoaPods developers
diff --git a/tensorflow/lite/g3doc/guide/op_select_allowlist.md b/tensorflow/lite/g3doc/guide/op_select_allowlist.md
index b9f3ac8..6051b5b 100644
--- a/tensorflow/lite/g3doc/guide/op_select_allowlist.md
+++ b/tensorflow/lite/g3doc/guide/op_select_allowlist.md
@@ -556,6 +556,7 @@
* `raw_ops.StridedSlice`
* `raw_ops.StridedSliceAssign`
* `raw_ops.StridedSliceGrad`
+* `raw_ops.StringFormat`
* `raw_ops.StringJoin`
* `raw_ops.StringLength`
* `raw_ops.StringLower`
diff --git a/tensorflow/lite/g3doc/models/pose_estimation/overview.md b/tensorflow/lite/g3doc/models/pose_estimation/overview.md
index 8aaee70..0629346 100644
--- a/tensorflow/lite/g3doc/models/pose_estimation/overview.md
+++ b/tensorflow/lite/g3doc/models/pose_estimation/overview.md
@@ -199,8 +199,7 @@
* Check out this
[blog post](https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-tensorflow-js-7dd0bc881cd5)
to learn more about pose estimation using TensorFlow JS.
-* Read the PoseNet paper
- [here](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Kendall_PoseNet_A_Convolutional_ICCV_2015_paper.pdf)
+* Read the PoseNet paper [here](https://arxiv.org/abs/1803.08225)
Also, check out these use cases of pose estimation.
diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc
index 010beeb..0b5b6c2 100644
--- a/tensorflow/lite/interpreter_test.cc
+++ b/tensorflow/lite/interpreter_test.cc
@@ -189,9 +189,10 @@
TfLiteType type;
size_t size;
} cases[] = {
- {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
- {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)},
- {kTfLiteInt16, sizeof(int16_t)}, {kTfLiteFloat16, sizeof(TfLiteFloat16)},
+ {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
+ {kTfLiteUInt32, sizeof(uint32_t)}, {kTfLiteUInt8, sizeof(uint8_t)},
+ {kTfLiteInt64, sizeof(int64_t)}, {kTfLiteInt16, sizeof(int16_t)},
+ {kTfLiteFloat16, sizeof(TfLiteFloat16)},
};
for (auto test : cases) {
@@ -261,6 +262,7 @@
TEST(BasicInterpreter, CheckResize) {
const float floats[] = {-3., -4.};
const int32_t int32s[] = {-3, -4};
+ const uint32_t uint32s[] = {3, 4};
const uint8_t uint8s[] = {3, 4};
const int64_t int64s[] = {6, -7};
const int16_t int16s[] = {8, -9};
@@ -274,6 +276,7 @@
} cases[] = {
{kTfLiteFloat32, sizeof(float), reinterpret_cast<const char*>(floats)},
{kTfLiteInt32, sizeof(int32_t), reinterpret_cast<const char*>(int32s)},
+ {kTfLiteUInt32, sizeof(uint32_t), reinterpret_cast<const char*>(uint32s)},
{kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
{kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
{kTfLiteInt16, sizeof(int16_t), reinterpret_cast<const char*>(int16s)},
@@ -313,8 +316,9 @@
TEST(BasicInterpreter, CheckAlignment) {
struct {
TfLiteType type;
- } cases[] = {{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8},
- {kTfLiteInt64}, {kTfLiteInt16}, {kTfLiteFloat16}};
+ } cases[] = {{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt32},
+ {kTfLiteUInt8}, {kTfLiteInt64}, {kTfLiteInt16},
+ {kTfLiteFloat16}};
for (auto test : cases) {
Interpreter interpreter;
diff --git a/tensorflow/lite/ios/BUILD.apple b/tensorflow/lite/ios/BUILD.apple
index 51590cf..dc150d1 100644
--- a/tensorflow/lite/ios/BUILD.apple
+++ b/tensorflow/lite/ios/BUILD.apple
@@ -5,7 +5,7 @@
"//tensorflow/lite/ios:ios.bzl",
"TFL_MINIMUM_OS_VERSION",
"strip_common_include_path_prefix",
- "tflite_ios_static_framework",
+ "tflite_ios_framework",
)
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
@@ -99,7 +99,7 @@
)
# bazel build -c opt --config=ios_fat //tensorflow/lite/ios:TensorFlowLiteC_framework
-tflite_ios_static_framework(
+tflite_ios_framework(
name = "TensorFlowLiteC_framework",
hdrs = [
":c_api.h",
@@ -115,6 +115,23 @@
],
)
+# Similar to TensorFlowLiteC_framework but this is a static framework and symbol
+# hiding is not applied. Note both have the same bundle name.
+ios_static_framework(
+ name = "TensorFlowLiteC_static_framework",
+ hdrs = [
+ ":c_api.h",
+ ":common.h",
+ ":xnnpack_delegate.h",
+ "//tensorflow/lite/c:c_api_types.h",
+ ],
+ bundle_name = "TensorFlowLiteC",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ deps = [
+ ":tensorflow_lite_c",
+ ],
+)
+
# This target builds the flex delegate as a separate static framework, which
# does not include the TensorFlow Lite runtime. As this target does not contain
# TensorFlow Lite runtime, it is intended to be linked along with the
@@ -140,7 +157,7 @@
# TensorFlowLiteC framework above in a composable way.
#
# bazel build -c opt --config=ios_fat //tensorflow/lite/ios:TensorFlowLiteCCoreML_framework
-tflite_ios_static_framework(
+tflite_ios_framework(
name = "TensorFlowLiteCCoreML_framework",
hdrs = [
":coreml_delegate.h",
@@ -159,7 +176,7 @@
# TensorFlowLiteC framework above in a composable way.
#
# bazel build -c opt --config=ios_fat //tensorflow/lite/ios:TensorFlowLiteCMetal_framework
-tflite_ios_static_framework(
+tflite_ios_framework(
name = "TensorFlowLiteCMetal_framework",
hdrs = [
":metal_delegate.h",
@@ -208,5 +225,6 @@
":TensorFlowLiteCMetal_framework",
":TensorFlowLiteC_framework",
":TensorFlowLiteSelectTfOps_framework",
+ ":TensorFlowLiteC_static_framework",
],
)
diff --git a/tensorflow/lite/ios/ios.bzl b/tensorflow/lite/ios/ios.bzl
index acb9cab..41c13f8 100644
--- a/tensorflow/lite/ios/ios.bzl
+++ b/tensorflow/lite/ios/ios.bzl
@@ -17,17 +17,17 @@
"notsan",
]
-# iOS static framework with symbol allowlist. Exported C++ symbols might cause
-# symbol collision with other libraries. List of symbols to allowlist can be
+# iOS framework with symbol allowlist. Exported C++ symbols might cause symbol
+# collision with other libraries. List of symbols to allowlist can be
# generated by running `nm -m -g FRAMEWORK_LIBRARY | grep _TfLite` for framework
# built with `ios_static_framework` rule.
-def tflite_ios_static_framework(
+def tflite_ios_framework(
name,
bundle_name,
allowlist_symbols_file,
exclude_resources = True,
**kwargs):
- """TFLite variant of ios_static_framework with symbol hiding.
+ """Apply symbol hiding to the output of ios_static_framework.
Args:
name: The name of the target.
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java
index 5a57734..5eec1fa 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java
@@ -33,7 +33,8 @@
* <p>Note: The Java {@link Delegate} maintains ownership of the native delegate instance, and
* must ensure its existence for the duration of usage with any {@link Interpreter}.
*
- * @return The native delegate handle.
+ * @return The native delegate handle. In C/C++, this should be a pointer to
+ * 'TfLiteOpaqueDelegate'.
*/
public long getNativeHandle();
}
diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD
index c92aa16..97ca710 100644
--- a/tensorflow/lite/java/src/main/native/BUILD
+++ b/tensorflow/lite/java/src/main/native/BUILD
@@ -26,12 +26,14 @@
"-ldl",
],
deps = [
+ "//tensorflow/lite:minimal_logging",
"//tensorflow/lite:op_resolver",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
"//tensorflow/lite:util",
"//tensorflow/lite/core/shims:common",
"//tensorflow/lite/core/shims:framework",
+ "//tensorflow/lite/core/shims:jni_initialization",
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
"//tensorflow/lite/java/jni",
],
diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 840985b..302e15c 100644
--- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -19,6 +19,7 @@
#include <time.h>
#include <atomic>
+#include <map>
#include <vector>
#include "tensorflow/lite/core/shims/c/common.h"
@@ -69,13 +70,13 @@
return reinterpret_cast<BufferErrorReporter*>(handle);
}
-TfLiteDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) {
+TfLiteOpaqueDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) {
if (handle == 0) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Invalid handle to delegate.");
return nullptr;
}
- return reinterpret_cast<TfLiteDelegate*>(handle);
+ return reinterpret_cast<TfLiteOpaqueDelegate*>(handle);
}
std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
@@ -162,8 +163,7 @@
// from either inputs or outputs.
// Returns -1 if invalid names are passed.
int GetTensorIndexForSignature(JNIEnv* env, jstring signature_tensor_name,
- jstring method_name,
- tflite::Interpreter* interpreter,
+ jstring method_name, Interpreter* interpreter,
bool is_input) {
// Fetch name strings.
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
@@ -271,7 +271,7 @@
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames(
JNIEnv* env, jclass clazz, jlong handle) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
jclass string_class = env->FindClass("java/lang/String");
if (string_class == nullptr) {
@@ -293,7 +293,7 @@
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs(
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
const jobjectArray signature_inputs = GetSignatureInputsOutputsList(
@@ -306,7 +306,7 @@
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureOutputs(
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
const jobjectArray signature_outputs = GetSignatureInputsOutputsList(
@@ -320,7 +320,7 @@
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndexFromSignature(
JNIEnv* env, jclass clazz, jlong handle, jstring signature_input_name,
jstring method_name) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return -1;
return GetTensorIndexForSignature(env, signature_input_name, method_name,
interpreter, /*is_input=*/true);
@@ -330,7 +330,7 @@
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndexFromSignature(
JNIEnv* env, jclass clazz, jlong handle, jstring signature_output_name,
jstring method_name) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return -1;
return GetTensorIndexForSignature(env, signature_output_name, method_name,
interpreter, /*is_input=*/false);
@@ -646,7 +646,7 @@
if (is_changed) {
TfLiteStatus status;
if (strict) {
- status = interpreter->ResizeInputTensorStrict(
+ status = interpreter->ResizeInputTensorStrict(
tensor_idx, convertJIntArrayToVector(env, dims));
} else {
status = interpreter->ResizeInputTensor(
@@ -673,7 +673,7 @@
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return;
- TfLiteDelegate* delegate = convertLongToDelegate(env, delegate_handle);
+ TfLiteOpaqueDelegate* delegate = convertLongToDelegate(env, delegate_handle);
if (delegate == nullptr) return;
TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate);
@@ -709,6 +709,7 @@
if (interpreter == nullptr) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Invalid handle to interpreter.");
+ return 0;
}
std::atomic_bool* cancellation_flag = new std::atomic_bool(false);
interpreter->SetCancellationFunction(cancellation_flag, [](void* payload) {
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java
index 80b3bf3..79326a5 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java
@@ -38,11 +38,11 @@
private static final ByteBuffer MOBILENET_FLOAT_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(
- "third_party/tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224.tflite");
+ "tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224.tflite");
private static final ByteBuffer MOBILENET_QUANTIZED_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(
- "third_party/tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
+ "tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
@Test
public void testMobileNet() {
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
index de320fd..66bc35b 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
@@ -39,7 +39,7 @@
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
private static final ByteBuffer MOBILENET_QUANTIZED_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(
- "third_party/tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
+ "tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
@Test
public void testBasic() throws Exception {
diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc
index 03795c5..06a4bdc 100644
--- a/tensorflow/lite/kernels/activations_test.cc
+++ b/tensorflow/lite/kernels/activations_test.cc
@@ -1884,6 +1884,26 @@
{.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778})));
}
+TEST(FloatActivationsOpTest, Softmax1DMax) {
+ FloatActivationsOpModel m(0.1f, {TensorType_FLOAT32, {8}},
+ TensorType_FLOAT32);
+ m.SetInput({std::numeric_limits<float>::max(), -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({1, 0, 0, 0, 0, 0, 0, 0})));
+}
+
+TEST(FloatActivationsOpTest, Softmax1DInf) {
+ FloatActivationsOpModel m(0.1f, {TensorType_FLOAT32, {8}},
+ TensorType_FLOAT32);
+ m.SetInput({std::numeric_limits<float>::infinity(), -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ auto output = m.GetOutput();
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_TRUE(isnan(output[i]));
+ }
+}
+
TEST(QuantizedActivationsOpTest, Softmax1DUint8) {
QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {8}, -10, 10},
TensorType_UINT8);
diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc
index bbd4708..d5c3aa9 100644
--- a/tensorflow/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/lite/kernels/detection_postprocess.cc
@@ -818,6 +818,13 @@
return &r;
}
+// Since the op is named "TFLite_Detection_PostProcess", the selective build
+// tool will assume the register function is named
+// "Register_TFLITE_DETECTION_POST_PROCESS".
+TfLiteRegistration* Register_TFLITE_DETECTION_POST_PROCESS() {
+ return Register_DETECTION_POSTPROCESS();
+}
+
} // namespace custom
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc
index 51cc86c..2a5041c 100644
--- a/tensorflow/lite/kernels/fully_connected.cc
+++ b/tensorflow/lite/kernels/fully_connected.cc
@@ -431,62 +431,77 @@
return kTfLiteOk;
}
-TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
- TfLiteFullyConnectedParams* params, OpData* data,
- const TfLiteTensor* input, const TfLiteTensor* filter,
- const TfLiteTensor* bias, TfLiteTensor* input_quantized,
- TfLiteTensor* scaling_factors,
- TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
- TfLiteTensor* input_offsets, TfLiteTensor* output) {
- int total_input_size = 1;
- for (int i = 0; i < input->dims->size; i++) {
- total_input_size *= input->dims->data[i];
- }
+void EvalHybridImpl(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, int thread_start, int thread_end,
+ TfLiteTensor* input_quantized,
+ TfLiteTensor* scaling_factors, TfLiteTensor* accum_scratch,
+ TfLiteTensor* row_sums, TfLiteTensor* input_offsets,
+ TfLiteTensor* output) {
+ ruy::profiler::ScopeLabel label("FullyConnected");
+ ruy::profiler::ScopeLabel inner_label("Hybrid Kernel");
+ const auto& input_shape = GetTensorShape(input);
+ const auto& output_shape = GetTensorShape(output);
+ const auto& filter_shape = GetTensorShape(filter);
+ const int input_dims_count = input_shape.DimensionsCount();
+ const int output_dims_count = output_shape.DimensionsCount();
+ const int filter_dims_count = filter_shape.DimensionsCount();
+ const int batch_size = thread_end - thread_start;
+ const int input_depth = MatchingDim(filter_shape, filter_dims_count - 1,
+ input_shape, input_dims_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dims_count - 2,
+ output_shape, output_dims_count - 1);
+ const int per_thread_input_size = batch_size * input_depth;
- const int input_size = filter->dims->data[1];
- const int batch_size = total_input_size / filter->dims->data[1];
- const int num_units = filter->dims->data[0];
const bool is_sparse = filter->sparsity != nullptr;
+ const float* per_thread_input =
+ GetTensorData<float>(input) + thread_start * input_depth;
+ float* per_thread_output =
+ GetTensorData<float>(output) + thread_start * output_depth;
+
// Output = bias if bias tensor exists.
if (bias) {
- tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
- batch_size,
- GetTensorData<float>(output));
+ tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias),
+ output_depth, batch_size,
+ per_thread_output);
} else {
- std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
+ std::fill_n(per_thread_output, batch_size * output_depth, 0.0f);
}
// Save matrix multiplication computation for all zero input.
- if (tensor_utils::IsZeroVector(GetTensorData<float>(input),
- total_input_size)) {
+ if (tensor_utils::IsZeroVector(per_thread_input, per_thread_input_size)) {
tensor_utils::ApplyActivationToVector(
- GetTensorData<float>(output), batch_size * num_units,
- params->activation, GetTensorData<float>(output));
- return kTfLiteOk;
+ per_thread_output, batch_size * output_depth, params->activation,
+ per_thread_output);
+ return;
}
// Quantize input from float to uint8 + quantization params (scaling factor).
- float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
+ float* scaling_factors_ptr =
+ GetTensorData<float>(scaling_factors) + thread_start;
int32_t* input_offset_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
- input_offset_ptr = GetTensorData<int32_t>(input_offsets);
+ input_offset_ptr = GetTensorData<int32_t>(input_offsets) + thread_start;
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
- int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
+ int8_t* quant_data =
+ GetTensorData<int8_t>(input_quantized) + thread_start * input_depth;
const int8_t* filter_data = GetTensorData<int8_t>(filter);
- const float* input_ptr = GetTensorData<float>(input);
- tensor_utils::BatchQuantizeFloats(
- input_ptr, batch_size, input_size, quant_data, scaling_factors_ptr,
- input_offset_ptr, params->asymmetric_quantize_inputs);
+ tensor_utils::BatchQuantizeFloats(per_thread_input, batch_size, input_depth,
+ quant_data, scaling_factors_ptr,
+ input_offset_ptr,
+ params->asymmetric_quantize_inputs);
for (int b = 0; b < batch_size; ++b) {
// Incorporate scaling of the filter.
scaling_factors_ptr[b] *= filter->params.scale;
}
// Compute output += weight * quantized_input
- int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
+ int32_t* scratch =
+ GetTensorData<int32_t>(accum_scratch) + thread_start * output_depth;
if (is_sparse) {
TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
if (!data->ledger_initialized) {
@@ -496,20 +511,116 @@
}
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
- num_units, input_size, quant_data, scaling_factors_ptr, batch_size,
- GetTensorData<float>(output));
+ output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size,
+ per_thread_output);
} else {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
- batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
+ filter_data, output_depth, input_depth, quant_data, scaling_factors_ptr,
+ batch_size, per_thread_output, /*per_channel_scale=*/nullptr,
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
}
// Apply activation function to floats.
- tensor_utils::ApplyActivationToVector(
- GetTensorData<float>(output), batch_size * num_units, params->activation,
- GetTensorData<float>(output));
+ tensor_utils::ApplyActivationToVector(per_thread_output,
+ batch_size * output_depth,
+ params->activation, per_thread_output);
+}
+
+struct HybridFullyConnectedTask : cpu_backend_threadpool::Task {
+ HybridFullyConnectedTask(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ const TfLiteTensor* input,
+ const TfLiteTensor* filter, const TfLiteTensor* bias,
+ const int thread_start, const int thread_end,
+ TfLiteTensor* input_quantized,
+ TfLiteTensor* scaling_factors,
+ TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
+ TfLiteTensor* input_offsets, TfLiteTensor* output)
+ : context(context),
+ node(node),
+ params(params),
+ data(data),
+ input(input),
+ filter(filter),
+ bias(bias),
+ thread_start(thread_start),
+ thread_end(thread_end),
+ input_quantized(input_quantized),
+ scaling_factors(scaling_factors),
+ accum_scratch(accum_scratch),
+ row_sums(row_sums),
+ input_offsets(input_offsets),
+ output(output) {}
+
+ void Run() override {
+ EvalHybridImpl(context, node, params, data, input, filter, bias,
+ thread_start, thread_end, input_quantized, scaling_factors,
+ accum_scratch, row_sums, input_offsets, output);
+ }
+
+ private:
+ TfLiteContext* context;
+ TfLiteNode* node;
+ TfLiteFullyConnectedParams* params;
+ OpData* data;
+ const TfLiteTensor* input;
+ const TfLiteTensor* filter;
+ const TfLiteTensor* bias;
+ const int thread_start;
+ const int thread_end;
+ TfLiteTensor* input_quantized;
+ TfLiteTensor* scaling_factors;
+ TfLiteTensor* accum_scratch;
+ TfLiteTensor* row_sums;
+ TfLiteTensor* input_offsets;
+ TfLiteTensor* output;
+};
+
+// The multi-threaded kernel slices the workload along the batch dimension. If
+// there's not enough batches of data, the number of threads used is equal to
+// the batch size.
+// TODO(b/173442777): If needed, we can improve this later with slicing along
+// the row dimension of the weight.
+TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* input_quantized,
+ TfLiteTensor* scaling_factors,
+ TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
+ TfLiteTensor* input_offsets, TfLiteTensor* output) {
+ const auto& output_shape = GetTensorShape(output);
+ CpuBackendContext* cpu_backend_context =
+ CpuBackendContext::GetFromContext(context);
+ const int max_threads = cpu_backend_context->max_num_threads();
+ const int batches =
+ FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
+ const int thread_count = std::max(1, std::min(batches, max_threads));
+ if (thread_count == 1) {
+ EvalHybridImpl(context, node, params, data, input, filter, bias, 0, batches,
+ input_quantized, scaling_factors, accum_scratch, row_sums,
+ input_offsets, output);
+ return kTfLiteOk;
+ }
+
+ std::vector<HybridFullyConnectedTask> tasks;
+ tasks.reserve(thread_count);
+ int thread_start = 0;
+ for (int i = 0; i < thread_count; ++i) {
+ // This makes sure the workload is relatively balanced when batches is not
+ // a multiple of thread_count. The first mod(batches, thread_count) tasks
+ // need to process one more batch than the rest.
+ int thread_end = thread_start + batches / thread_count;
+ if (i < batches % thread_count) thread_end++;
+
+ tasks.emplace_back(context, node, params, data, input, filter, bias,
+ thread_start, thread_end, input_quantized,
+ scaling_factors, accum_scratch, row_sums, input_offsets,
+ output);
+ thread_start = thread_end;
+ }
+ cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
+ cpu_backend_context);
return kTfLiteOk;
}
diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc
index ba48030..0adc80a 100644
--- a/tensorflow/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/lite/kernels/fully_connected_test.cc
@@ -296,7 +296,8 @@
HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
const TensorData& weights,
const TensorData& output = {TensorType_FLOAT32},
- bool asymmetric_inputs = false)
+ bool asymmetric_inputs = false,
+ int num_threads = 1)
: batches_(batches), units_(units) {
int total_input_size = 1;
for (size_t i = 0; i < input.shape.size(); ++i) {
@@ -322,7 +323,9 @@
resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_FULLY_CONNECTED,
ops::builtin::Register_FULLY_CONNECTED_PIE());
- BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
+ BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)},
+ num_threads, /*allow_fp32_relax_to_fp16=*/false,
+ /*apply_delegate=*/false);
}
void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
void SetWeights(const std::vector<float>& data) {
@@ -879,6 +882,44 @@
/*max_abs_error=*/1.3f)));
}
+TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8MultiThreaded) {
+ for (int num_threads = 1; num_threads <= 4; ++num_threads) {
+ HybridFullyConnectedOpModel m(
+ /*units=*/3, /*batches=*/4,
+ /*input=*/{TensorType_FLOAT32, {4, 10}},
+ /*weights=*/
+ {TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
+ /*output=*/{TensorType_FLOAT32}, /*asymmetric_inputs=*/false,
+ /*num_threads=*/num_threads); // Hybrid
+
+ m.SetSignedWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 2
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 3
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 3));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 24, 25, 26, //
+ 58, 59, 60, //
+ 24, 25, 26, //
+ 58, 59, 60, //
+ },
+ /*max_abs_error=*/1.3f)));
+ }
+}
+
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
HybridFullyConnectedOpModel m(
/*units=*/3, /*batches=*/2,
@@ -1413,6 +1454,76 @@
ElementsAreArray(ArrayFloatNear(
{0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0}, 1e-3)));
}
+
+TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16TestMultiThreaded) {
+ std::initializer_list<float> weight_data = {
+ /* 1st row */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
+ 10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
+ /* 2nd row */
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ /* 3rd row */
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
+ -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ /* 4th row */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
+ 8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
+ TensorData weight = {};
+ weight.type = TensorType_FLOAT32;
+ weight.shape = {4, 48};
+ weight.traversal_order = {0, 1, 2};
+ weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
+ weight.block_map = {1};
+ weight.block_size = {16};
+ for (int num_threads = 1; num_threads <= 4; ++num_threads) {
+ SparseFullyConnectedOpModel<float> m(
+ GetRegistration(),
+ /*units=*/4, /*batches=*/4,
+ /*input=*/{TensorType_FLOAT32, {4, 48}}, weight, weight_data,
+ /*num_threads)=*/num_threads, /*symmetric_quantize_weights=*/true);
+ m.SetBias({1, 2, 3, 4});
+ m.SetInput({
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 0
+ 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
+ -1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
+ 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
+ -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
+ 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 1
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
+ 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 2
+ 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
+ -1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
+ 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
+ -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
+ 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 3
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 4));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0, 0,
+ 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0},
+ 1e-3)));
+ }
+}
// TODO(b/148391360): Add tests for unsupported sparsity format.
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
diff --git a/tensorflow/lite/kernels/hashtable/hashtable.cc b/tensorflow/lite/kernels/hashtable/hashtable.cc
index bca8f2c..baea163 100644
--- a/tensorflow/lite/kernels/hashtable/hashtable.cc
+++ b/tensorflow/lite/kernels/hashtable/hashtable.cc
@@ -83,10 +83,23 @@
TfLiteTensor* resource_handle_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kResourceHandleTensor,
&resource_handle_tensor));
- TF_LITE_ENSURE_EQ(context, resource_handle_tensor->type, kTfLiteInt32);
+ TF_LITE_ENSURE(context, resource_handle_tensor->type == kTfLiteResource ||
+ resource_handle_tensor->type == kTfLiteInt32);
+
+ // Resource tensor buffer as a hash table handler will have an 32-bit integer
+ // identity.
+ size_t bytesRequired = sizeof(int32_t);
+ resource_handle_tensor->bytes = bytesRequired;
+ // Realloc space for an integer handle value.
+ TfLiteTensorRealloc(bytesRequired, resource_handle_tensor);
+
+ // Make shape be [1] to store one integer value.
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
outputSize->data[0] = 1;
- return context->ResizeTensor(context, resource_handle_tensor, outputSize);
+ if (resource_handle_tensor->dims)
+ TfLiteIntArrayFree(resource_handle_tensor->dims);
+ resource_handle_tensor->dims = outputSize;
+ return kTfLiteOk;
}
TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
@@ -95,14 +108,12 @@
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
// The resource id is generated based on the given table name.
- const int resource_id = std::hash<std::string>{}(params->table_name);
+ const int32_t resource_id = std::hash<std::string>{}(params->table_name);
TfLiteTensor* resource_handle_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kResourceHandleTensor,
&resource_handle_tensor));
- auto* resource_handle_data =
- GetTensorData<std::int32_t>(resource_handle_tensor);
- resource_handle_data[0] = resource_id;
+ *resource_handle_tensor->data.i32 = resource_id;
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto& resources = subgraph->resources();
@@ -120,6 +131,9 @@
return &r;
}
+// Alias for selective build.
+TfLiteRegistration* Register_HASH_TABLE_V2() { return Register_HASHTABLE(); }
+
} // namespace custom
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/hashtable/hashtable_find.cc b/tensorflow/lite/kernels/hashtable/hashtable_find.cc
index f26fe82..39be535 100644
--- a/tensorflow/lite/kernels/hashtable/hashtable_find.cc
+++ b/tensorflow/lite/kernels/hashtable/hashtable_find.cc
@@ -37,7 +37,9 @@
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
- TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
+
+ TF_LITE_ENSURE(context, input_resource_id_tensor->type == kTfLiteResource ||
+ input_resource_id_tensor->type == kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
@@ -96,6 +98,11 @@
return &r;
}
+// Alias for selective build.
+TfLiteRegistration* Register_LOOKUP_TABLE_FIND_V2() {
+ return Register_HASHTABLE_FIND();
+}
+
} // namespace custom
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/hashtable/hashtable_import.cc b/tensorflow/lite/kernels/hashtable/hashtable_import.cc
index fad9345..806a398 100644
--- a/tensorflow/lite/kernels/hashtable/hashtable_import.cc
+++ b/tensorflow/lite/kernels/hashtable/hashtable_import.cc
@@ -36,7 +36,8 @@
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
- TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
+ TF_LITE_ENSURE(context, input_resource_id_tensor->type == kTfLiteResource ||
+ input_resource_id_tensor->type == kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
@@ -91,6 +92,11 @@
return &r;
}
+// Alias for selective build.
+TfLiteRegistration* Register_LOOKUP_TABLE_IMPORT_V2() {
+ return Register_HASHTABLE_IMPORT();
+}
+
} // namespace custom
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/hashtable/hashtable_ops_test.cc b/tensorflow/lite/kernels/hashtable/hashtable_ops_test.cc
index f4a0d3c..c010744 100644
--- a/tensorflow/lite/kernels/hashtable/hashtable_ops_test.cc
+++ b/tensorflow/lite/kernels/hashtable/hashtable_ops_test.cc
@@ -355,8 +355,8 @@
}
// Resource id tensor.
- interpreter_->SetTensorParametersReadWrite(kResourceTensorId, kTfLiteInt32,
- "", {1}, TfLiteQuantization());
+ interpreter_->SetTensorParametersReadWrite(
+ kResourceTensorId, kTfLiteResource, "", {1}, TfLiteQuantization());
// Key tensor for import.
interpreter_->SetTensorParametersReadWrite(kKeyTensorId, key_type_, "",
@@ -389,7 +389,7 @@
if (table_two_initialization) {
// Resource id tensor.
interpreter_->SetTensorParametersReadWrite(
- kResourceTwoTensorId, kTfLiteInt32, "", {1}, TfLiteQuantization());
+ kResourceTwoTensorId, kTfLiteResource, "", {1}, TfLiteQuantization());
// Key tensor for import.
interpreter_->SetTensorParametersReadWrite(
@@ -627,7 +627,8 @@
graph.SetQuery({"2", "3", "4"}, -1);
graph.AddTensors();
graph.BuildDefaultGraph();
- EXPECT_EQ(graph.AllocateTensors(), kTfLiteError);
+ EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
+ EXPECT_EQ(graph.Invoke(), kTfLiteError);
}
// HashtableOpModel creates a model with one signle Hashtable op.
@@ -635,7 +636,7 @@
public:
explicit HashtableOpModel(const char* table_name, TensorType key_dtype,
TensorType value_dtype) {
- output_ = AddOutput(GetTensorType<int>());
+ output_ = AddOutput({TensorType_RESOURCE, {1}});
// Set up and pass in custom options using flexbuffer.
flexbuffers::Builder fbb;
@@ -650,7 +651,10 @@
BuildInterpreter({});
}
- std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
+ std::vector<int> GetOutput() {
+ TfLiteTensor* tensor_ptr = interpreter_->tensor(output_);
+ return std::vector<int>(tensor_ptr->data.i32, tensor_ptr->data.i32 + 1);
+ }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
resource::ResourceMap& GetResources() {
@@ -744,7 +748,20 @@
BaseHashtableOpModel() {}
void SetResourceId(const std::vector<int>& data) {
- PopulateTensor(resource_id_, data);
+ int32_t* tensor_buffer =
+ reinterpret_cast<int32_t*>(malloc(sizeof(int32_t)));
+ tensor_buffer[0] = data[0];
+
+ TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
+ dims->data[0] = 1;
+
+ auto resource_handle_tensor = interpreter_->tensor(resource_id_);
+
+ TfLiteTensorReset(
+ resource_handle_tensor->type, resource_handle_tensor->name, dims,
+ resource_handle_tensor->params, reinterpret_cast<char*>(tensor_buffer),
+ sizeof(int32_t), kTfLiteDynamic, resource_handle_tensor->allocation,
+ resource_handle_tensor->is_variable, resource_handle_tensor);
}
void CreateHashtableResource(int resource_id) {
@@ -786,7 +803,7 @@
key_type_ = key_type;
value_type_ = value_type;
- resource_id_ = AddInput({TensorType_INT32, {1}});
+ resource_id_ = AddInput({TensorType_RESOURCE, {1}});
lookup_ = AddInput({key_type, {lookup_size}});
default_value_ = AddInput({value_type, {1}});
@@ -864,7 +881,7 @@
key_type_ = key_type;
value_type_ = value_type;
- resource_id_ = AddInput({TensorType_INT32, {1}});
+ resource_id_ = AddInput({TensorType_RESOURCE, {1}});
keys_ = AddInput({key_type, {initdata_size}});
values_ = AddInput({value_type, {initdata_size}});
@@ -941,7 +958,7 @@
key_type_ = key_type;
value_type_ = value_type;
- resource_id_ = AddInput({TensorType_INT32, {1}});
+ resource_id_ = AddInput({TensorType_RESOURCE, {1}});
output_ = AddOutput({TensorType_INT64, {1}});
diff --git a/tensorflow/lite/kernels/hashtable/hashtable_size.cc b/tensorflow/lite/kernels/hashtable/hashtable_size.cc
index 34a8031..a993f5b 100644
--- a/tensorflow/lite/kernels/hashtable/hashtable_size.cc
+++ b/tensorflow/lite/kernels/hashtable/hashtable_size.cc
@@ -35,7 +35,8 @@
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
- TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
+ TF_LITE_ENSURE(context, input_resource_id_tensor->type == kTfLiteResource ||
+ input_resource_id_tensor->type == kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
@@ -78,6 +79,11 @@
return &r;
}
+// Alias for selective build.
+TfLiteRegistration* Register_LOOKUP_TABLE_SIZE_V2() {
+ return Register_HASHTABLE_SIZE();
+}
+
} // namespace custom
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index edd1ffd..1a905ef 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -2146,21 +2146,24 @@
auto input2_map = MapAsVector(input2_data, input2_shape);
auto output_map = MapAsVector(output_data, output_shape);
if (input1_shape == input2_shape) {
- output_map.array() = input1_map.array() + input2_map.array();
+ output_map.array() = (input1_map.array() + input2_map.array())
+ .cwiseMax(params.quantized_activation_min)
+ .cwiseMin(params.quantized_activation_max);
} else if (input2_shape.FlatSize() == 1) {
auto scalar = input2_data[0];
- output_map.array() = input1_map.array() + scalar;
+ output_map.array() = (input1_map.array() + scalar)
+ .cwiseMax(params.quantized_activation_min)
+ .cwiseMin(params.quantized_activation_max);
} else if (input1_shape.FlatSize() == 1) {
auto scalar = input1_data[0];
- output_map.array() = scalar + input2_map.array();
+ output_map.array() = (scalar + input2_map.array())
+ .cwiseMax(params.quantized_activation_min)
+ .cwiseMin(params.quantized_activation_max);
} else {
reference_ops::BroadcastAdd4DSlow(params, input1_shape, input1_data,
input2_shape, input2_data, output_shape,
output_data);
- return;
}
- output_map = output_map.cwiseMax(params.quantized_activation_min);
- output_map = output_map.cwiseMin(params.quantized_activation_max);
}
template <typename T>
@@ -2715,7 +2718,23 @@
NDOpsHelper<N>(output_desc, div_func);
}
-// TODO(aselle): This is not actually optimized yet.
+template <typename T>
+inline void SubWithActivation(
+ const ArithmeticParams& params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
+ ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
+ TFLITE_DCHECK_EQ(input1_shape.FlatSize(), input2_shape.FlatSize());
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ T activation_min, activation_max;
+ GetActivationParams(params, &activation_min, &activation_max);
+ output_map.array() = (input1_map.array() - input2_map.array())
+ .cwiseMin(activation_max)
+ .cwiseMax(activation_min);
+}
+
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
@@ -2724,49 +2743,8 @@
const RuntimeShape& output_shape,
float* output_data) {
ruy::profiler::ScopeLabel label("SubNonBroadcast");
- const int flat_size =
- MatchingElementsSize(input1_shape, input2_shape, output_shape);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], params.float_activation_min,
- params.float_activation_max);
- }
-}
-
-inline void SetActivationMinMax(const ArithmeticParams& params,
- int32* activation_min, int32* activation_max) {
- *activation_min = params.quantized_activation_min;
- *activation_max = params.quantized_activation_max;
-}
-
-inline void SetActivationMinMax(const ArithmeticParams& params,
- float* activation_min, float* activation_max) {
- *activation_min = params.float_activation_min;
- *activation_max = params.float_activation_max;
-}
-
-inline void SetActivationMinMax(const ArithmeticParams& params,
- int64_t* activation_min,
- int64_t* activation_max) {
- *activation_min = params.int64_activation_min;
- *activation_max = params.int64_activation_max;
-}
-
-template <typename T>
-inline void SubWithActivation(
- const ArithmeticParams& params, const RuntimeShape& input1_shape,
- const T* input1_data, const RuntimeShape& input2_shape,
- const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
- ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
- const int flat_size =
- MatchingElementsSize(input1_shape, input2_shape, output_shape);
- T activation_min, activation_max;
- SetActivationMinMax(params, &activation_min, &activation_max);
-
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], activation_min, activation_max);
- }
+ SubWithActivation<float>(params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
}
template <typename T>
diff --git a/tensorflow/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/lite/kernels/internal/resize_bilinear_test.cc
index 12ae975..59f5105 100644
--- a/tensorflow/lite/kernels/internal/resize_bilinear_test.cc
+++ b/tensorflow/lite/kernels/internal/resize_bilinear_test.cc
@@ -114,7 +114,7 @@
if (op_params.align_corners) {
// Align_corners causes small discrepencies between reference & optimized
// versions.
- error_threshold = 3e-4;
+ error_threshold = 1e-3;
}
TestOneResizeBilinear<uint8>(op_params, batch, depth, input_width,
input_height, output_width, output_height,
@@ -139,7 +139,7 @@
if (op_params.align_corners) {
// align_corners causes small discrepencies between reference & optimized
// versions.
- error_threshold = 1e-4;
+ error_threshold = 1e-3;
}
TestOneResizeBilinear<float>(op_params, batch, depth, input_width,
input_height, output_width, output_height,
@@ -164,7 +164,7 @@
if (op_params.align_corners) {
// Align_corners causes small discrepencies between reference & optimized
// versions.
- error_threshold = 1e-4;
+ error_threshold = 1e-3;
}
TestOneResizeBilinear<float>(op_params, batch, depth, input_width,
input_height, output_width, output_height,
diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc
index a781cf3..41bd44a 100644
--- a/tensorflow/lite/kernels/kernel_util.cc
+++ b/tensorflow/lite/kernels/kernel_util.cc
@@ -486,6 +486,9 @@
case kTfLiteInt32:
TF_LITE_ASSERT_EQ(sizeof(int32_t), 4);
return 4;
+ case kTfLiteUInt32:
+ TF_LITE_ASSERT_EQ(sizeof(uint32_t), 4);
+ return 4;
case kTfLiteInt64:
TF_LITE_ASSERT_EQ(sizeof(int64_t), 8);
return 8;
diff --git a/tensorflow/lite/kernels/perception/dense_image_warp.cc b/tensorflow/lite/kernels/perception/dense_image_warp.cc
index f4101c0..139232e 100644
--- a/tensorflow/lite/kernels/perception/dense_image_warp.cc
+++ b/tensorflow/lite/kernels/perception/dense_image_warp.cc
@@ -144,6 +144,11 @@
return ®
}
+// Alias for selective build.
+TfLiteRegistration* Register_DENSE_IMAGE_WARP() {
+ return RegisterDenseImageWarp();
+}
+
} // namespace custom
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc b/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc
index 4e1aca9..7159e98 100644
--- a/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc
+++ b/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc
@@ -244,6 +244,11 @@
return &r;
}
+// Alias for selective build.
+TfLiteRegistration* Register_MAX_POOL_WITH_ARGMAX() {
+ return RegisterMaxPoolWithArgmax();
+}
+
} // namespace custom
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/perception/max_unpooling_2d.cc b/tensorflow/lite/kernels/perception/max_unpooling_2d.cc
index ce51b14..5f58561 100644
--- a/tensorflow/lite/kernels/perception/max_unpooling_2d.cc
+++ b/tensorflow/lite/kernels/perception/max_unpooling_2d.cc
@@ -127,6 +127,11 @@
return ®
}
+// Alias for selective build.
+TfLiteRegistration* Register_MAX_UNPOOLING2D() {
+ return RegisterMaxUnpooling2D();
+}
+
} // namespace custom
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h
index 7cc986a..ec2d248 100644
--- a/tensorflow/lite/kernels/test_util.h
+++ b/tensorflow/lite/kernels/test_util.h
@@ -915,6 +915,7 @@
if (std::is_same<T, int8_t>::value) return TensorType_INT8;
if (std::is_same<T, int16_t>::value) return TensorType_INT16;
if (std::is_same<T, int32_t>::value) return TensorType_INT32;
+ if (std::is_same<T, uint32_t>::value) return TensorType_UINT32;
if (std::is_same<T, int64_t>::value) return TensorType_INT64;
if (std::is_same<T, uint8_t>::value) return TensorType_UINT8;
if (std::is_same<T, string>::value) return TensorType_STRING;
@@ -956,6 +957,16 @@
};
template <>
+struct TypeUnion<uint32_t> {
+ public:
+ // NOLINTNEXTLINE
+ static constexpr TensorType tensor_type = TensorType::TensorType_UINT32;
+ // NOLINTNEXTLINE
+ static constexpr TfLiteType tflite_type = TfLiteType::kTfLiteUInt32;
+ typedef uint32_t ScalarType;
+};
+
+template <>
struct TypeUnion<int16_t> {
public:
// NOLINTNEXTLINE
diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD
index 48d0848..1a20c7b 100644
--- a/tensorflow/lite/micro/BUILD
+++ b/tensorflow/lite/micro/BUILD
@@ -1,9 +1,5 @@
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-load(
"//tensorflow/lite/micro:build_def.bzl",
"micro_copts",
)
@@ -111,6 +107,7 @@
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:compatibility",
+ "//tensorflow/lite/micro/kernels:conv",
"//tensorflow/lite/micro/kernels:ethosu",
"//tensorflow/lite/micro/kernels:fully_connected",
"//tensorflow/lite/micro/kernels:micro_ops",
@@ -220,7 +217,18 @@
],
)
-tflite_micro_cc_test(
+cc_library(
+ name = "system_setup",
+ srcs = [
+ "system_setup.cc",
+ ],
+ hdrs = [
+ "system_setup.h",
+ ],
+ copts = micro_copts(),
+)
+
+cc_test(
name = "micro_error_reporter_test",
srcs = [
"micro_error_reporter_test.cc",
@@ -230,7 +238,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "micro_mutable_op_resolver_test",
srcs = [
"micro_mutable_op_resolver_test.cc",
@@ -242,7 +250,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "micro_interpreter_test",
srcs = [
"micro_interpreter_test.cc",
@@ -259,7 +267,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "simple_memory_allocator_test",
srcs = [
"simple_memory_allocator_test.cc",
@@ -271,7 +279,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "recording_simple_memory_allocator_test",
srcs = [
"recording_simple_memory_allocator_test.cc",
@@ -284,7 +292,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "micro_allocator_test",
srcs = [
"micro_allocator_test.cc",
@@ -298,7 +306,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "recording_micro_allocator_test",
srcs = [
"recording_micro_allocator_test.cc",
@@ -313,7 +321,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "memory_helpers_test",
srcs = [
"memory_helpers_test.cc",
@@ -325,7 +333,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "testing_helpers_test",
srcs = [
"testing_helpers_test.cc",
@@ -337,7 +345,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "micro_utils_test",
srcs = [
"micro_utils_test.cc",
@@ -348,7 +356,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "micro_string_test",
srcs = [
"micro_string_test.cc",
@@ -359,7 +367,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "micro_time_test",
srcs = [
"micro_time_test.cc",
@@ -370,7 +378,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "memory_arena_threshold_test",
srcs = [
"memory_arena_threshold_test.cc",
diff --git a/tensorflow/lite/micro/arduino/debug_log.cc b/tensorflow/lite/micro/arduino/debug_log.cc
index da39c76..f1babc1 100644
--- a/tensorflow/lite/micro/arduino/debug_log.cc
+++ b/tensorflow/lite/micro/arduino/debug_log.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,27 +12,10 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-#include "tensorflow/lite/micro/debug_log.h"
-
-#include "Arduino.h"
-
-// The Arduino DUE uses a different object for the default serial port shown in
-// the monitor than most other models, so make sure we pick the right one. See
-// https://github.com/arduino/Arduino/issues/3088#issuecomment-406655244
-#if defined(__SAM3X8E__)
-#define DEBUG_SERIAL_OBJECT (SerialUSB)
-#else
-#define DEBUG_SERIAL_OBJECT (Serial)
-#endif
-
-// On Arduino platforms, we set up a serial port and write to it for debug
-// logging.
-extern "C" void DebugLog(const char* s) {
- static bool is_initialized = false;
- if (!is_initialized) {
- DEBUG_SERIAL_OBJECT.begin(9600);
- is_initialized = true;
- }
- DEBUG_SERIAL_OBJECT.print(s);
-}
+// This file is empty to ensure that a specialized implementation of
+// debug_log.h is used (instead of the default implementation from
+// tensorflow/lite/micro/debug_log.cc).
+//
+// The actual target-specific implementation of debug_log.h is in
+// system_setup.cc since that allows us to consolidate all the target-specific
+// specializations into one source file.
diff --git a/tensorflow/lite/micro/arduino/system_setup.cc b/tensorflow/lite/micro/arduino/system_setup.cc
new file mode 100644
index 0000000..3bf21c9
--- /dev/null
+++ b/tensorflow/lite/micro/arduino/system_setup.cc
@@ -0,0 +1,36 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/system_setup.h"
+
+#include "Arduino.h"
+#include "tensorflow/lite/micro/debug_log.h"
+
+// The Arduino DUE uses a different object for the default serial port shown in
+// the monitor than most other models, so make sure we pick the right one. See
+// https://github.com/arduino/Arduino/issues/3088#issuecomment-406655244
+#if defined(__SAM3X8E__)
+#define DEBUG_SERIAL_OBJECT (SerialUSB)
+#else
+#define DEBUG_SERIAL_OBJECT (Serial)
+#endif
+
+extern "C" void DebugLog(const char* s) { DEBUG_SERIAL_OBJECT.print(s); }
+
+namespace tflite {
+
+void InitializeTarget() { DEBUG_SERIAL_OBJECT.begin(9600); }
+
+} // namespace tflite
diff --git a/tensorflow/lite/micro/benchmarks/BUILD b/tensorflow/lite/micro/benchmarks/BUILD
index 23faaa3..4394a9b 100644
--- a/tensorflow/lite/micro/benchmarks/BUILD
+++ b/tensorflow/lite/micro/benchmarks/BUILD
@@ -15,6 +15,9 @@
hdrs = [
"micro_benchmark.h",
],
+ visibility = [
+ "//visibility:public",
+ ],
deps = [
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
@@ -46,6 +49,7 @@
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:op_resolvers",
+ "//tensorflow/lite/micro:system_setup",
"//tensorflow/lite/micro/kernels:fully_connected",
],
)
@@ -63,6 +67,7 @@
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:micro_utils",
"//tensorflow/lite/micro:op_resolvers",
+ "//tensorflow/lite/micro:system_setup",
"//tensorflow/lite/micro/examples/person_detection:model_settings",
"//tensorflow/lite/micro/examples/person_detection:person_detect_model_data",
"//tensorflow/lite/micro/examples/person_detection:simple_images_test_data",
diff --git a/tensorflow/lite/micro/benchmarks/README.md b/tensorflow/lite/micro/benchmarks/README.md
index 30275e5..74de759 100644
--- a/tensorflow/lite/micro/benchmarks/README.md
+++ b/tensorflow/lite/micro/benchmarks/README.md
@@ -29,13 +29,13 @@
To run the keyword benchmark on x86, run
```
-make -f tensorflow/lite/micro/tools/make/Makefile test_keyword_benchmark
+make -f tensorflow/lite/micro/tools/make/Makefile run_keyword_benchmark
```
To run the person detection benchmark on x86, run
```
-make -f tensorflow/lite/micro/tools/make/Makefile test_person_detection_benchmark
+make -f tensorflow/lite/micro/tools/make/Makefile run_person_detection_benchmark
```
## Run on Xtensa XPG Simulator
@@ -44,7 +44,7 @@
Xtensa toolchain and license. With these set up, run:
```
-make -f tensorflow/lite/micro/tools/make/Makefile TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=<target architecture> XTENSA_CORE=<xtensa core> test_keyword_benchmark -j18
+make -f tensorflow/lite/micro/tools/make/Makefile TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=<target architecture> XTENSA_CORE=<xtensa core> run_keyword_benchmark -j18
```
## Run on Sparkfun Edge
diff --git a/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc b/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc
index 4830d63..ba114e9 100644
--- a/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc
+++ b/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc
@@ -24,6 +24,7 @@
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_profiler.h"
+#include "tensorflow/lite/micro/system_setup.h"
/*
* Keyword Spotting Benchmark for performance optimizations. The model used in
@@ -77,6 +78,7 @@
} // namespace tflite
int main(int argc, char** argv) {
+ tflite::InitializeTarget();
tflite::MicroProfiler profiler;
uint32_t event_handle = profiler.BeginEvent("InitializeKeywordRunner");
diff --git a/tensorflow/lite/micro/benchmarks/micro_benchmark.h b/tensorflow/lite/micro/benchmarks/micro_benchmark.h
index 2eb3787..272c720 100644
--- a/tensorflow/lite/micro/benchmarks/micro_benchmark.h
+++ b/tensorflow/lite/micro/benchmarks/micro_benchmark.h
@@ -43,7 +43,7 @@
void RunSingleIteration() {
// Run the model on this input and make sure it succeeds.
TfLiteStatus invoke_status = interpreter_.Invoke();
- if (invoke_status != kTfLiteOk) {
+ if (invoke_status == kTfLiteError) {
MicroPrintf("Invoke failed.");
}
}
diff --git a/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc b/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc
index e6d5eb4..1e98bbd 100644
--- a/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc
+++ b/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc
@@ -23,6 +23,7 @@
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_utils.h"
+#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
/*
@@ -74,6 +75,8 @@
} // namespace tflite
int main(int argc, char** argv) {
+ tflite::InitializeTarget();
+
tflite::MicroProfiler profiler;
uint32_t event_handle = profiler.BeginEvent("InitializeBenchmarkRunner");
diff --git a/tensorflow/lite/micro/cortex_m_corstone_300/README.md b/tensorflow/lite/micro/cortex_m_corstone_300/README.md
new file mode 100644
index 0000000..b4ff9a1
--- /dev/null
+++ b/tensorflow/lite/micro/cortex_m_corstone_300/README.md
@@ -0,0 +1,47 @@
+ <!-- mdformat off(b/169948621#comment2) -->
+
+# Running a fixed virtual platform based on Corstone-300 software
+
+This target makes use of a fixed virtual platform (FVP) based on Arm Cortex-300
+based software. More info about Arm Corstone-300 software:
+https://developer.arm.com/ip-products/subsystem/corstone/corstone-300. More info
+about FVPs:
+https://developer.arm.com/tools-and-software/simulation-models/fixed-virtual-platforms.
+
+To fullfill the needed requirements it is depending the following projects:
+
+- Arm Ethos-U Core Platform:
+ https://review.mlplatform.org/admin/repos/ml/ethos-u/ethos-u-core-platform.
+ - Arm Ethos-U Core Platform provides the linker file as well as UART and
+ retarget functions.
+- CMSIS: https://github.com/ARM-software/CMSIS_5.
+ - CMSIS provides startup functionality, e.g. for setting up interrupt
+ handlers and clock speed.
+
+# General build info
+
+This target is based on the cortex_m_generic target and except that for now the
+only supported toolchain is GCC, the same general build info applies:
+tensorflow/lite/micro/cortex_m_generic/README.md.
+
+Required parameters:
+
+- TARGET: cortex_m_corstone_300
+- TARGET_ARCH: cortex-mXX (For all options see:
+ tensorflow/lite/micro/tools/make/targets/cortex_m_corstone_300_makefile.inc)
+
+# How to run
+
+Note that Corstone-300 is targetted for Cortex-M55 but it is backwards
+compatible. This means one could potentially run it for example with a
+Cortex-M7. Note that the clock speed would be that of an Cortex-M55. This may
+not matter when running unit tests or for debugging.
+
+Some examples:
+
+```
+make -j -f tensorflow/lite/micro/tools/make/Makefile OPTIMIZED_KERNEL_DIR=cmsis_nn TARGET=cortex_m_corstone_300 TARGET_ARCH=cortex-m55 test_kernel_fully_connected_test
+make -j -f tensorflow/lite/micro/tools/make/Makefile TARGET=cortex_m_corstone_300 TARGET_ARCH=cortex-m55 test_kernel_fully_connected_test
+make -j -f tensorflow/lite/micro/tools/make/Makefile OPTIMIZED_KERNEL_DIR=cmsis_nn TARGET=cortex_m_corstone_300 TARGET_ARCH=cortex-m7+fp test_kernel_fully_connected_test
+make -j -f tensorflow/lite/micro/tools/make/Makefile TARGET=cortex_m_corstone_300 TARGET_ARCH=cortex-m3 test_kernel_fully_connected_test
+```
diff --git a/tensorflow/lite/micro/cortex_m_corstone_300/system_setup.cc b/tensorflow/lite/micro/cortex_m_corstone_300/system_setup.cc
new file mode 100644
index 0000000..a2438e8
--- /dev/null
+++ b/tensorflow/lite/micro/cortex_m_corstone_300/system_setup.cc
@@ -0,0 +1,26 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/system_setup.h"
+
+namespace tflite {
+
+extern "C" {
+void uart_init(void);
+}
+
+void InitializeTarget() { uart_init(); }
+
+} // namespace tflite
diff --git a/tensorflow/lite/micro/examples/hello_world/BUILD b/tensorflow/lite/micro/examples/hello_world/BUILD
index 34106b4..bcb443a 100644
--- a/tensorflow/lite/micro/examples/hello_world/BUILD
+++ b/tensorflow/lite/micro/examples/hello_world/BUILD
@@ -1,10 +1,5 @@
# Description:
# TensorFlow Lite for Microcontrollers "hello world" example.
-
-load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
load(
"//tensorflow/lite/micro:build_def.bzl",
"micro_copts",
@@ -27,7 +22,7 @@
copts = micro_copts(),
)
-tflite_micro_cc_test(
+cc_test(
name = "hello_world_test",
srcs = [
"hello_world_test.cc",
@@ -87,6 +82,7 @@
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:op_resolvers",
+ "//tensorflow/lite/micro:system_setup",
"//tensorflow/lite/schema:schema_fbs",
],
)
diff --git a/tensorflow/lite/micro/examples/hello_world/README.md b/tensorflow/lite/micro/examples/hello_world/README.md
index c23b355..7dcc891 100644
--- a/tensorflow/lite/micro/examples/hello_world/README.md
+++ b/tensorflow/lite/micro/examples/hello_world/README.md
@@ -45,7 +45,7 @@
command:
```
-make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp TAGS=no_arc_mli generate_hello_world_make_project
+make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp OPTIMIZED_KERNEL_DIR=arc_mli ARC_TAGS=no_arc_mli generate_hello_world_make_project
```
### Build and Run Example
@@ -245,7 +245,7 @@
Generate hello world project
```
-make -f tensorflow/lite/micro/tools/make/Makefile generate_hello_world_make_project TARGET=himax_we1_evb TAGS=no_arc_mli
+make -f tensorflow/lite/micro/tools/make/Makefile generate_hello_world_make_project TARGET=himax_we1_evb ARC_TAGS=no_arc_mli
```
### Build and Burn Example
@@ -454,7 +454,7 @@
- STM32F7 discovery kit board
- Mini-USB cable
-- ARM Mbed CLI ([installation instructions](https://os.mbed.com/docs/mbed-os/v5.12/tools/installation-and-setup.html))
+- ARM Mbed CLI ([installation instructions](https://os.mbed.com/docs/mbed-os/v5.12/tools/installation-and-setup.html). Check it out for MacOS Catalina - [mbed-cli is broken on MacOS Catalina #930](https://github.com/ARMmbed/mbed-cli/issues/930#issuecomment-660550734))
- Python 2.7 and pip
Since Mbed requires a special folder structure for projects, we'll first run a
diff --git a/tensorflow/lite/micro/examples/hello_world/main_functions.cc b/tensorflow/lite/micro/examples/hello_world/main_functions.cc
index 0c8541f..b8c630c 100644
--- a/tensorflow/lite/micro/examples/hello_world/main_functions.cc
+++ b/tensorflow/lite/micro/examples/hello_world/main_functions.cc
@@ -21,6 +21,7 @@
#include "tensorflow/lite/micro/examples/hello_world/output_handler.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
+#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Globals, used for compatibility with Arduino-style sketches.
@@ -38,6 +39,8 @@
// The name of this function is important for Arduino compatibility.
void setup() {
+ tflite::InitializeTarget();
+
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
diff --git a/tensorflow/lite/micro/examples/image_recognition_experimental/BUILD b/tensorflow/lite/micro/examples/image_recognition_experimental/BUILD
index 2f707d9..69c37ee 100644
--- a/tensorflow/lite/micro/examples/image_recognition_experimental/BUILD
+++ b/tensorflow/lite/micro/examples/image_recognition_experimental/BUILD
@@ -1,11 +1,5 @@
# Description:
# TensorFlow Lite for Microcontrollers image recognition example.
-
-load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-
package(
features = ["-layering_check"],
licenses = ["notice"], # Apache 2.0
@@ -27,7 +21,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "image_recognition_test",
srcs = ["image_recognition_test.cc"],
tags = [
diff --git a/tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc b/tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc
index 76b21cb..feb6ed4 100644
--- a/tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc
+++ b/tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc
@@ -33,7 +33,7 @@
endif
$(eval $(call microlite_test,image_recognition,\
-$(IMAGE_RECOGNITION_SRCS),$(IMAGE_RECOGNITION_HDRS)))
+$(IMAGE_RECOGNITION_SRCS),$(IMAGE_RECOGNITION_HDRS), exclude))
$(eval $(call microlite_test,image_recognition_test,\
$(IMAGE_RECOGNITION_TEST_SRCS),$(IMAGE_RECOGNITION_TEST_HDRS)))
diff --git a/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc b/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc
index 4249309..87d68ed 100644
--- a/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc
+++ b/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc
@@ -23,6 +23,7 @@
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
+#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
#define NUM_OUT_CH 3
@@ -34,6 +35,7 @@
"Dog", "Frog", "Horse", "Ship", "Truck"};
int main(int argc, char** argv) {
+ tflite::InitializeTarget();
init_lcd();
wait_ms(100);
diff --git a/tensorflow/lite/micro/examples/magic_wand/BUILD b/tensorflow/lite/micro/examples/magic_wand/BUILD
index 0f9c517..2223c6e 100644
--- a/tensorflow/lite/micro/examples/magic_wand/BUILD
+++ b/tensorflow/lite/micro/examples/magic_wand/BUILD
@@ -1,11 +1,5 @@
# Description:
# TensorFlow Lite for Microcontrollers "gesture recognition" example.
-
-load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["-layering_check"],
@@ -34,7 +28,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "magic_wand_test",
srcs = [
"magic_wand_test.cc",
@@ -71,7 +65,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "accelerometer_handler_test",
srcs = [
"accelerometer_handler_test.cc",
@@ -99,7 +93,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "gesture_predictor_test",
srcs = [
"gesture_predictor_test.cc",
@@ -126,7 +120,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "output_handler_test",
srcs = [
"output_handler_test.cc",
@@ -160,6 +154,7 @@
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:op_resolvers",
+ "//tensorflow/lite/micro:system_setup",
"//tensorflow/lite/schema:schema_fbs",
],
)
diff --git a/tensorflow/lite/micro/examples/magic_wand/main_functions.cc b/tensorflow/lite/micro/examples/magic_wand/main_functions.cc
index 15abe0d..583cee8 100644
--- a/tensorflow/lite/micro/examples/magic_wand/main_functions.cc
+++ b/tensorflow/lite/micro/examples/magic_wand/main_functions.cc
@@ -23,6 +23,7 @@
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
+#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Globals, used for compatibility with Arduino-style sketches.
@@ -42,6 +43,8 @@
// The name of this function is important for Arduino compatibility.
void setup() {
+ tflite::InitializeTarget();
+
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
static tflite::MicroErrorReporter micro_error_reporter; // NOLINT
diff --git a/tensorflow/lite/micro/examples/micro_speech/BUILD b/tensorflow/lite/micro/examples/micro_speech/BUILD
index e7acd65..cdd7516 100644
--- a/tensorflow/lite/micro/examples/micro_speech/BUILD
+++ b/tensorflow/lite/micro/examples/micro_speech/BUILD
@@ -1,11 +1,5 @@
# Description:
# TensorFlow Lite microcontroller example.
-
-load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["-layering_check"],
@@ -44,7 +38,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "micro_speech_test",
srcs = [
"micro_speech_test.cc",
@@ -111,7 +105,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "simple_features_generator_reference_test",
srcs = [
"simple_features/simple_features_generator_test.cc",
@@ -143,7 +137,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "simple_features_generator_fixed_test",
srcs = [
"simple_features/simple_features_generator_test.cc",
@@ -191,7 +185,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "audio_provider_test",
srcs = [
"audio_provider_test.cc",
@@ -206,7 +200,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "audio_provider_mock_test",
srcs = [
"audio_provider_mock_test.cc",
@@ -239,7 +233,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "feature_provider_test",
srcs = [
"feature_provider_test.cc",
@@ -272,11 +266,15 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "feature_provider_mock_test",
+ size = "small",
srcs = [
"feature_provider_mock_test.cc",
],
+ tags = [
+ "noasan", # TODO(b/179930607): Fix with asan.
+ ],
deps = [
":feature_provider_mock",
"//tensorflow/lite/c:common",
@@ -303,7 +301,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "recognize_commands_test",
srcs = [
"recognize_commands_test.cc",
@@ -335,7 +333,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "command_responder_test",
srcs = [
"command_responder_test.cc",
@@ -364,6 +362,7 @@
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:op_resolvers",
+ "//tensorflow/lite/micro:system_setup",
"//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings",
"//tensorflow/lite/micro/examples/micro_speech/micro_features:model",
"//tensorflow/lite/schema:schema_fbs",
@@ -385,6 +384,7 @@
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:op_resolvers",
+ "//tensorflow/lite/micro:system_setup",
"//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings",
"//tensorflow/lite/micro/examples/micro_speech/micro_features:model",
"//tensorflow/lite/schema:schema_fbs",
diff --git a/tensorflow/lite/micro/examples/micro_speech/README.md b/tensorflow/lite/micro/examples/micro_speech/README.md
index 5fd5ff9..25bd8c9 100644
--- a/tensorflow/lite/micro/examples/micro_speech/README.md
+++ b/tensorflow/lite/micro/examples/micro_speech/README.md
@@ -66,11 +66,12 @@
```
make -f tensorflow/lite/micro/tools/make/Makefile \
-TARGET=arc_emsdp TAGS=reduce_codesize \
+TARGET=arc_emsdp ARC_TAGS=reduce_codesize \
+OPTIMIZED_KERNEL_DIR=arc_mli \
generate_micro_speech_mock_make_project
```
-Note that `TAGS=reduce_codesize` applies example specific changes of code to
+Note that `ARC_TAGS=reduce_codesize` applies example specific changes of code to
reduce total size of application. It can be omitted.
### Build and Run Example
diff --git a/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/Makefile.inc b/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/Makefile.inc
index 74860da..d59adc2 100644
--- a/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/Makefile.inc
+++ b/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/Makefile.inc
@@ -4,7 +4,7 @@
# In particular:
# - Extend Heap and stack size for application needs
# - Use Linker command file with better usage of fast memory
-# - Optional (TAGS=reduce_codesize): In case project was
+# - Optional (ARC_TAGS=reduce_codesize): In case project was
# generated with MLI usage, reduce scratch buffers.
MICRO_SPEECH_HDRS += \
@@ -36,7 +36,7 @@
@echo Makefile: No Reference fallback for MLI supported functions >> $@
-ifneq ($(filter $(ALL_TAGS), reduce_codesize),)
+ifneq ($(filter $(ARC_TAGS), reduce_codesize),)
# In case 'reduce_codesize' tag is present, we replace common MLI functions with
# specializations appropriate for this particular graph. But such changes of code
# with high probability may not be acceptable for other graphs and will need
diff --git a/tensorflow/lite/micro/examples/micro_speech/main_functions.cc b/tensorflow/lite/micro/examples/micro_speech/main_functions.cc
index a0a858b..55b0d30 100644
--- a/tensorflow/lite/micro/examples/micro_speech/main_functions.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/main_functions.cc
@@ -24,6 +24,7 @@
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
+#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Globals, used for compatibility with Arduino-style sketches.
@@ -47,6 +48,8 @@
// The name of this function is important for Arduino compatibility.
void setup() {
+ tflite::InitializeTarget();
+
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/BUILD b/tensorflow/lite/micro/examples/micro_speech/micro_features/BUILD
index 3dff486..39dd4ca 100644
--- a/tensorflow/lite/micro/examples/micro_speech/micro_features/BUILD
+++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/BUILD
@@ -1,10 +1,4 @@
# Library for generating feature vectors from audio data
-
-load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["-layering_check"],
@@ -76,11 +70,15 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "micro_features_generator_test",
+ size = "small",
srcs = [
"micro_features_generator_test.cc",
],
+ tags = [
+ "noasan", # TODO(b/179930607): Fix with asan.
+ ],
deps = [
":micro_features_generator",
":micro_features_generator_test_data",
diff --git a/tensorflow/lite/micro/examples/person_detection/BUILD b/tensorflow/lite/micro/examples/person_detection/BUILD
index cf69ef9..2c0800e 100644
--- a/tensorflow/lite/micro/examples/person_detection/BUILD
+++ b/tensorflow/lite/micro/examples/person_detection/BUILD
@@ -1,11 +1,5 @@
# Description:
# TensorFlow Lite for Microcontrollers Vision Example.
-
-load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["-layering_check"],
@@ -53,7 +47,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "person_detection_test",
srcs = ["person_detection_test.cc"],
tags = [
@@ -87,7 +81,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "image_provider_test",
srcs = [
"image_provider_test.cc",
@@ -115,7 +109,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "detection_responder_test",
srcs = [
"detection_responder_test.cc",
@@ -144,6 +138,7 @@
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:op_resolvers",
+ "//tensorflow/lite/micro:system_setup",
"//tensorflow/lite/schema:schema_fbs",
],
)
diff --git a/tensorflow/lite/micro/examples/person_detection/README.md b/tensorflow/lite/micro/examples/person_detection/README.md
index 52e59ac..9877343 100644
--- a/tensorflow/lite/micro/examples/person_detection/README.md
+++ b/tensorflow/lite/micro/examples/person_detection/README.md
@@ -52,11 +52,12 @@
```
make -f tensorflow/lite/micro/tools/make/Makefile \
-TARGET=arc_emsdp TAGS=reduce_codesize \
+TARGET=arc_emsdp ARC_TAGS=reduce_codesize \
+OPTIMIZED_KERNEL_DIR=arc_mli \
generate_person_detection_int8_make_project
```
-Note that `TAGS=reduce_codesize` applies example specific changes of code to
+Note that `ARC_TAGS=reduce_codesize` applies example specific changes of code to
reduce total size of application. It can be omitted.
### Build and Run Example
diff --git a/tensorflow/lite/micro/examples/person_detection/arc_emsdp/Makefile.inc b/tensorflow/lite/micro/examples/person_detection/arc_emsdp/Makefile.inc
index 1555f78..85a0846 100644
--- a/tensorflow/lite/micro/examples/person_detection/arc_emsdp/Makefile.inc
+++ b/tensorflow/lite/micro/examples/person_detection/arc_emsdp/Makefile.inc
@@ -25,7 +25,7 @@
@sed -E -i 's#MLI_ONLY *\?= *false#MLI_ONLY \?= true#' $(word 2, $^)
@echo Makefile: No Reference fallback for MLI supported functions >> $@
-ifneq ($(filter $(ALL_TAGS), reduce_codesize),)
+ifneq ($(filter $(ARC_TAGS), reduce_codesize),)
#In case 'reduce_codesize' tag is present, we replace common MLI functions with
#specializations appropriate for this particular graph.But such changes of code
#with high probability may not be acceptable for other graphs and will need
diff --git a/tensorflow/lite/micro/examples/person_detection/main_functions.cc b/tensorflow/lite/micro/examples/person_detection/main_functions.cc
index b97d4e1..7e6e40d 100644
--- a/tensorflow/lite/micro/examples/person_detection/main_functions.cc
+++ b/tensorflow/lite/micro/examples/person_detection/main_functions.cc
@@ -22,6 +22,7 @@
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
+#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Globals, used for compatibility with Arduino-style sketches.
@@ -45,6 +46,8 @@
// The name of this function is important for Arduino compatibility.
void setup() {
+ tflite::InitializeTarget();
+
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD
index f62ea11..2a0a2d9 100644
--- a/tensorflow/lite/micro/kernels/BUILD
+++ b/tensorflow/lite/micro/kernels/BUILD
@@ -1,8 +1,4 @@
load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-load(
"//tensorflow/lite/micro:build_def.bzl",
"micro_copts",
)
@@ -27,6 +23,10 @@
packages = ["//tensorflow/lite/micro"],
)
+####################################
+# C++ libraries
+####################################
+
cc_library(
name = "activation_utils",
hdrs = ["activation_utils.h"],
@@ -37,6 +37,43 @@
)
cc_library(
+ name = "conv",
+ srcs = [
+ "conv_common.cc",
+ ] + select({
+ "//conditions:default": [
+ "conv.cc",
+ ],
+ ":xtensa_hifimini": [
+ "xtensa/conv.cc",
+ ],
+ }),
+ hdrs = ["conv.h"],
+ copts = micro_copts(),
+ visibility = [
+ # Kernel variants need to be visible to the examples and benchmarks.
+ ":micro",
+ ],
+ deps = [
+ ":fixedpoint_utils",
+ ":kernel_util",
+ ":xtensa",
+ "//tensorflow/lite/c:common",
+ "//tensorflow/lite/kernels/internal:common",
+ "//tensorflow/lite/kernels/internal:quantization_util",
+ "//tensorflow/lite/kernels/internal:reference_base",
+ "//tensorflow/lite/kernels/internal:tensor",
+ "//tensorflow/lite/kernels:kernel_util",
+ "//tensorflow/lite/kernels:padding",
+ ] + select({
+ "//conditions:default": [],
+ ":xtensa_hifimini": [
+ #"//third_party/xtensa/cstub64s:hifi_mini",
+ ],
+ }),
+)
+
+cc_library(
name = "conv_test_common",
srcs = [
"conv_test_common.cc",
@@ -211,14 +248,12 @@
"zeros_like.cc",
] + select({
"//conditions:default": [
- "conv.cc",
"depthwise_conv.cc",
"quantize.cc",
"softmax.cc",
"svdf.cc",
],
":xtensa_hifimini": [
- "xtensa/conv.cc",
"xtensa/depthwise_conv.cc",
"xtensa/quantize.cc",
"xtensa/softmax.cc",
@@ -287,23 +322,11 @@
}),
)
+####################################
+# C++ tests
+####################################
+
cc_test(
- name = "shape_test",
- srcs = ["shape_test.cc"],
- deps = [
- ":kernel_runner",
- "//tensorflow/lite/c:common",
- "//tensorflow/lite/micro:op_resolvers",
- "//tensorflow/lite/micro:test_helpers",
- "//tensorflow/lite/micro/testing:micro_test",
- ],
-)
-
-test_suite(
- name = "all_tests",
-)
-
-tflite_micro_cc_test(
name = "activations_test",
srcs = [
"activations_test.cc",
@@ -317,7 +340,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "add_test",
srcs = [
"add_test.cc",
@@ -331,7 +354,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "arg_min_max_test",
srcs = [
"arg_min_max_test.cc",
@@ -345,7 +368,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "batch_to_space_nd_test",
srcs = [
"batch_to_space_nd_test.cc",
@@ -359,7 +382,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "cast_test",
srcs = ["cast_test.cc"],
deps = [
@@ -372,7 +395,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "ceil_test",
srcs = [
"ceil_test.cc",
@@ -386,7 +409,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "circular_buffer_test",
srcs = [
"circular_buffer_test.cc",
@@ -401,7 +424,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "comparisons_test",
srcs = [
"comparisons_test.cc",
@@ -414,7 +437,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "concatenation_test",
srcs = [
"concatenation_test.cc",
@@ -427,7 +450,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "conv_test",
srcs = [
"conv_test.cc",
@@ -442,7 +465,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "depthwise_conv_test",
srcs = [
"depthwise_conv_test.cc",
@@ -456,7 +479,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "dequantize_test",
srcs = [
"dequantize_test.cc",
@@ -469,7 +492,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "detection_postprocess_test",
srcs = [
"detection_postprocess_test.cc",
@@ -485,7 +508,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "elementwise_test",
srcs = ["elementwise_test.cc"],
deps = [
@@ -498,7 +521,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "exp_test",
srcs = ["exp_test.cc"],
deps = [
@@ -511,7 +534,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "floor_test",
srcs = [
"floor_test.cc",
@@ -525,7 +548,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "fully_connected_test",
srcs = [
"fully_connected_test.cc",
@@ -540,7 +563,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "hard_swish_test",
srcs = ["hard_swish_test.cc"],
deps = [
@@ -552,7 +575,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "l2norm_test",
srcs = [
"l2norm_test.cc",
@@ -566,7 +589,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "logical_test",
srcs = [
"logical_test.cc",
@@ -580,7 +603,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "logistic_test",
srcs = [
"logistic_test.cc",
@@ -594,7 +617,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "maximum_minimum_test",
srcs = [
"maximum_minimum_test.cc",
@@ -608,7 +631,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "mul_test",
srcs = [
"mul_test.cc",
@@ -621,7 +644,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "neg_test",
srcs = [
"neg_test.cc",
@@ -635,7 +658,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "pack_test",
srcs = [
"pack_test.cc",
@@ -649,7 +672,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "pad_test",
srcs = [
"pad_test.cc",
@@ -667,7 +690,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "pooling_test",
srcs = [
"pooling_test.cc",
@@ -680,7 +703,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "prelu_test",
srcs = [
"prelu_test.cc",
@@ -693,7 +716,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "quantization_util_test",
srcs = [
"quantization_util_test.cc",
@@ -705,7 +728,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "quantize_test",
srcs = [
"quantize_test.cc",
@@ -718,7 +741,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "reduce_test",
srcs = [
"reduce_test.cc",
@@ -732,7 +755,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "reshape_test",
srcs = [
"reshape_test.cc",
@@ -747,7 +770,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "resize_nearest_neighbor_test",
srcs = [
"resize_nearest_neighbor_test.cc",
@@ -761,7 +784,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "round_test",
srcs = [
"round_test.cc",
@@ -775,7 +798,19 @@
],
)
-tflite_micro_cc_test(
+cc_test(
+ name = "shape_test",
+ srcs = ["shape_test.cc"],
+ deps = [
+ ":kernel_runner",
+ "//tensorflow/lite/c:common",
+ "//tensorflow/lite/micro:op_resolvers",
+ "//tensorflow/lite/micro:test_helpers",
+ "//tensorflow/lite/micro/testing:micro_test",
+ ],
+)
+
+cc_test(
name = "softmax_test",
srcs = [
"softmax_test.cc",
@@ -789,7 +824,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "space_to_batch_nd_test",
srcs = [
"space_to_batch_nd_test.cc",
@@ -804,7 +839,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "split_test",
srcs = [
"split_test.cc",
@@ -819,7 +854,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "split_v_test",
srcs = [
"split_v_test.cc",
@@ -834,7 +869,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "strided_slice_test",
srcs = [
"strided_slice_test.cc",
@@ -848,7 +883,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "sub_test",
srcs = [
"sub_test.cc",
@@ -861,7 +896,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "svdf_test",
srcs = [
"svdf_test.cc",
@@ -874,7 +909,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "tanh_test",
srcs = ["tanh_test.cc"],
deps = [
@@ -885,7 +920,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "transpose_conv_test",
srcs = [
"transpose_conv_test.cc",
@@ -895,12 +930,13 @@
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:micro_utils",
+ "//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)
-tflite_micro_cc_test(
+cc_test(
name = "unpack_test",
srcs = [
"unpack_test.cc",
@@ -914,7 +950,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "zeros_like_test",
srcs = ["zeros_like_test.cc"],
deps = [
diff --git a/tensorflow/lite/micro/kernels/add_n.cc b/tensorflow/lite/micro/kernels/add_n.cc
new file mode 100644
index 0000000..390d285
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/add_n.cc
@@ -0,0 +1,101 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdint.h>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace add_n {
+
+constexpr int kInputTensor1 = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int num_inputs = NumInputs(node);
+ TF_LITE_ENSURE(context, num_inputs >= 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input1;
+ TF_LITE_ENSURE_OK(context,
+ GetInputSafe(context, node, kInputTensor1, &input1));
+ TfLiteTensor* output;
+ TF_LITE_ENSURE_OK(context,
+ GetOutputSafe(context, node, kOutputTensor, &output));
+ output->type = input1->type;
+
+ // Check that all input tensors have the same shape and type.
+ for (int i = kInputTensor1 + 1; i < num_inputs; ++i) {
+ const TfLiteTensor* input;
+ TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
+ TF_LITE_ENSURE(context, HaveSameShapes(input1, input));
+ TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
+ }
+
+ // Use the first input node's dimension to be the dimension of the output
+ // node.
+ TfLiteIntArray* input1_dims = input1->dims;
+ TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input1_dims);
+ return context->ResizeTensor(context, output, output_dims);
+}
+
+template <typename T>
+void EvalAddN(TfLiteContext* context, TfLiteNode* node) {
+ // OLD-TODO(haoliang): Initialize all_inputs only once during init.
+ VectorOfTensors<T> all_inputs(*context, *node->inputs);
+ // Safe to use unchecked since caller checks that tensor is valid
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ int num_inputs = NumInputs(node);
+ // Safe to use unchecked since caller checks that tensor is valid
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ reference_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(),
+ GetTensorData<T>(output));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1;
+ TF_LITE_ENSURE_OK(context,
+ GetInputSafe(context, node, kInputTensor1, &input1));
+ TfLiteTensor* output;
+ TF_LITE_ENSURE_OK(context,
+ GetOutputSafe(context, node, kOutputTensor, &output));
+ if (output->type == kTfLiteFloat32) {
+ EvalAddN<float>(context, node);
+ } else if (output->type == kTfLiteInt32) {
+ EvalAddN<int32_t>(context, node);
+ } else {
+ TF_LITE_KERNEL_LOG(context, "AddN only supports FLOAT32|INT32 now, got %s.",
+ TfLiteTypeGetName(output->type));
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace add_n
+
+TfLiteRegistration* Register_ADD_N() {
+ static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
+ add_n::Prepare, add_n::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/add_n_test.cc b/tensorflow/lite/micro/kernels/add_n_test.cc
new file mode 100644
index 0000000..4b229b65
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/add_n_test.cc
@@ -0,0 +1,92 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdint.h>
+
+#include <vector>
+
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseAddNOpModel : public SingleOpModel {
+ public:
+ BaseAddNOpModel(const std::vector<TensorData>& inputs,
+ const TensorData& output) {
+ int num_inputs = inputs.size();
+ std::vector<std::vector<int>> input_shapes;
+
+ for (int i = 0; i < num_inputs; ++i) {
+ inputs_.push_back(AddInput(inputs[i]));
+ input_shapes.push_back(GetShape(inputs_[i]));
+ }
+
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions,
+ CreateAddNOptions(builder_).Union());
+ BuildInterpreter(input_shapes);
+ }
+
+ int input(int i) { return inputs_[i]; }
+
+ protected:
+ std::vector<int> inputs_;
+ int output_;
+};
+
+class FloatAddNOpModel : public BaseAddNOpModel {
+ public:
+ using BaseAddNOpModel::BaseAddNOpModel;
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class IntegerAddNOpModel : public BaseAddNOpModel {
+ public:
+ using BaseAddNOpModel::BaseAddNOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
+TEST(FloatAddNOpModel, AddMultipleTensors) {
+ FloatAddNOpModel m({{TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}}},
+ {TensorType_FLOAT32, {}});
+ m.PopulateTensor<float>(m.input(0), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input(1), {0.1, 0.2, 0.3, 0.5});
+ m.PopulateTensor<float>(m.input(2), {0.5, 0.1, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.4, 0.5, 1.1, 1.5}));
+}
+
+TEST(IntegerAddNOpModel, AddMultipleTensors) {
+ IntegerAddNOpModel m({{TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}},
+ {TensorType_INT32, {}});
+ m.PopulateTensor<int32_t>(m.input(0), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input(1), {1, 2, 3, 5});
+ m.PopulateTensor<int32_t>(m.input(2), {10, -5, 1, -2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-9, -1, 11, 11}));
+}
+
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/arc_mli/README.md b/tensorflow/lite/micro/kernels/arc_mli/README.md
index 1995eda..3d6ddc0 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/README.md
+++ b/tensorflow/lite/micro/kernels/arc_mli/README.md
@@ -21,16 +21,16 @@
For example:
```
-make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp generate_person_detection_int8_make_project
+make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp OPTIMIZED_KERNEL_DIR=arc_mli generate_person_detection_int8_make_project
```
In case MLI implementation can’t be used, kernels in this folder fallback to
TFLM reference implementations. For applications which may not benefit from MLI
library, projects can be generated without these implementations by adding
-`TAGS=no_arc_mli` in the command line, which can reduce overall code size:
+`ARC_TAGS=no_arc_mli` in the command line, which can reduce overall code size:
```
-make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp TAGS=no_arc_mli generate_person_detection_int8_make_project
+make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp OPTIMIZED_KERNEL_DIR=arc_mli ARC_TAGS=no_arc_mli generate_person_detection_int8_make_project
```
For ARC EM SDP board, a pre-compiled MLI library is downloaded and used in the
@@ -39,7 +39,7 @@
ARC EM SDP platform, add `BUILD_ARC_MLI=true` option to make command:
```
-make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp BUILD_ARC_MLI=true generate_person_detection_int8_make_project
+make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp OPTIMIZED_KERNEL_DIR=arc_mli BUILD_ARC_MLI=true generate_person_detection_int8_make_project
```
If an application exclusively uses accelerated MLI kernel implementations, one
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
index 3691cd9..e82fee6 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/lite/kernels/internal/reference/conv.h"
+#include "tensorflow/lite/micro/kernels/conv.h"
#include "CMSIS/NN/Include/arm_nn_types.h"
#include "CMSIS/NN/Include/arm_nnfunctions.h"
@@ -21,6 +21,7 @@
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@@ -30,90 +31,13 @@
namespace tflite {
namespace {
-constexpr int kInputTensor = 0;
-constexpr int kFilterTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-// Conv is quantized along dimension 0:
-// https://www.tensorflow.org/lite/performance/quantization_spec
-constexpr int kConvQuantizedDimension = 0;
-
struct OpData {
- TfLitePaddingValues padding;
-
- // Cached tensor zero point values for quantized operations.
- int32_t input_zero_point;
- int32_t filter_zero_point;
- int32_t output_zero_point;
-
- // The scaling factor from input to output (aka the 'real multiplier') can
- // be represented as a fixed point multiplier plus a left shift.
- int32_t output_multiplier;
- int output_shift;
-
- // Per channel output multiplier and shift.
- int32_t* per_channel_output_multiplier;
- int32_t* per_channel_output_shift;
-
- // The range of the fused activation layer. For example for kNone and
- // uint8_t these would be 0 and 255.
- int32_t output_activation_min;
- int32_t output_activation_max;
+ OpDataConv reference_op_data;
// Index to buffer for optimizations if applicable.
int buffer_idx;
};
-inline PaddingType RuntimePaddingType(TfLitePadding padding) {
- switch (padding) {
- case TfLitePadding::kTfLitePaddingSame:
- return PaddingType::kSame;
- case TfLitePadding::kTfLitePaddingValid:
- return PaddingType::kValid;
- case TfLitePadding::kTfLitePaddingUnknown:
- default:
- return PaddingType::kNone;
- }
-}
-
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
- const TfLiteConvParams* params, int width,
- int height, int filter_width, int filter_height,
- int out_width, int out_height,
- const TfLiteType data_type, OpData* data) {
- bool has_bias = node->inputs->size == 3;
- // Check number of inputs/outputs
- TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-
- // Matching GetWindowedOutputSize in TensorFlow.
- auto padding = params->padding;
- data->padding = ComputePaddingHeightWidth(
- params->stride_height, params->stride_width,
- params->dilation_height_factor, params->dilation_width_factor, height,
- width, filter_height, filter_width, padding, &out_height, &out_width);
-
- // Note that quantized inference requires that all tensors have their
- // parameters set. This is usually done during quantized training.
- if (data_type != kTfLiteFloat32) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- const TfLiteTensor* bias =
- GetOptionalInputTensor(context, node, kBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- int num_channels = filter->dims->data[kConvQuantizedDimension];
-
- TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
- context, input, filter, bias, output, params->activation,
- &data->output_multiplier, &data->output_shift,
- &data->output_activation_min, &data->output_activation_max,
- data->per_channel_output_multiplier,
- reinterpret_cast<int*>(data->per_channel_output_shift), num_channels));
- }
- return kTfLiteOk;
-}
-
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
@@ -124,12 +48,13 @@
TFLITE_DCHECK(node->builtin_data != nullptr);
int32_t buf_size = 0;
- const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
+ const auto& params =
+ *(static_cast<const TfLiteConvParams*>(node->builtin_data));
OpData* data = static_cast<OpData*>(node->user_data);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
+ const TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
@@ -161,34 +86,31 @@
// non-int8 cases. Protect this section with a if (input->type == kTfLiteInt8)
// when the issue is fixed.
const int num_channels = filter->dims->data[kConvQuantizedDimension];
- data->per_channel_output_multiplier =
+ data->reference_op_data.per_channel_output_multiplier =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, num_channels * sizeof(int32_t)));
- data->per_channel_output_shift =
+ data->reference_op_data.per_channel_output_shift =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, num_channels * sizeof(int32_t)));
- TF_LITE_ENSURE_STATUS(CalculateOpData(
+ TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
context, node, params, input_dims.w, input_dims.h, filter_dims.w,
- filter_dims.h, output_dims.w, output_dims.h, input->type, data));
-
- data->input_zero_point = input->params.zero_point;
- data->filter_zero_point = filter->params.zero_point;
- data->output_zero_point = output->params.zero_point;
+ filter_dims.h, output_dims.w, output_dims.h, input->type,
+ &data->reference_op_data));
if (input->type == kTfLiteInt8) {
// Initialize cmsis_nn convolution parameters
cmsis_nn_conv_params conv_params;
conv_params.input_offset = -input->params.zero_point;
conv_params.output_offset = output->params.zero_point;
- conv_params.stride.h = params->stride_height;
- conv_params.stride.w = params->stride_width;
- conv_params.dilation.h = params->dilation_height_factor;
- conv_params.dilation.w = params->dilation_width_factor;
- conv_params.padding.h = data->padding.height;
- conv_params.padding.w = data->padding.width;
- conv_params.activation.min = data->output_activation_min;
- conv_params.activation.max = data->output_activation_max;
+ conv_params.stride.h = params.stride_height;
+ conv_params.stride.w = params.stride_width;
+ conv_params.dilation.h = params.dilation_height_factor;
+ conv_params.dilation.w = params.dilation_width_factor;
+ conv_params.padding.h = data->reference_op_data.padding.height;
+ conv_params.padding.w = data->reference_op_data.padding.width;
+ conv_params.activation.min = data->reference_op_data.output_activation_min;
+ conv_params.activation.max = data->reference_op_data.output_activation_max;
buf_size = arm_convolve_wrapper_s8_get_buffer_size(
&conv_params, &input_dims, &filter_dims, &output_dims);
@@ -203,73 +125,34 @@
return kTfLiteOk;
}
-TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, const OpData& data,
- const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* filter,
- const TfLiteEvalTensor* bias,
- TfLiteEvalTensor* im2col,
- TfLiteEvalTensor* hwcn_weights,
- TfLiteEvalTensor* output) {
- const int32_t input_offset = -data.input_zero_point;
- const int32_t filter_offset = -data.filter_zero_point;
- const int32_t output_offset = data.output_zero_point;
-
- ConvParams op_params;
- op_params.padding_type = RuntimePaddingType(params->padding);
- op_params.padding_values.width = data.padding.width;
- op_params.padding_values.height = data.padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = data.output_multiplier;
- op_params.output_shift = -data.output_shift;
- op_params.quantized_activation_min = data.output_activation_min;
- op_params.quantized_activation_max = data.output_activation_max;
- reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<uint8_t>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<uint8_t>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData<int32_t>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<uint8_t>(output),
- tflite::micro::GetTensorShape(im2col),
- tflite::micro::GetTensorData<uint8_t>(im2col), nullptr);
- return kTfLiteOk;
-}
-
TfLiteStatus EvalQuantizedPerChannel(
- TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params,
+ TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params,
const OpData& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output, TfLiteEvalTensor* im2col) {
cmsis_nn_conv_params conv_params;
- conv_params.dilation.h = params->dilation_height_factor;
- conv_params.dilation.w = params->dilation_width_factor;
+ conv_params.dilation.h = params.dilation_height_factor;
+ conv_params.dilation.w = params.dilation_width_factor;
// TODO(#43557) Remove checks for dilation and call to reference
// implementation when dilation is supported in the optimized implementation
// by CMSIS-NN.
if (conv_params.dilation.h == 1 && conv_params.dilation.w == 1) {
// Initialize cmsis_nn convolution parameters
- conv_params.input_offset = -data.input_zero_point;
- conv_params.output_offset = data.output_zero_point;
- conv_params.stride.h = params->stride_height;
- conv_params.stride.w = params->stride_width;
- conv_params.padding.h = data.padding.height;
- conv_params.padding.w = data.padding.width;
- conv_params.activation.min = data.output_activation_min;
- conv_params.activation.max = data.output_activation_max;
+ conv_params.input_offset = -data.reference_op_data.input_zero_point;
+ conv_params.output_offset = data.reference_op_data.output_zero_point;
+ conv_params.stride.h = params.stride_height;
+ conv_params.stride.w = params.stride_width;
+ conv_params.padding.h = data.reference_op_data.padding.height;
+ conv_params.padding.w = data.reference_op_data.padding.width;
+ conv_params.activation.min = data.reference_op_data.output_activation_min;
+ conv_params.activation.max = data.reference_op_data.output_activation_max;
// Initialize cmsis_nn per channel quantization parameters
cmsis_nn_per_channel_quant_params quant_params;
- quant_params.multiplier =
- const_cast<int32_t*>(data.per_channel_output_multiplier);
- quant_params.shift = const_cast<int32_t*>(data.per_channel_output_shift);
+ quant_params.multiplier = const_cast<int32_t*>(
+ data.reference_op_data.per_channel_output_multiplier);
+ quant_params.shift =
+ const_cast<int32_t*>(data.reference_op_data.per_channel_output_shift);
RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
@@ -340,22 +223,11 @@
tflite::micro::GetTensorData<int8_t>(output)),
ARM_MATH_SUCCESS);
} else {
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.input_offset = -data.input_zero_point;
- op_params.output_offset = data.output_zero_point;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.padding_values.height = data.padding.height;
- op_params.padding_values.width = data.padding.width;
- op_params.quantized_activation_min = data.output_activation_min;
- op_params.quantized_activation_max = data.output_activation_max;
-
reference_integer_ops::ConvPerChannel(
- op_params, data.per_channel_output_multiplier,
- data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
+ ConvParamsQuantized(params, data.reference_op_data),
+ data.reference_op_data.per_channel_output_multiplier,
+ data.reference_op_data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
@@ -367,54 +239,20 @@
return kTfLiteOk;
}
-TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, const OpData& data,
- const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* filter,
- const TfLiteEvalTensor* bias, TfLiteEvalTensor* im2col,
- TfLiteEvalTensor* hwcn_weights,
- TfLiteEvalTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.padding_type = RuntimePaddingType(params->padding);
- op_params.padding_values.width = data.padding.width;
- op_params.padding_values.height = data.padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<float>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<float>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData<float>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<float>(output),
- tflite::micro::GetTensorShape(im2col),
- tflite::micro::GetTensorData<float>(im2col));
- return kTfLiteOk;
-}
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
+ const auto& params =
+ *(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));
const TfLiteEvalTensor* input =
- tflite::micro::GetEvalInput(context, node, kInputTensor);
+ tflite::micro::GetEvalInput(context, node, kConvInputTensor);
const TfLiteEvalTensor* filter =
- tflite::micro::GetEvalInput(context, node, kFilterTensor);
+ tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
- ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
+ ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
: nullptr;
TfLiteEvalTensor* output =
- tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+ tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
@@ -424,18 +262,38 @@
"Hybrid models are not supported on TFLite Micro.");
switch (input->type) { // Already know in/out types are same.
- case kTfLiteFloat32:
- EvalFloat(context, node, params, data, input, filter, bias, nullptr,
- nullptr, output);
+ case kTfLiteFloat32: {
+ tflite::reference_ops::Conv(
+ ConvParamsFloat(params, data.reference_op_data),
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<float>(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData<float>(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData<float>(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<float>(output),
+ tflite::micro::GetTensorShape(nullptr), nullptr);
break;
+ }
case kTfLiteInt8:
return EvalQuantizedPerChannel(context, node, params, data, input, filter,
bias, output, nullptr);
break;
- case kTfLiteUInt8:
- return EvalQuantized(context, node, params, data, input, filter, bias,
- nullptr, nullptr, output);
+ case kTfLiteUInt8: {
+ reference_ops::Conv(ConvParamsQuantized(params, data.reference_op_data),
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<uint8_t>(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData<uint8_t>(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData<int32_t>(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<uint8_t>(output),
+ tflite::micro::GetTensorShape(nullptr), nullptr,
+ nullptr);
break;
+ }
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
diff --git a/tensorflow/lite/micro/kernels/conv.cc b/tensorflow/lite/micro/kernels/conv.cc
index dc821df..4530f94 100644
--- a/tensorflow/lite/micro/kernels/conv.cc
+++ b/tensorflow/lite/micro/kernels/conv.cc
@@ -13,12 +13,13 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/lite/kernels/internal/reference/conv.h"
+#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@@ -28,294 +29,74 @@
namespace tflite {
namespace {
-constexpr int kInputTensor = 0;
-constexpr int kFilterTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-// Conv is quantized along dimension 0:
-// https://www.tensorflow.org/lite/performance/quantization_spec
-constexpr int kConvQuantizedDimension = 0;
-
-// This file has 2 implementation of Conv.
-
-struct OpData {
- TfLitePaddingValues padding;
-
- // Cached tensor zero point values for quantized operations.
- int32_t input_zero_point;
- int32_t filter_zero_point;
- int32_t output_zero_point;
-
- // The scaling factor from input to output (aka the 'real multiplier') can
- // be represented as a fixed point multiplier plus a left shift.
- int32_t output_multiplier;
- int output_shift;
-
- // Per channel output multiplier and shift.
- int32_t* per_channel_output_multiplier;
- int32_t* per_channel_output_shift;
-
- // The range of the fused activation layer. For example for kNone and
- // uint8_t these would be 0 and 255.
- int32_t output_activation_min;
- int32_t output_activation_max;
-};
-
-inline PaddingType RuntimePaddingType(TfLitePadding padding) {
- switch (padding) {
- case TfLitePadding::kTfLitePaddingSame:
- return PaddingType::kSame;
- case TfLitePadding::kTfLitePaddingValid:
- return PaddingType::kValid;
- case TfLitePadding::kTfLitePaddingUnknown:
- default:
- return PaddingType::kNone;
- }
-}
-
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
- const TfLiteConvParams* params, int width,
- int height, int filter_width, int filter_height,
- int out_width, int out_height,
- const TfLiteType data_type, OpData* data) {
- bool has_bias = node->inputs->size == 3;
- // Check number of inputs/outputs
- TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-
- // Matching GetWindowedOutputSize in TensorFlow.
- auto padding = params->padding;
- data->padding = ComputePaddingHeightWidth(
- params->stride_height, params->stride_width,
- params->dilation_height_factor, params->dilation_width_factor, height,
- width, filter_height, filter_width, padding, &out_height, &out_width);
-
- // Note that quantized inference requires that all tensors have their
- // parameters set. This is usually done during quantized training.
- if (data_type != kTfLiteFloat32) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE(context, input != nullptr);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- TF_LITE_ENSURE(context, filter != nullptr);
- const TfLiteTensor* bias =
- GetOptionalInputTensor(context, node, kBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE(context, output != nullptr);
- int output_channels = filter->dims->data[kConvQuantizedDimension];
-
- TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
- context, input, filter, bias, output, params->activation,
- &data->output_multiplier, &data->output_shift,
- &data->output_activation_min, &data->output_activation_max,
- data->per_channel_output_multiplier,
- reinterpret_cast<int*>(data->per_channel_output_shift),
- output_channels));
- }
- return kTfLiteOk;
-}
-
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
- return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->user_data != nullptr);
- TFLITE_DCHECK(node->builtin_data != nullptr);
-
- OpData* data = static_cast<OpData*>(node->user_data);
- const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
-
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE(context, output != nullptr);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE(context, input != nullptr);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- TF_LITE_ENSURE(context, filter != nullptr);
-
- int input_width = input->dims->data[2];
- int input_height = input->dims->data[1];
- int filter_width = filter->dims->data[2];
- int filter_height = filter->dims->data[1];
- int output_width = output->dims->data[2];
- int output_height = output->dims->data[1];
-
- // Dynamically allocate per-channel quantization parameters.
- const int num_channels = filter->dims->data[kConvQuantizedDimension];
- data->per_channel_output_multiplier =
- static_cast<int32_t*>(context->AllocatePersistentBuffer(
- context, num_channels * sizeof(int32_t)));
- data->per_channel_output_shift =
- static_cast<int32_t*>(context->AllocatePersistentBuffer(
- context, num_channels * sizeof(int32_t)));
-
- // All per-channel quantized tensors need valid zero point and scale arrays.
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE_EQ(context, filter->quantization.type,
- kTfLiteAffineQuantization);
-
- const auto* affine_quantization =
- static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
- TF_LITE_ENSURE(context, affine_quantization);
- TF_LITE_ENSURE(context, affine_quantization->scale);
- TF_LITE_ENSURE(context, affine_quantization->zero_point);
-
- TF_LITE_ENSURE(context,
- affine_quantization->scale->size == 1 ||
- affine_quantization->scale->size ==
- filter->dims->data[kConvQuantizedDimension]);
- TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
- affine_quantization->zero_point->size);
- }
-
- TF_LITE_ENSURE_STATUS(CalculateOpData(
- context, node, params, input_width, input_height, filter_width,
- filter_height, output_width, output_height, input->type, data));
-
- data->input_zero_point = input->params.zero_point;
- data->filter_zero_point = filter->params.zero_point;
- data->output_zero_point = output->params.zero_point;
-
- return kTfLiteOk;
-} // namespace conv
-
-void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, const OpData& data,
- const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
- TfLiteEvalTensor* im2col, TfLiteEvalTensor* hwcn_weights,
- TfLiteEvalTensor* output) {
- const int32_t input_offset = -data.input_zero_point;
- const int32_t filter_offset = -data.filter_zero_point;
- const int32_t output_offset = data.output_zero_point;
-
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.padding_type = RuntimePaddingType(params->padding);
- op_params.padding_values.width = data.padding.width;
- op_params.padding_values.height = data.padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = data.output_multiplier;
- op_params.output_shift = -data.output_shift;
- op_params.quantized_activation_min = data.output_activation_min;
- op_params.quantized_activation_max = data.output_activation_max;
- reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<uint8_t>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<uint8_t>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData<int32_t>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<uint8_t>(output),
- tflite::micro::GetTensorShape(im2col),
- tflite::micro::GetTensorData<uint8_t>(im2col), nullptr);
-}
-
-void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, const OpData& data,
- const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* filter,
- const TfLiteEvalTensor* bias,
- TfLiteEvalTensor* output,
- TfLiteEvalTensor* im2col) {
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.input_offset = -data.input_zero_point;
- op_params.output_offset = data.output_zero_point;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.padding_values.height = data.padding.height;
- op_params.padding_values.width = data.padding.width;
- op_params.quantized_activation_min = data.output_activation_min;
- op_params.quantized_activation_max = data.output_activation_max;
-
- reference_integer_ops::ConvPerChannel(
- op_params, data.per_channel_output_multiplier,
- data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<int8_t>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<int8_t>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData<int32_t>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<int8_t>(output));
-}
-
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, const OpData& data,
- const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter,
- const TfLiteEvalTensor* bias, TfLiteEvalTensor* im2col,
- TfLiteEvalTensor* hwcn_weights, TfLiteEvalTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.padding_type = RuntimePaddingType(params->padding);
- op_params.padding_values.width = data.padding.width;
- op_params.padding_values.height = data.padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<float>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<float>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData<float>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<float>(output),
- tflite::micro::GetTensorShape(im2col),
- tflite::micro::GetTensorData<float>(im2col));
+ return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
-
const TfLiteEvalTensor* input =
- tflite::micro::GetEvalInput(context, node, kInputTensor);
+ tflite::micro::GetEvalInput(context, node, kConvInputTensor);
const TfLiteEvalTensor* filter =
- tflite::micro::GetEvalInput(context, node, kFilterTensor);
+ tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
- ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
+ ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
: nullptr;
TfLiteEvalTensor* output =
- tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+ tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+ const auto& params =
+ *(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));
TFLITE_DCHECK(node->user_data != nullptr);
- const OpData& data = *(static_cast<const OpData*>(node->user_data));
+ const auto& data = *(static_cast<const OpDataConv*>(node->user_data));
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
switch (input->type) { // Already know in/out types are same.
- case kTfLiteFloat32:
- EvalFloat(context, node, params, data, input, filter, bias, nullptr,
- nullptr, output);
+ case kTfLiteFloat32: {
+ tflite::reference_ops::Conv(
+ ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<float>(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData<float>(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData<float>(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<float>(output),
+ tflite::micro::GetTensorShape(nullptr), nullptr);
break;
- case kTfLiteInt8:
- EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
- output, nullptr);
+ }
+ case kTfLiteInt8: {
+ reference_integer_ops::ConvPerChannel(
+ ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
+ data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<int8_t>(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData<int8_t>(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData<int32_t>(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int8_t>(output));
break;
- case kTfLiteUInt8:
- EvalQuantized(context, node, params, data, input, filter, bias, nullptr,
- nullptr, output);
+ }
+ case kTfLiteUInt8: {
+ reference_ops::Conv(ConvParamsQuantized(params, data),
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<uint8_t>(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData<uint8_t>(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData<int32_t>(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<uint8_t>(output),
+ tflite::micro::GetTensorShape(nullptr), nullptr,
+ nullptr);
break;
+ }
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
@@ -329,7 +110,7 @@
TfLiteRegistration Register_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
- /*prepare=*/Prepare,
+ /*prepare=*/ConvPrepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h
new file mode 100644
index 0000000..46bc731
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/conv.h
@@ -0,0 +1,77 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
+#define TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
+
+#include <cstdint>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+struct OpDataConv {
+ TfLitePaddingValues padding;
+
+ // Cached tensor zero point values for quantized operations.
+ int32_t input_zero_point;
+ int32_t filter_zero_point;
+ int32_t output_zero_point;
+
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multiplier plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+
+ // Per channel output multiplier and shift.
+ int32_t* per_channel_output_multiplier;
+ int32_t* per_channel_output_shift;
+
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+};
+
+extern const int kConvInputTensor;
+extern const int kConvWeightsTensor;
+extern const int kConvBiasTensor;
+extern const int kConvOutputTensor;
+extern const int kConvQuantizedDimension;
+
+// Returns a ConvParams struct with all the parameters needed for a
+// float computation.
+ConvParams ConvParamsFloat(const TfLiteConvParams& params,
+ const OpDataConv& data);
+
+// Returns a ConvParams struct with all the parameters needed for a
+// quantized computation.
+ConvParams ConvParamsQuantized(const TfLiteConvParams& params,
+ const OpDataConv& data);
+
+TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteConvParams& params, int width,
+ int height, int filter_width,
+ int filter_height, int out_width,
+ int out_height, const TfLiteType data_type,
+ OpDataConv* data);
+
+TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
diff --git a/tensorflow/lite/micro/kernels/conv_common.cc b/tensorflow/lite/micro/kernels/conv_common.cc
new file mode 100644
index 0000000..a4a36ae
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/conv_common.cc
@@ -0,0 +1,182 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/conv.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/padding.h"
+#include "tensorflow/lite/micro/kernels/conv.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+
+namespace tflite {
+
+const int kConvInputTensor = 0;
+const int kConvWeightsTensor = 1;
+const int kConvBiasTensor = 2;
+const int kConvOutputTensor = 0;
+
+// Conv is quantized along dimension 0:
+// https://www.tensorflow.org/lite/performance/quantization_spec
+const int kConvQuantizedDimension = 0;
+
+// Returns a ConvParams struct with all the parameters needed for a
+// float computation.
+ConvParams ConvParamsFloat(const TfLiteConvParams& params,
+ const OpDataConv& data) {
+ ConvParams op_params;
+ CalculateActivationRange(params.activation, &op_params.float_activation_min,
+ &op_params.float_activation_max);
+ op_params.padding_type = tflite::micro::RuntimePaddingType(params.padding);
+ op_params.padding_values.width = data.padding.width;
+ op_params.padding_values.height = data.padding.height;
+ op_params.stride_width = params.stride_width;
+ op_params.stride_height = params.stride_height;
+ op_params.dilation_width_factor = params.dilation_width_factor;
+ op_params.dilation_height_factor = params.dilation_height_factor;
+ return op_params;
+}
+
+// Returns a ConvParams struct with all the parameters needed for a
+// quantized computation.
+ConvParams ConvParamsQuantized(const TfLiteConvParams& params,
+ const OpDataConv& data) {
+ ConvParams op_params;
+ op_params.input_offset = -data.input_zero_point;
+ op_params.weights_offset = -data.filter_zero_point;
+ op_params.output_offset = data.output_zero_point;
+ op_params.output_multiplier = data.output_multiplier;
+ op_params.output_shift = -data.output_shift;
+ op_params.padding_type = tflite::micro::RuntimePaddingType(params.padding);
+ op_params.padding_values.height = data.padding.height;
+ op_params.padding_values.width = data.padding.width;
+ op_params.stride_height = params.stride_height;
+ op_params.stride_width = params.stride_width;
+ op_params.dilation_height_factor = params.dilation_height_factor;
+ op_params.dilation_width_factor = params.dilation_width_factor;
+ op_params.quantized_activation_min = data.output_activation_min;
+ op_params.quantized_activation_max = data.output_activation_max;
+ return op_params;
+}
+
+TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteConvParams& params, int width,
+ int height, int filter_width,
+ int filter_height, int out_width,
+ int out_height, const TfLiteType data_type,
+ OpDataConv* data) {
+ bool has_bias = node->inputs->size == 3;
+ // Check number of inputs/outputs
+ TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ // Matching GetWindowedOutputSize in TensorFlow.
+ auto padding = params.padding;
+ data->padding = ComputePaddingHeightWidth(
+ params.stride_height, params.stride_width, params.dilation_height_factor,
+ params.dilation_width_factor, height, width, filter_height, filter_width,
+ padding, &out_height, &out_width);
+
+ const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
+ TF_LITE_ENSURE(context, input != nullptr);
+ const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
+ TF_LITE_ENSURE(context, filter != nullptr);
+ const TfLiteTensor* bias =
+ GetOptionalInputTensor(context, node, kConvBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ if (data_type != kTfLiteFloat32) {
+ int output_channels = filter->dims->data[kConvQuantizedDimension];
+
+ TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
+ context, input, filter, bias, output, params.activation,
+ &data->output_multiplier, &data->output_shift,
+ &data->output_activation_min, &data->output_activation_max,
+ data->per_channel_output_multiplier,
+ reinterpret_cast<int*>(data->per_channel_output_shift),
+ output_channels));
+ }
+
+ data->input_zero_point = input->params.zero_point;
+ data->filter_zero_point = filter->params.zero_point;
+ data->output_zero_point = output->params.zero_point;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TFLITE_DCHECK(node->user_data != nullptr);
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+
+ OpDataConv* data = static_cast<OpDataConv*>(node->user_data);
+ const auto& params =
+ *(static_cast<const TfLiteConvParams*>(node->builtin_data));
+
+ TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
+ const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
+ TF_LITE_ENSURE(context, input != nullptr);
+ const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
+ TF_LITE_ENSURE(context, filter != nullptr);
+
+ const int input_width = input->dims->data[2];
+ const int input_height = input->dims->data[1];
+ const int filter_width = filter->dims->data[2];
+ const int filter_height = filter->dims->data[1];
+ const int output_width = output->dims->data[2];
+ const int output_height = output->dims->data[1];
+
+ // Dynamically allocate per-channel quantization parameters.
+ const int num_channels = filter->dims->data[kConvQuantizedDimension];
+ data->per_channel_output_multiplier =
+ static_cast<int32_t*>(context->AllocatePersistentBuffer(
+ context, num_channels * sizeof(int32_t)));
+ data->per_channel_output_shift =
+ static_cast<int32_t*>(context->AllocatePersistentBuffer(
+ context, num_channels * sizeof(int32_t)));
+
+ // All per-channel quantized tensors need valid zero point and scale arrays.
+ if (input->type == kTfLiteInt8) {
+ TF_LITE_ENSURE_EQ(context, filter->quantization.type,
+ kTfLiteAffineQuantization);
+
+ const auto* affine_quantization =
+ static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
+ TFLITE_DCHECK(affine_quantization != nullptr);
+ TFLITE_DCHECK(affine_quantization->scale != nullptr);
+ TFLITE_DCHECK(affine_quantization->zero_point != nullptr);
+
+ TF_LITE_ENSURE(context,
+ affine_quantization->scale->size == 1 ||
+ affine_quantization->scale->size ==
+ filter->dims->data[kConvQuantizedDimension]);
+ TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
+ affine_quantization->zero_point->size);
+ }
+
+ TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
+ context, node, params, input_width, input_height, filter_width,
+ filter_height, output_width, output_height, input->type, data));
+
+ return kTfLiteOk;
+}
+} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/elu.cc b/tensorflow/lite/micro/kernels/elu.cc
index ec8cc36..12d287d 100644
--- a/tensorflow/lite/micro/kernels/elu.cc
+++ b/tensorflow/lite/micro/kernels/elu.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,59 +12,31 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <stddef.h>
+
+#include "tensorflow/lite/kernels/internal/reference/elu.h"
#include <algorithm>
#include <cmath>
-#include <cstdint>
#include <functional>
#include <limits>
-#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/cpu_backend_context.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/compatibility.h"
-#include "tensorflow/lite/kernels/internal/cppmath.h"
-#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
-#include "tensorflow/lite/kernels/internal/reference/logistic.h"
-#include "tensorflow/lite/kernels/internal/reference/prelu.h"
-#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
-#include "tensorflow/lite/kernels/internal/reference/softmax.h"
-#include "tensorflow/lite/kernels/internal/reference/tanh.h"
-#include "tensorflow/lite/kernels/internal/tensor.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
-
-#if __aarch64__ && __clang__
-#include <arm_neon.h>
-#endif
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
-namespace builtin {
+namespace micro {
namespace activations {
+namespace {
// OLD-TODO(b/142762739): We should figure out a multi-threading plan for most
// of the activation ops below.
-enum KernelType {
- kReference,
- kGenericOptimized,
- kFixedPointOptimized,
-};
-
struct OpData {
- int32_t input_multiplier = 0;
- int input_left_shift = 0;
- int32_t input_range_radius = 0;
- int diff_min = 0;
uint8_t table[256] = {0};
};
@@ -97,42 +69,19 @@
uint8_t* output_data = GetTensorData<uint8_t>(output);
const uint8_t* input_data = GetTensorData<uint8_t>(input);
int i = 0;
-#if __aarch64__ && __clang__
- // This code uses ARM64-only instructions.
- // OLD-TODO(b/143709993): Port to ARMv7
- // Load the tables into registers. (4*4 128-bit registers)
- uint8x16x4_t table[4];
- table[0] = vld1q_u8_x4(data->table + 16 * 4 * 0);
- table[1] = vld1q_u8_x4(data->table + 16 * 4 * 1);
- table[2] = vld1q_u8_x4(data->table + 16 * 4 * 2);
- table[3] = vld1q_u8_x4(data->table + 16 * 4 * 3);
-
- // Vectorized loop; process uint8x16_t (16 elements) at a time.
- constexpr int vectorized_16_loop_step = 16;
- const int vectorized_16_loop_end =
- size / vectorized_16_loop_step * vectorized_16_loop_step;
- for (; i < vectorized_16_loop_end; i += vectorized_16_loop_step) {
- uint8x16_t input = vld1q_u8(input_data + i);
- uint8x16_t output = optimized_ops::aarch64_lookup_vector(table, input);
- vst1q_u8(output_data + i, output);
- }
- // Postamble and non-ARM64 code: simple for loop.
-#endif
for (; i < size; ++i) {
output_data[i] = data->table[input_data[i]];
}
}
+} // namespace
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
- return new OpData;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<OpData*>(buffer);
+ return nullptr;
}
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
@@ -144,8 +93,7 @@
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
- return context->ResizeTensor(context, output,
- TfLiteIntArrayCopy(input->dims));
+ return kTfLiteError;
}
TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
@@ -174,12 +122,12 @@
optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
- } break;
+ }
case kTfLiteInt8: {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
EvalUsingLookupTable(data, input, output);
return kTfLiteOk;
- } break;
+ }
default:
TF_LITE_KERNEL_LOG(
context, "Only float32 and int8 is supported currently, got %s.",
@@ -190,12 +138,8 @@
} // namespace activations
-TfLiteRegistration* Register_ELU() {
- static TfLiteRegistration r = {activations::Init, activations::Free,
- activations::EluPrepare, activations::EluEval};
- return &r;
-}
+TfLiteRegistration* Register_ELU() { return nullptr; }
-} // namespace builtin
+} // namespace micro
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/elu_test.cc b/tensorflow/lite/micro/kernels/elu_test.cc
index 5dedc7a..5eb893b 100644
--- a/tensorflow/lite/micro/kernels/elu_test.cc
+++ b/tensorflow/lite/micro/kernels/elu_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,150 +12,33 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <math.h>
-#include <stdint.h>
-#include <stdlib.h>
-
-#include <algorithm>
-#include <initializer_list>
#include <limits>
-#include <map>
-#include <memory>
-#include <random>
-#include <string>
-#include <utility>
-#include <vector>
+#include <type_traits>
-#include "absl/memory/memory.h"
-#include "flatbuffers/flatbuffers.h" // from @flatbuffers
-#include "tensorflow/lite/core/api/op_resolver.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/test_util.h"
-#include "tensorflow/lite/schema/schema_generated.h"
-#include "tensorflow/lite/string_type.h"
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/micro/kernels/kernel_runner.h"
+#include "tensorflow/lite/micro/test_helpers.h"
+#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
-
+namespace testing {
namespace {
-using ::testing::ElementsAreArray;
-
-class BaseActivationsOpModel : public SingleOpModel {
- public:
- // Most activations don't take any options, so this constructor works for
- // them.
- BaseActivationsOpModel(BuiltinOperator type, TensorData input) {
- input_ = AddInput(input);
- if (input.type == TensorType_UINT8) {
- output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
- } else if (input.type == TensorType_INT8) {
- output_ = AddOutput({input.type, {}, 0, 0, 1. / 256, -128});
- } else {
- output_ = AddOutput({input.type, {}});
- }
- SetBuiltinOp(type, BuiltinOptions_NONE, 0);
- BuildInterpreter({GetShape(input_)});
+#ifdef notdef
+BaseActivationsOpModel(BuiltinOperator type, TensorData input) {
+ input_ = AddInput(input);
+ if (input.type == TensorType_UINT8) {
+ output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
+ } else if (input.type == TensorType_INT8) {
+ output_ = AddOutput({input.type, {}, 0, 0, 1. / 256, -128});
+ } else {
+ output_ = AddOutput({input.type, {}});
}
-
- BaseActivationsOpModel(TfLiteRegistration* registration, BuiltinOperator type,
- TensorData input) {
- input_ = AddInput(input);
- if (input.type == TensorType_UINT8) {
- output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
- } else if (input.type == TensorType_INT8) {
- output_ = AddOutput({input.type, {}, 0, 0, 1. / 256, -128});
- } else {
- output_ = AddOutput({input.type, {}});
- }
- SetBuiltinOp(type, BuiltinOptions_NONE, 0);
- resolver_ = absl::make_unique<SingleOpResolver>(type, registration);
- BuildInterpreter({GetShape(input_)});
- }
-
- // A dedicated constructor for SOFTMAX, which does some options.
- BaseActivationsOpModel(float softmax_beta, TensorData input,
- TensorType output_type) {
- input_ = AddInput(input);
- if (output_type == TensorType_UINT8) {
- output_ = AddOutput({TensorType_UINT8, {}, 0, 0, 1. / 256});
- } else if (output_type == TensorType_INT8) {
- output_ = AddOutput({TensorType_INT8, {}, 0, 0, 1. / 256, -128});
- } else if (input.type == TensorType_INT16 &&
- output_type == TensorType_INT16) {
- output_ = AddOutput({TensorType_INT16,
- {},
- 0,
- 0,
- 1.0f / (std::numeric_limits<int16_t>::max() + 1),
- 0});
- } else if (input.type != TensorType_INT16 &&
- output_type == TensorType_INT16) {
- output_ = AddOutput({TensorType_INT16, {}, 0, 0, 1. / 32768, -16384});
- } else {
- output_ = AddOutput({output_type, {}});
- }
- SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
- CreateSoftmaxOptions(builder_, softmax_beta).Union());
- BuildInterpreter({GetShape(input_)});
- }
-
- // A dedicated constructor for LeakyRelu, which does some options.
- BaseActivationsOpModel(TensorData input, float alpha) {
- input_ = AddInput(input);
- // The output scale and input scale might be different.
- if (input.type == TensorType_UINT8 || input.type == TensorType_INT8 ||
- input.type == TensorType_INT16) {
- auto output_min = (input.min >= 0) ? input.min : input.min * alpha;
- auto output_max = (input.max >= 0) ? input.max : input.max * alpha;
- if (input.type == TensorType_INT16) {
- output_ = AddOutput({TensorType_INT16,
- {},
- 0,
- 0,
- output_max / (std::numeric_limits<int16_t>::max()),
- 0});
- } else {
- output_ = AddOutput({input.type, {}, output_min, output_max});
- }
- } else {
- output_ = AddOutput({input.type, {}});
- }
- SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
- CreateLeakyReluOptions(builder_, alpha).Union());
- BuildInterpreter({GetShape(input_)});
- }
-
- BaseActivationsOpModel(BuiltinOperator type, const TensorData& input,
- const TensorData& output) {
- input_ = AddInput(input);
- output_ = AddOutput(output);
- SetBuiltinOp(type, BuiltinOptions_NONE, 0);
- BuildInterpreter({GetShape(input_)});
- }
-
- BaseActivationsOpModel(TfLiteRegistration* registration, BuiltinOperator type,
- const TensorData& input, const TensorData& output) {
- input_ = AddInput(input);
- output_ = AddOutput(output);
- SetBuiltinOp(type, BuiltinOptions_NONE, 0);
- resolver_ = absl::make_unique<SingleOpResolver>(type, registration);
- BuildInterpreter({GetShape(input_)});
- }
-
- protected:
- int input_;
- int output_;
-};
-
-class FloatActivationsOpModel : public BaseActivationsOpModel {
- public:
- using BaseActivationsOpModel::BaseActivationsOpModel;
-
- void SetInput(const std::vector<float>& data) {
- PopulateTensor(input_, data);
- }
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
-};
+ SetBuiltinOp(type, BuiltinOptions_NONE, 0);
+ BuildInterpreter({GetShape(input_)});
+}
+#endif // notdef
// Our fixed-point math function implementations have roughly 12 bits of
// accuracy, when specialized to 16-bit fixed-point arithmetic.
@@ -176,41 +59,25 @@
const float kQuantizedTolerance = 2 * (1. / 256);
const float kQuantizedToleranceInt16 = 2 * (1. / 4096);
-class QuantizedActivationsOpModel : public BaseActivationsOpModel {
- public:
- using BaseActivationsOpModel::BaseActivationsOpModel;
+TF_LITE_MICRO_TESTS_BEGIN
- template <typename T>
- void SetInput(const std::vector<float>& data) {
- QuantizeAndPopulate<T>(input_, data);
- }
- template <typename T>
- std::vector<T> GetOutput() {
- return ExtractVector<T>(output_);
- }
-
- template <typename T>
- std::vector<float> GetDequantizedOutput() {
- return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
- GetZeroPoint(output_));
- }
-};
-
-TEST(FloatActivationsOpTest, Elu) {
+TF_LITE_MICRO_TEST(FloatActivationsOpTestElu) {
+#ifdef notdef
FloatActivationsOpModel m(BuiltinOperator_ELU,
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
m.SetInput({
0, -6, 2, -4, //
3, -2, 10, -0.1, //
});
- m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
0.0, -0.997521, 2.0, -0.981684, //
3.0, -0.864665, 10.0, -0.0951626, //
})));
+#endif // notdef
}
-TEST(QuantizedActivationsOpTest, EluInt8) {
+TF_LITE_MICRO_TEST(QuantizedActivationsOpTestEluInt8) {
+#ifdef notdef
const float kMin = -1;
const float kMax = 127.f / 128.f;
QuantizedActivationsOpModel model(
@@ -231,7 +98,11 @@
3.0, -0.875, 6.0, -0.125, //
},
kQuantizedTolerance)));
+#endif // notdef
}
+TF_LITE_MICRO_TESTS_END
+
} // namespace
+} // namespace testing
} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/kernel_util.cc b/tensorflow/lite/micro/kernels/kernel_util.cc
index deca92b..d769f9e 100644
--- a/tensorflow/lite/micro/kernels/kernel_util.cc
+++ b/tensorflow/lite/micro/kernels/kernel_util.cc
@@ -37,5 +37,17 @@
return RuntimeShape(dims_size, dims_data);
}
+PaddingType RuntimePaddingType(TfLitePadding padding) {
+ switch (padding) {
+ case TfLitePadding::kTfLitePaddingSame:
+ return PaddingType::kSame;
+ case TfLitePadding::kTfLitePaddingValid:
+ return PaddingType::kValid;
+ case TfLitePadding::kTfLitePaddingUnknown:
+ default:
+ return PaddingType::kNone;
+ }
+}
+
} // namespace micro
} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h
index 79cd58e..043fb02 100644
--- a/tensorflow/lite/micro/kernels/kernel_util.h
+++ b/tensorflow/lite/micro/kernels/kernel_util.h
@@ -18,6 +18,7 @@
#include <cstdint>
+#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/types.h"
@@ -69,6 +70,8 @@
bool HaveSameShapes(const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2);
+PaddingType RuntimePaddingType(TfLitePadding padding);
+
} // namespace micro
} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/quantize_test.cc b/tensorflow/lite/micro/kernels/quantize_test.cc
index ad302f0..b5da979 100644
--- a/tensorflow/lite/micro/kernels/quantize_test.cc
+++ b/tensorflow/lite/micro/kernels/quantize_test.cc
@@ -49,7 +49,7 @@
}
}
-#if !defined(HIFIMINI)
+#if !defined(XTENSA)
template <typename T>
void TestQuantizeFloat(const int* input_dims_data, const float* input_data,
const int* output_dims_data, const float* golden,
@@ -79,7 +79,7 @@
ValidateQuantizeGoldens(tensors, tensors_size, golden, golden_quantized,
scale, zero_point, output_dims_count, output_data);
}
-#endif // defined(HIFIMINI)
+#endif // defined(XTENSA)
template <typename InputType, typename OutputType>
void TestRequantize(const int* input_dims_data, const float* input_data,
@@ -121,7 +121,7 @@
TF_LITE_MICRO_TESTS_BEGIN
-#if !defined(HIFIMINI)
+#if !defined(XTENSA)
TF_LITE_MICRO_TEST(QuantizeOpTestUint8) {
const int length = 10;
const int dims[] = {2, 2, 5};
@@ -267,9 +267,9 @@
values_quantized, output_scale,
output_zero_point, output_quantized);
}
-#endif // defined(HIFIMINI)
+#endif // defined(XTENSA)
-#if !defined(HIFIMINI)
+#if !defined(XTENSA)
// TODO(b/155682734): Hifimini optimized quantize requires input scale to be
// smaller then output scale.
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
@@ -288,7 +288,7 @@
values_quantized, output_scale,
output_zero_point, output_quantized);
}
-#endif // defined(HIFIMINI)
+#endif // defined(XTENSA)
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt32) {
const int length = 10;
diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc
index 41a11a8..04c2c15 100644
--- a/tensorflow/lite/micro/kernels/xtensa/conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc
@@ -13,12 +13,13 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/lite/kernels/internal/reference/conv.h"
+#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@@ -30,36 +31,6 @@
namespace tflite {
namespace {
-constexpr int kInputTensor = 0;
-constexpr int kFilterTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-// Conv is quantized along dimension 0:
-// https://www.tensorflow.org/lite/performance/quantization_spec
-constexpr int kConvQuantizedDimension = 0;
-
-struct OpData {
- TfLitePaddingValues padding;
- // The scaling factor from input to output (aka the 'real multiplier') can
- // be represented as a fixed point multiplier plus a left shift.
- int32_t output_multiplier;
- int output_shift;
-
- // Cached tensor zero point values for quantized operations.
- int32_t input_zero_point;
- int32_t output_zero_point;
-
- // Per channel output multiplier and shift.
- int32_t* per_channel_output_multiplier;
- int32_t* per_channel_output_shift;
-
- // The range of the fused activation layer. For example for kNone and
- // uint8_t these would be 0 and 255.
- int32_t output_activation_min;
- int32_t output_activation_max;
-};
-
#if defined(HIFIMINI)
void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier,
const int32_t* output_shift,
@@ -263,164 +234,27 @@
}
#endif
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, int width, int height,
- int filter_width, int filter_height, int out_width,
- int out_height, const TfLiteType data_type,
- OpData* data) {
- bool has_bias = node->inputs->size == 3;
- // Check number of inputs/outputs
- TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-
- // Matching GetWindowedOutputSize in TensorFlow.
- auto padding = params->padding;
- data->padding = ComputePaddingHeightWidth(
- params->stride_height, params->stride_width,
- params->dilation_height_factor, params->dilation_width_factor, height,
- width, filter_height, filter_width, padding, &out_height, &out_width);
-
- // Note that quantized inference requires that all tensors have their
- // parameters set. This is usually done during quantized training.
- if (data_type != kTfLiteFloat32) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- const TfLiteTensor* bias =
- GetOptionalInputTensor(context, node, kBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- int output_channels = filter->dims->data[kConvQuantizedDimension];
-
- return tflite::PopulateConvolutionQuantizationParams(
- context, input, filter, bias, output, params->activation,
- &data->output_multiplier, &data->output_shift,
- &data->output_activation_min, &data->output_activation_max,
- data->per_channel_output_multiplier,
- reinterpret_cast<int*>(data->per_channel_output_shift),
- output_channels);
- }
- return kTfLiteOk;
-}
-
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
- return context->AllocatePersistentBuffer(context, sizeof(OpData));
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->user_data != nullptr);
- TFLITE_DCHECK(node->builtin_data != nullptr);
- auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
-
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
-
- auto* op_data = reinterpret_cast<OpData*>(node->user_data);
-
- int input_width = input->dims->data[2];
- int input_height = input->dims->data[1];
- int filter_width = filter->dims->data[2];
- int filter_height = filter->dims->data[1];
- int output_width = output->dims->data[2];
- int output_height = output->dims->data[1];
-
- // Per channel quantization is only needed for int8_t inference. For other
- // quantized types, only a single scale and zero point is needed.
- const int num_channels = filter->dims->data[kConvQuantizedDimension];
- // Dynamically allocate per-channel quantization parameters.
- op_data->per_channel_output_multiplier =
- reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
- context, num_channels * sizeof(int32_t)));
- op_data->per_channel_output_shift =
- reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
- context, num_channels * sizeof(int32_t)));
- op_data->input_zero_point = input->params.zero_point;
- op_data->output_zero_point = output->params.zero_point;
- // All per-channel quantized tensors need valid zero point and scale arrays.
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE_EQ(context, filter->quantization.type,
- kTfLiteAffineQuantization);
-
- const auto* affine_quantization =
- reinterpret_cast<TfLiteAffineQuantization*>(
- filter->quantization.params);
- TF_LITE_ENSURE(context, affine_quantization);
- TF_LITE_ENSURE(context, affine_quantization->scale);
- TF_LITE_ENSURE(context, affine_quantization->zero_point);
-
- TF_LITE_ENSURE(context,
- affine_quantization->scale->size == 1 ||
- affine_quantization->scale->size ==
- filter->dims->data[kConvQuantizedDimension]);
- TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
- affine_quantization->zero_point->size);
- }
-
- return CalculateOpData(context, node, params, input_width, input_height,
- filter_width, filter_height, output_width,
- output_height, input->type, op_data);
-}
-
-void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, OpData* data,
- const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* filter,
- const TfLiteEvalTensor* bias,
- TfLiteEvalTensor* output,
- TfLiteEvalTensor* im2col) {
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.input_offset = -data->input_zero_point;
- op_params.output_offset = data->output_zero_point;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.padding_values.height = data->padding.height;
- op_params.padding_values.width = data->padding.width;
- op_params.quantized_activation_min = data->output_activation_min;
- op_params.quantized_activation_max = data->output_activation_max;
-
-#if defined(HIFIMINI)
- ConvPerChannel(op_params, data->per_channel_output_multiplier,
- data->per_channel_output_shift,
- tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<int8_t>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<int8_t>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData<int32_t>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<int8_t>(output));
-#else
- reference_integer_ops::ConvPerChannel(
- op_params, data->per_channel_output_multiplier,
- data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<int8_t>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<int8_t>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetTensorData<int32_t>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<int8_t>(output));
-#endif
+ return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
- auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
- auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ const auto& params =
+ *(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));
+ const auto& op_data = *(reinterpret_cast<OpDataConv*>(node->user_data));
TfLiteEvalTensor* output =
- tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+ tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
const TfLiteEvalTensor* input =
- tflite::micro::GetEvalInput(context, node, kInputTensor);
+ tflite::micro::GetEvalInput(context, node, kConvInputTensor);
const TfLiteEvalTensor* filter =
- tflite::micro::GetEvalInput(context, node, kFilterTensor);
+ tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
- ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
+ ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
: nullptr;
#if defined(HIFIMINI)
@@ -430,10 +264,10 @@
input_dims[3] == 32 && filter_dims[0] == 32 && filter_dims[1] == 1 &&
filter_dims[2] == 1 && filter_dims[3] == 32) {
Conv1x32Input32x32Filter(
- -op_data->input_zero_point, op_data->output_zero_point,
- op_data->output_activation_min, op_data->output_activation_max,
- op_data->per_channel_output_multiplier,
- op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
+ -op_data.input_zero_point, op_data.output_zero_point,
+ op_data.output_activation_min, op_data.output_activation_max,
+ op_data.per_channel_output_multiplier, op_data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
@@ -446,10 +280,35 @@
#endif
switch (input->type) {
- case kTfLiteInt8:
- EvalQuantizedPerChannel(context, node, params, op_data, input, filter,
- bias, output, nullptr);
+ case kTfLiteInt8: {
+#if defined(HIFIMINI)
+ ConvPerChannel(ConvParamsQuantized(params, op_data),
+ op_data.per_channel_output_multiplier,
+ op_data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<int8_t>(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData<int8_t>(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData<int32_t>(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int8_t>(output));
+#else
+ reference_integer_ops::ConvPerChannel(
+ ConvParamsQuantized(params, op_data),
+ op_data.per_channel_output_multiplier,
+ op_data.per_channel_output_shift,
+ tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<int8_t>(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData<int8_t>(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData<int32_t>(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int8_t>(output));
+#endif
break;
+ }
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
@@ -462,7 +321,7 @@
TfLiteRegistration Register_CONV_2D() {
return {/*init=*/Init,
/*free=*/nullptr,
- /*prepare=*/Prepare,
+ /*prepare=*/ConvPrepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
diff --git a/tensorflow/lite/micro/kernels/xtensa/quantize.cc b/tensorflow/lite/micro/kernels/xtensa/quantize.cc
index 3b84e06..e5b5a07 100644
--- a/tensorflow/lite/micro/kernels/xtensa/quantize.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/quantize.cc
@@ -109,25 +109,56 @@
}
}
-TfLiteStatus EvalHifimini(TfLiteContext* context, TfLiteNode* node) {
+#endif // defined(HIFIMINI)
+
+#if defined(HIFIMINI) || defined(FUSION_F1)
+TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
+#if defined(HIFIMINI)
auto* op_data = static_cast<OpData*>(node->user_data);
+#elif defined(FUSION_F1)
+ auto* op_data = static_cast<OpDataQuantizeReference*>(node->user_data);
+#endif
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
if (output->type == kTfLiteInt8 && input->type == kTfLiteInt16) {
+#if defined(HIFIMINI)
AffineQuantize(op_data->scale_multiplier, op_data->zero_point,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
+#elif defined(FUSION_F1)
+ int size = ElementCount(*input->dims);
+ TF_LITE_ENSURE_EQ(
+ context,
+ xa_nn_elm_quantize_asym16s_asym8s(
+ tflite::micro::GetTensorData<int8_t>(output),
+ tflite::micro::GetTensorData<int16_t>(input),
+ op_data->input_zero_point, op_data->quantization_params.zero_point,
+ op_data->requantize_output_shift,
+ op_data->requantize_output_multiplier, size),
+ 0);
+#else
+ static_assert(false, "Unsupported xtensa architecture.");
+#endif
} else if (output->type == kTfLiteInt32 && input->type == kTfLiteInt16) {
int size = ElementCount(*input->dims);
+
+ // This ifdef is only needed because the hifimini code is not following the
+ // convention of the rest of the codebase. Ideally we would be using the
+ // same structs as much as possible and reduce the need for such ifdefs.
+#if defined(HIFIMINI)
+ int32_t zero_point = op_data->zero_point;
+#elif defined(FUSION_F1)
+ int32_t zero_point = op_data->quantization_params.zero_point;
+#endif
reference_ops::Requantize(tflite::micro::GetTensorData<int16_t>(input),
size, op_data->requantize_output_multiplier,
op_data->requantize_output_shift,
- op_data->input_zero_point, op_data->zero_point,
+ op_data->input_zero_point, zero_point,
tflite::micro::GetTensorData<int32_t>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
@@ -137,7 +168,7 @@
}
return kTfLiteOk;
}
-#endif // defined(HIFIMINI)
+#endif // defined(HIFIMINI) || defined(FUSION_F1)
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
@@ -179,8 +210,8 @@
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-#if defined(HIFIMINI)
- return EvalHifimini(context, node);
+#if defined(HIFIMINI) || defined(FUSION_F1)
+ return EvalXtensa(context, node);
#else
return EvalQuantizeReference(context, node);
#endif
diff --git a/tensorflow/lite/micro/kernels/xtensa/softmax.cc b/tensorflow/lite/micro/kernels/xtensa/softmax.cc
index b1d7ecd..a609adc 100644
--- a/tensorflow/lite/micro/kernels/xtensa/softmax.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/softmax.cc
@@ -24,6 +24,7 @@
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
namespace tflite {
namespace {
@@ -32,7 +33,14 @@
struct OpData {
uint16_t* exp_lut;
};
+#elif defined(FUSION_F1)
+struct OpData {
+ SoftmaxParams params;
+ int scratch_tensor_index;
+};
+#endif
+#if defined(HIFIMINI)
// Number of unique int8_t and int16_t values. Used in exponent lookup table
// computation.
constexpr int kInt8Range =
@@ -173,8 +181,63 @@
}
#endif // defined(HIFIMINI)
+#if defined(FUSION_F1)
+TfLiteStatus PrepareHifi4(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_OK(context, SoftmaxPrepare(context, node));
+
+ // Calculate scratch memory requirements and request scratch buffer
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* output = GetOutput(context, node, 0);
+
+ const RuntimeShape& input_shape = GetTensorShape(input);
+ const RuntimeShape& output_shape = GetTensorShape(output);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+ if (input->type == kTfLiteInt8) {
+ int required_scratch =
+ get_softmax_scratch_size(PREC_ASYM8S, PREC_ASYM8S, depth);
+ TF_LITE_ENSURE(context, required_scratch > 0);
+
+ auto* data = static_cast<OpData*>(node->user_data);
+ TF_LITE_ENSURE_OK(
+ context, context->RequestScratchBufferInArena(
+ context, required_scratch, &(data->scratch_tensor_index)));
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHifi4(const OpData* op_data, const TfLiteEvalTensor* input,
+ TfLiteEvalTensor* output, TfLiteContext* context) {
+ const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
+ const int8_t* input_data = tflite::micro::GetTensorData<int8_t>(input);
+ const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
+ int16_t* output_data = tflite::micro::GetTensorData<int16_t>(output);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+ void* p_scratch = static_cast<void*>(
+ context->GetScratchBuffer(context, op_data->scratch_tensor_index));
+
+ for (int i = 0; i < outer_size; ++i) {
+ int err = xa_nn_vec_softmax_asym8s_16(
+ &output_data[i * depth], &input_data[i * depth],
+ op_data->params.diff_min, op_data->params.input_left_shift,
+ op_data->params.input_multiplier, depth, p_scratch);
+ TF_LITE_ENSURE(context, err == 0);
+ }
+ return kTfLiteOk;
+}
+
+#endif // defined(FUSION_F1)
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-#if defined(HIFIMINI)
+#if defined(HIFIMINI) || defined(FUSION_F1)
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
#else
@@ -185,6 +248,8 @@
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
#if defined(HIFIMINI)
return PrepareHifimini(context, node);
+#elif defined(FUSION_F1)
+ return PrepareHifi4(context, node);
#else
return SoftmaxPrepare(context, node);
#endif
@@ -208,7 +273,7 @@
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
-#else // !defined(HIFIMINI)
+#else // !defined(HIFIMINI)
switch (input->type) {
case kTfLiteFloat32: {
SoftmaxParams op_data = *static_cast<SoftmaxParams*>(node->user_data);
@@ -221,12 +286,17 @@
}
case kTfLiteInt8: {
if (output->type == kTfLiteInt16) {
+#if defined(FUSION_F1)
+ return EvalHifi4(static_cast<OpData*>(node->user_data), input, output,
+ context);
+#else
SoftmaxParams op_data = *static_cast<SoftmaxParams*>(node->user_data);
tflite::reference_ops::Softmax(
op_data, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
+#endif
} else {
SoftmaxParams op_data = *static_cast<SoftmaxParams*>(node->user_data);
tflite::reference_ops::Softmax(
diff --git a/tensorflow/lite/micro/kernels/xtensa/svdf.cc b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
index f9d6e18..6aea649 100644
--- a/tensorflow/lite/micro/kernels/xtensa/svdf.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
@@ -51,14 +51,14 @@
* Note: passing OpData by value might seem like an oversight but it helps
* reduce the latency. See b/155656675 for more details.
*/
-void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
- const TfLiteEvalTensor* input_tensor,
- const TfLiteEvalTensor* weights_feature_tensor,
- const TfLiteEvalTensor* weights_time_tensor,
- const TfLiteEvalTensor* bias_tensor,
- const TfLiteSVDFParams* params,
- TfLiteEvalTensor* activation_state_tensor,
- TfLiteEvalTensor* output_tensor, OpData data) {
+void EvalIntegerSvdfHifimini(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteEvalTensor* input_tensor,
+ const TfLiteEvalTensor* weights_feature_tensor,
+ const TfLiteEvalTensor* weights_time_tensor,
+ const TfLiteEvalTensor* bias_tensor,
+ const TfLiteSVDFParams* params,
+ TfLiteEvalTensor* activation_state_tensor,
+ TfLiteEvalTensor* output_tensor, OpData data) {
const int n_rank = params->rank;
const int n_batch = input_tensor->dims->data[0];
const int n_input = input_tensor->dims->data[1];
@@ -243,7 +243,76 @@
}
}
}
-#endif
+
+#elif defined(FUSION_F1)
+
+TfLiteStatus EvalIntegerSvdfHifi4(
+ TfLiteContext* context, TfLiteNode* node,
+ const TfLiteEvalTensor* input_tensor,
+ const TfLiteEvalTensor* weights_feature_tensor,
+ const TfLiteEvalTensor* weights_time_tensor,
+ const TfLiteEvalTensor* bias_tensor, const TfLiteSVDFParams* params,
+ TfLiteEvalTensor* activation_state_tensor, TfLiteEvalTensor* output_tensor,
+ const OpData& data) {
+ const int n_rank = params->rank;
+ const int n_batch = input_tensor->dims->data[0];
+ const int n_input = input_tensor->dims->data[1];
+ const int n_filter = weights_feature_tensor->dims->data[0];
+ const int n_unit = n_filter / n_rank;
+ const int n_memory = weights_time_tensor->dims->data[1];
+
+ TFLITE_DCHECK(context != nullptr);
+ TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
+
+ // Shift states.
+ int16_t* const state_ptr =
+ tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
+
+ // Left shift the activation_state.
+ int num_bytes = sizeof(*state_ptr) * (n_batch * n_filter * n_memory - 1);
+ xa_nn_memmove_16(state_ptr, state_ptr + 1, num_bytes);
+
+ // Note: no need to clear the latest activation, matmul is not accumulative.
+
+ // Feature matmul.
+ const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
+ const int8_t* weight_feature =
+ tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
+ int16_t* result_in_batch = state_ptr + (n_memory - 1);
+
+ for (int b = 0; b < n_batch; b++) {
+ TF_LITE_ENSURE_EQ(context,
+ xa_nn_matXvec_out_stride_sym8sxasym8s_16(
+ &result_in_batch[b * n_filter * n_memory],
+ weight_feature, &input[b * n_input], NULL, n_filter,
+ n_input, n_input, n_memory, -data.input_zero_point,
+ (data.effective_scale_1_a), data.effective_scale_1_b),
+ 0);
+ }
+
+ // Time weights dot product + activation
+ for (int b = 0; b < n_batch; ++b) {
+ const int16_t* vector1_ptr =
+ tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
+ const int16_t* vector2_ptr =
+ tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
+ b * n_memory * n_filter;
+ const int32_t* bias_ptr =
+ tflite::micro::GetTensorData<int32_t>(bias_tensor);
+ int8_t* output_ptr =
+ tflite::micro::GetTensorData<int8_t>(output_tensor) + b * n_unit;
+
+ TF_LITE_ENSURE_EQ(
+ context,
+ xa_nn_dot_prod_16x16_asym8s(
+ output_ptr, vector1_ptr, vector2_ptr, bias_ptr, n_memory * n_rank,
+ (data.effective_scale_2_a), data.effective_scale_2_b,
+ data.output_zero_point, n_unit),
+ 0);
+ }
+ return kTfLiteOk;
+}
+#endif // defined(FUSION_F1) || defined(HIFIMINI)
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context != nullptr);
@@ -274,11 +343,14 @@
const int rank = params->rank;
const int input_size = input->dims->data[1];
const int batch_size = input->dims->data[0];
+
+#if defined(HIFIMINI)
// Ensure the input size is a multiple of two. This is necessary since
// optimized kernels access the memory in chunks of two, and all accesses
// must be aligned to 16 bits.
// TODO(b/153202598): Remove when padding is allowed in TFLite tensors.
TF_LITE_ENSURE_EQ(context, input_size % 2, 0);
+#endif // defined(HIFIMINI)
const int num_filters = weights_feature->dims->data[0];
TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
@@ -339,9 +411,10 @@
static_cast<double>(activation_state->params.scale *
weights_time->params.scale / output->params.scale);
- TF_LITE_ENSURE_EQ(context, static_cast<double>(bias->params.scale),
- static_cast<double>(activation_state->params.scale *
- weights_time->params.scale));
+ TF_LITE_ENSURE_NEAR(context, static_cast<double>(bias->params.scale),
+ static_cast<double>(activation_state->params.scale *
+ weights_time->params.scale),
+ 1e-5);
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
@@ -396,13 +469,18 @@
const OpData& data = *(static_cast<const OpData*>(node->user_data));
#if defined(HIFIMINI)
- EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
- params, activation_state, output, data);
+ EvalIntegerSvdfHifimini(context, node, input, weights_feature, weights_time,
+ bias, params, activation_state, output, data);
+ return kTfLiteOk;
+#elif defined(FUSION_F1)
+ return EvalIntegerSvdfHifi4(context, node, input, weights_feature,
+ weights_time, bias, params, activation_state,
+ output, data);
#else
EvalIntegerSvdfReference(context, node, input, weights_feature, weights_time,
bias, params, activation_state, output, data);
-#endif
return kTfLiteOk;
+#endif
}
} // namespace
diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa.h b/tensorflow/lite/micro/kernels/xtensa/xtensa.h
index 18a68b3..1554b55 100644
--- a/tensorflow/lite/micro/kernels/xtensa/xtensa.h
+++ b/tensorflow/lite/micro/kernels/xtensa/xtensa.h
@@ -20,6 +20,7 @@
#include <xtensa/tie/xt_hifi2.h>
#elif defined(FUSION_F1)
#include "include/nnlib/xa_nnlib_api.h"
+#include "include/nnlib/xa_nnlib_standards.h"
#endif
#endif // TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_XTENSA_H_
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc
deleted file mode 100644
index 01a2f4e..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc
+++ /dev/null
@@ -1,240 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-#include "tensorflow/lite/micro/micro_utils.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace activations {
-
-constexpr int kInputTensor = 0;
-constexpr int kOutputTensor = 0;
-
-template <typename Q>
-inline void ReluQuantized(int32_t lower, const RuntimeShape& input_shape,
- const Q* input_data, const RuntimeShape& output_shape,
- Q* output_data) {
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
- for (int i = 0; i < flat_size; ++i) {
- const Q val = input_data[i];
- const Q clamped = val < lower ? lower : val;
- output_data[i] = clamped;
- }
-}
-
-inline void ReluFloat(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
- for (int i = 0; i < flat_size; ++i) {
- const float val = input_data[i];
- const float lower = 0.0f;
- const float clamped = val < lower ? lower : val;
- output_data[i] = clamped;
- }
-}
-
-inline void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
- for (int i = 0; i < flat_size; ++i) {
- const float val = input_data[i];
- const float upper = 6.0f;
- const float lower = 0.0f;
- const float clamped = val > upper ? upper : val < lower ? lower : val;
- output_data[i] = clamped;
- }
-}
-
-template <typename Q>
-inline void Relu6Quantized(Q lower, Q upper, const RuntimeShape& input_shape,
- const Q* input_data,
- const RuntimeShape& output_shape, Q* output_data) {
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
- for (int i = 0; i < flat_size; ++i) {
- const Q val = input_data[i];
- const Q clamped = val > upper ? upper : val < lower ? lower : val;
- output_data[i] = clamped;
- }
-}
-
-TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
- return kTfLiteOk;
-}
-
-TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- switch (input->type) {
- case kTfLiteFloat32: {
-#if HIFI_VFPU
- int err;
- const float* inp_data_ptr;
- float* out_data_ptr;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
-
- inp_data_ptr = GetTensorData<float>(input);
- out_data_ptr = GetTensorData<float>(output);
-
- err = xa_nn_vec_relu_std_f32_f32(out_data_ptr, inp_data_ptr, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_relu_std_f32_f32 failed");
-#else
- ReluFloat(GetTensorShape(input), GetTensorData<float>(input),
- GetTensorShape(output), GetTensorData<float>(output));
-#endif /* HIFI_VFPU */
- return kTfLiteOk;
- }
- case kTfLiteInt8: {
- ReluQuantized<int8_t>(input->params.zero_point, GetTensorShape(input),
- GetTensorData<int8_t>(input),
- GetTensorShape(output),
- GetTensorData<int8_t>(output));
- return kTfLiteOk;
- }
- case kTfLiteUInt8: {
- int err;
- const uint8_t* inp_data_ptr;
- uint8_t* out_data_ptr;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
- const uint8_t zero = input->params.zero_point;
-
- inp_data_ptr = GetTensorData<uint8_t>(input);
- out_data_ptr = GetTensorData<uint8_t>(output);
-
- err = xa_nn_vec_activation_min_max_asym8_asym8(
- out_data_ptr, inp_data_ptr, zero, std::numeric_limits<uint8_t>::max(),
- flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "xa_nn_vec_activation_min_max_asym8_asym8 failed");
- return kTfLiteOk;
- }
- default: {
- TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
- TfLiteTypeGetName(input->type));
- return kTfLiteError;
- }
- }
-}
-
-TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
- return kTfLiteOk;
-}
-
-TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- switch (input->type) {
- case kTfLiteFloat32: {
-#if HIFI_VFPU
- int err;
- const float* inp_data_ptr;
- float* out_data_ptr;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
-
- inp_data_ptr = GetTensorData<float>(input);
- out_data_ptr = GetTensorData<float>(output);
-
- err = xa_nn_vec_relu6_f32_f32(out_data_ptr, inp_data_ptr, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_relu6_f32_f32 failed");
-#else
- Relu6Float(GetTensorShape(input), GetTensorData<float>(input),
- GetTensorShape(output), GetTensorData<float>(output));
-#endif /* HIFI_VFPU */
- return kTfLiteOk;
- }
- case kTfLiteInt8: {
- const int8_t six = FloatToAsymmetricQuantizedInt8(
- 6.0f, input->params.scale, input->params.zero_point);
- const int8_t zero = input->params.zero_point;
- Relu6Quantized<int8_t>(
- zero, six, GetTensorShape(input), GetTensorData<int8_t>(input),
- GetTensorShape(output), GetTensorData<int8_t>(output));
- return kTfLiteOk;
- }
- case kTfLiteUInt8: {
- const uint8_t six = FloatToAsymmetricQuantizedUInt8(
- 6.0f, input->params.scale, input->params.zero_point);
- const uint8_t zero = input->params.zero_point;
- int err;
- const uint8_t* inp_data_ptr;
- uint8_t* out_data_ptr;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
-
- inp_data_ptr = GetTensorData<uint8_t>(input);
- out_data_ptr = GetTensorData<uint8_t>(output);
-
- err = xa_nn_vec_activation_min_max_asym8_asym8(out_data_ptr, inp_data_ptr,
- zero, six, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "xa_nn_vec_activation_min_max_asym8_asym8 failed");
- return kTfLiteOk;
- }
- default: {
- TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
- TfLiteTypeGetName(input->type));
- return kTfLiteError;
- }
- }
-}
-
-} // namespace activations
-
-TfLiteRegistration Register_RELU() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/activations::ReluPrepare,
- /*invoke=*/activations::ReluEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-TfLiteRegistration Register_RELU6() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/activations::Relu6Prepare,
- /*invoke=*/activations::Relu6Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/add.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/add.cc
deleted file mode 100644
index 90590ab..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/add.cc
+++ /dev/null
@@ -1,273 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/add.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
-#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-#include "tensorflow/lite/micro/memory_helpers.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace add {
-
-constexpr int kInputTensor1 = 0;
-constexpr int kInputTensor2 = 1;
-constexpr int kOutputTensor = 0;
-
-struct OpData {
- bool requires_broadcast;
-
- // These fields are used in both the general 8-bit -> 8bit quantized path,
- // and the special 16-bit -> 16bit quantized path
- int input1_shift;
- int input2_shift;
- int32_t output_activation_min;
- int32_t output_activation_max;
-
- // These fields are used only in the general 8-bit -> 8bit quantized path
- int32_t input1_multiplier;
- int32_t input2_multiplier;
- int32_t output_multiplier;
- int output_shift;
- int left_shift;
- int32_t input1_offset;
- int32_t input2_offset;
- int32_t output_offset;
-};
-
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteAddParams* params,
- const TfLiteTensor* input1,
- const TfLiteTensor* input2, TfLiteTensor* output,
- OpData* data) {
- data->requires_broadcast = !HaveSameShapes(input1, input2);
-
- if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
- // 8bit -> 8bit general quantized path, with general rescalings
- data->input1_offset = -input1->params.zero_point;
- data->input2_offset = -input2->params.zero_point;
- data->output_offset = output->params.zero_point;
- data->left_shift = 20;
- const double twice_max_input_scale =
- 2 * static_cast<double>(
- std::max(input1->params.scale, input2->params.scale));
- const double real_input1_multiplier =
- static_cast<double>(input1->params.scale) / twice_max_input_scale;
- const double real_input2_multiplier =
- static_cast<double>(input2->params.scale) / twice_max_input_scale;
- const double real_output_multiplier =
- twice_max_input_scale /
- ((1 << data->left_shift) * static_cast<double>(output->params.scale));
-
- QuantizeMultiplierSmallerThanOneExp(
- real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
-
- QuantizeMultiplierSmallerThanOneExp(
- real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
-
- QuantizeMultiplierSmallerThanOneExp(
- real_output_multiplier, &data->output_multiplier, &data->output_shift);
-
- TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
- context, params->activation, output, &data->output_activation_min,
- &data->output_activation_max));
- }
-
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalAdd(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-#define TF_LITE_ADD(opname) \
- reference_ops::opname(op_params, GetTensorShape(input1), \
- GetTensorData<float>(input1), GetTensorShape(input2), \
- GetTensorData<float>(input2), GetTensorShape(output), \
- GetTensorData<float>(output))
- if (data->requires_broadcast) {
- TF_LITE_ADD(BroadcastAdd4DSlow);
- } else {
-#if HIFI_VFPU
- int err;
- const RuntimeShape& input1_shape = GetTensorShape(input1);
- const RuntimeShape& input2_shape = GetTensorShape(input2);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size =
- MatchingElementsSize(input1_shape, input2_shape, output_shape);
-
- err = xa_nn_elm_add_f32xf32_f32(GetTensorData<float>(output),
- GetTensorData<float>(input1),
- GetTensorData<float>(input2), flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_elm_add_f32xf32_f32 failed");
-
- err = xa_nn_vec_activation_min_max_f32_f32(
- GetTensorData<float>(output), GetTensorData<float>(output),
- output_activation_min, output_activation_max, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err,
- "xa_nn_vec_activation_min_max_f32_f32 failed");
-#else
- TF_LITE_ADD(Add);
-#endif /* HIFI_VFPU */
- }
-#undef TF_LITE_ADD
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, const OpData* data,
- const TfLiteTensor* input1,
- const TfLiteTensor* input2,
- TfLiteTensor* output) {
- if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
- tflite::ArithmeticParams op_params;
- op_params.left_shift = data->left_shift;
- op_params.input1_offset = data->input1_offset;
- op_params.input1_multiplier = data->input1_multiplier;
- op_params.input1_shift = data->input1_shift;
- op_params.input2_offset = data->input2_offset;
- op_params.input2_multiplier = data->input2_multiplier;
- op_params.input2_shift = data->input2_shift;
- op_params.output_offset = data->output_offset;
- op_params.output_multiplier = data->output_multiplier;
- op_params.output_shift = data->output_shift;
- SetActivationParams(data->output_activation_min,
- data->output_activation_max, &op_params);
- bool need_broadcast = reference_ops::ProcessBroadcastShapes(
- GetTensorShape(input1), GetTensorShape(input2), &op_params);
-#define TF_LITE_ADD(type, opname, dtype) \
- type::opname(op_params, GetTensorShape(input1), \
- GetTensorData<dtype>(input1), GetTensorShape(input2), \
- GetTensorData<dtype>(input2), GetTensorShape(output), \
- GetTensorData<dtype>(output));
- if (output->type == kTfLiteInt8) {
- if (need_broadcast) {
- TF_LITE_ADD(reference_integer_ops, BroadcastAdd4DSlow, int8_t);
- } else {
- TF_LITE_ADD(reference_integer_ops, Add, int8_t);
- }
- } else {
- if (need_broadcast) {
- TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, uint8_t);
- } else {
- int err;
- const RuntimeShape& input1_shape = GetTensorShape(input1);
- const RuntimeShape& input2_shape = GetTensorShape(input2);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size =
- MatchingElementsSize(input1_shape, input2_shape, output_shape);
-
- err = xa_nn_elm_add_asym8xasym8_asym8(
- GetTensorData<uint8_t>(output), op_params.output_offset,
- op_params.output_shift, op_params.output_multiplier,
- op_params.quantized_activation_min,
- op_params.quantized_activation_max, GetTensorData<uint8_t>(input1),
- op_params.input1_offset, op_params.input1_shift,
- op_params.input1_multiplier, GetTensorData<uint8_t>(input2),
- op_params.input2_offset, op_params.input2_shift,
- op_params.input2_multiplier, op_params.left_shift, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_elm_add_asym8xasym8_asym8 failed");
- }
- }
-#undef TF_LITE_ADD
- }
-
- return kTfLiteOk;
-}
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
- void* data = nullptr;
- if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
- kTfLiteError) {
- return nullptr;
- }
- return data;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->user_data != nullptr);
- TFLITE_DCHECK(node->builtin_data != nullptr);
-
- const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- OpData* data = static_cast<OpData*>(node->user_data);
- auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
-
- TF_LITE_ENSURE_STATUS(
- CalculateOpData(context, params, input1, input2, output, data));
-
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
-
- TFLITE_DCHECK(node->user_data != nullptr);
- const OpData* data = static_cast<const OpData*>(node->user_data);
-
- const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- if (output->type == kTfLiteFloat32) {
- TF_LITE_ENSURE_OK(
- context, EvalAdd(context, node, params, data, input1, input2, output));
- } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
- TF_LITE_ENSURE_OK(context, EvalAddQuantized(context, node, params, data,
- input1, input2, output));
- } else {
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(output->type), output->type);
- return kTfLiteError;
- }
-
- return kTfLiteOk;
-}
-
-} // namespace add
-
-TfLiteRegistration Register_ADD() {
- return {/*init=*/add::Init,
- /*free=*/nullptr,
- /*prepare=*/add::Prepare,
- /*invoke=*/add::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc
deleted file mode 100755
index 68fe4f5..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc
+++ /dev/null
@@ -1,536 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/conv.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/padding.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace conv {
-
-constexpr int kInputTensor = 0;
-constexpr int kFilterTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-// Conv is quantized along dimension 0:
-// https://www.tensorflow.org/lite/performance/quantization_spec
-constexpr int kConvQuantizedDimension = 0;
-
-// This file has 2 implementation of Conv.
-
-struct OpData {
- TfLitePaddingValues padding;
- // The scaling factor from input to output (aka the 'real multiplier') can
- // be represented as a fixed point multiplier plus a left shift.
- int32_t output_multiplier;
- int output_shift;
-
- // Per channel output multiplier and shift.
- int32_t* per_channel_output_multiplier;
- int32_t* per_channel_output_shift;
-
- // The range of the fused activation layer. For example for kNone and
- // uint8_t these would be 0 and 255.
- int32_t output_activation_min;
- int32_t output_activation_max;
-};
-
-inline PaddingType RuntimePaddingType(TfLitePadding padding) {
- switch (padding) {
- case TfLitePadding::kTfLitePaddingSame:
- return PaddingType::kSame;
- case TfLitePadding::kTfLitePaddingValid:
- return PaddingType::kValid;
- case TfLitePadding::kTfLitePaddingUnknown:
- default:
- return PaddingType::kNone;
- }
-}
-
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
- const TfLiteConvParams* params, int width,
- int height, int filter_width, int filter_height,
- int out_width, int out_height,
- const TfLiteType data_type, OpData* data) {
- bool has_bias = node->inputs->size == 3;
- // Check number of inputs/outputs
- TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-
- // Matching GetWindowedOutputSize in TensorFlow.
- auto padding = params->padding;
- data->padding = ComputePaddingHeightWidth(
- params->stride_height, params->stride_width,
- params->dilation_height_factor, params->dilation_width_factor, height,
- width, filter_height, filter_width, padding, &out_height, &out_width);
-
- // Note that quantized inference requires that all tensors have their
- // parameters set. This is usually done during quantized training.
- if (data_type != kTfLiteFloat32) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- const TfLiteTensor* bias =
- GetOptionalInputTensor(context, node, kBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- int output_channels = filter->dims->data[kConvQuantizedDimension];
-
- TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
- context, input, filter, bias, output, params->activation,
- &data->output_multiplier, &data->output_shift,
- &data->output_activation_min, &data->output_activation_max,
- data->per_channel_output_multiplier,
- reinterpret_cast<int*>(data->per_channel_output_shift),
- output_channels));
- }
- return kTfLiteOk;
-}
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
- void* data = nullptr;
- if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
- kTfLiteError) {
- return nullptr;
- }
- return data;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->user_data != nullptr);
- TFLITE_DCHECK(node->builtin_data != nullptr);
-
- OpData* data = static_cast<OpData*>(node->user_data);
- const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
-
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
-
- int input_width = input->dims->data[2];
- int input_height = input->dims->data[1];
- int filter_width = filter->dims->data[2];
- int filter_height = filter->dims->data[1];
- int output_width = output->dims->data[2];
- int output_height = output->dims->data[1];
-
- // Dynamically allocate per-channel quantization parameters.
- const int num_channels = filter->dims->data[kConvQuantizedDimension];
- TF_LITE_ENSURE_STATUS(context->AllocatePersistentBuffer(
- context, num_channels * sizeof(int32_t),
- reinterpret_cast<void**>(&data->per_channel_output_multiplier)));
- TF_LITE_ENSURE_STATUS(context->AllocatePersistentBuffer(
- context, num_channels * sizeof(int32_t),
- reinterpret_cast<void**>(&data->per_channel_output_shift)));
-
- // All per-channel quantized tensors need valid zero point and scale arrays.
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE_EQ(context, filter->quantization.type,
- kTfLiteAffineQuantization);
-
- const auto* affine_quantization =
- static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
- TF_LITE_ENSURE(context, affine_quantization);
- TF_LITE_ENSURE(context, affine_quantization->scale);
- TF_LITE_ENSURE(context, affine_quantization->zero_point);
-
- TF_LITE_ENSURE(context,
- affine_quantization->scale->size == 1 ||
- affine_quantization->scale->size ==
- filter->dims->data[kConvQuantizedDimension]);
- TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
- affine_quantization->zero_point->size);
- }
-
- return CalculateOpData(context, node, params, input_width, input_height,
- filter_width, filter_height, output_width,
- output_height, input->type, data);
-} // namespace conv
-
-TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, const OpData& data,
- const TfLiteTensor* input,
- const TfLiteTensor* filter, const TfLiteTensor* bias,
- TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
- TfLiteTensor* output) {
- const int32_t input_offset = -input->params.zero_point;
- const int32_t filter_offset = -filter->params.zero_point;
- const int32_t output_offset = output->params.zero_point;
-
- if ((params->dilation_width_factor == 1) &&
- (params->dilation_height_factor == 1)) {
- const uint8_t *input_data, *filter_data;
- const int32_t* bias_data;
- uint8_t* output_data;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& filter_shape = GetTensorShape(filter);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const RuntimeShape& bias_shape = GetTensorShape(bias);
-
- input_data = GetTensorData<uint8_t>(input);
- filter_data = GetTensorData<uint8_t>(filter);
- bias_data = GetTensorData<int32_t>(bias);
- output_data = GetTensorData<uint8_t>(output);
-
- const int stride_width = params->stride_width;
- const int stride_height = params->stride_height;
- const int pad_width = data.padding.width;
- const int pad_height = data.padding.height;
- const int32_t output_activation_min = data.output_activation_min;
- const int32_t output_activation_max = data.output_activation_max;
- const int32_t output_multiplier = data.output_multiplier;
- const int output_shift = -data.output_shift;
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
-
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
- const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
- if (bias_data) {
- TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
- }
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int filter_height = filter_shape.Dims(1);
- const int filter_width = filter_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
- const int filter_depth = filter_shape.Dims(3);
-
- int err, output_data_format = 0;
- uint8_t* p_scratch;
- uint8_t* p_filter;
- // Calculate filter_depth_padded as next near multiple of 4
- int filter_depth_padded = (filter_depth + 3) & (~3);
- int out_length = output_height * output_width * output_depth;
- int filter_size_padded = filter_height * filter_width * filter_depth_padded;
- int required_scratch, input_precision = PREC_ASYM8;
- int h, c;
-
- required_scratch = xa_nn_conv2d_std_getsize(
- input_height, input_depth, filter_height, filter_width, stride_height,
- pad_height, output_height, input_precision);
-
- if (required_scratch <= 0) {
- TF_LITE_KERNEL_LOG(context,
- "conv2d_std_asym8: xa_nn_conv2d_std_getsize failed");
- return kTfLiteError;
- }
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- p_scratch = xtensa_nnlib_scratch_buf;
-
- p_filter = p_scratch;
- required_scratch +=
- ALIGNED_SIZE((sizeof(uint8_t) * filter_size_padded * output_depth), 8);
- p_scratch +=
- ALIGNED_SIZE(sizeof(uint8_t) * filter_size_padded * output_depth, 8);
-
- if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context,
- "conv2d_std_asym8: insufficient scratch memory");
- return kTfLiteError;
- }
-
- // Padding filter coefficients depthwise
- for (h = 0; h < filter_height * filter_width * output_depth; h++) {
- for (c = 0; c < filter_depth; c++) {
- p_filter[h * filter_depth_padded + c] =
- filter_data[h * filter_depth + c];
- }
- for (c = input_depth; c < filter_depth_padded; c++) {
- p_filter[h * filter_depth_padded + c] =
- -filter_offset; // filter_depth[h*input_depth + c];
- }
- }
-
- for (int batch = 0; batch < batches; ++batch) {
- uint8_t* p_out_temp;
- p_out_temp = &output_data[batch * out_length];
-
- err = xa_nn_conv2d_std_asym8xasym8(
- p_out_temp,
- &input_data[batch * input_height * input_width * input_depth],
- p_filter, // filter_data,
- bias_data, input_height, input_width, input_depth, filter_height,
- filter_width, output_depth, stride_width, stride_height, pad_width,
- pad_height, output_height, output_width, input_offset, filter_offset,
- output_multiplier, output_shift, output_offset, output_data_format,
- static_cast<void*>(p_scratch));
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "conv2d_std_asym8: xa_nn_conv2d_std_asym8xasym8 failed");
-
- err = xa_nn_vec_activation_min_max_asym8_asym8(
- p_out_temp, p_out_temp, output_activation_min, output_activation_max,
- out_length);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "xa_nn_vec_activation_min_max_asym8_asym8 failed");
- }
- } else {
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.padding_type = RuntimePaddingType(params->padding);
- op_params.padding_values.width = data.padding.width;
- op_params.padding_values.height = data.padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = data.output_multiplier;
- op_params.output_shift = -data.output_shift;
- op_params.quantized_activation_min = data.output_activation_min;
- op_params.quantized_activation_max = data.output_activation_max;
- reference_ops::Conv(op_params, GetTensorShape(input),
- GetTensorData<uint8_t>(input), GetTensorShape(filter),
- GetTensorData<uint8_t>(filter), GetTensorShape(bias),
- GetTensorData<int32_t>(bias), GetTensorShape(output),
- GetTensorData<uint8_t>(output), GetTensorShape(im2col),
- GetTensorData<uint8_t>(im2col), nullptr);
- }
- return kTfLiteOk;
-}
-
-void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, const OpData& data,
- const TfLiteTensor* input,
- const TfLiteTensor* filter,
- const TfLiteTensor* bias, TfLiteTensor* output,
- TfLiteTensor* im2col) {
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.input_offset = -input->params.zero_point;
- op_params.output_offset = output->params.zero_point;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.padding_values.height = data.padding.height;
- op_params.padding_values.width = data.padding.width;
- op_params.quantized_activation_min = data.output_activation_min;
- op_params.quantized_activation_max = data.output_activation_max;
-
- reference_integer_ops::ConvPerChannel(
- op_params, data.per_channel_output_multiplier,
- data.per_channel_output_shift, GetTensorShape(input),
- GetTensorData<int8_t>(input), GetTensorShape(filter),
- GetTensorData<int8_t>(filter), GetTensorShape(bias),
- GetTensorData<int32_t>(bias), GetTensorShape(output),
- GetTensorData<int8_t>(output));
-}
-
-TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteConvParams* params, const OpData& data,
- const TfLiteTensor* input, const TfLiteTensor* filter,
- const TfLiteTensor* bias, TfLiteTensor* im2col,
- TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
-
-#if HIFI_VFPU
- if ((params->dilation_width_factor == 1) &&
- (params->dilation_height_factor == 1)) {
- const float *input_data, *filter_data;
- const float* bias_data;
- float* output_data;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& filter_shape = GetTensorShape(filter);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const RuntimeShape& bias_shape = GetTensorShape(bias);
-
- input_data = GetTensorData<float>(input);
- filter_data = GetTensorData<float>(filter);
- bias_data = GetTensorData<float>(bias);
- output_data = GetTensorData<float>(output);
-
- const int stride_width = params->stride_width;
- const int stride_height = params->stride_height;
- const int pad_width = data.padding.width;
- const int pad_height = data.padding.height;
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
-
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
- const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
- if (bias_data) {
- TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
- }
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int filter_height = filter_shape.Dims(1);
- const int filter_width = filter_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
- const int filter_depth = filter_shape.Dims(3);
- int err, output_data_format = 0;
- uint8_t* p_scratch;
- float* p_filter;
- // Calculate filter_depth_padded as next near multiple of 2
- int filter_depth_padded = (filter_depth + 1) & (~1);
- int out_length = output_height * output_width * output_depth;
- int filter_size_padded = filter_height * filter_width * filter_depth_padded;
- int required_scratch, input_precision = PREC_F32;
- int h, c;
-
- required_scratch = xa_nn_conv2d_std_getsize(
- input_height, input_depth, filter_height, filter_width, stride_height,
- pad_height, output_height, input_precision);
-
- if (required_scratch <= 0) {
- TF_LITE_KERNEL_LOG(context,
- "conv2d_std_f32: xa_nn_conv2d_std_getsize failed");
- return kTfLiteError;
- }
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- p_scratch = xtensa_nnlib_scratch_buf;
-
- p_filter = reinterpret_cast<float*>(p_scratch);
- p_scratch +=
- ALIGNED_SIZE((sizeof(float) * filter_size_padded * output_depth), 8);
- required_scratch +=
- ALIGNED_SIZE((sizeof(float) * filter_size_padded * output_depth), 8);
-
- if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context,
- "conv2d_std_f32: insufficient scratch memory");
- return kTfLiteError;
- }
-
- // Padding filter coefficients depthwise
- for (h = 0; h < filter_height * filter_width * output_depth; h++) {
- for (c = 0; c < filter_depth; c++) {
- p_filter[h * filter_depth_padded + c] =
- filter_data[h * filter_depth + c];
- }
- for (c = input_depth; c < filter_depth_padded; c++) {
- p_filter[h * filter_depth_padded + c] = 0;
- }
- }
-
- for (int batch = 0; batch < batches; ++batch) {
- float* p_out_temp;
- p_out_temp = &output_data[batch * out_length];
-
- err = xa_nn_conv2d_std_f32(
- p_out_temp,
- &input_data[batch * input_height * input_width * input_depth],
- p_filter, bias_data, input_height, input_width, input_depth,
- filter_height, filter_width, output_depth, stride_width,
- stride_height, pad_width, pad_height, output_height, output_width,
- output_data_format, static_cast<void*>(p_scratch));
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "conv2d_std_f32: xa_nn_conv2d_std_f32xf32 failed");
-
- err = xa_nn_vec_activation_min_max_f32_f32(
- p_out_temp, p_out_temp, output_activation_min, output_activation_max,
- out_length);
-
- CHECK_ERR_HIFI_NNLIB_KER(err,
- "xa_nn_vec_activation_min_max_f32_f32 failed");
- }
- } else
-#endif /* HIFI_VFPU */
- {
- // TODO(b/154032858): Investigate removing extra copies.
- ConvParams op_params;
- op_params.padding_type = RuntimePaddingType(params->padding);
- op_params.padding_values.width = data.padding.width;
- op_params.padding_values.height = data.padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- reference_ops::Conv(op_params, GetTensorShape(input),
- GetTensorData<float>(input), GetTensorShape(filter),
- GetTensorData<float>(filter), GetTensorShape(bias),
- GetTensorData<float>(bias), GetTensorShape(output),
- GetTensorData<float>(output), GetTensorShape(im2col),
- GetTensorData<float>(im2col));
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
-
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
-
- TFLITE_DCHECK(node->user_data != nullptr);
- const OpData& data = *(static_cast<const OpData*>(node->user_data));
-
- switch (input->type) { // Already know in/out types are same.
- case kTfLiteFloat32:
- EvalFloat(context, node, params, data, input, filter, bias, nullptr,
- nullptr, output);
- break;
- case kTfLiteInt8:
- EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
- output, nullptr);
- break;
- case kTfLiteUInt8:
- EvalQuantized(context, node, params, data, input, filter, bias, nullptr,
- nullptr, output);
- break;
- default:
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(input->type), input->type);
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-} // namespace conv
-
-TfLiteRegistration Register_CONV_2D() {
- return {/*init=*/conv::Init,
- /*free=*/nullptr,
- /*prepare=*/conv::Prepare,
- /*invoke=*/conv::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc
deleted file mode 100755
index dbebfc9..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc
+++ /dev/null
@@ -1,521 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
-#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/padding.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace depthwise_conv {
-namespace {
-
-constexpr int kInputTensor = 0;
-constexpr int kFilterTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-// Depthwise conv is quantized along dimension 3:
-// https://www.tensorflow.org/lite/performance/quantization_spec
-constexpr int kDepthwiseConvQuantizedDimension = 3;
-
-struct OpData {
- TfLitePaddingValues padding;
- // The scaling factor from input to output (aka the 'real multiplier') can
- // be represented as a fixed point multiplier plus a left shift.
- int32_t output_multiplier;
- int output_shift;
-
- // Per channel output multiplier and shift.
- int32_t* per_channel_output_multiplier;
- int32_t* per_channel_output_shift;
- // The range of the fused activation layer. For example for kNone and
- // uint8_t these would be 0 and 255.
- int32_t output_activation_min;
- int32_t output_activation_max;
-};
-
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
- TfLiteDepthwiseConvParams* params, int width,
- int height, int filter_width, int filter_height,
- const TfLiteType data_type, OpData* data) {
- bool has_bias = node->inputs->size == 3;
- // Check number of inputs/outputs
- TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-
- int unused_output_height, unused_output_width;
- data->padding = ComputePaddingHeightWidth(
- params->stride_height, params->stride_width, 1, 1, height, width,
- filter_height, filter_width, params->padding, &unused_output_height,
- &unused_output_width);
-
- // Note that quantized inference requires that all tensors have their
- // parameters set. This is usually done during quantized training.
- if (data_type != kTfLiteFloat32) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- const TfLiteTensor* bias =
- GetOptionalInputTensor(context, node, kBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
-
- return tflite::PopulateConvolutionQuantizationParams(
- context, input, filter, bias, output, params->activation,
- &data->output_multiplier, &data->output_shift,
- &data->output_activation_min, &data->output_activation_max,
- data->per_channel_output_multiplier,
- reinterpret_cast<int*>(data->per_channel_output_shift), num_channels);
- }
- return kTfLiteOk;
-}
-
-} // namespace
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
- void* data = nullptr;
- if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
- kTfLiteError) {
- return nullptr;
- }
- return data;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->user_data != nullptr);
- TFLITE_DCHECK(node->builtin_data != nullptr);
-
- auto* params =
- reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
- OpData* data = static_cast<OpData*>(node->user_data);
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
-
- const TfLiteType data_type = input->type;
- int width = SizeOfDimension(input, 2);
- int height = SizeOfDimension(input, 1);
- int filter_width = SizeOfDimension(filter, 2);
- int filter_height = SizeOfDimension(filter, 1);
-
- // Per channel quantization is only needed for int8_t inference. For other
- // quantized types, only a single scale and zero point is needed.
- const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
- // Dynamically allocate per-channel quantization parameters.
- TF_LITE_ENSURE_STATUS(context->AllocatePersistentBuffer(
- context, num_channels * sizeof(int32_t),
- reinterpret_cast<void**>(&data->per_channel_output_multiplier)));
- TF_LITE_ENSURE_STATUS(context->AllocatePersistentBuffer(
- context, num_channels * sizeof(int32_t),
- reinterpret_cast<void**>(&data->per_channel_output_shift)));
-
- // All per-channel quantized tensors need valid zero point and scale arrays.
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE_EQ(context, filter->quantization.type,
- kTfLiteAffineQuantization);
-
- const auto* affine_quantization =
- reinterpret_cast<TfLiteAffineQuantization*>(
- filter->quantization.params);
- TF_LITE_ENSURE(context, affine_quantization);
- TF_LITE_ENSURE(context, affine_quantization->scale);
- TF_LITE_ENSURE(context, affine_quantization->zero_point);
- TF_LITE_ENSURE(
- context, affine_quantization->scale->size == 1 ||
- affine_quantization->scale->size ==
- filter->dims->data[kDepthwiseConvQuantizedDimension]);
- TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
- affine_quantization->zero_point->size);
- }
-
- return CalculateOpData(context, node, params, width, height, filter_width,
- filter_height, data_type, data);
-}
-
-TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteDepthwiseConvParams* params, const OpData* data,
- const TfLiteTensor* input, const TfLiteTensor* filter,
- const TfLiteTensor* bias, TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
-
-#if HIFI_VFPU
- if ((params->dilation_width_factor == 1) &&
- (params->dilation_height_factor == 1)) {
- const float *input_data, *filter_data, *bias_data;
- float* output_data;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& filter_shape = GetTensorShape(filter);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const RuntimeShape& bias_shape = GetTensorShape(bias);
-
- input_data = GetTensorData<float>(input);
- filter_data = GetTensorData<float>(filter);
- bias_data = GetTensorData<float>(bias);
- output_data = GetTensorData<float>(output);
-
- const int stride_width = params->stride_width;
- const int stride_height = params->stride_height;
- const int pad_width = data->padding.width;
- const int pad_height = data->padding.height;
- const int depth_multiplier = params->depth_multiplier;
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
-
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int input_depth = input_shape.Dims(3);
- const int filter_height = filter_shape.Dims(1);
- const int filter_width = filter_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
- const int filter_depth = filter_shape.Dims(3);
- TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
- TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
-
- int32_t err, input_data_format = 0, output_data_format = 0;
- uint8_t* p_scratch;
- float* p_filter;
- int filter_depth_padded, filter_size_padded, required_scratch;
- int input_precision = PREC_F32;
- int h, c, i;
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- p_scratch = xtensa_nnlib_scratch_buf;
-
- filter_depth_padded = (filter_depth + 1) & (~1);
- filter_size_padded = filter_height * filter_width * filter_depth_padded;
-
- required_scratch = xa_nn_conv2d_depthwise_getsize(
- input_height, input_width, input_depth, filter_height, filter_width,
- depth_multiplier, stride_width, stride_height, pad_width, pad_height,
- output_height, output_width, input_precision, input_data_format);
-
- if (required_scratch <= 0) {
- TF_LITE_KERNEL_LOG(
- context, "DepthwiseConvFloat: xa_nn_conv2d_depthwise_getsize failed");
- return kTfLiteError;
- }
-
- required_scratch += ALIGNED_SIZE(sizeof(float) * filter_size_padded, 8);
- if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context,
- "DepthwiseConvFloat: insufficient scratch memory");
- return kTfLiteError;
- }
-
- p_filter = reinterpret_cast<float*>(p_scratch);
- p_scratch += ALIGNED_SIZE(sizeof(float) * filter_size_padded, 8);
-
- for (h = 0; h < filter_height * filter_width; h++) {
- for (c = 0; c < filter_depth; c++) {
- p_filter[h * filter_depth_padded + c] =
- filter_data[h * filter_depth + c];
- }
- for (c = filter_depth; c < filter_depth_padded; c++) {
- p_filter[h * filter_depth_padded + c] = 0;
- }
- }
-
- for (i = 0; i < batches; i++) {
- err = xa_nn_conv2d_depthwise_f32(
- &output_data[i * output_height * output_width * output_depth],
- p_filter, // filter_data,
- &input_data[i * input_height * input_width * input_depth], bias_data,
- input_height, input_width, input_depth, filter_height, filter_width,
- depth_multiplier, stride_width, stride_height, pad_width, pad_height,
- output_height, output_width, input_data_format, output_data_format,
- static_cast<void*>(p_scratch));
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "DepthwiseConvFloat: xa_nn_conv2d_depthwise_f32 failed");
- }
-
- int out_length = batches * output_height * output_width * output_depth;
- err = xa_nn_vec_activation_min_max_f32_f32(
- output_data, output_data, output_activation_min, output_activation_max,
- out_length);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "DepthwiseConvFloat: xa_nn_vec_activation_min_max_f32_f32 failed");
- } else
-#endif /* HIFI_VFPU */
- {
- tflite::DepthwiseParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = data->padding.width;
- op_params.padding_values.height = data->padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.depth_multiplier = params->depth_multiplier;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- tflite::reference_ops::DepthwiseConv(
- op_params, GetTensorShape(input), GetTensorData<float>(input),
- GetTensorShape(filter), GetTensorData<float>(filter),
- GetTensorShape(bias), GetTensorData<float>(bias),
- GetTensorShape(output), GetTensorData<float>(output));
- }
- return kTfLiteOk;
-}
-
-void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
- TfLiteDepthwiseConvParams* params,
- const OpData* data, const TfLiteTensor* input,
- const TfLiteTensor* filter,
- const TfLiteTensor* bias, TfLiteTensor* output) {
- DepthwiseParams op_params;
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = data->padding.width;
- op_params.padding_values.height = data->padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.depth_multiplier = params->depth_multiplier;
- op_params.input_offset = -input->params.zero_point;
- op_params.weights_offset = 0;
- op_params.output_offset = output->params.zero_point;
- // TODO(b/130439627): Use calculated value for clamping.
- op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
- op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
-
- reference_integer_ops::DepthwiseConvPerChannel(
- op_params, data->per_channel_output_multiplier,
- data->per_channel_output_shift, GetTensorShape(input),
- GetTensorData<int8_t>(input), GetTensorShape(filter),
- GetTensorData<int8_t>(filter), GetTensorShape(bias),
- GetTensorData<int32_t>(bias), GetTensorShape(output),
- GetTensorData<int8_t>(output));
-}
-
-TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteDepthwiseConvParams* params,
- const OpData* data, const TfLiteTensor* input,
- const TfLiteTensor* filter, const TfLiteTensor* bias,
- TfLiteTensor* output) {
- const int32_t input_offset = -input->params.zero_point;
- const int32_t filter_offset = -filter->params.zero_point;
- const int32_t output_offset = output->params.zero_point;
-
- if ((params->dilation_width_factor == 1) &&
- (params->dilation_height_factor == 1)) {
- const uint8_t *input_data, *filter_data;
- const int32_t* bias_data;
- uint8_t* output_data;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& filter_shape = GetTensorShape(filter);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const RuntimeShape& bias_shape = GetTensorShape(bias);
-
- input_data = GetTensorData<uint8_t>(input);
- filter_data = GetTensorData<uint8_t>(filter);
- bias_data = GetTensorData<int32_t>(bias);
- output_data = GetTensorData<uint8_t>(output);
-
- const int stride_width = params->stride_width;
- const int stride_height = params->stride_height;
- const int pad_width = data->padding.width;
- const int pad_height = data->padding.height;
- const int depth_multiplier = params->depth_multiplier;
- const int32_t output_activation_min = data->output_activation_min;
- const int32_t output_activation_max = data->output_activation_max;
- const int32_t output_multiplier = data->output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- const int output_shift = -data->output_shift;
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
-
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int input_depth = input_shape.Dims(3);
- const int filter_height = filter_shape.Dims(1);
- const int filter_width = filter_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
- const int filter_depth = filter_shape.Dims(3);
- TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
- TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
-
- int32_t err, i, input_data_format = 0, output_data_format = 0;
- uint8_t* p_scratch;
- uint8_t* p_filter;
- int filter_depth_padded, filter_size_padded, required_scratch;
- int input_precision = PREC_ASYM8;
- int h;
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- p_scratch = xtensa_nnlib_scratch_buf;
-
- required_scratch = xa_nn_conv2d_depthwise_getsize(
- input_height, input_width, input_depth, filter_height, filter_width,
- depth_multiplier, stride_width, stride_height, pad_width, pad_height,
- output_height, output_width, input_precision, input_data_format);
-
- if (required_scratch <= 0) {
- TF_LITE_KERNEL_LOG(
- context, "DepthwiseConvAsym8: xa_nn_conv2d_depthwise_getsize failed");
- return kTfLiteError;
- }
-
- filter_depth_padded = (filter_depth + 3) & (~3);
- filter_size_padded = filter_height * filter_width * filter_depth_padded;
- required_scratch += ALIGNED_SIZE(sizeof(uint8_t) * filter_size_padded, 8);
-
- if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context,
- "DepthwiseConvAsym8: insufficient scratch memory");
- return kTfLiteError;
- }
-
- p_filter = p_scratch;
- p_scratch += ALIGNED_SIZE(sizeof(uint8_t) * filter_size_padded, 8);
- int pad_value = filter_depth_padded - filter_depth;
-
- for (h = 0; h < filter_height * filter_width; h++) {
- memcpy(&p_filter[h * filter_depth_padded], &filter_data[h * filter_depth],
- filter_depth);
- memset(&p_filter[h * filter_depth_padded + filter_depth], -filter_offset,
- pad_value);
- }
-
- for (i = 0; i < batches; i++) {
- err = xa_nn_conv2d_depthwise_asym8xasym8(
- &output_data[i * output_height * output_width * output_depth],
- p_filter, // filter_data,
- &input_data[i * input_height * input_width * input_depth], bias_data,
- input_height, input_width, input_depth, filter_height, filter_width,
- depth_multiplier, stride_width, stride_height, pad_width, pad_height,
- output_height, output_width, input_offset, filter_offset,
- output_multiplier, output_shift, output_offset, input_data_format,
- output_data_format, static_cast<void*>(p_scratch));
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "DepthwiseConvAsym8: xa_nn_conv2d_depthwise_asym8xasym8 failed");
- }
-
- int out_length = batches * output_height * output_width * output_depth;
- err = xa_nn_vec_activation_min_max_asym8_asym8(
- output_data, output_data, output_activation_min, output_activation_max,
- out_length);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err,
- "DepthwiseConvAsym8: xa_nn_vec_activation_min_max_asym8_asym8 "
- "failed");
-
- } else {
- tflite::DepthwiseParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = data->padding.width;
- op_params.padding_values.height = data->padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.depth_multiplier = params->depth_multiplier;
- op_params.quantized_activation_min = data->output_activation_min;
- op_params.quantized_activation_max = data->output_activation_max;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = data->output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = -data->output_shift;
-
- tflite::reference_ops::DepthwiseConv(
- op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
- GetTensorShape(filter), GetTensorData<uint8_t>(filter),
- GetTensorShape(bias), GetTensorData<int32_t>(bias),
- GetTensorShape(output), GetTensorData<uint8_t>(output));
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->user_data != nullptr);
- TFLITE_DCHECK(node->builtin_data != nullptr);
-
- auto* params =
- reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
- const OpData& data = *(static_cast<const OpData*>(node->user_data));
-
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- const TfLiteTensor* bias =
- (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
-
- // TODO(aselle): Consider whether float conv and quantized conv should be
- // separate ops to avoid dispatch overhead here.
- switch (input->type) { // Already know in/out types are same.
- case kTfLiteFloat32:
- EvalFloat(context, node, params, &data, input, filter, bias, output);
- break;
- case kTfLiteInt8:
- EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias,
- output);
- break;
- case kTfLiteUInt8:
- EvalQuantized(context, node, params, &data, input, filter, bias, output);
- break;
- default:
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(input->type), input->type);
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-} // namespace depthwise_conv
-
-TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
- return {/*init=*/depthwise_conv::Init,
- /*free=*/nullptr,
- /*prepare=*/depthwise_conv::Prepare,
- /*invoke=*/depthwise_conv::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc
deleted file mode 100644
index 1f2b71e..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc
+++ /dev/null
@@ -1,70 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/floor.h"
-
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace floor {
-
-constexpr int kInputTensor = 0;
-constexpr int kOutputTensor = 0;
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#if HIFI_VFPU
- int err;
- const float* inp_data_ptr;
- float* out_data_ptr;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
-
- inp_data_ptr = GetTensorData<float>(input);
- out_data_ptr = GetTensorData<float>(output);
-
- err = xa_nn_elm_floor_f32_f32(out_data_ptr, inp_data_ptr, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_elm_floor_f32_f32 failed");
-#else
- reference_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
- GetTensorShape(output), GetTensorData<float>(output));
-#endif /* HIFI_VFPU */
- return kTfLiteOk;
-}
-} // namespace floor
-
-TfLiteRegistration Register_FLOOR() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/nullptr,
- /*invoke=*/floor::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc
deleted file mode 100644
index 3347af9..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc
+++ /dev/null
@@ -1,283 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace fully_connected {
-namespace {
-
-struct OpData {
- // The scaling factor from input to output (aka the 'real multiplier') can
- // be represented as a fixed point multiplier plus a left shift.
- int32_t output_multiplier;
- int output_shift;
- // The range of the fused activation layer. For example for kNone and
- // uint8_t these would be 0 and 255.
- int32_t output_activation_min;
- int32_t output_activation_max;
- // The index of the temporary tensor where the quantized inputs are cached.
- int input_quantized_index;
-};
-
-constexpr int kInputTensor = 0;
-constexpr int kWeightsTensor = 1;
-constexpr int kBiasTensor = 2;
-constexpr int kOutputTensor = 0;
-
-TfLiteStatus CalculateOpData(TfLiteContext* context,
- TfLiteFusedActivation activation,
- TfLiteType data_type, const TfLiteTensor* input,
- const TfLiteTensor* filter,
- const TfLiteTensor* bias, TfLiteTensor* output,
- OpData* data) {
- TfLiteStatus status = kTfLiteOk;
- if (data_type != kTfLiteFloat32) {
- double real_multiplier = 0.0;
- TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
- context, input, filter, bias, output, &real_multiplier));
- int exponent;
- QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
- data->output_shift = -exponent;
- TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
- context, activation, output, &data->output_activation_min,
- &data->output_activation_max));
- }
- return status;
-}
-
-} // namespace
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
- void* data = nullptr;
- if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
- kTfLiteError) {
- return nullptr;
- }
- return data;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->user_data != nullptr);
- TFLITE_DCHECK(node->builtin_data != nullptr);
-
- OpData* data = static_cast<OpData*>(node->user_data);
- const auto params =
- static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
- const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- TF_LITE_ENSURE_EQ(context, input->type, output->type);
- TF_LITE_ENSURE_MSG(context, input->type == filter->type,
- "Hybrid models are not supported on TFLite Micro.");
-
- return CalculateOpData(context, params->activation, input->type, input,
- filter, bias, output, data);
-}
-
-TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
- const OpData& data, const TfLiteTensor* input,
- const TfLiteTensor* filter,
- const TfLiteTensor* bias, TfLiteTensor* output) {
- tflite::FullyConnectedParams op_params;
- op_params.input_offset = -input->params.zero_point;
- op_params.weights_offset = -filter->params.zero_point;
- op_params.output_offset = output->params.zero_point;
- op_params.output_multiplier = data.output_multiplier;
- // TODO(b/138810107): Figure out whether output shift should be inverted
- op_params.output_shift = -data.output_shift;
- op_params.quantized_activation_min = data.output_activation_min;
- op_params.quantized_activation_max = data.output_activation_max;
-
- reference_integer_ops::FullyConnected(
- op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
- GetTensorShape(filter), GetTensorData<int8_t>(filter),
- GetTensorShape(bias), GetTensorData<int32_t>(bias),
- GetTensorShape(output), GetTensorData<int8_t>(output));
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- const OpData& data, const TfLiteTensor* input,
- const TfLiteTensor* filter, const TfLiteTensor* bias,
- TfLiteTensor* output) {
- const int32_t input_offset = -input->params.zero_point;
- const int32_t filter_offset = -filter->params.zero_point;
- const int32_t output_offset = output->params.zero_point;
-
- tflite::FullyConnectedParams op_params;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = data.output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = -data.output_shift;
- op_params.quantized_activation_min = data.output_activation_min;
- op_params.quantized_activation_max = data.output_activation_max;
-
-#define TF_LITE_FULLY_CONNECTED(output_data_type) \
- reference_ops::FullyConnected( \
- op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
- GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
- GetTensorShape(bias), GetTensorData<int32_t>(bias), \
- GetTensorShape(output), GetTensorData<output_data_type>(output))
- switch (output->type) {
- case kTfLiteUInt8: {
- int ret, b, weight_depth, out_depth, batches;
- uint8_t* p_out = GetTensorData<uint8_t>(output);
- weight_depth = GetTensorShape(filter).Dims(
- GetTensorShape(filter).DimensionsCount() - 1);
- out_depth = GetTensorShape(output).Dims(
- GetTensorShape(output).DimensionsCount() - 1);
- batches = FlatSizeSkipDim(GetTensorShape(output),
- GetTensorShape(output).DimensionsCount() - 1);
- for (b = 0; b < batches; b++) {
- ret = xa_nn_fully_connected_asym8xasym8_asym8(
- (GetTensorData<uint8_t>(output) + b * out_depth),
- GetTensorData<uint8_t>(filter),
- (GetTensorData<uint8_t>(input) + b * weight_depth),
- GetTensorData<int32_t>(bias), weight_depth, out_depth,
- op_params.input_offset, op_params.weights_offset,
- op_params.output_multiplier, op_params.output_shift,
- op_params.output_offset);
- CHECK_ERR_HIFI_NNLIB_KER(
- ret, "xa_nn_fully_connected_asym8xasym8_asym8 failed");
- }
- ret = xa_nn_vec_activation_min_max_asym8_asym8(
- p_out, p_out, data.output_activation_min, data.output_activation_max,
- batches * out_depth);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- ret, "xa_nn_vec_activation_min_max_asym8_asym8 failed");
- break;
- }
- case kTfLiteInt16:
- TF_LITE_FULLY_CONNECTED(int16_t);
- break;
- default:
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(output->type), output->type);
- return kTfLiteError;
- }
-
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteFusedActivation activation,
- const TfLiteTensor* input, const TfLiteTensor* filter,
- const TfLiteTensor* bias, TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(activation, &output_activation_min,
- &output_activation_max);
- tflite::FullyConnectedParams op_params;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-#if HIFI_VFPU
- int ret, b, weight_depth, out_depth, batches;
- weight_depth =
- GetTensorShape(filter).Dims(GetTensorShape(filter).DimensionsCount() - 1);
- out_depth =
- GetTensorShape(output).Dims(GetTensorShape(output).DimensionsCount() - 1);
- batches = FlatSizeSkipDim(GetTensorShape(output),
- GetTensorShape(output).DimensionsCount() - 1);
-
- for (b = 0; b < batches; b++) {
- ret = xa_nn_fully_connected_f32(
- (GetTensorData<float>(output) + b * out_depth),
- GetTensorData<float>(filter),
- (GetTensorData<float>(input) + b * weight_depth),
- GetTensorData<float>(bias), weight_depth, out_depth);
- CHECK_ERR_HIFI_NNLIB_KER(ret, "xa_nn_fully_connected_f32 failed.");
- }
- float* p_out = GetTensorData<float>(output);
- ret = xa_nn_vec_activation_min_max_f32_f32(
- p_out, p_out, output_activation_min, output_activation_max,
- batches * out_depth);
- CHECK_ERR_HIFI_NNLIB_KER(ret, "xa_nn_vec_activation_min_max_f32_f32 failed");
-#else
- tflite::reference_ops::FullyConnected(
- op_params, GetTensorShape(input), GetTensorData<float>(input),
- GetTensorShape(filter), GetTensorData<float>(filter),
- GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
- GetTensorData<float>(output));
-#endif /* HIFI_VFPU */
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->builtin_data != nullptr);
- const auto* params =
- static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
- const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- TFLITE_DCHECK(node->user_data != nullptr);
- const OpData& data = *(static_cast<const OpData*>(node->user_data));
-
- // Checks in Prepare ensure input, output and filter types are all the same.
- switch (input->type) {
- case kTfLiteFloat32:
- return EvalFloat(context, node, params->activation, input, filter, bias,
- output);
- case kTfLiteInt8:
- return EvalQuantizedInt8(context, node, data, input, filter, bias,
- output);
-
- case kTfLiteUInt8:
- return EvalQuantized(context, node, data, input, filter, bias, output);
-
- default:
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(input->type), input->type);
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-} // namespace fully_connected
-
-TfLiteRegistration Register_FULLY_CONNECTED() {
- return {/*init=*/fully_connected::Init,
- /*free=*/nullptr,
- /*prepare=*/fully_connected::Prepare,
- /*invoke=*/fully_connected::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc
deleted file mode 100644
index 3158a18..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc
+++ /dev/null
@@ -1,145 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/logistic.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace activations {
-namespace {
-constexpr int kInputTensor = 0;
-constexpr int kOutputTensor = 0;
-
-struct OpData {
- int32_t input_zero_point;
- int32_t input_range_radius;
- int32_t input_multiplier;
- int input_left_shift;
-};
-
-TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
- OpData* data) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- TF_LITE_ENSURE_EQ(context, input->type, output->type);
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE_EQ(context, output->params.zero_point,
- std::numeric_limits<int8_t>::min());
-
- static constexpr int kInputIntegerBits = 4;
- const double input_real_multiplier =
- static_cast<double>(input->params.scale) *
- static_cast<double>(1 << (31 - kInputIntegerBits));
-
- const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
- data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
-
- data->input_range_radius =
- CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
- }
- return kTfLiteOk;
-}
-} // namespace
-
-TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- OpData data;
- CalculateArithmeticOpData(context, node, &data);
-
- if (input->type == kTfLiteFloat32) {
- switch (output->type) {
- case kTfLiteFloat32: {
-#if HIFI_VFPU
- int err;
- const float* inp_data_ptr;
- float* out_data_ptr;
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size = MatchingFlatSize(input_shape, output_shape);
-
- inp_data_ptr = GetTensorData<float>(input);
- out_data_ptr = GetTensorData<float>(output);
-
- err = xa_nn_vec_sigmoid_f32_f32(out_data_ptr, inp_data_ptr, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_sigmoid_f32_f32 failed");
-#else
- reference_ops::Logistic(
- GetTensorShape(input), GetTensorData<float>(input),
- GetTensorShape(output), GetTensorData<float>(output));
-#endif /* HIFI_VFPU */
- return kTfLiteOk;
- }
- default:
- TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
- TfLiteTypeGetName(input->type),
- TfLiteTypeGetName(output->type));
- return kTfLiteError;
- }
- } else if (input->type == kTfLiteInt8) {
- switch (output->type) {
- case kTfLiteInt8: {
- reference_integer_ops::Logistic(
- input->params.zero_point, data.input_range_radius,
- data.input_multiplier, data.input_left_shift,
- NumElements(input->dims), GetTensorData<int8_t>(input),
- GetTensorData<int8_t>(output));
- return kTfLiteOk;
- }
- default:
- TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
- TfLiteTypeGetName(input->type),
- TfLiteTypeGetName(output->type));
- return kTfLiteError;
- }
- } else {
- // TODO(b/141211002): Also support other data types once we have supported
- // temporary tensors in TFLM.
- TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
- TfLiteTypeGetName(input->type),
- TfLiteTypeGetName(output->type));
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-} // namespace activations
-
-TfLiteRegistration Register_LOGISTIC() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/nullptr,
- /*invoke=*/activations::LogisticEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/mul.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/mul.cc
deleted file mode 100644
index b4cf2ce..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/mul.cc
+++ /dev/null
@@ -1,229 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/mul.h"
-
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
-#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-#include "tensorflow/lite/micro/memory_helpers.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace mul {
-
-constexpr int kInput1Tensor = 0;
-constexpr int kInput2Tensor = 1;
-constexpr int kOutputTensor = 0;
-
-struct OpData {
- int32_t output_activation_min;
- int32_t output_activation_max;
-
- int32_t output_multiplier;
- int output_shift;
-};
-
-TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, OpData* data) {
- const TfLiteTensor* input1 = GetInput(context, node, kInput1Tensor);
- const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
- TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
-
- if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
- TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
- context, params->activation, output, &data->output_activation_min,
- &data->output_activation_max));
-
- double real_multiplier = static_cast<double>(input1->params.scale) *
- static_cast<double>(input2->params.scale) /
- static_cast<double>(output->params.scale);
- QuantizeMultiplier(real_multiplier, &data->output_multiplier,
- &data->output_shift);
- }
-
- return kTfLiteOk;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input1 = GetInput(context, node, kInput1Tensor);
- const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- if (output->dims->size == 0) {
- return AllocateOutputDimensionsFromInput(context, input1, input2, output);
- }
-
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, OpData* data,
- const TfLiteTensor* input1,
- const TfLiteTensor* input2, TfLiteTensor* output) {
- if (output->type == kTfLiteInt8 || output->type == kTfLiteUInt8) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(data->output_activation_min,
- data->output_activation_max, &op_params);
- op_params.input1_offset = -input1->params.zero_point;
- op_params.input2_offset = -input2->params.zero_point;
- op_params.output_offset = output->params.zero_point;
- op_params.output_multiplier = data->output_multiplier;
- op_params.output_shift = data->output_shift;
- bool need_broadcast = reference_ops::ProcessBroadcastShapes(
- GetTensorShape(input1), GetTensorShape(input2), &op_params);
-
-#define TF_LITE_MUL(type, opname, dtype) \
- type::opname(op_params, GetTensorShape(input1), \
- GetTensorData<dtype>(input1), GetTensorShape(input2), \
- GetTensorData<dtype>(input2), GetTensorShape(output), \
- GetTensorData<dtype>(output));
-
- if (output->type == kTfLiteInt8) {
- if (need_broadcast) {
- TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int8_t);
- } else {
- TF_LITE_MUL(reference_integer_ops, Mul, int8_t);
- }
- } else if (output->type == kTfLiteUInt8) {
- if (need_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, uint8_t);
- } else {
- int err;
- const RuntimeShape& input1_shape = GetTensorShape(input1);
- const RuntimeShape& input2_shape = GetTensorShape(input2);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size =
- MatchingElementsSize(input1_shape, input2_shape, output_shape);
-
- err = xa_nn_elm_mul_asym8xasym8_asym8(
- GetTensorData<uint8_t>(output), op_params.output_offset,
- op_params.output_shift, op_params.output_multiplier,
- op_params.quantized_activation_min,
- op_params.quantized_activation_max, GetTensorData<uint8_t>(input1),
- op_params.input1_offset, GetTensorData<uint8_t>(input2),
- op_params.input2_offset, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_elm_mul_asym8xasym8_asym8 failed");
- }
- }
-#undef TF_LITE_MUL
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- bool need_broadcast = reference_ops::ProcessBroadcastShapes(
- GetTensorShape(input1), GetTensorShape(input2), &op_params);
-#define TF_LITE_MUL(opname) \
- reference_ops::opname(op_params, GetTensorShape(input1), \
- GetTensorData<float>(input1), GetTensorShape(input2), \
- GetTensorData<float>(input2), GetTensorShape(output), \
- GetTensorData<float>(output));
-
- if (need_broadcast) {
- TF_LITE_MUL(BroadcastMul4DSlow);
- } else {
-#if HIFI_VFPU
- int err;
- const RuntimeShape& input1_shape = GetTensorShape(input1);
- const RuntimeShape& input2_shape = GetTensorShape(input2);
- const RuntimeShape& output_shape = GetTensorShape(output);
- const int flat_size =
- MatchingElementsSize(input1_shape, input2_shape, output_shape);
-
- err = xa_nn_elm_mul_f32xf32_f32(GetTensorData<float>(output),
- GetTensorData<float>(input1),
- GetTensorData<float>(input2), flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_elm_mul_f32xf32_f32 failed");
-
- err = xa_nn_vec_activation_min_max_f32_f32(
- GetTensorData<float>(output), GetTensorData<float>(output),
- output_activation_min, output_activation_max, flat_size);
-
- CHECK_ERR_HIFI_NNLIB_KER(err,
- "xa_nn_vec_activation_min_max_f32_f32 failed");
-#else
- TF_LITE_MUL(Mul);
-#endif /* HIFI_VFPU */
- }
-#undef TF_LITE_MUL
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
- OpData data;
-
- const TfLiteTensor* input1 = GetInput(context, node, kInput1Tensor);
- const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, &data));
-
- switch (input1->type) {
- case kTfLiteUInt8:
- case kTfLiteInt8:
- TF_LITE_ENSURE_OK(context, EvalQuantized(context, node, params, &data,
- input1, input2, output));
- break;
- case kTfLiteFloat32:
- TF_LITE_ENSURE_OK(context, EvalFloat(context, node, params, &data, input1,
- input2, output));
- break;
- default:
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(input1->type), input1->type);
- return kTfLiteError;
- }
-
- return kTfLiteOk;
-}
-} // namespace mul
-
-TfLiteRegistration Register_MUL() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/nullptr,
- /*invoke=*/mul::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc
deleted file mode 100644
index 7c32b9e..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc
+++ /dev/null
@@ -1,581 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/lite/kernels/internal/reference/pooling.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/padding.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace pooling {
-
-namespace {
-
-constexpr int kInputTensor = 0;
-constexpr int kOutputTensor = 0;
-
-struct OpData {
- TfLitePaddingValues padding;
-};
-
-TfLiteStatus CalculateOpData(const TfLiteContext* context,
- const TfLitePoolParams* params,
- const TfLiteTensor* input,
- const TfLiteTensor* output, OpData* data) {
- // input: batch, height, width, channel
- int height = SizeOfDimension(input, 1);
- int width = SizeOfDimension(input, 2);
-
- int out_height, out_width;
-
- data->padding = ComputePaddingHeightWidth(
- params->stride_height, params->stride_width,
- /*dilation_rate_height=*/1,
- /*dilation_rate_width=*/1, height, width, params->filter_height,
- params->filter_width, params->padding, &out_height, &out_width);
-
- return kTfLiteOk;
-}
-
-TfLiteStatus AverageEvalFloat(TfLiteContext* context, const TfLiteNode* node,
- const TfLitePoolParams* params,
- const OpData* data, const TfLiteTensor* input,
- TfLiteTensor* output) {
- float activation_min, activation_max;
- CalculateActivationRange(params->activation, &activation_min,
- &activation_max);
-
-#if HIFI_VFPU
- const int stride_height = params->stride_height;
- const int stride_width = params->stride_width;
- const int pad_width = data->padding.width;
- const int pad_height = data->padding.height;
- const int kernel_height = params->filter_height;
- const int kernel_width = params->filter_width;
-
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int depth = MatchingDim(input_shape, 3, output_shape, 3);
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
-
- const float* inp_data_ptr;
- float* out_data_ptr;
- int inp_data_format = 0, out_data_format = 0, out_length;
- int inp_precision = PREC_F32, out_precision = PREC_F32;
- void* p_scratch;
- int err, required_scratch = 0;
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- p_scratch = (void*)xtensa_nnlib_scratch_buf;
-
- required_scratch = xa_nn_avgpool_getsize(
- depth, inp_precision, out_precision, input_height, input_width,
- kernel_height, kernel_width,
- stride_width, // x_stride,
- stride_height, // y_stride,
- pad_width, // x_padding,
- pad_height, // y_padding,
- output_height, output_width, inp_data_format, out_data_format);
-
- if (required_scratch <= 0) {
- TF_LITE_KERNEL_LOG(context,
- "AveragepoolFloat: xa_nn_avgpool_getsize failed");
- return kTfLiteError;
- }
-
- if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context,
- "AveragepoolFloat: insufficient scratch memory");
- return kTfLiteError;
- }
-
- inp_data_ptr = GetTensorData<float>(input);
- out_data_ptr = GetTensorData<float>(output);
-
- for (int batch = 0; batch < batches; ++batch) {
- err = xa_nn_avgpool_f32(
- &out_data_ptr[output_height * output_width * depth * batch],
- &inp_data_ptr[output_height * output_width * depth * batch],
- input_height, input_width, depth, kernel_height, kernel_width,
- stride_width, stride_height, pad_width, pad_height, output_height,
- output_width, inp_data_format, out_data_format, p_scratch);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "AveragepoolFloat: xa_nn_avgpool_f32 failed");
- }
-
- out_length = batches * output_height * output_width * depth;
- uint32_t p_unalign_val = (uint32_t)out_data_ptr, p_align_val;
- p_align_val = (p_unalign_val + 7) & (~7);
-
- // pre loop for activation_min_max
- int pre_loop_count = p_align_val - p_unalign_val;
- pre_loop_count = MIN(pre_loop_count, out_length);
-
- for (int i = 0; i < pre_loop_count; i++) {
- ACTIVATION_MIN_MAX(float, out_data_ptr[i], out_data_ptr[i], activation_min,
- activation_max)
- }
-
- out_length = out_length - pre_loop_count;
-
- if (out_length) {
- err = xa_nn_vec_activation_min_max_f32_f32(
- out_data_ptr, out_data_ptr, activation_min, activation_max, out_length);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "AveragepoolFloat: xa_nn_vec_activation_min_max_f32_f32 failed");
- }
-#else
- PoolParams op_params;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.filter_height = params->filter_height;
- op_params.filter_width = params->filter_width;
- op_params.padding_values.height = data->padding.height;
- op_params.padding_values.width = data->padding.width;
- op_params.float_activation_min = activation_min;
- op_params.float_activation_max = activation_max;
- reference_ops::AveragePool(
- op_params, GetTensorShape(input), GetTensorData<float>(input),
- GetTensorShape(output), GetTensorData<float>(output));
-#endif /* HIFI_VFPU */
- return kTfLiteOk;
-}
-
-TfLiteStatus AverageEvalQuantized(TfLiteContext* context,
- const TfLiteNode* node,
- const TfLitePoolParams* params,
- const OpData* data, const TfLiteTensor* input,
- TfLiteTensor* output) {
- TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
- int32_t activation_min, activation_max;
- (void)CalculateActivationRangeQuantized(context, params->activation, output,
- &activation_min, &activation_max);
-
- if (input->type == kTfLiteUInt8) {
- const int stride_height = params->stride_height;
- const int stride_width = params->stride_width;
- const int pad_width = data->padding.width;
- const int pad_height = data->padding.height;
- const int kernel_height = params->filter_height;
- const int kernel_width = params->filter_width;
-
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int depth = MatchingDim(input_shape, 3, output_shape, 3);
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
-
- const uint8_t* inp_data_ptr;
- uint8_t* out_data_ptr;
- int inp_data_format = 0, out_data_format = 0, out_length;
- int inp_precision = PREC_ASYM8, out_precision = PREC_ASYM8;
- void* p_scratch;
- int err, required_scratch = 0;
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- p_scratch = (void*)xtensa_nnlib_scratch_buf;
-
- required_scratch = xa_nn_avgpool_getsize(
- depth, inp_precision, out_precision, input_height, input_width,
- kernel_height, kernel_width,
- stride_width, // x_stride,
- stride_height, // y_stride,
- pad_width, // x_padding,
- pad_height, // y_padding,
- output_height, output_width, inp_data_format, out_data_format);
-
- if (required_scratch <= 0) {
- TF_LITE_KERNEL_LOG(context,
- "AveragepoolAsym8: xa_nn_avgpool_getsize failed");
- return kTfLiteError;
- }
-
- if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context,
- "AveragepoolAsym8: insufficient scratch memory");
- return kTfLiteError;
- }
-
- inp_data_ptr = GetTensorData<uint8_t>(input);
- out_data_ptr = GetTensorData<uint8_t>(output);
-
- for (int batch = 0; batch < batches; ++batch) {
- err = xa_nn_avgpool_asym8(
- &out_data_ptr[output_height * output_width * depth * batch],
- &inp_data_ptr[output_height * output_width * depth * batch],
- input_height, input_width, depth, kernel_height, kernel_width,
- stride_width, stride_height, pad_width, pad_height, output_height,
- output_width, inp_data_format, out_data_format, p_scratch);
-
- CHECK_ERR_HIFI_NNLIB_KER(err,
- "AveragepoolAsym8: xa_nn_avgpool_asym8 failed");
- }
-
- out_length = batches * output_height * output_width * depth;
- uint32_t p_unalign_val = (uint32_t)out_data_ptr, p_align_val;
- p_align_val = (p_unalign_val + 7) & (~7);
-
- // pre loop for activation_min_max
- int pre_loop_count = p_align_val - p_unalign_val;
- pre_loop_count = MIN(pre_loop_count, out_length);
-
- for (int i = 0; i < pre_loop_count; i++) {
- ACTIVATION_MIN_MAX_ASYM8(out_data_ptr[i], out_data_ptr[i], activation_min,
- activation_max)
- }
-
- out_length = out_length - pre_loop_count;
-
- if (out_length > 0) {
- err = xa_nn_vec_activation_min_max_asym8_asym8(
- out_data_ptr, out_data_ptr, activation_min, activation_max,
- out_length);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err,
- "AveragepoolAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed");
- }
- } else {
- PoolParams op_params;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.filter_height = params->filter_height;
- op_params.filter_width = params->filter_width;
- op_params.padding_values.height = data->padding.height;
- op_params.padding_values.width = data->padding.width;
- op_params.quantized_activation_min = activation_min;
- op_params.quantized_activation_max = activation_max;
- reference_integer_ops::AveragePool(
- op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
- GetTensorShape(output), GetTensorData<int8_t>(output));
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLitePoolParams* params, OpData* data,
- const TfLiteTensor* input, TfLiteTensor* output) {
- float activation_min, activation_max;
- CalculateActivationRange(params->activation, &activation_min,
- &activation_max);
-
-#if HIFI_VFPU
- const int stride_height = params->stride_height;
- const int stride_width = params->stride_width;
- const int pad_width = data->padding.width;
- const int pad_height = data->padding.height;
- const int kernel_height = params->filter_height;
- const int kernel_width = params->filter_width;
-
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int depth = MatchingDim(input_shape, 3, output_shape, 3);
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
-
- const float* inp_data_ptr;
- float* out_data_ptr;
- int inp_data_format = 0, out_data_format = 0, out_length;
- int inp_precision = PREC_F32, out_precision = PREC_F32;
- void* p_scratch;
- int err, required_scratch = 0;
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- p_scratch = (void*)xtensa_nnlib_scratch_buf;
-
- required_scratch = xa_nn_maxpool_getsize(
- depth, inp_precision, out_precision, input_height, input_width,
- kernel_height, kernel_width,
- stride_width, // x_stride,
- stride_height, // y_stride,
- pad_width, // x_padding,
- pad_height, // y_padding,
- output_height, output_width, inp_data_format, out_data_format);
-
- if (required_scratch <= 0) {
- TF_LITE_KERNEL_LOG(context, "MaxpoolFloat: xa_nn_maxpool_getsize failed");
- return kTfLiteError;
- }
-
- if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context, "MaxpoolFloat: insufficient scratch memory");
- return kTfLiteError;
- }
-
- inp_data_ptr = GetTensorData<float>(input);
- out_data_ptr = GetTensorData<float>(output);
-
- for (int batch = 0; batch < batches; ++batch) {
- err = xa_nn_maxpool_f32(
- &out_data_ptr[output_height * output_width * depth * batch],
- &inp_data_ptr[output_height * output_width * depth * batch],
- input_height, input_width, depth, kernel_height, kernel_width,
- stride_width, stride_height, pad_width, pad_height, output_height,
- output_width, inp_data_format, out_data_format, p_scratch);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolFloat: xa_nn_maxpool_f32 failed");
- }
-
- out_length = batches * output_height * output_width * depth;
- uint32_t p_unalign_val = (uint32_t)out_data_ptr, p_align_val;
- p_align_val = (p_unalign_val + 7) & (~7);
-
- // pre loop for activation_min_max
- int pre_loop_count = p_align_val - p_unalign_val;
- pre_loop_count = MIN(pre_loop_count, out_length);
-
- for (int i = 0; i < pre_loop_count; i++) {
- ACTIVATION_MIN_MAX(float, out_data_ptr[i], out_data_ptr[i], activation_min,
- activation_max)
- }
-
- out_length = out_length - pre_loop_count;
-
- if (out_length > 0) {
- err = xa_nn_vec_activation_min_max_f32_f32(
- out_data_ptr, out_data_ptr, activation_min, activation_max, out_length);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "MaxpoolFloat: xa_nn_vec_activation_min_max_f32_f32 failed");
- }
-#else
- tflite::PoolParams op_params;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.filter_height = params->filter_height;
- op_params.filter_width = params->filter_width;
- op_params.padding_values.height = data->padding.height;
- op_params.padding_values.width = data->padding.width;
- op_params.float_activation_min = activation_min;
- op_params.float_activation_max = activation_max;
- reference_ops::MaxPool(op_params, GetTensorShape(input),
- GetTensorData<float>(input), GetTensorShape(output),
- GetTensorData<float>(output));
-#endif /* HIFI_VFPU */
- return kTfLiteOk;
-}
-
-TfLiteStatus MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLitePoolParams* params, OpData* data,
- const TfLiteTensor* input, TfLiteTensor* output) {
- TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
-
- int32_t activation_min, activation_max;
- (void)CalculateActivationRangeQuantized(context, params->activation, output,
- &activation_min, &activation_max);
-
- if (input->type == kTfLiteUInt8) {
- const int stride_height = params->stride_height;
- const int stride_width = params->stride_width;
- const int pad_width = data->padding.width;
- const int pad_height = data->padding.height;
- const int kernel_height = params->filter_height;
- const int kernel_width = params->filter_width;
-
- const RuntimeShape& input_shape = GetTensorShape(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int depth = MatchingDim(input_shape, 3, output_shape, 3);
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
-
- const uint8_t* inp_data_ptr;
- uint8_t* out_data_ptr;
- int inp_data_format = 0, out_data_format = 0, out_length;
- int inp_precision = PREC_ASYM8, out_precision = PREC_ASYM8;
- void* p_scratch;
- int err, required_scratch = 0;
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- p_scratch = (void*)xtensa_nnlib_scratch_buf;
-
- required_scratch = xa_nn_maxpool_getsize(
- depth, inp_precision, out_precision, input_height, input_width,
- kernel_height, kernel_width,
- stride_width, // x_stride,
- stride_height, // y_stride,
- pad_width, // x_padding,
- pad_height, // y_padding,
- output_height, output_width, inp_data_format, out_data_format);
-
- if (required_scratch <= 0) {
- TF_LITE_KERNEL_LOG(context, "MaxpoolAsym8: xa_nn_maxpool_getsize failed");
- return kTfLiteError;
- }
-
- if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context, "MaxpoolAsym8: insufficient scratch memory");
- return kTfLiteError;
- }
-
- inp_data_ptr = GetTensorData<uint8_t>(input);
- out_data_ptr = GetTensorData<uint8_t>(output);
-
- for (int batch = 0; batch < batches; ++batch) {
- err = xa_nn_maxpool_asym8(
- &out_data_ptr[output_height * output_width * depth * batch],
- &inp_data_ptr[output_height * output_width * depth * batch],
- input_height, input_width, depth, kernel_height, kernel_width,
- stride_width, stride_height, pad_width, pad_height, output_height,
- output_width, inp_data_format, out_data_format, p_scratch);
-
- CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolAsym8: xa_nn_maxpool_asym8 failed");
- }
-
- out_length = batches * output_height * output_width * depth;
- uint32_t p_unalign_val = (uint32_t)out_data_ptr, p_align_val;
- p_align_val = (p_unalign_val + 7) & (~7);
-
- // pre loop for activation_min_max
- int pre_loop_count = p_align_val - p_unalign_val;
- pre_loop_count = MIN(pre_loop_count, out_length);
-
- for (int i = 0; i < pre_loop_count; i++) {
- ACTIVATION_MIN_MAX_ASYM8(out_data_ptr[i], out_data_ptr[i], activation_min,
- activation_max)
- }
-
- out_length = out_length - pre_loop_count;
-
- if (out_length > 0) {
- err = xa_nn_vec_activation_min_max_asym8_asym8(
- out_data_ptr, out_data_ptr, activation_min, activation_max,
- out_length);
-
- CHECK_ERR_HIFI_NNLIB_KER(
- err, "MaxpoolAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed");
- }
- } else {
- tflite::PoolParams op_params;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.filter_height = params->filter_height;
- op_params.filter_width = params->filter_width;
- op_params.padding_values.height = data->padding.height;
- op_params.padding_values.width = data->padding.width;
- op_params.quantized_activation_min = activation_min;
- op_params.quantized_activation_max = activation_max;
- reference_integer_ops::MaxPool(
- op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
- GetTensorShape(output), GetTensorData<int8_t>(output));
- }
- return kTfLiteOk;
-}
-} // namespace
-
-TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
- OpData data;
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data));
-
- // Inputs and outputs share the same type, guaranteed by the converter.
- switch (input->type) {
- case kTfLiteFloat32:
- AverageEvalFloat(context, node, params, &data, input, output);
- break;
- case kTfLiteUInt8:
- case kTfLiteInt8:
- AverageEvalQuantized(context, node, params, &data, input, output);
- break;
- default:
- TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
- TfLiteTypeGetName(input->type));
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
- OpData data;
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data));
-
- switch (input->type) {
- case kTfLiteFloat32:
- MaxEvalFloat(context, node, params, &data, input, output);
- break;
- case kTfLiteUInt8:
- case kTfLiteInt8:
- MaxEvalQuantized(context, node, params, &data, input, output);
- break;
- default:
- TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
- TfLiteTypeGetName(input->type));
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-} // namespace pooling
-
-TfLiteRegistration Register_AVERAGE_POOL_2D() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/nullptr,
- /*invoke=*/pooling::AverageEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-TfLiteRegistration Register_MAX_POOL_2D() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/nullptr,
- /*invoke=*/pooling::MaxEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc
deleted file mode 100755
index 65ead0f..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc
+++ /dev/null
@@ -1,207 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/softmax.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace activations {
-namespace {
-
-TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
- const TfLiteTensor* input,
- TfLiteTensor* output,
- const TfLiteSoftmaxParams* params,
- SoftmaxParams* op_data) {
- if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
- if (input->type == kTfLiteUInt8) {
- TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
- TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
- if (output->type == kTfLiteInt16) {
- TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
- // NOTE: Current int16_t softmax output does not require symmetric
- // scaling
- // - so no need to verify scale here.
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
- TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
- TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
- }
- }
-
- static const int kScaledDiffIntegerBits = 5;
-
- int input_left_shift;
- tflite::PreprocessSoftmaxScaling(
- static_cast<double>(params->beta),
- static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
- &op_data->input_multiplier, &input_left_shift);
- op_data->input_left_shift = input_left_shift;
- op_data->diff_min =
- -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
- op_data->input_left_shift);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
- TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
- op_data->beta = static_cast<double>(params->beta);
- }
- return kTfLiteOk;
-}
-
-} // namespace
-
-TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input = GetInput(context, node, 0);
- TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
-
- return kTfLiteOk;
-}
-
-// Takes a tensor and performs softmax along the last dimension.
-TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input,
- TfLiteTensor* output, const SoftmaxParams& op_data) {
-#if HIFI_VFPU
- const RuntimeShape& input_shape = GetTensorShape(input);
- const float* input_data = GetTensorData<float>(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- float* output_data = GetTensorData<float>(output);
- const int trailing_dim = input_shape.DimensionsCount() - 1;
- const int outer_size =
- MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
- const int depth =
- MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- float* p_scratch = (float*)xtensa_nnlib_scratch_buf;
-
- if (depth * sizeof(float) > XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context, "Softmax: insufficient scratch memory");
- return kTfLiteError;
- }
-
- for (int i = 0; i < outer_size; ++i) {
- for (int c = 0; c < depth; ++c) {
- p_scratch[c] =
- input_data[i * depth + c] * static_cast<float>(op_data.beta);
- }
-
- int err =
- xa_nn_vec_softmax_f32_f32(&output_data[i * depth], p_scratch, depth);
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_f32_f32 failed");
- }
-#else
- tflite::reference_ops::Softmax(
- op_data, GetTensorShape(input), GetTensorData<float>(input),
- GetTensorShape(output), GetTensorData<float>(output));
-#endif /* HIFI_VFPU */
- return kTfLiteOk;
-}
-
-TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
- TfLiteTensor* output,
- const SoftmaxParams& op_data) {
- if (input->type == kTfLiteUInt8) {
- const RuntimeShape& input_shape = GetTensorShape(input);
- const uint8_t* input_data = GetTensorData<uint8_t>(input);
- const RuntimeShape& output_shape = GetTensorShape(output);
- uint8_t* output_data = GetTensorData<uint8_t>(output);
- const int trailing_dim = input_shape.DimensionsCount() - 1;
- const int outer_size =
- MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
- const int depth =
- MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-
- ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
- void* p_scratch = (void*)xtensa_nnlib_scratch_buf;
-
- if (get_softmax_scratch_size(PREC_ASYM8, PREC_ASYM8, depth) >
- XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
- TF_LITE_KERNEL_LOG(context, "Softmax: insufficient scratch memory");
- return kTfLiteError;
- }
-
- for (int i = 0; i < outer_size; ++i) {
- int err = xa_nn_vec_softmax_asym8_asym8(
- &output_data[i * depth], &input_data[i * depth], op_data.diff_min,
- op_data.input_left_shift, op_data.input_multiplier, depth, p_scratch);
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_asym8_asym8 failed");
- }
- } else {
- if (output->type == kTfLiteInt16) {
- tflite::reference_ops::Softmax(
- op_data, GetTensorShape(input), GetTensorData<int8_t>(input),
- GetTensorShape(output), GetTensorData<int16_t>(output));
- } else {
- tflite::reference_ops::Softmax(
- op_data, GetTensorShape(input), GetTensorData<int8_t>(input),
- GetTensorShape(output), GetTensorData<int8_t>(output));
- }
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
-
- const TfLiteTensor* input = GetInput(context, node, 0);
- TfLiteTensor* output = GetOutput(context, node, 0);
-
- SoftmaxParams op_data;
- TF_LITE_ENSURE_STATUS(
- CalculateSoftmaxParams(context, input, output, params, &op_data));
-
- switch (input->type) {
- case kTfLiteFloat32: {
- return SoftmaxFloat(context, input, output, op_data);
- }
- case kTfLiteInt8:
- case kTfLiteUInt8: {
- return SoftmaxQuantized(context, input, output, op_data);
- }
- default:
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(input->type), input->type);
- return kTfLiteError;
- }
-}
-} // namespace activations
-
-TfLiteRegistration Register_SOFTMAX() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/activations::SoftmaxPrepare,
- /*invoke=*/activations::SoftmaxEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc
deleted file mode 100644
index d8ee6b2..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc
+++ /dev/null
@@ -1,601 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <math.h>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/kernels/op_macros.h"
-#include "tensorflow/lite/micro/kernels/activation_utils.h"
-#include "tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h"
-#include "tensorflow/lite/micro/micro_utils.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace svdf {
-namespace {
-
-struct OpData {
- int32_t effective_scale_1_a;
- int32_t effective_scale_2_a;
- // b versions of each scale are kept at int since the numbers are just the
- // shift value - typically between [-32, 32].
- int effective_scale_1_b;
- int effective_scale_2_b;
- int scratch_tensor_index;
- int scratch_output_tensor_index;
-};
-
-/**
- * This version of SVDF is specific to TFLite Micro. It contains the following
- * differences between the TFLite version:
- *
- * 1.) Scratch tensor allocation - scratch tensors must be known ahead of time
- * for the Micro interpreter.
- * 2.) Output dimensions - the TFLite version determines output size and runtime
- * and resizes the output tensor. Micro runtime does not support tensor
- * resizing.
- */
-
-static inline TfLiteStatus ApplyTimeWeightsBiasAndActivation(
- TfLiteContext* context, int batch_size, int memory_size, int num_filters,
- int num_units, int rank, const float* const __restrict__ weights_time_ptr,
- const float* const __restrict__ bias_ptr, TfLiteFusedActivation activation,
- float* const __restrict__ state_ptr, float* const __restrict__ scratch_ptr,
- float* const __restrict__ output_ptr) {
- // Compute matmul(activation_state, weights_time).
-#if HIFI_VFPU
- float* scratch_bias = scratch_ptr;
- if (bias_ptr) {
- const float* bias_data = bias_ptr;
- for (int j = 0; j < num_units; ++j) {
- scratch_bias[j] = *bias_data++;
- }
- } else {
- for (int j = 0; j < num_units; ++j) {
- scratch_bias[j] = 0.0f;
- }
- }
- int err = 0;
- for (int b = 0; b < batch_size; ++b) {
- const float* weights_time_vec = weights_time_ptr;
- const float* mat_ptr = state_ptr + b * memory_size * num_filters;
- float* output_ptr_batch = output_ptr + b * num_units;
- for (int j = 0; j < num_units; j++) {
- err = xa_nn_matXvec_f32xf32_f32(
- output_ptr_batch, mat_ptr, NULL, weights_time_vec, NULL, scratch_bias,
- 1, memory_size * rank, 0, memory_size * rank, 0);
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_f32xf32_f32 failed");
-
- output_ptr_batch++;
- mat_ptr += memory_size * rank;
- weights_time_vec += memory_size * rank;
- }
- }
-#else
- for (int b = 0; b < batch_size; ++b) {
- // Perform batched vector dot product:
- float* scratch_ptr_batch = scratch_ptr + b * num_filters;
- const float* vector1_ptr = weights_time_ptr;
- const float* vector2_ptr = state_ptr + b * memory_size * num_filters;
- for (int i = 0; i < num_filters; ++i) {
- *scratch_ptr_batch = 0.f;
- for (int j = 0; j < memory_size; ++j) {
- *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
- }
- scratch_ptr_batch++;
- }
- }
-
- // Initialize output with bias if provided.
- if (bias_ptr) {
- // VectorBatchVectorAssign
- for (int i = 0; i < batch_size; ++i) {
- float* output_data = output_ptr + i * num_units;
- const float* bias_data = bias_ptr;
- for (int j = 0; j < num_units; ++j) {
- *output_data++ = *bias_data++;
- }
- }
- } else {
- float* output_data = output_ptr;
- for (int i = 0; i < batch_size * num_units; ++i) {
- *output_data++ = 0.0f;
- }
- }
-
- // Reduction sum.
- for (int b = 0; b < batch_size; ++b) {
- float* output_ptr_batch = output_ptr + b * num_units;
- float* scratch_ptr_batch = scratch_ptr + b * num_filters;
-
- // Reduction sum vector
- for (int i = 0; i < num_units; ++i) {
- for (int j = 0; j < rank; j++) {
- output_ptr_batch[i] += *scratch_ptr_batch++;
- }
- }
- }
-#endif /* HIFI_VFPU */
-
- // Apply activation.
- for (int b = 0; b < batch_size; ++b) {
- float* output_ptr_batch = output_ptr + b * num_units;
- for (int i = 0; i < num_units; ++i) {
- *output_ptr_batch = ActivationValFloat(activation, *output_ptr_batch);
- ++output_ptr_batch;
- }
- }
- return kTfLiteOk;
-}
-
-inline TfLiteStatus EvalFloatSVDF(
- TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
- const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
- const TfLiteTensor* bias, const TfLiteSVDFParams* params,
- int scratch_tensor_index, TfLiteTensor* activation_state,
- TfLiteTensor* output) {
- const int rank = params->rank;
- const int batch_size = input->dims->data[0];
- const int input_size = input->dims->data[1];
- const int num_filters = weights_feature->dims->data[0];
- const int num_units = num_filters / rank;
- const int memory_size = weights_time->dims->data[1];
-
- const float* weights_feature_ptr = GetTensorData<float>(weights_feature);
- const float* weights_time_ptr = GetTensorData<float>(weights_time);
- const float* bias_ptr = GetTensorData<float>(bias);
- const float* input_ptr = GetTensorData<float>(input);
-
- float* state_ptr = GetTensorData<float>(activation_state);
-
- TFLITE_DCHECK(context != nullptr);
- TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
-
- float* scratch_ptr = static_cast<float*>(
- context->GetScratchBuffer(context, scratch_tensor_index));
-
- float* output_ptr = GetTensorData<float>(output);
-
- // Left shift the activation_state.
- {
- float* new_state_start = state_ptr;
- const float* old_state_start = state_ptr + 1;
- const float* old_state_end =
- state_ptr + batch_size * num_filters * memory_size;
- while (old_state_start != old_state_end) {
- *new_state_start++ = *old_state_start++;
- }
- }
-
- // Note: no need to clear the latest activation, matmul is not accumulative.
-
- // Compute conv1d(inputs, weights_feature).
- // The activation_state's rightmost column is used to save current cycle
- // activation. This is achieved by starting at state_ptr[memory_size - 1] and
- // having the stride equal to memory_size.
-
- // Perform batched matrix vector multiply operation:
- {
- const float* matrix = weights_feature_ptr;
- const float* vector = input_ptr;
- float* result = &state_ptr[memory_size - 1];
- float* result_in_batch = result;
-
-#if HIFI_VFPU
- float* out_scratch = scratch_ptr;
- float* bias_scratch = output_ptr;
- for (int i = 0; i < num_units; i++) bias_scratch[i] = 0.0f;
-
- int err = 0;
- for (int i = 0; i < batch_size; i++) {
- /* We are using output buffer for bias (it is needed by NNLib kernel,
- so only num_units size is guaranteed, so introduced rank loop and
- calling matXvec for num_units rows */
- for (int j = 0; j < rank; j++) {
- err = xa_nn_matXvec_f32xf32_f32(
- &out_scratch[j * num_units], &matrix[j * input_size * num_units],
- NULL, &vector[i * input_size], NULL, bias_scratch, num_units,
- input_size, 0, input_size, 0);
- CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_f32xf32_f32 failed");
- }
- for (int j = 0; j < num_filters; ++j) {
- *result_in_batch = out_scratch[j];
- result_in_batch += memory_size;
- }
- }
-#else
- for (int i = 0; i < batch_size; ++i) {
- const float* matrix_ptr = matrix;
- for (int j = 0; j < num_filters; ++j) {
- float dot_prod = 0.0f;
- const float* vector_in_batch = vector + i * input_size;
- for (int k = 0; k < input_size; ++k) {
- dot_prod += *matrix_ptr++ * *vector_in_batch++;
- }
- *result_in_batch = dot_prod;
- result_in_batch += memory_size;
- }
- }
-#endif /* HIFI_VFPU */
- }
-
- return ApplyTimeWeightsBiasAndActivation(
- context, batch_size, memory_size, num_filters, num_units, rank,
- weights_time_ptr, bias_ptr, params->activation, state_ptr, scratch_ptr,
- output_ptr);
-}
-
-void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
- const TfLiteTensor* input_tensor,
- const TfLiteTensor* weights_feature_tensor,
- const TfLiteTensor* weights_time_tensor,
- const TfLiteTensor* bias_tensor,
- const TfLiteSVDFParams* params,
- TfLiteTensor* activation_state_tensor,
- TfLiteTensor* output_tensor, const OpData& data,
- int32_t input_zp, int32_t output_zp) {
- const int n_rank = params->rank;
- const int n_batch = input_tensor->dims->data[0];
- const int n_input = input_tensor->dims->data[1];
- const int n_filter = weights_feature_tensor->dims->data[0];
- const int n_unit = n_filter / n_rank;
- const int n_memory = weights_time_tensor->dims->data[1];
-
- TFLITE_DCHECK(context != nullptr);
- TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
-
- int32_t* scratch_tensor = static_cast<int32_t*>(
- context->GetScratchBuffer(context, data.scratch_tensor_index));
- int32_t* scratch_output_tensor = static_cast<int32_t*>(
- context->GetScratchBuffer(context, data.scratch_output_tensor_index));
-
- // Shift states.
- int16_t* const state_ptr = GetTensorData<int16_t>(activation_state_tensor);
-
- // Left shift the activation_state.
- {
- int16_t* new_state_start = state_ptr;
- const int16_t* old_state_start = state_ptr + 1;
- const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
- while (old_state_start != old_state_end) {
- *new_state_start++ = *old_state_start++;
- }
- }
-
- // Note: no need to clear the latest activation, matmul is not accumulative.
-
- // Feature matmul.
- {
- int16_t* state = GetTensorData<int16_t>(activation_state_tensor);
- const int8_t* input = GetTensorData<int8_t>(input_tensor);
- const int8_t* weight_feature =
- GetTensorData<int8_t>(weights_feature_tensor);
- const int32_t output_max = std::numeric_limits<int16_t>::max();
- const int32_t output_min = std::numeric_limits<int16_t>::min();
- int16_t* result_in_batch = state + (n_memory - 1);
- for (int b = 0; b < n_batch; b++) {
- const int8_t* matrix_ptr = weight_feature;
- for (int r = 0; r < n_filter; r++) {
- int32_t dot_prod = 0;
- const int8_t* vector_in_batch = input + b * n_input;
- for (int c = 0; c < n_input; c++) {
- dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp);
- }
- dot_prod = MultiplyByQuantizedMultiplier(
- dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
- dot_prod = std::min(std::max(output_min, dot_prod), output_max);
- // This assumes state is symmetrically quantized. Otherwise last bit of
- // state should be initialized to its zero point and accumulate the
- // dot_prod.
- // Equivalent as the following:
- // result_in_batch = zero point, which happens to be zero.
- // result_in_batch += dot_prod_56.
- *result_in_batch = dot_prod;
- result_in_batch += n_memory;
- }
- }
- }
-
- // Time.
- {
- for (int b = 0; b < n_batch; ++b) {
- int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
-
- // Perform batched vector dot product:
- const int16_t* vector1_ptr = GetTensorData<int16_t>(weights_time_tensor);
- const int16_t* vector2_ptr =
- GetTensorData<int16_t>(activation_state_tensor) +
- b * n_memory * n_filter;
-
- for (int i = 0; i < n_filter; i++) {
- *scratch_ptr_batch = 0;
- for (int j = 0; j < n_memory; j++) {
- *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
- }
- scratch_ptr_batch++;
- }
- }
- }
-
- // Reduce, add bias, rescale, activation.
- {
- // Add bias.
- if (bias_tensor) {
- // Vector batch assign:
- const int32_t* bias_data = GetTensorData<int32_t>(bias_tensor);
- for (int i = 0; i < n_batch; ++i) {
- int32_t* output_ptr = scratch_output_tensor + i * n_unit;
- const int32_t* bias_ptr = bias_data;
- for (int j = 0; j < n_unit; ++j) {
- *output_ptr++ = *bias_ptr++;
- }
- }
- } else {
- int32_t* output_ptr = scratch_output_tensor;
- for (int i = 0; i < n_batch * n_unit; ++i) {
- *output_ptr++ = 0;
- }
- }
-
- // Reduce.
- for (int b = 0; b < n_batch; ++b) {
- int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
- int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
-
- // Reduction sum vector
- for (int i = 0; i < n_unit; ++i) {
- for (int j = 0; j < n_rank; ++j) {
- output_temp_ptr[i] += *scratch_ptr_batch++;
- }
- }
- }
-
- // Rescale.
- const int32_t output_max = std::numeric_limits<int8_t>::max();
- const int32_t output_min = std::numeric_limits<int8_t>::min();
- for (int i = 0; i < n_batch * n_unit; ++i) {
- int32_t x1 = scratch_output_tensor[i];
- int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
- data.effective_scale_2_b);
- int32_t x3 = x2 + output_zp;
- int32_t x4 = std::min(std::max(output_min, x3), output_max);
- GetTensorData<int8_t>(output_tensor)[i] = static_cast<int8_t>(x4);
- }
- }
-}
-
-} // namespace
-
-// Input tensors.
-constexpr int kInputTensor = 0;
-constexpr int kWeightsFeatureTensor = 1;
-constexpr int kWeightsTimeTensor = 2;
-constexpr int kBiasTensor = 3;
-// This is a variable tensor, and will be modified by this op.
-constexpr int kInputActivationStateTensor = 4;
-
-// Output tensor.
-constexpr int kOutputTensor = 0;
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
- void* data = nullptr;
- if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
- kTfLiteError) {
- return nullptr;
- }
- return data;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->builtin_data != nullptr);
-
- const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
-
- // Validate Tensor Inputs (dtype depends on quantization):
- // [0] = Input, {2, batch_size, input_size}
- // [1] = Weights Feature, {2, num_filters, input_size}
- // [2] = Weights Time, {2, num_filters, memory_size}
- // [3] = Bias (optional), {1, num_units}
- // [4] = Activation State (variable),
- // {2, batch_size, memory_size * num_filters}
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* weights_feature =
- GetInput(context, node, kWeightsFeatureTensor);
- const TfLiteTensor* weights_time =
- GetInput(context, node, kWeightsTimeTensor);
- const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
- const TfLiteTensor* activation_state =
- GetInput(context, node, kInputActivationStateTensor);
-
- // Define input constants based on input tensor definition above:
- const int rank = params->rank;
- const int input_size = input->dims->data[1];
- const int batch_size = input->dims->data[0];
- const int num_filters = weights_feature->dims->data[0];
- TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
- const int num_units = num_filters / rank;
- const int memory_size = weights_time->dims->data[1];
-
- // Validate Input Tensor:
- TF_LITE_ENSURE(context,
- input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
- TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
-
- // Validate Tensor Output:
- // [0] = float/int8_t, {2, batch_size, num_units}
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
- TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
- TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
-
- // Validate Weights Feature Input Tensor:
- TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
- TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
-
- // Validate Weights Time Input Tensor:
- TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
- TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
- TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
-
- // Validate Optional Bias Input Tensor:
- if (bias != nullptr) {
- TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
- }
-
- // Validate Activation State Input Tensor:
- TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
- TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
- TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
- memory_size * num_filters);
-
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
-
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
- TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
- TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
- if (bias != nullptr) {
- TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
- }
-
- TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
-
- const auto* input_params =
- reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
- const auto* weights_feature_params =
- static_cast<const TfLiteAffineQuantization*>(
- weights_feature->quantization.params);
- const auto* state_params = static_cast<const TfLiteAffineQuantization*>(
- activation_state->quantization.params);
- const auto* weight_time_params =
- static_cast<const TfLiteAffineQuantization*>(
- weights_time->quantization.params);
- const auto* output_params = static_cast<const TfLiteAffineQuantization*>(
- output->quantization.params);
- const double effective_scale_1 = static_cast<double>(
- input_params->scale->data[0] * weights_feature_params->scale->data[0] /
- state_params->scale->data[0]);
- const double effective_scale_2 = static_cast<double>(
- state_params->scale->data[0] * weight_time_params->scale->data[0] /
- output_params->scale->data[0]);
-
- TFLITE_DCHECK(node->user_data != nullptr);
- OpData* data = static_cast<OpData*>(node->user_data);
-
- QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
- &(data->effective_scale_1_b));
- QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
- &(data->effective_scale_2_b));
-
- TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
-
- const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
- context, batch_size * num_filters * sizeof(int32_t),
- &(data->scratch_tensor_index));
- TF_LITE_ENSURE_OK(context, scratch_status);
-
- const TfLiteStatus scratch_output_status =
- context->RequestScratchBufferInArena(
- context, batch_size * num_units * sizeof(int32_t),
- &(data->scratch_output_tensor_index));
- TF_LITE_ENSURE_OK(context, scratch_output_status);
- } else {
- TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
- TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
- TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
- if (bias != nullptr) {
- TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
- }
- TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
-
- TFLITE_DCHECK(node->user_data != nullptr);
- OpData* data = static_cast<OpData*>(node->user_data);
-
- TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
- const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
- context, batch_size * num_filters * sizeof(float),
- &(data->scratch_tensor_index));
- TF_LITE_ENSURE_OK(context, scratch_status);
- }
-
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* weights_feature =
- GetInput(context, node, kWeightsFeatureTensor);
- const TfLiteTensor* weights_time =
- GetInput(context, node, kWeightsTimeTensor);
- const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
- TfLiteTensor* activation_state =
- GetVariableInput(context, node, kInputActivationStateTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- TFLITE_DCHECK(node->user_data != nullptr);
- const OpData& data = *(static_cast<const OpData*>(node->user_data));
-
- switch (weights_feature->type) {
- case kTfLiteFloat32: {
- return EvalFloatSVDF(context, node, input, weights_feature, weights_time,
- bias, params, data.scratch_tensor_index,
- activation_state, output);
- break;
- }
-
- case kTfLiteInt8: {
- TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
-
- EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
- params, activation_state, output, data,
- input->params.zero_point, output->params.zero_point);
- return kTfLiteOk;
- break;
- }
-
- default:
- TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
- TfLiteTypeGetName(weights_feature->type));
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-} // namespace svdf
-
-TfLiteRegistration Register_SVDF() {
- return {/*init=*/svdf::Init,
- /*free=*/nullptr,
- /*prepare=*/svdf::Prepare,
- /*invoke=*/svdf::Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
-}
-
-} // namespace micro
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h b/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h
deleted file mode 100755
index 6fe6bae..0000000
--- a/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __XTENSA_TF_MICRO_COMMON__
-#define __XTENSA_TF_MICRO_COMMON__
-
-#include "xa_nnlib_api.h"
-#include "xa_nnlib_standards.h"
-
-#define CHECK_ERR_HIFI_NNLIB_KER(ret, err_msg) \
- if (ret != 0) { \
- TF_LITE_KERNEL_LOG(context, err_msg); \
- return kTfLiteError; \
- }
-
-#ifndef XTENSA_NNLIB_MAX_SCRATCH_SIZE
-#define XTENSA_NNLIB_MAX_SCRATCH_SIZE (70 * 1024)
-#endif
-
-#define ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM \
- uint8_t xtensa_nnlib_scratch_buf[XTENSA_NNLIB_MAX_SCRATCH_SIZE];
-
-#define MIN(a, b) (a) < (b) ? (a) : (b);
-#define MAX(a, b) (a) > (b) ? (a) : (b);
-
-#define ACTIVATION_MIN_MAX(data_type, out, inp, min, max) \
- { \
- data_type temp = MAX(inp, min); \
- out = MIN(temp, max); \
- }
-
-#define ACTIVATION_MIN_MAX_F32(out, inp, min, max) \
- { \
- float temp = MAX(inp, min); \
- out = MIN(temp, max); \
- }
-
-#define ACTIVATION_MIN_MAX_ASYM8(out, inp, min, max) \
- { \
- int32_t temp = MAX((int32_t)inp, min); \
- out = (uint8_t)MIN(temp, max); \
- }
-
-#define ALIGNED_SIZE(x, bytes) (((x) + (bytes - 1)) & (~(bytes - 1)))
-#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
-
-#endif /* __XTENSA_TF_MICRO_COMMON__ */
diff --git a/tensorflow/lite/micro/memory_helpers.cc b/tensorflow/lite/micro/memory_helpers.cc
index 08bd9a8..2d8f759 100644
--- a/tensorflow/lite/micro/memory_helpers.cc
+++ b/tensorflow/lite/micro/memory_helpers.cc
@@ -63,6 +63,9 @@
case kTfLiteInt32:
*size = sizeof(int32_t);
break;
+ case kTfLiteUInt32:
+ *size = sizeof(uint32_t);
+ break;
case kTfLiteUInt8:
*size = sizeof(uint8_t);
break;
diff --git a/tensorflow/lite/micro/memory_helpers_test.cc b/tensorflow/lite/micro/memory_helpers_test.cc
index 5f28dea..230539c 100644
--- a/tensorflow/lite/micro/memory_helpers_test.cc
+++ b/tensorflow/lite/micro/memory_helpers_test.cc
@@ -137,6 +137,10 @@
TF_LITE_MICRO_EXPECT_EQ(sizeof(int32_t), size);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+ tflite::TfLiteTypeSizeOf(kTfLiteUInt32, &size));
+ TF_LITE_MICRO_EXPECT_EQ(sizeof(uint32_t), size);
+
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::TfLiteTypeSizeOf(kTfLiteUInt8, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(uint8_t), size);
diff --git a/tensorflow/lite/micro/memory_planner/BUILD b/tensorflow/lite/micro/memory_planner/BUILD
index e524e85..a190890 100644
--- a/tensorflow/lite/micro/memory_planner/BUILD
+++ b/tensorflow/lite/micro/memory_planner/BUILD
@@ -1,8 +1,4 @@
load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-load(
"//tensorflow/lite/micro:build_def.bzl",
"micro_copts",
)
@@ -58,7 +54,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "linear_memory_planner_test",
srcs = [
"linear_memory_planner_test.cc",
@@ -69,7 +65,7 @@
],
)
-tflite_micro_cc_test(
+cc_test(
name = "greedy_memory_planner_test",
srcs = [
"greedy_memory_planner_test.cc",
diff --git a/tensorflow/lite/micro/sparkfun_edge/debug_log.cc b/tensorflow/lite/micro/sparkfun_edge/debug_log.cc
index 984d2a9..f1babc1 100644
--- a/tensorflow/lite/micro/sparkfun_edge/debug_log.cc
+++ b/tensorflow/lite/micro/sparkfun_edge/debug_log.cc
@@ -1,4 +1,4 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,24 +12,10 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-// Implementation for the DebugLog() function that prints to the UART on the
-// SparkFun Edge microcontroller. The same should work for other targets using
-// the Ambiq Apollo 3.
-
-#include "tensorflow/lite/micro/debug_log.h"
-
-#include "am_bsp.h" // NOLINT
-#include "am_util.h" // NOLINT
-
-extern "C" void DebugLog(const char* s) {
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
- static bool is_initialized = false;
- if (!is_initialized) {
- am_bsp_uart_printf_enable();
- is_initialized = true;
- }
-
- am_util_stdio_printf("%s", s);
-#endif
-}
+// This file is empty to ensure that a specialized implementation of
+// debug_log.h is used (instead of the default implementation from
+// tensorflow/lite/micro/debug_log.cc).
+//
+// The actual target-specific implementation of debug_log.h is in
+// system_setup.cc since that allows us to consolidate all the target-specific
+// specializations into one source file.
diff --git a/tensorflow/lite/micro/sparkfun_edge/micro_time.cc b/tensorflow/lite/micro/sparkfun_edge/micro_time.cc
index 9987a3b..a7db6e4 100644
--- a/tensorflow/lite/micro/sparkfun_edge/micro_time.cc
+++ b/tensorflow/lite/micro/sparkfun_edge/micro_time.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,91 +12,10 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-// Reference implementation of timer functions. Platforms are not required to
-// implement these timer methods, but they are required to enable profiling.
-
-// On platforms that have a POSIX stack or C library, it can be written using
-// methods from <sys/time.h> or clock() from <time.h>.
-
-// To add an equivalent function for your own platform, create your own
-// implementation file, and place it in a subfolder with named after the OS
-// you're targeting. For example, see the Cortex M bare metal version in
-// tensorflow/lite/micro/bluepill/micro_timer.cc or the mbed one on
-// tensorflow/lite/micro/mbed/micro_timer.cc.
-
-#include "tensorflow/lite/micro/micro_time.h"
-
-#include "tensorflow/lite/micro/debug_log.h"
-
-// These are headers from Ambiq's Apollo3 SDK.
-#include "am_bsp.h" // NOLINT
-#include "am_mcu_apollo.h" // NOLINT
-#include "am_util.h" // NOLINT
-
-namespace tflite {
-namespace {
-
-// Select CTIMER 1 as benchmarking timer on Sparkfun Edge. This timer must not
-// be used elsewhere.
-constexpr int kTimerNum = 1;
-
-// Clock set to operate at 12MHz.
-constexpr int kClocksPerSecond = 12e6;
-
-// Enables 96MHz burst mode on Sparkfun Edge. Enable in timer since most
-// benchmarks and profilers want maximum performance for debugging.
-void BurstModeEnable() {
- am_hal_clkgen_control(AM_HAL_CLKGEN_CONTROL_SYSCLK_MAX, 0);
-
- // Set the default cache configuration
- am_hal_cachectrl_config(&am_hal_cachectrl_defaults);
- am_hal_cachectrl_enable();
-
- am_hal_burst_avail_e eBurstModeAvailable;
- am_hal_burst_mode_e eBurstMode;
-
- // Check that the Burst Feature is available.
- int status = am_hal_burst_mode_initialize(&eBurstModeAvailable);
- if (status != AM_HAL_STATUS_SUCCESS ||
- eBurstModeAvailable != AM_HAL_BURST_AVAIL) {
- DebugLog("Failed to initialize burst mode.");
- return;
- }
-
- status = am_hal_burst_mode_enable(&eBurstMode);
-
- if (status != AM_HAL_STATUS_SUCCESS || eBurstMode != AM_HAL_BURST_MODE) {
- DebugLog("Failed to Enable Burst Mode operation\n");
- }
-}
-
-} // namespace
-
-int32_t ticks_per_second() { return kClocksPerSecond; }
-
-// Calling this method enables a timer that runs for eternity. The user is
-// responsible for avoiding trampling on this timer's config, otherwise timing
-// measurements may no longer be valid.
-int32_t GetCurrentTimeTicks() {
- // TODO(b/150808076): Split out initialization, intialize in interpreter.
- static bool is_initialized = false;
- if (!is_initialized) {
- BurstModeEnable();
- am_hal_ctimer_config_t timer_config;
- // Operate as a 32-bit timer.
- timer_config.ui32Link = 1;
- // Set timer A to continuous mode at 12MHz.
- timer_config.ui32TimerAConfig =
- AM_HAL_CTIMER_FN_CONTINUOUS | AM_HAL_CTIMER_HFRC_12MHZ;
-
- am_hal_ctimer_stop(kTimerNum, AM_HAL_CTIMER_BOTH);
- am_hal_ctimer_clear(kTimerNum, AM_HAL_CTIMER_BOTH);
- am_hal_ctimer_config(kTimerNum, &timer_config);
- am_hal_ctimer_start(kTimerNum, AM_HAL_CTIMER_TIMERA);
- is_initialized = true;
- }
- return CTIMERn(kTimerNum)->TMR0;
-}
-
-} // namespace tflite
+// This file is empty to ensure that a specialized implementation of
+// micro_time.h is used (instead of the default implementation from
+// tensorflow/lite/micro/micro_time.cc).
+//
+// The actual target-specific implementation of micro_time.h is in
+// system_setup.cc since that allows us to consolidate all the target-specific
+// specializations into one source file.
diff --git a/tensorflow/lite/micro/sparkfun_edge/system_setup.cc b/tensorflow/lite/micro/sparkfun_edge/system_setup.cc
new file mode 100644
index 0000000..995a3bb
--- /dev/null
+++ b/tensorflow/lite/micro/sparkfun_edge/system_setup.cc
@@ -0,0 +1,99 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/system_setup.h"
+
+#include "tensorflow/lite/micro/debug_log.h"
+#include "tensorflow/lite/micro/micro_time.h"
+
+// These are headers from Ambiq's Apollo3 SDK.
+#include "am_bsp.h" // NOLINT
+#include "am_mcu_apollo.h" // NOLINT
+#include "am_util.h" // NOLINT
+
+namespace {
+
+// Select CTIMER 1 as benchmarking timer on Sparkfun Edge. This timer must not
+// be used elsewhere.
+constexpr int kTimerNum = 1;
+
+// Clock set to operate at 12MHz.
+constexpr int kClocksPerSecond = 12e6;
+
+// Enables 96MHz burst mode on Sparkfun Edge. Enable in timer since most
+// benchmarks and profilers want maximum performance for debugging.
+void BurstModeEnable() {
+ am_hal_clkgen_control(AM_HAL_CLKGEN_CONTROL_SYSCLK_MAX, 0);
+
+ // Set the default cache configuration
+ am_hal_cachectrl_config(&am_hal_cachectrl_defaults);
+ am_hal_cachectrl_enable();
+
+ am_hal_burst_avail_e eBurstModeAvailable;
+ am_hal_burst_mode_e eBurstMode;
+
+ // Check that the Burst Feature is available.
+ int status = am_hal_burst_mode_initialize(&eBurstModeAvailable);
+ if (status != AM_HAL_STATUS_SUCCESS ||
+ eBurstModeAvailable != AM_HAL_BURST_AVAIL) {
+ DebugLog("Failed to initialize burst mode.\n");
+ return;
+ }
+
+ status = am_hal_burst_mode_enable(&eBurstMode);
+
+ if (status != AM_HAL_STATUS_SUCCESS || eBurstMode != AM_HAL_BURST_MODE) {
+ DebugLog("Failed to Enable Burst Mode operation\n");
+ }
+}
+
+} // namespace
+
+// Implementation for the DebugLog() function that prints to the UART on the
+// SparkFun Edge microcontroller. The same should work for other targets using
+// the Ambiq Apollo 3.
+extern "C" void DebugLog(const char* s) {
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+ am_util_stdio_printf("%s", s);
+#endif
+}
+
+namespace tflite {
+
+// Calling this method enables a timer that runs for eternity. The user is
+// responsible for avoiding trampling on this timer's config, otherwise timing
+// measurements may no longer be valid.
+void InitializeTarget() {
+ am_bsp_uart_printf_enable();
+
+ BurstModeEnable();
+ am_hal_ctimer_config_t timer_config;
+ // Operate as a 32-bit timer.
+ timer_config.ui32Link = 1;
+ // Set timer A to continuous mode at 12MHz.
+ timer_config.ui32TimerAConfig =
+ AM_HAL_CTIMER_FN_CONTINUOUS | AM_HAL_CTIMER_HFRC_12MHZ;
+
+ am_hal_ctimer_stop(kTimerNum, AM_HAL_CTIMER_BOTH);
+ am_hal_ctimer_clear(kTimerNum, AM_HAL_CTIMER_BOTH);
+ am_hal_ctimer_config(kTimerNum, &timer_config);
+ am_hal_ctimer_start(kTimerNum, AM_HAL_CTIMER_TIMERA);
+}
+
+int32_t ticks_per_second() { return kClocksPerSecond; }
+
+int32_t GetCurrentTimeTicks() { return CTIMERn(kTimerNum)->TMR0; }
+
+} // namespace tflite
diff --git a/tensorflow/lite/micro/system_setup.cc b/tensorflow/lite/micro/system_setup.cc
new file mode 100644
index 0000000..db4a100
--- /dev/null
+++ b/tensorflow/lite/micro/system_setup.cc
@@ -0,0 +1,25 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/system_setup.h"
+
+namespace tflite {
+
+// To add an equivalent function for your own platform, create your own
+// implementation file, and place it in a subfolder named after the target. See
+// tensorflow/lite/micro/debug_log.cc for a similar example.
+void InitializeTarget() {}
+
+} // namespace tflite
diff --git a/tensorflow/lite/micro/system_setup.h b/tensorflow/lite/micro/system_setup.h
new file mode 100644
index 0000000..71ab13a
--- /dev/null
+++ b/tensorflow/lite/micro/system_setup.h
@@ -0,0 +1,27 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_SYSTEM_SETUP_H_
+#define TENSORFLOW_LITE_MICRO_SYSTEM_SETUP_H_
+
+namespace tflite {
+
+// This should called during initialization of TFLM binaries and tests. It can
+// be specialized if there is a need for custom target-specific intialization.
+// For more information, see tensorflow/lite/micro/system_setup.cc.
+void InitializeTarget();
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_MICRO_SYSTEM_SETUP_H_
diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD
index d07e965..335953d 100644
--- a/tensorflow/lite/micro/testing/BUILD
+++ b/tensorflow/lite/micro/testing/BUILD
@@ -1,9 +1,3 @@
-load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
-load(
- "//tensorflow/lite/micro/testing:micro_test.bzl",
- "tflite_micro_cc_test",
-)
-
package(
features = ["-layering_check"],
licenses = ["notice"], # Apache 2.0
@@ -36,11 +30,12 @@
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:micro_utils",
+ "//tensorflow/lite/micro:system_setup",
"//tensorflow/lite/micro:test_helpers",
],
)
-tflite_micro_cc_test(
+cc_test(
name = "util_test",
srcs = [
"util_test.cc",
@@ -79,10 +74,3 @@
"@absl_py//absl:app",
],
)
-
-bzl_library(
- name = "micro_test_bzl",
- srcs = ["micro_test.bzl"],
- visibility = ["//visibility:private"],
- deps = ["//tensorflow/lite/micro:build_def_bzl"],
-)
diff --git a/tensorflow/lite/micro/testing/micro_test.bzl b/tensorflow/lite/micro/testing/micro_test.bzl
deleted file mode 100644
index 5e1a56f..0000000
--- a/tensorflow/lite/micro/testing/micro_test.bzl
+++ /dev/null
@@ -1,73 +0,0 @@
-"""Rules for simple testing without dependencies by parsing output logs."""
-
-load(
- "//tensorflow/lite/micro:build_def.bzl",
- "micro_copts",
-)
-
-def tflite_micro_cc_test(
- name,
- size = "medium",
- expected_in_logs = "~~~ALL TESTS PASSED~~~",
- srcs = [],
- includes = [],
- defines = [],
- copts = micro_copts(),
- nocopts = "",
- linkopts = [],
- deps = [],
- tags = [],
- visibility = None):
- """Tests a C/C++ binary without testing framework dependencies`.
-
- Runs a C++ binary, and tests that the output logs contain the
- expected value. This is a deliberately spartan way of testing, to match
- what's available when testing microcontroller binaries.
-
- Args:
- name: a unique name for this rule.
- expected_in_logs: A regular expression that is required to be
- present in the binary's logs for the test to pass.
- srcs: sources to compile (C, C++, ld scripts).
- includes: include paths to add to this rule and its dependents.
- defines: list of `VAR` or `VAR=VAL` to pass to CPP for this rule and
- its dependents.
- copts: gcc compilation flags for this rule only.
- nocopts: list of gcc compilation flags to remove for this rule
- only. No regexp like for `cc_library`.
- linkopts: `gcc` flags to add to the linking phase. For "pure" ld flags,
- prefix them with the `-Wl,` prefix here.
- deps: dependencies. only `tflite_bare_metal_cc_library()` dependencies
- allowed.
- visibility: visibility.
- """
- native.cc_binary(
- name = name + "_binary",
- srcs = srcs,
- includes = includes,
- defines = defines,
- copts = copts,
- nocopts = nocopts,
- linkopts = linkopts,
- deps = deps,
- tags = tags,
- visibility = visibility,
- )
- native.sh_test(
- name = name,
- size = size,
- srcs = [
- "//tensorflow/lite/micro/testing:test_linux_binary.sh",
- ],
- args = [
- native.package_name() + "/" + name + "_binary",
- "'" + expected_in_logs + "'",
- ],
- data = [
- name + "_binary",
- # Internal test dependency placeholder
- ],
- deps = [
- ],
- tags = tags,
- )
diff --git a/tensorflow/lite/micro/testing/micro_test.h b/tensorflow/lite/micro/testing/micro_test.h
index b751876..229dfa6 100644
--- a/tensorflow/lite/micro/testing/micro_test.h
+++ b/tensorflow/lite/micro/testing/micro_test.h
@@ -56,6 +56,7 @@
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
+#include "tensorflow/lite/micro/system_setup.h"
namespace micro_test {
extern int tests_passed;
@@ -64,6 +65,19 @@
extern bool did_test_fail;
} // namespace micro_test
+namespace tflite {
+
+// This additional helper function is used (instead of directly calling
+// tflite::InitializeTarget from the TF_LITE_MICRO_TESTS_BEGIN macro) to avoid
+// adding a dependency from every bazel test target to micro:system_setp (which
+// is the target that implements InitializeTarget().
+//
+// The underlying issue here is that the use of the macros results in
+// dependencies that can be containted within the micro/testing:micro_test
+// target bleeding on to all the tests.
+inline void InitializeTest() { InitializeTarget(); }
+} // namespace tflite
+
#define TF_LITE_MICRO_TESTS_BEGIN \
namespace micro_test { \
int tests_passed; \
@@ -74,7 +88,8 @@
\
int main(int argc, char** argv) { \
micro_test::tests_passed = 0; \
- micro_test::tests_failed = 0;
+ micro_test::tests_failed = 0; \
+ tflite::InitializeTest();
#define TF_LITE_MICRO_TESTS_END \
MicroPrintf("%d/%d tests passed", micro_test::tests_passed, \
diff --git a/tensorflow/lite/micro/testing/test_hexagon_binary.sh b/tensorflow/lite/micro/testing/test_hexagon_binary.sh
index a3ea244..98b3c50 100755
--- a/tensorflow/lite/micro/testing/test_hexagon_binary.sh
+++ b/tensorflow/lite/micro/testing/test_hexagon_binary.sh
@@ -20,7 +20,6 @@
# Second argument is a regular expression that's required to be in the output
# logs for the test to pass.
-declare -r ROOT_DIR=`pwd`
declare -r TEST_TMPDIR=/tmp/test_hexagon_binary/
declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
@@ -29,11 +28,14 @@
hexagon-elfcopy $1 $1.elf
hexagon-sim $1.elf 2>&1 | tee ${MICRO_LOG_FILENAME}
-if grep -q "$2" ${MICRO_LOG_FILENAME}
+if [[ ${2} != "non_test_binary" ]]
then
- echo "$1: PASS"
- exit 0
-else
- echo "$1: FAIL - '$2' not found in logs."
- exit 1
+ if grep -q "$2" ${MICRO_LOG_FILENAME}
+ then
+ echo "$1: PASS"
+ exit 0
+ else
+ echo "$1: FAIL - '$2' not found in logs."
+ exit 1
+ fi
fi
diff --git a/tensorflow/lite/micro/testing/test_linux_binary.sh b/tensorflow/lite/micro/testing/test_linux_binary.sh
deleted file mode 100755
index 30cf041..0000000
--- a/tensorflow/lite/micro/testing/test_linux_binary.sh
+++ /dev/null
@@ -1,55 +0,0 @@
-#!/bin/bash
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-#
-# Tests a Linux binary by parsing the log output.
-#
-# First argument is the binary location.
-# Second argument is a regular expression that's required to be in the output logs
-# for the test to pass.
-
-declare -r ROOT_DIR=`pwd`
-declare -r TEST_TMPDIR=/tmp/test_linux_binary/
-declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
-declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
-mkdir -p ${MICRO_LOG_PATH}
-
-ERROR_MSG="$1: FAIL - '$2' not found in logs."
-print_error_and_exit() {
- echo ${ERROR_MSG}
- cat ${MICRO_LOG_FILENAME}
- exit 1
-}
-
-# This traps the signal from the test binary ($1) and checks if there was a
-# segfault and adds that to the error log (which would otherwise be missing).
-trap 'if [[ $? -eq 139 ]]; then echo "Segmentation fault" >> ${MICRO_LOG_FILENAME}; print_error_and_exit; fi' CHLD
-
-# This trap statement prevents the bash script from segfaulting with a cryptic
-# message like:
-# tensorflow/lite/micro/testing/test_linux_binary.sh: line 44: 210514 Segmentation fault $1 > ${MICRO_LOG_FILENAME} 2>&1
-# What we get instead is purely another Segmentation fault text in the output.
-trap '' SEGV
-
-$1 > ${MICRO_LOG_FILENAME} 2>&1
-
-if grep -q "$2" ${MICRO_LOG_FILENAME}
-then
- echo "$1: PASS"
- exit 0
-else
- print_error_and_exit
-fi
-
diff --git a/tensorflow/lite/micro/testing/test_with_arm_corstone_300.sh b/tensorflow/lite/micro/testing/test_with_arm_corstone_300.sh
new file mode 100755
index 0000000..c5293e5
--- /dev/null
+++ b/tensorflow/lite/micro/testing/test_with_arm_corstone_300.sh
@@ -0,0 +1,48 @@
+#!/bin/bash -e
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+#
+# Parameters:
+# ${1} - path to a binary to test or directory (all *_test will be run).
+# ${2} - String that is checked for pass/fail.
+# ${3} - target (e.g. cortex_m_generic.)
+
+set -e
+
+BINARY_TO_TEST=${1}
+PASS_STRING=${2}
+TARGET=${3}
+
+RESULTS_DIRECTORY=/tmp/${TARGET}_logs
+MICRO_LOG_FILENAME=${RESULTS_DIRECTORY}/logs.txt
+mkdir -p ${RESULTS_DIRECTORY}
+
+FVP="FVP_Corstone_SSE-300_Ethos-U55 "
+FVP+="--cpulimit 1 "
+FVP+="-C mps3_board.visualisation.disable-visualisation=1 "
+FVP+="-C mps3_board.telnetterminal0.start_telnet=0 "
+FVP+='-C mps3_board.uart0.out_file="-" '
+FVP+='-C mps3_board.uart0.unbuffered_output=1'
+${FVP} ${BINARY_TO_TEST} | tee ${MICRO_LOG_FILENAME}
+
+if grep -q "$PASS_STRING" ${MICRO_LOG_FILENAME}
+then
+ echo "$BINARY_TO_TEST: PASS"
+ exit 0
+else
+ echo "$BINARY_TO_TEST: FAIL - '$PASS_STRING' not found in logs."
+ exit 1
+fi
diff --git a/tensorflow/lite/micro/testing/test_xtensa_binary.sh b/tensorflow/lite/micro/testing/test_xtensa_binary.sh
index fb9ca9c..9141d2f 100755
--- a/tensorflow/lite/micro/testing/test_xtensa_binary.sh
+++ b/tensorflow/lite/micro/testing/test_xtensa_binary.sh
@@ -21,7 +21,6 @@
# Second argument is a regular expression that's required to be in the output
# logs for the test to pass.
-declare -r ROOT_DIR=`pwd`
declare -r TEST_TMPDIR=/tmp/test_xtensa_binary/
declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
@@ -29,11 +28,13 @@
xt-run $1 2>&1 | tee ${MICRO_LOG_FILENAME}
-if grep -q "$2" ${MICRO_LOG_FILENAME}
+if [[ ${2} != "non_test_binary" ]]
then
- echo "$1: PASS"
- exit 0
-else
- echo "$1: FAIL - '$2' not found in logs."
- exit 1
+ if grep -q "$2" ${MICRO_LOG_FILENAME}
+ then
+ exit 0
+ else
+ exit 1
+ fi
fi
+
diff --git a/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh b/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh
deleted file mode 100755
index 403b39f..0000000
--- a/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/bin/bash -e
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-# Tests an Xtensa binary by parsing the log output.
-#
-# First argument is the binary location.
-# Second argument is a regular expression that's required to be in the output
-# logs for the test to pass.
-
-declare -r ROOT_DIR=`pwd`
-declare -r TEST_TMPDIR=/tmp/test_xtensa_hifi_binary/
-declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
-declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
-mkdir -p ${MICRO_LOG_PATH}
-
-xt-run $1 2>&1 | tee ${MICRO_LOG_FILENAME}
-
-if grep -q "$2" ${MICRO_LOG_FILENAME}
-then
- echo "$1: PASS"
- exit 0
-else
- echo "$1: FAIL - '$2' not found in logs."
- exit 1
-fi
diff --git a/tensorflow/lite/micro/testing/test_xtensa_hifimini_binary.sh b/tensorflow/lite/micro/testing/test_xtensa_hifimini_binary.sh
deleted file mode 100755
index 3272562..0000000
--- a/tensorflow/lite/micro/testing/test_xtensa_hifimini_binary.sh
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/bin/bash -e
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-#
-# Tests an Xtensa XPG binary by parsing the log output.
-#
-# First argument is the binary location.
-# Second argument is a regular expression that's required to be in the output
-# logs for the test to pass.
-
-declare -r ROOT_DIR=`pwd`
-declare -r TEST_TMPDIR=/tmp/test_xtensa_hifimini_binary/
-declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
-declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
-mkdir -p ${MICRO_LOG_PATH}
-
-xt-run --xtensa-core=${XTENSA_CORE} $1 2>&1 | tee ${MICRO_LOG_FILENAME}
-
-if grep -q "$2" ${MICRO_LOG_FILENAME}
-then
- echo "$1: PASS"
- exit 0
-else
- echo "$1: FAIL - '$2' not found in logs."
- exit 1
-fi
diff --git a/tensorflow/lite/micro/testing/test_xtensa_hifimini_staging_binary.sh b/tensorflow/lite/micro/testing/test_xtensa_hifimini_staging_binary.sh
deleted file mode 100755
index 1844f09..0000000
--- a/tensorflow/lite/micro/testing/test_xtensa_hifimini_staging_binary.sh
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/bin/bash -e
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-#
-# Tests an Xtensa XPG binary by parsing the log output.
-#
-# First argument is the binary location.
-# Second argument is a regular expression that's required to be in the output
-# logs for the test to pass.
-
-declare -r ROOT_DIR=`pwd`
-declare -r TEST_TMPDIR=/tmp/test_xtensa_hifimini_staging_binary/
-declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
-declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
-mkdir -p ${MICRO_LOG_PATH}
-
-xt-run --xtensa-core=${XTENSA_CORE} $1 2>&1 | tee ${MICRO_LOG_FILENAME}
-
-if grep -q "$2" ${MICRO_LOG_FILENAME}
-then
- echo "$1: PASS"
- exit 0
-else
- echo "$1: FAIL - '$2' not found in logs."
- exit 1
-fi
diff --git a/tensorflow/lite/micro/tools/ci_build/test_bazel.sh b/tensorflow/lite/micro/tools/ci_build/test_bazel.sh
index 77bd919..92732b7 100755
--- a/tensorflow/lite/micro/tools/ci_build/test_bazel.sh
+++ b/tensorflow/lite/micro/tools/ci_build/test_bazel.sh
@@ -48,9 +48,25 @@
# Now that we are set up to download fewer external deps as part of a bazel
# build, we can go ahead and invoke bazel.
-CC=clang readable_run bazel test tensorflow/lite/micro/... --test_tag_filters=-no_oss --build_tag_filters=-no_oss
-CC=clang readable_run bazel test tensorflow/lite/micro/... --config=msan --test_tag_filters=-no_oss,-nomsan --build_tag_filters=-no_oss,-nomsan
-CC=clang readable_run bazel test tensorflow/lite/micro/... --config=asan --test_tag_filters=-no_oss,-noasan --build_tag_filters=-no_oss,-noasan
+CC=clang readable_run bazel test tensorflow/lite/micro/... \
+ --test_tag_filters=-no_oss --build_tag_filters=-no_oss \
+ --test_output=errors
+
+CC=clang readable_run bazel test tensorflow/lite/micro/... \
+ --config=msan \
+ --test_tag_filters=-no_oss,-nomsan --build_tag_filters=-no_oss,-nomsan \
+ --test_output=errors
+
+CC=clang readable_run bazel test tensorflow/lite/micro/... \
+ --config=asan \
+ --test_tag_filters=-no_oss,-noasan --build_tag_filters=-no_oss,-noasan \
+ --test_output=errors
+
# TODO(b/178621680): enable ubsan once bazel + clang + ubsan errors are fixed.
#CC=clang readable_run bazel test tensorflow/lite/micro/... --config=ubsan --test_tag_filters=-no_oss,-noubsan --build_tag_filters=-no_oss,-noubsan
-CC=clang readable_run bazel test tensorflow/lite/micro/... --test_tag_filters=-no_oss --build_tag_filters=-no_oss --copt=-DTF_LITE_STATIC_MEMORY
+
+CC=clang readable_run bazel test tensorflow/lite/micro/... \
+ --test_tag_filters=-no_oss --build_tag_filters=-no_oss \
+ --copt=-DTF_LITE_STATIC_MEMORY \
+ --test_output=errors
+
diff --git a/tensorflow/lite/micro/tools/ci_build/test_bluepill.sh b/tensorflow/lite/micro/tools/ci_build/test_bluepill.sh
index 90b52f2..5f5d7c1 100755
--- a/tensorflow/lite/micro/tools/ci_build/test_bluepill.sh
+++ b/tensorflow/lite/micro/tools/ci_build/test_bluepill.sh
@@ -40,3 +40,8 @@
# debugging info on failures.
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} OPTIMIZATION_LEVEL=-Os test
+
+# We use Renode differently when running the full test suite (make test) vs an
+# individual test. So, we test only of the kernels individually as well to have
+# both of the Renode variations be part of the CI.
+readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} test_kernel_add_test
diff --git a/tensorflow/lite/micro/tools/ci_build/test_cortex_m_corstone_300.sh b/tensorflow/lite/micro/tools/ci_build/test_cortex_m_corstone_300.sh
new file mode 100755
index 0000000..6a0c817
--- /dev/null
+++ b/tensorflow/lite/micro/tools/ci_build/test_cortex_m_corstone_300.sh
@@ -0,0 +1,37 @@
+#!/usr/bin/env bash
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Tests Arm Cortex-M55 microprocessor code with CMSIS-NN optimizied kernels using FVP based on Arm Corstone-300 software.
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+ROOT_DIR=${SCRIPT_DIR}/../../../../..
+cd "${ROOT_DIR}"
+
+source tensorflow/lite/micro/tools/ci_build/helper_functions.sh
+
+TARGET=cortex_m_corstone_300
+TARGET_ARCH=cortex-m55
+OPTIMIZED_KERNEL_DIR=cmsis_nn
+
+# TODO(b/143715361): downloading first to allow for parallel builds.
+readable_run make -f tensorflow/lite/micro/tools/make/Makefile OPTIMIZED_KERNEL_DIR=${OPTIMIZED_KERNEL_DIR} TARGET=${TARGET} TARGET_ARCH=${TARGET_ARCH} third_party_downloads
+
+# Avoid running tests in parallel.
+readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
+readable_run make -j -f tensorflow/lite/micro/tools/make/Makefile OPTIMIZED_KERNEL_DIR=${OPTIMIZED_KERNEL_DIR} TARGET=${TARGET} TARGET_ARCH=${TARGET_ARCH} build
+readable_run make -f tensorflow/lite/micro/tools/make/Makefile OPTIMIZED_KERNEL_DIR=${OPTIMIZED_KERNEL_DIR} TARGET=${TARGET} TARGET_ARCH=${TARGET_ARCH} test
diff --git a/tensorflow/lite/micro/tools/ci_build/tflm_bazel/workspace.bzl b/tensorflow/lite/micro/tools/ci_build/tflm_bazel/workspace.bzl
index 5ad6d75..3484d6f 100644
--- a/tensorflow/lite/micro/tools/ci_build/tflm_bazel/workspace.bzl
+++ b/tensorflow/lite/micro/tools/ci_build/tflm_bazel/workspace.bzl
@@ -58,12 +58,11 @@
tf_http_archive(
name = "eigen_archive",
build_file = clean_dep("//third_party:eigen.BUILD"),
- patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
- sha256 = "768b744d98505db4d73562b7813ee1e102dd185cf79a7ef1d5dbcc6e7e918eaf", # SHARED_EIGEN_SHA
- strip_prefix = "eigen-352f1422d3ceea19a04cab297c6339e0870e1c6c",
+ sha256 = "d76992f1972e4ff270221c7ee8125610a8e02bb46708a7295ee646e99287083b", # SHARED_EIGEN_SHA
+ strip_prefix = "eigen-90ee821c563fa20db4d64d6991ddca256d5c52f2",
urls = [
- "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/352f1422d3ceea19a04cab297c6339e0870e1c6c/eigen-352f1422d3ceea19a04cab297c6339e0870e1c6c.tar.gz",
- "https://gitlab.com/libeigen/eigen/-/archive/352f1422d3ceea19a04cab297c6339e0870e1c6c/eigen-352f1422d3ceea19a04cab297c6339e0870e1c6c.tar.gz",
+ "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/90ee821c563fa20db4d64d6991ddca256d5c52f2/eigen-90ee821c563fa20db4d64d6991ddca256d5c52f2.tar.gz",
+ "https://gitlab.com/libeigen/eigen/-/archive/90ee821c563fa20db4d64d6991ddca256d5c52f2/eigen-90ee821c563fa20db4d64d6991ddca256d5c52f2.tar.gz",
],
)
diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile
index 4820f9d..5ee2586 100644
--- a/tensorflow/lite/micro/tools/make/Makefile
+++ b/tensorflow/lite/micro/tools/make/Makefile
@@ -93,7 +93,7 @@
third_party/flatbuffers/include \
third_party/ruy
-TEST_SCRIPT := tensorflow/lite/micro/testing/test_linux_binary.sh
+TEST_SCRIPT :=
MICROLITE_LIBS := -lm
@@ -316,6 +316,7 @@
tensorflow/lite/micro/kernels/comparisons.cc \
tensorflow/lite/micro/kernels/concatenation.cc \
tensorflow/lite/micro/kernels/conv.cc \
+tensorflow/lite/micro/kernels/conv_common.cc \
tensorflow/lite/micro/kernels/conv_test_common.cc \
tensorflow/lite/micro/kernels/depthwise_conv.cc \
tensorflow/lite/micro/kernels/dequantize.cc \
diff --git a/tensorflow/lite/micro/tools/make/corstone_300_download.sh b/tensorflow/lite/micro/tools/make/corstone_300_download.sh
new file mode 100755
index 0000000..4ac60bb
--- /dev/null
+++ b/tensorflow/lite/micro/tools/make/corstone_300_download.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Called with following arguments:
+# 1 - Path to the downloads folder which is typically
+# tensorflow/lite/micro/tools/make/downloads
+#
+# This script is called from the Makefile and uses the following convention to
+# enable determination of sucess/failure:
+#
+# - If the script is successful, the only output on stdout should be SUCCESS.
+# The makefile checks for this particular string.
+#
+# - Any string on stdout that is not SUCCESS will be shown in the makefile as
+# the cause for the script to have failed.
+#
+# - Any other informational prints should be on stderr.
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+ROOT_DIR=${SCRIPT_DIR}/../../../../..
+cd "${ROOT_DIR}"
+
+source tensorflow/lite/micro/tools/make/bash_helpers.sh
+
+DOWNLOADS_DIR=${1}
+if [ ! -d ${DOWNLOADS_DIR} ]; then
+ echo "The top-level downloads directory: ${DOWNLOADS_DIR} does not exist."
+ exit 1
+fi
+
+DOWNLOADED_CORSTONE_PATH=${DOWNLOADS_DIR}/corstone300
+
+if [ -d ${DOWNLOADED_CORSTONE_PATH} ]; then
+ echo >&2 "${DOWNLOADED_CORSTONE_PATH} already exists, skipping the download."
+else
+ UNAME_S=`uname -s`
+ if [ ${UNAME_S} == Linux ]; then
+ CORSTONE_URL=https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_Ethos-U55_11.12_57.tgz
+ EXPECTED_MD5=08cc89b02a41917c2224f390f3ac0b47
+ else
+ echo "OS type ${UNAME_S} not supported."
+ exit 1
+ fi
+
+ TEMPFILE=$(mktemp -d)/temp_file
+ wget ${CORSTONE_URL} -O ${TEMPFILE} >&2
+ check_md5 ${TEMPFILE} ${EXPECTED_MD5}
+
+ TEMPDIR=$(mktemp -d)
+ tar -C ${TEMPDIR} -xvzf ${TEMPFILE} >&2
+ mkdir ${DOWNLOADED_CORSTONE_PATH}
+ ${TEMPDIR}/FVP_Corstone_SSE-300_Ethos-U55.sh --i-agree-to-the-contained-eula --no-interactive -d ${DOWNLOADED_CORSTONE_PATH} >&2
+fi
+
+echo "SUCCESS"
diff --git a/tensorflow/lite/micro/tools/make/ethos_u_core_platform_download.sh b/tensorflow/lite/micro/tools/make/ethos_u_core_platform_download.sh
new file mode 100755
index 0000000..d00800a
--- /dev/null
+++ b/tensorflow/lite/micro/tools/make/ethos_u_core_platform_download.sh
@@ -0,0 +1,80 @@
+#!/bin/bash
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Called with following arguments:
+# 1 - Path to the downloads folder which is typically
+# tensorflow/lite/micro/tools/make/downloads
+#
+# This script is called from the Makefile and uses the following convention to
+# enable determination of sucess/failure:
+#
+# - If the script is successful, the only output on stdout should be SUCCESS.
+# The makefile checks for this particular string.
+#
+# - Any string on stdout that is not SUCCESS will be shown in the makefile as
+# the cause for the script to have failed.
+#
+# - Any other informational prints should be on stderr.
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+ROOT_DIR=${SCRIPT_DIR}/../../../../..
+cd "${ROOT_DIR}"
+
+source tensorflow/lite/micro/tools/make/bash_helpers.sh
+
+DOWNLOADS_DIR=${1}
+if [ ! -d ${DOWNLOADS_DIR} ]; then
+ echo "The top-level downloads directory: ${DOWNLOADS_DIR} does not exist."
+ exit 1
+fi
+
+DOWNLOADED_ETHOS_U_CORE_PLATFORM_PATH=${DOWNLOADS_DIR}/ethos_u_core_platform
+
+if [ -d ${DOWNLOADED_ETHOS_U_CORE_PLATFORM_PATH} ]; then
+ echo >&2 "${DOWNLOADED_ETHOS_U_CORE_PLATFORM_PATH} already exists, skipping the download."
+else
+ UNAME_S=`uname -s`
+ if [ ${UNAME_S} == Linux ]; then
+ ETHOS_U_CORE_PLATFORM_URL=https://git.mlplatform.org/ml/ethos-u/ethos-u-core-platform.git/snapshot/ethos-u-core-platform-6663630bb3feea222fd38278a962297c08d0b320.tar.gz
+ EXPECTED_MD5=11683ce5cbf4e4d1003ca93a85ad0b08
+ else
+ echo "OS type ${UNAME_S} not supported."
+ exit 1
+ fi
+
+ TEMPFILE=$(mktemp -d)/temp_file
+ wget ${ETHOS_U_CORE_PLATFORM_URL} -O ${TEMPFILE} >&2
+ check_md5 ${TEMPFILE} ${EXPECTED_MD5}
+
+ mkdir ${DOWNLOADED_ETHOS_U_CORE_PLATFORM_PATH}
+ tar xzf ${TEMPFILE} --strip-components=1 -C ${DOWNLOADED_ETHOS_U_CORE_PLATFORM_PATH} >&2
+
+ # Run C preprocessor on linker file to get rid of ifdefs and make sure compiler is downloaded first.
+ COMPILER=${DOWNLOADS_DIR}/gcc_embedded/bin/arm-none-eabi-gcc
+ if [ ! -f ${COMPILER} ]; then
+ RETURN_VALUE=`./tensorflow/lite/micro/tools/make/arm_gcc_download.sh ${DOWNLOADS_DIR}`
+ if [ "SUCCESS" != "${RETURN_VALUE}" ]; then
+ echo "The script ./tensorflow/lite/micro/tools/make/arm_gcc_download.sh failed."
+ exit 1
+ fi
+ fi
+ LINKER_PATH=${DOWNLOADED_ETHOS_U_CORE_PLATFORM_PATH}/targets/corstone-300
+ ${COMPILER} -E -x c -P -o ${LINKER_PATH}/platform_parsed.ld ${LINKER_PATH}/platform.ld
+fi
+
+echo "SUCCESS"
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/arc_mli.inc b/tensorflow/lite/micro/tools/make/ext_libs/arc_mli.inc
index 5dbb91d..28a42b4 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/arc_mli.inc
+++ b/tensorflow/lite/micro/tools/make/ext_libs/arc_mli.inc
@@ -18,8 +18,8 @@
# MLI Library is used by default for ARC platform whenever it is possible.
# To use TFLM reference implementation MLI should be intentionally turned off
-# by passing 'no_arc_mli' tag (make -f <tflm_main_makefile> TAGS=no_arc_mli ...)
-ifeq ($(filter no_arc_mli,$(ALL_TAGS)),)
+# by passing 'no_arc_mli' tag (make -f <tflm_main_makefile> ARC_TAGS=no_arc_mli ...)
+ifeq ($(filter no_arc_mli,$(ARC_TAGS)),)
ALL_TAGS += arc_mli
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_download.sh
index b6745d8..2a4f3eb 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_download.sh
@@ -49,9 +49,9 @@
echo >&2 "${DOWNLOADED_CMSIS_PATH} already exists, skipping the download."
else
- ZIP_PREFIX="01f5b32badf7b78c85a24a7149b56400fa6a2999"
+ ZIP_PREFIX="71627bc91534ed9eec2361c0ef6442cd057653e0"
CMSIS_URL="http://github.com/ARM-software/CMSIS_5/archive/${ZIP_PREFIX}.zip"
- CMSIS_MD5="823916c6f1749c65fd0bfdeec20b30ed"
+ CMSIS_MD5="207c49970758c663e2ce1cc0245972a9"
# wget is much faster than git clone of the entire repo. So we wget a specific
# version and can then apply a patch, as needed.
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/ethos_u.inc b/tensorflow/lite/micro/tools/make/ext_libs/ethos_u.inc
index e306c23e..67f5a8e 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/ethos_u.inc
+++ b/tensorflow/lite/micro/tools/make/ext_libs/ethos_u.inc
@@ -13,7 +13,7 @@
# Unless an external path is provided we force a download during the first phase of make so
# that the files exist prior to the call to recursive_find below. add_third_party_download
# prevents the use of wildcards and recursive_find in selecting which files to add to THIRD_PARTY_SRCS.
-ETHOSU_DEFAULT_DOWNLOAD_DRIVER_PATH := $(MAKEFILE_DIR)/downloads/ethosu
+ETHOSU_DEFAULT_DOWNLOAD_DRIVER_PATH := $(MAKEFILE_DIR)/downloads/ethos_u_core_driver
ETHOSU_DRIVER_PATH := $(ETHOSU_DEFAULT_DOWNLOAD_DRIVER_PATH)
ifeq ($(ETHOSU_DRIVER_PATH), $(ETHOSU_DEFAULT_DOWNLOAD_DRIVER_PATH))
$(call $(or $(shell $(DOWNLOAD_SCRIPT) $(ETHOSU_URL) $(ETHOSU_MD5) $(ETHOSU_DRIVER_PATH) >&2 && echo SUCCESS), $(error $(DOWNLOAD_SCRIPT) failed)))
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
index 0ab9af2..2cc5115 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
@@ -41,9 +41,9 @@
fi
if [[ ${2} == "hifi4" ]]; then
- LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_hifi4_12_22.zip"
+ LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_hifi4_02_11_2021.zip"
LIBRARY_DIRNAME="xa_nnlib_hifi4"
- LIBRARY_MD5="bb4aa8bd589ee1b4b9fd71349a1e7317"
+ LIBRARY_MD5="8b934f61ffe0a966644849602810fb1b"
else
echo "Attempting to download an unsupported xtensa variant: ${2}"
exit 1
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc
deleted file mode 100644
index 7e8fe2b..0000000
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc
+++ /dev/null
@@ -1,73 +0,0 @@
-ifneq ($(filter xtensa_hifi, $(ALL_TAGS)),)
-
- XTENSA_PATH = $(MAKEFILE_DIR)/downloads
-
- ifneq (,$(filter hifi4%, $(TARGET_ARCH)))
-
- NNLIB = xa_nnlib_hifi4
-
- CCFLAGS += -DNNLIB_V2 \
- -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=70*1024
-
- CXXFLAGS += -DNNLIB_V2 \
- -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=70*1024
-
- MICROLITE_CC_SRCS += \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/activations/hifi4/xa_nn_activations_f32_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/activations/hifi4/xa_nn_activations_asym8_asym8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/activations/hifi4/xa_nn_activations_32_16.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/activations/hifi4/xa_nn_activations_32_8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/activations/hifi4/xa_nn_softmax_asym8_asym8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/basic/hifi4/xa_nn_floor_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/basic/hifi4/xa_nn_elm_add_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/basic/hifi4/xa_nn_elm_add_quant8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/basic/hifi4/xa_nn_elm_mul_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/basic/hifi4/xa_nn_elm_mul_quant8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_asym8xasym8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_matXvec_asym8xasym8_asym8_circ.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_matXvec_f32_circ.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise_asym8xasym8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/cnn/hifi4/xa_nn_circ_buf.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/fc/hifi4/xa_nn_fully_connected.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/matXvec/hifi4/xa_nn_matXvec_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/matXvec/hifi4/xa_nn_matXvec_16x16.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/matXvec/hifi4/xa_nn_matXvec_8x16.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/matXvec/hifi4/xa_nn_matXvec_8x8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/matXvec/hifi4/xa_nn_matXvec_asym8xasym8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_avgpool.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_avgpool_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_avgpool_asym8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_maxpool.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_maxpool_f32.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_maxpool_asym8.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_avgpool_f32_nhwc.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_avgpool_asym8_nhwc.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_maxpool_f32_nhwc.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_maxpool_asym8_nhwc.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/kernels/pool/hifi4/xa_nn_inv_256_tbl.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/vec_sigmoidf_hifi4.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/vec_tanhf_hifi4.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/vec_reluf_hifi4.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/vec_softmaxf_hifi4.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/vec_alognf_hifi4.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/scl_sigmoidf_hifi4.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/scl_tanhf_hifi4.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/expf_tbl.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/pow2f_tbl.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/inff_tbl.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/tanhf_tbl.c \
- $(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/src/nanf_tbl.c \
-
- INCLUDES += -I$(XTENSA_PATH)/$(NNLIB)/algo/kernels/ \
- -I$(XTENSA_PATH)/$(NNLIB)/include/nnlib/ \
- -I$(XTENSA_PATH)/$(NNLIB)/include/ \
- -I$(XTENSA_PATH)/$(NNLIB)/algo/common/include/ \
- -I$(XTENSA_PATH)/$(NNLIB)/algo/ndsp/hifi4/include/ \
-
- endif
-
-endif
diff --git a/tensorflow/lite/micro/tools/make/helper_functions.inc b/tensorflow/lite/micro/tools/make/helper_functions.inc
index f9c0c29..efbda33 100644
--- a/tensorflow/lite/micro/tools/make/helper_functions.inc
+++ b/tensorflow/lite/micro/tools/make/helper_functions.inc
@@ -437,10 +437,20 @@
# Handles the details of generating a binary target, including specializing
# for the current platform, and generating project file targets.
+#
+# Note that while the function is called microlite_test, it is used for both
+# test and non-test binaries.
+
+# Files that end with _test are added as test targets (i.e. can be executed with
+# make test_<target>. ALl others can be executed with make run_<target>
+#
# Arguments are:
-# 1 - Name of test.
-# 2 - C/C++ source files implementing the test.
-# 3 - C/C++ header files needed for the test.
+# 1 - Name of target.
+# 2 - C/C++ source files
+# 3 - C/C++ header files
+# 4 - if "exclude", then the non-test target will be excluded from
+# MICROLITE_BUILD_TARGETS. This exception is needed because not all the
+# microlite_test targets (e.g. the examples) are buildable on all platforms.
# Calling eval on the output will create the targets that you need.
define microlite_test
ifeq (,$(findstring _test, $(1)))
@@ -461,24 +471,21 @@
$$(MICROLITE_LIB_PATH) $$(LDFLAGS) $$(MICROLITE_LIBS)
$(1): $$($(1)_BINARY)
$(1)_bin: $$($(1)_BINARY).bin
-test_$(1): $$($(1)_BINARY)
- @test -f $$(TEST_SCRIPT) || (echo 'Unable to find the test script. Is the software emulation available in $$(TARGET)?'; exit 1)
- $$(TEST_SCRIPT) $$($(1)_BINARY) $$(TEST_PASS_STRING) $$(TARGET)
ifneq (,$(findstring _test,$(1)))
MICROLITE_TEST_TARGETS += test_$(1)
MICROLITE_BUILD_TARGETS += $$($(1)_BINARY)
-endif
-# The ifneq can make is seem that the body of the if block is executed when
-# _benchmark is not found in $(1). Actually, the check is saying that if
-# findstring does not return empty, i.e. if _benchmark is found in $(1), we
-# should add something to the MICROLITE_BUILD_TARGETS.
-#
-# This ensures that a `make build` command will builds all the tests and
-# benchmarks, though `make test` will only run the tests.
-ifneq (,$(findstring _benchmark,$(1)))
- MICROLITE_BUILD_TARGETS += $$($(1)_BINARY)
+test_$(1): $$($(1)_BINARY)
+ $$(TEST_SCRIPT) $$($(1)_BINARY) $$(TEST_PASS_STRING) $$(TARGET)
+
+else
+ ifeq ($(findstring exclude,$(4)),)
+ MICROLITE_BUILD_TARGETS += $$($(1)_BINARY)
+ endif
+
+run_$(1): $$($(1)_BINARY)
+ $$(TEST_SCRIPT) $$($(1)_BINARY) non_test_binary $$(TARGET)
endif
$(eval $(call generate_microlite_projects,$(1),$(call specialize,$(2)),$(3)))
diff --git a/tensorflow/lite/micro/tools/make/targets/arc/README.md b/tensorflow/lite/micro/tools/make/targets/arc/README.md
index 7f61025..420f06e 100644
--- a/tensorflow/lite/micro/tools/make/targets/arc/README.md
+++ b/tensorflow/lite/micro/tools/make/targets/arc/README.md
@@ -149,7 +149,7 @@
TensorFlow repo:
```
-make -f tensorflow/lite/micro/tools/make/Makefile generate_person_detection_test_int8_make_project TARGET=arc_emsdp
+make -f tensorflow/lite/micro/tools/make/Makefile generate_person_detection_test_int8_make_project TARGET=arc_emsdp OPTIMIZED_KERNEL_DIR=arc_mli
```
The application project will be generated into
@@ -166,8 +166,8 @@
quantized layers. Kernels which use MLI-based implementations are kept in the
*tensorflow/lite/micro/kernels/arc_mli* folder. For applications which may not
benefit from MLI library, the project can be generated without these
-implementations by adding `TAGS=no_arc_mli` in the command line. This can reduce
-code size when the optimized kernels are not required.
+implementations by adding `ARC_TAGS=no_arc_mli` in the command line. This can
+reduce code size when the optimized kernels are not required.
For more options on embARC MLI usage see
[kernels/arc_mli/README.md](/tensorflow/lite/micro/kernels/arc_mli/README.md).
@@ -279,7 +279,7 @@
command from the root directory of the TensorFlow repo:
```
-make -f tensorflow/lite/micro/tools/make/Makefile generate_person_detection_test_int8_make_project TARGET=arc_custom TCF_FILE=<path_to_tcf_file> LCF_FILE=<path_to_lcf_file>
+make -f tensorflow/lite/micro/tools/make/Makefile generate_person_detection_test_int8_make_project TARGET=arc_custom OPTIMIZED_KERNEL_DIR=arc_mli TCF_FILE=<path_to_tcf_file> LCF_FILE=<path_to_lcf_file>
```
The application project will be generated into
@@ -291,8 +291,8 @@
quantized layers. Kernels which use MLI-based implementations are kept in the
*tensorflow/lite/micro/kernels/arc_mli* folder. For applications which may not
benefit from MLI library, the project can be generated without these
-implementations by adding `TAGS=no_arc_mli` in the command line. This can reduce
-code size when the optimized kernels are not required.
+implementations by adding `ARC_TAGS=no_arc_mli` in the command line. This can
+reduce code size when the optimized kernels are not required.
For more options on embARC MLI usage see
[kernels/arc_mli/README.md](/tensorflow/lite/micro/kernels/arc_mli/README.md).
diff --git a/tensorflow/lite/micro/tools/make/targets/arc_emsdp_makefile.inc b/tensorflow/lite/micro/tools/make/targets/arc_emsdp_makefile.inc
index c7b5c12..b83f9aa 100644
--- a/tensorflow/lite/micro/tools/make/targets/arc_emsdp_makefile.inc
+++ b/tensorflow/lite/micro/tools/make/targets/arc_emsdp_makefile.inc
@@ -21,7 +21,7 @@
BUILD_ARC_MLI := false
ARC_MLI_PRE_COMPILED_TARGET := emsdp_em11d_em9d_dfss
-ifneq ($(filter no_arc_mli,$(ALL_TAGS)),)
+ifneq ($(filter no_arc_mli,$(ARC_TAGS)),)
MLI_LIB_DIR = arc_mli_package
$(eval $(call add_third_party_download,$(EMBARC_MLI_PRE_COMPILED_URL),$(EMBARC_MLI_PRE_COMPILED_MD5),$(MLI_LIB_DIR),))
else ifeq ($(BUILD_ARC_MLI), true)
diff --git a/tensorflow/lite/micro/tools/make/targets/cortex_m_corstone_300_makefile.inc b/tensorflow/lite/micro/tools/make/targets/cortex_m_corstone_300_makefile.inc
new file mode 100644
index 0000000..435694f
--- /dev/null
+++ b/tensorflow/lite/micro/tools/make/targets/cortex_m_corstone_300_makefile.inc
@@ -0,0 +1,169 @@
+# ARM Cortex M makefile targeted for a FVP based on Arm Corstone-300 software.
+# For more info see: tensorflow/lite/micro/cortex_m_corstone_300/README.md
+
+export PATH := $(MAKEFILE_DIR)/downloads/corstone300/models/Linux64_GCC-6.4:$(PATH)
+DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/corstone_300_download.sh ${MAKEFILE_DIR}/downloads)
+ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+ $(error Something went wrong with the Arm Corstone-300 software download: $(DOWNLOAD_RESULT))
+endif
+
+ETHOS_U_CORE_PLATFORM := ${PWD}/$(MAKEFILE_DIR)/downloads/ethos_u_core_platform/targets/corstone-300
+DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/ethos_u_core_platform_download.sh ${MAKEFILE_DIR}/downloads)
+ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+ $(error Something went wrong with the Ethos-U Core Platform software download: $(DOWNLOAD_RESULT))
+endif
+
+# This target has dependencies to CMSIS-Device so just in case running without OPTIMIZED_KERNEL_DIR=cmsis_nn.
+CMSIS_DEFAULT_DOWNLOAD_PATH := $(MAKEFILE_DIR)/downloads/cmsis
+CMSIS_PATH := $(CMSIS_DEFAULT_DOWNLOAD_PATH)
+ifeq ($(CMSIS_PATH), $(CMSIS_DEFAULT_DOWNLOAD_PATH))
+ DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/ext_libs/cmsis_download.sh ${MAKEFILE_DIR}/downloads)
+ ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+ $(error Something went wrong with the CMSIS download: $(DOWNLOAD_RESULT))
+ endif
+endif
+
+FLOAT := soft
+GCC_TARGET_ARCH := $(TARGET_ARCH)
+
+ifeq ($(TARGET_ARCH), cortex-m0)
+ CORE=M0
+
+else ifeq ($(TARGET_ARCH), cortex-m3)
+ CORE=M3
+
+else ifeq ($(TARGET_ARCH), cortex-m33)
+ CORE=M33
+ FLOAT=hard
+ CMSIS_ARM_FEATURES := _DSP_DP
+
+else ifeq ($(TARGET_ARCH), cortex-m33+nodsp)
+ CORE=M33
+
+else ifeq ($(TARGET_ARCH), cortex-m4)
+ CORE=M4
+ GCC_TARGET_ARCH := cortex-m4+nofp
+
+else ifeq ($(TARGET_ARCH), cortex-m4+fp)
+ CORE=M4
+ FLOAT=hard
+ GCC_TARGET_ARCH := cortex-m4
+ CMSIS_ARM_FEATURES := _FP
+
+else ifeq ($(TARGET_ARCH), cortex-m55)
+ CORE=M55
+ FLOAT=hard
+
+else ifeq ($(TARGET_ARCH), cortex-m55+nodsp+nofp)
+ CORE=M55
+
+else ifeq ($(TARGET_ARCH), cortex-m55+nofp)
+ CORE=M55
+
+else ifeq ($(TARGET_ARCH), cortex-m7)
+ CORE=M7
+ GCC_TARGET_ARCH := cortex-m7+nofp
+
+else ifeq ($(TARGET_ARCH), cortex-m7+fp)
+ CORE=M7
+ FLOAT=hard
+ GCC_TARGET_ARCH := cortex-m7
+ CMSIS_ARM_FEATURES := _DP
+
+else
+ $(error "TARGET_ARCH=$(TARGET_ARCH) is not supported")
+endif
+
+ifneq ($(filter cortex-m55%,$(TARGET_ARCH)),)
+ # soft-abi=soft disables MVE - use softfp instead for M55.
+ ifeq ($(FLOAT),soft)
+ FLOAT=softfp
+ endif
+endif
+
+ifeq ($(TOOLCHAIN), gcc)
+ export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH)
+ DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/arm_gcc_download.sh ${MAKEFILE_DIR}/downloads)
+ ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+ $(error Something went wrong with the GCC download: $(DOWNLOAD_RESULT))
+ endif
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+
+ FLAGS_GCC = -mcpu=$(GCC_TARGET_ARCH) -mfpu=auto
+ CXXFLAGS += $(FLAGS_GCC)
+ CCFLAGS += $(FLAGS_GCC)
+
+ LDFLAGS += \
+ --specs=nosys.specs \
+ -T $(ETHOS_U_CORE_PLATFORM)/platform_parsed.ld \
+ -Wl,-Map=${TENSORFLOW_ROOT}$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref \
+ -Wl,--gc-sections \
+ --entry Reset_Handler
+
+else
+ $(error "TOOLCHAIN=$(TOOLCHAIN) is not supported.")
+endif
+
+# TODO: fix warnings.
+OMIT_ERRORS = \
+ -Wno-implicit-fallthrough \
+ -Wno-strict-aliasing
+
+PLATFORM_FLAGS = \
+ -DTF_LITE_MCU_DEBUG_LOG \
+ -mthumb \
+ -mfloat-abi=$(FLOAT) \
+ -funsigned-char \
+ -mlittle-endian \
+ ${OMIT_ERRORS} \
+ -fomit-frame-pointer \
+ -MD \
+ -DCPU_$(CORE)=1
+
+# Common + C/C++ flags
+CXXFLAGS += $(PLATFORM_FLAGS)
+CCFLAGS += $(PLATFORM_FLAGS)
+
+ARM_CPU := $(subst cortex-m,ARMCM,$(GCC_TARGET_ARCH))
+ARM_CPU := $(subst +nofp,,$(ARM_CPU))
+CXXFLAGS += -D$(ARM_CPU)$(CMSIS_ARM_FEATURES)
+CCFLAGS += -D$(ARM_CPU)$(CMSIS_ARM_FEATURES)
+
+THIRD_PARTY_CC_SRCS += \
+ $(ETHOS_U_CORE_PLATFORM)/retarget.c \
+ $(ETHOS_U_CORE_PLATFORM)/uart.c
+
+CMSIS_DEFAULT_DOWNLOAD_PATH := $(MAKEFILE_DIR)/downloads/cmsis
+CMSIS_PATH := $(CMSIS_DEFAULT_DOWNLOAD_PATH)
+THIRD_PARTY_CC_SRCS += \
+ $(CMSIS_PATH)/Device/ARM/$(ARM_CPU)/Source/system_$(ARM_CPU).c \
+ $(CMSIS_PATH)/Device/ARM/$(ARM_CPU)/Source/startup_$(ARM_CPU).c
+INCLUDES += \
+ -I$(CMSIS_PATH)/Device/ARM/$(ARM_CPU)/Include \
+ -I$(CMSIS_PATH)/CMSIS/Core/Include
+
+# TODO(#47071): Examine why Micro benchmarks fails.
+MICRO_LITE_BENCHMARKS := $(filter-out tensorflow/lite/micro/benchmarks/Makefile.inc, $(MICRO_LITE_BENCHMARKS))
+
+# TODO(#47070): Examine why some tests fail here.
+EXCLUDED_TESTS := \
+ tensorflow/lite/micro/micro_interpreter_test.cc \
+ tensorflow/lite/micro/micro_allocator_test.cc \
+ tensorflow/lite/micro/memory_helpers_test.cc \
+ tensorflow/lite/micro/micro_error_reporter_test.cc \
+ tensorflow/lite/micro/output_handler_test.cc \
+ tensorflow/lite/micro/memory_arena_threshold_test.cc \
+ tensorflow/lite/micro/recording_micro_allocator_test.cc \
+ tensorflow/lite/micro/kernels/circular_buffer_test.cc \
+ tensorflow/lite/micro/kernels/pooling_test.cc
+MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS))
+EXCLUDED_EXAMPLE_TESTS := \
+ tensorflow/lite/micro/examples/magic_wand/Makefile.inc \
+ tensorflow/lite/micro/examples/micro_speech/Makefile.inc \
+ tensorflow/lite/micro/examples/person_detection/Makefile.inc \
+ tensorflow/lite/micro/examples/hello_world/Makefile.inc \
+ tensorflow/lite/micro/examples/network_tester/Makefile.inc \
+ tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc
+MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS))
+
+TEST_SCRIPT := tensorflow/lite/micro/testing/test_with_arm_corstone_300.sh
diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md
deleted file mode 100644
index 6c88ce3..0000000
--- a/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md
+++ /dev/null
@@ -1,35 +0,0 @@
-# Building TensorFlow Lite for Microcontrollers for Cadence Tensilica HiFi DSPs
-
-This document describes the steps to build and run the Tensorflow Lite Micro on
-the Cadence HiFi DSPs.
-
-## Pre-requisites
-
-The Xtensa development tools and the target processor configurations should be
-installed on the system. Please check [https://tensilicatools.com] for more
-information about downloading and installing the required tools.
-
-The PATH variable should be set to include the <xtensa_tools_root>/bin
-directory. The XTENSA_SYSTEM and XTENSA_CORE environment variables should be set
-to the required tools version and the required processor configuration.
-
-## Building for HiFi Processors
-
-To build the code using Xtensa tools for the processor configuration selected by
-XTENSA_CORE , set TARGET=xtensa_hifi. Additionally TARGET_ARCH can be used to
-select optimized HiFi NN kernels specific to the processor configuration.
-Currently the HiFi4 NN kernels are provided which can be enabled as follows:
-
-make -f tensorflow/lite/micro/tools/make/Makefile test_micro_speech_test
-TARGET=xtensa_hifi TARGET_ARCH=hifi4
-
-Xtensa specific TF Lite Micro kernels are implemented in this folder:
-tensorflow/lite/micro/kernels/xtensa_hifi/
-
-A scratch memory allocation is needed for the HiFi optimized kernels. This
-allocation is currently done on stack and it's size can be controlled by
-defining 'XTENSA_NNLIB_MAX_SCRATCH_SIZE' appropriately in the file
-'tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc
-
-The files containing the HiFi optimized NN kernels are present in this folder:
-tensorflow/lite/micro/kernels/xtensa_hifi/xa_nnlib/
diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc
deleted file mode 100644
index e3850c3..0000000
--- a/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc
+++ /dev/null
@@ -1,61 +0,0 @@
-# Settings for Xtensa toolchain for the hifi kernels.
-# REQUIRED:
-# Environment variables:
-# - XTENSA_BASE must be set to location of
-# the Xtensa developer tools installation directory.
-# Command line arguments:
-# - XTENSA_TOOLS_VERSION: For example: RI-2019.2-linux
-# - XTENSA_CORE: The name of the Xtensa core to use
-# For example: hifi3
-
-ifeq ($(TARGET), xtensa_hifi)
- TARGET_ARCH := hifi3_bd5
-
- ifndef XTENSA_BASE
- $(error XTENSA_BASE is undefined)
- endif
-
- ifndef XTENSA_TOOLS_VERSION
- $(error XTENSA_TOOLS_VERSION is undefined)
- endif
-
- ifndef XTENSA_CORE
- $(error XTENSA_CORE is undefined)
- endif
-
- PLATFORM_ARGS = \
- -mno-mul16 \
- -mno-mul32 \
- -mno-div32 \
- -fsigned-char \
- -fno-exceptions \
- -mlongcalls \
- -INLINE:requested \
- -mcoproc \
- -fno-zero-initialized-in-bss \
- -mtext-section-literals \
- -fno-unsafe-math-optimizations \
-
- TF_LITE_MICRO_FLAGS = \
- -DTF_LITE_STATIC_MEMORY\
-
- export PATH := $(XTENSA_BASE)/tools/$(XTENSA_TOOLS_VERSION)/XtensaTools/bin:$(PATH)
- TARGET_TOOLCHAIN_PREFIX := xt-
- CXX_TOOL := clang++
- CC_TOOL := clang
-
- CXXFLAGS = -O0 $(PLATFORM_ARGS) -std=c++11 $(TF_LITE_MICRO_FLAGS)
- #TODO: Use -std=c11 ?
- CCFLAGS = -O3 $(PLATFORM_ARGS) $(TF_LITE_MICRO_FLAGS)
-
- TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh
-
- # These are microcontroller-specific rules for converting the ELF output
- # of the linker into a binary image that can be loaded directly.
- OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy
-
- $(BINDIR)/%.bin: $(BINDIR)/%
- echo "here"
- @mkdir -p $(dir $@)
- $(OBJCOPY) $< $@ -O binary
-endif
diff --git a/tensorflow/lite/objc/sources/TFLInterpreter.mm b/tensorflow/lite/objc/sources/TFLInterpreter.mm
index 03a20f0..58b009d 100644
--- a/tensorflow/lite/objc/sources/TFLInterpreter.mm
+++ b/tensorflow/lite/objc/sources/TFLInterpreter.mm
@@ -421,6 +421,7 @@
case kTfLiteString:
case kTfLiteComplex64:
case kTfLiteComplex128:
+ case kTfLiteUInt32:
case kTfLiteUInt64:
case kTfLiteResource:
case kTfLiteVariant:
diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc
index 1ec0d1a..d02d2d2 100644
--- a/tensorflow/lite/optional_debug_tools.cc
+++ b/tensorflow/lite/optional_debug_tools.cc
@@ -51,6 +51,8 @@
return "kTfLiteFloat32";
case kTfLiteInt32:
return "kTfLiteInt32";
+ case kTfLiteUInt32:
+ return "kTfLiteUInt32";
case kTfLiteUInt8:
return "kTfLiteUInt8";
case kTfLiteInt8:
diff --git a/tensorflow/lite/portable_type_to_tflitetype.h b/tensorflow/lite/portable_type_to_tflitetype.h
index 9fbcfb8..83a0ac6 100644
--- a/tensorflow/lite/portable_type_to_tflitetype.h
+++ b/tensorflow/lite/portable_type_to_tflitetype.h
@@ -59,6 +59,7 @@
// No string mapping is included here, since the TF Lite packed representation
// doesn't correspond to a C++ type well.
MATCH_TYPE_AND_TFLITE_TYPE(int32_t, kTfLiteInt32);
+MATCH_TYPE_AND_TFLITE_TYPE(uint32_t, kTfLiteUInt32);
MATCH_TYPE_AND_TFLITE_TYPE(int16_t, kTfLiteInt16);
MATCH_TYPE_AND_TFLITE_TYPE(int64_t, kTfLiteInt64);
MATCH_TYPE_AND_TFLITE_TYPE(float, kTfLiteFloat32);
diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py
index 4ea9d9f..f7ef3b3 100644
--- a/tensorflow/lite/python/interpreter.py
+++ b/tensorflow/lite/python/interpreter.py
@@ -817,7 +817,7 @@
Raises:
ValueError: If the interpreter was unable to create.
"""
- self._custom_op_registerers = custom_op_registerers
+ self._custom_op_registerers = custom_op_registerers or []
super(InterpreterWithCustomOps, self).__init__(
model_path=model_path,
model_content=model_content,
diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py
index 62bd971..2c43ba2 100644
--- a/tensorflow/lite/python/interpreter_test.py
+++ b/tensorflow/lite/python/interpreter_test.py
@@ -67,6 +67,12 @@
'testdata/permute_float.tflite'),
custom_op_registerers=[bogus_name])
+ def testNoCustomOps(self):
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_path=resource_loader.get_path_to_datafile(
+ 'testdata/permute_float.tflite'))
+ self.assertTrue(interpreter._safe_to_run())
+
class InterpreterTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc
index 1785aa0..5fabf66 100644
--- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc
@@ -42,6 +42,8 @@
return NPY_FLOAT64;
case kTfLiteInt32:
return NPY_INT32;
+ case kTfLiteUInt32:
+ return NPY_UINT32;
case kTfLiteInt16:
return NPY_INT16;
case kTfLiteUInt8:
@@ -80,6 +82,8 @@
return kTfLiteFloat64;
case NPY_INT32:
return kTfLiteInt32;
+ case NPY_UINT32:
+ return kTfLiteUInt32;
case NPY_INT16:
return kTfLiteInt16;
case NPY_UINT8:
diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc
index a70ea5b..2a744fb 100644
--- a/tensorflow/lite/python/optimize/calibration_wrapper.cc
+++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc
@@ -73,6 +73,8 @@
return TensorType_FLOAT64;
case kTfLiteInt32:
return TensorType_INT32;
+ case kTfLiteUInt32:
+ return TensorType_UINT32;
case kTfLiteUInt8:
return TensorType_UINT8;
case kTfLiteInt8:
diff --git a/tensorflow/lite/python/test_util.py b/tensorflow/lite/python/test_util.py
index da9453b..3da1e80 100644
--- a/tensorflow/lite/python/test_util.py
+++ b/tensorflow/lite/python/test_util.py
@@ -24,7 +24,7 @@
def get_ops_list(model_data):
- """Return a set of ops in the tflite model data."""
+ """Returns a set of ops in the tflite model data."""
model = schema_fb.Model.GetRootAsModel(model_data, 0)
op_set = set()
@@ -40,3 +40,18 @@
else:
op_set.add(visualize.BuiltinCodeToName(builtin_code))
return op_set
+
+
+def get_output_shapes(model_data):
+ """Returns a list of output shapes in the tflite model data."""
+ model = schema_fb.Model.GetRootAsModel(model_data, 0)
+
+ output_shapes = []
+ for subgraph_idx in range(model.SubgraphsLength()):
+ subgraph = model.Subgraphs(subgraph_idx)
+ for output_idx in range(subgraph.OutputsLength()):
+ output_tensor_idx = subgraph.Outputs(output_idx)
+ output_tensor = subgraph.Tensors(output_tensor_idx)
+ output_shapes.append(output_tensor.ShapeAsNumpy().tolist())
+
+ return output_shapes
diff --git a/tensorflow/lite/python/tflite_convert_test.py b/tensorflow/lite/python/tflite_convert_test.py
index 7c66e31..71810cf 100644
--- a/tensorflow/lite/python/tflite_convert_test.py
+++ b/tensorflow/lite/python/tflite_convert_test.py
@@ -54,7 +54,8 @@
def _run(self,
flags_str,
should_succeed,
- expected_ops_in_converted_model=None):
+ expected_ops_in_converted_model=None,
+ expected_output_shapes=None):
output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
tflite_bin = resource_loader.get_path_to_datafile('tflite_convert')
cmdline = '{0} --output_file={1} {2}'.format(tflite_bin, output_file,
@@ -69,6 +70,9 @@
op_set = tflite_test_util.get_ops_list(content)
for opname in expected_ops_in_converted_model:
self.assertIn(opname, op_set)
+ if expected_output_shapes:
+ output_shapes = tflite_test_util.get_output_shapes(content)
+ self.assertEqual(output_shapes, expected_output_shapes)
os.remove(output_file)
else:
self.assertFalse(should_succeed)
@@ -88,6 +92,17 @@
keras.models.save_model(model, keras_file)
return keras_file
+ def _getKerasFunctionalModelFile(self):
+ """Returns a functional Keras model with output shapes [[1, 1], [1, 2]]."""
+ input_tensor = keras.layers.Input(shape=(1,))
+ output1 = keras.layers.Dense(1, name='b')(input_tensor)
+ output2 = keras.layers.Dense(2, name='a')(input_tensor)
+ model = keras.models.Model(inputs=input_tensor, outputs=[output1, output2])
+
+ keras_file = self._getFilepath('functional_model.h5')
+ keras.models.save_model(model, keras_file)
+ return keras_file
+
class TfLiteConvertV1Test(TestModels):
@@ -482,6 +497,25 @@
self._run(flags_str, should_succeed=True)
os.remove(keras_file)
+ @test_util.run_v2_only
+ def testFunctionalKerasModel(self):
+ keras_file = self._getKerasFunctionalModelFile()
+
+ flags_str = '--keras_model_file={}'.format(keras_file)
+ self._run(flags_str, should_succeed=True,
+ expected_output_shapes=[[1, 1], [1, 2]])
+ os.remove(keras_file)
+
+ @test_util.run_v2_only
+ def testFunctionalKerasModelMLIR(self):
+ keras_file = self._getKerasFunctionalModelFile()
+
+ flags_str = (
+ '--keras_model_file={} --experimental_new_converter'.format(keras_file))
+ self._run(flags_str, should_succeed=True,
+ expected_output_shapes=[[1, 1], [1, 2]])
+ os.remove(keras_file)
+
def testMissingRequired(self):
self._run('--invalid_args', should_succeed=False)
diff --git a/tensorflow/lite/python/tflite_keras_util.py b/tensorflow/lite/python/tflite_keras_util.py
index c9f5b40..21f8873 100644
--- a/tensorflow/lite/python/tflite_keras_util.py
+++ b/tensorflow/lite/python/tflite_keras_util.py
@@ -183,11 +183,6 @@
model, inputs=inputs, build_graph=False, training=False, saving=True):
outputs = model(inputs, training=False)
- # Outputs always has to be a flat dict.
- output_names = model.output_names # Functional Model.
- if output_names is None: # Subclassed Model.
- output_names = create_pseudo_output_names(outputs)
- outputs = nest.flatten(outputs)
- return {name: output for name, output in zip(output_names, outputs)}
+ return outputs
return _wrapped_model
diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py
index f69bd32..1ed3907 100644
--- a/tensorflow/lite/python/util.py
+++ b/tensorflow/lite/python/util.py
@@ -66,6 +66,7 @@
dtypes.complex128: _types_pb2.COMPLEX128,
dtypes.resource: _types_pb2.RESOURCE,
dtypes.variant: _types_pb2.VARIANT,
+ dtypes.uint32: _types_pb2.UINT32,
}
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
@@ -81,6 +82,7 @@
9: dtypes.int8,
10: dtypes.float64,
11: dtypes.complex128,
+ 16: dtypes.uint32,
}
_TFLITE_FILE_IDENTIFIER = b"TFL3"
diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py
index 7e0f2c9..0fd2be1 100644
--- a/tensorflow/lite/python/util_test.py
+++ b/tensorflow/lite/python/util_test.py
@@ -48,6 +48,8 @@
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.int32), _types_pb2.INT32)
self.assertEqual(
+ util.convert_dtype_to_tflite_type(dtypes.uint32), _types_pb2.UINT32)
+ self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.uint8),
_types_pb2.QUANTIZED_UINT8)
self.assertEqual(
@@ -89,13 +91,16 @@
util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(11), dtypes.complex128)
+ self.assertEqual(
+ util._convert_tflite_enum_type_to_tf_type(16), dtypes.uint32)
with self.assertRaises(ValueError) as error:
util._convert_tflite_enum_type_to_tf_type(20)
self.assertEqual(
"Unsupported enum 20. The valid map of enum to tf types is : "
"{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, "
"5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, "
- "10: tf.float64, 11: tf.complex128}", str(error.exception))
+ "10: tf.float64, 11: tf.complex128, 16: tf.uint32}",
+ str(error.exception))
def testTensorName(self):
with ops.Graph().as_default():
@@ -108,6 +113,30 @@
got_name = util.get_tensor_name(out_tensors[i])
self.assertEqual(got_name, expect_names[i])
+ def testUint32PassThrough(self):
+ model = tf.keras.Sequential([
+ tf.keras.layers.InputLayer(input_shape=(4,), dtype=tf.uint32),
+ tf.keras.layers.Reshape(target_shape=(2, 2))
+ ])
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ tflite_model = converter.convert()
+ interpreter = tf.lite.Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()[0]
+ output_details = interpreter.get_output_details()[0]
+
+ self.assertEqual(input_details["dtype"], np.uint32)
+ self.assertEqual(output_details["dtype"], np.uint32)
+
+ in_array = np.array([[1, 1, 1, 1]], dtype="uint32") * ((1 << 32) - 1)
+ expected_out = np.reshape(in_array, (2, 2))
+
+ interpreter.set_tensor(input_details["index"], in_array)
+ interpreter.invoke()
+
+ output_data = interpreter.get_tensor(output_details["index"])[0]
+ self.assertAllEqual(expected_out, output_data)
+
@test_util.enable_control_flow_v2
def testRemoveLowerUsingSwitchMerge(self):
with ops.Graph().as_default():
diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs
index 278afef..dae3a46 100644
--- a/tensorflow/lite/schema/schema.fbs
+++ b/tensorflow/lite/schema/schema.fbs
@@ -47,6 +47,7 @@
UINT64 = 12,
RESOURCE = 13,
VARIANT = 14,
+ UINT32 = 15,
}
// Custom quantization parameters for experimenting with new quantization
diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h
index 23e055b..c96049c 100755
--- a/tensorflow/lite/schema/schema_generated.h
+++ b/tensorflow/lite/schema/schema_generated.h
@@ -404,11 +404,12 @@
TensorType_UINT64 = 12,
TensorType_RESOURCE = 13,
TensorType_VARIANT = 14,
+ TensorType_UINT32 = 15,
TensorType_MIN = TensorType_FLOAT32,
- TensorType_MAX = TensorType_VARIANT
+ TensorType_MAX = TensorType_UINT32
};
-inline const TensorType (&EnumValuesTensorType())[15] {
+inline const TensorType (&EnumValuesTensorType())[16] {
static const TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
@@ -424,13 +425,14 @@
TensorType_COMPLEX128,
TensorType_UINT64,
TensorType_RESOURCE,
- TensorType_VARIANT
+ TensorType_VARIANT,
+ TensorType_UINT32
};
return values;
}
inline const char * const *EnumNamesTensorType() {
- static const char * const names[16] = {
+ static const char * const names[17] = {
"FLOAT32",
"FLOAT16",
"INT32",
@@ -446,13 +448,14 @@
"UINT64",
"RESOURCE",
"VARIANT",
+ "UINT32",
nullptr
};
return names;
}
inline const char *EnumNameTensorType(TensorType e) {
- if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_VARIANT)) return "";
+ if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT32)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesTensorType()[index];
}
diff --git a/tensorflow/lite/testing/op_tests/fused_batch_norm.py b/tensorflow/lite/testing/op_tests/fused_batch_norm.py
index f0d7b4f..ee33b78 100644
--- a/tensorflow/lite/testing/op_tests/fused_batch_norm.py
+++ b/tensorflow/lite/testing/op_tests/fused_batch_norm.py
@@ -31,8 +31,18 @@
"dtype": [tf.float32],
"input_shape": [[1, 1, 6, 2]],
"epsilon": [0.001, 0.1],
+ "is_training": [False],
}]
+ # Training support in MLIR converter.
+ if options.use_experimental_converter:
+ test_parameters = test_parameters + [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 1, 6, 2]],
+ "epsilon": [0.001, 0.1],
+ "is_training": [True],
+ }]
+
def build_graph(parameters):
"""Build the testing graph for fused batch normalization."""
input_shape = parameters["input_shape"]
@@ -43,7 +53,8 @@
mean = create_tensor_data(parameters["dtype"], scale_shape)
variance = create_tensor_data(parameters["dtype"], scale_shape)
- x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ x = tf.compat.v1.placeholder(
+ dtype=parameters["dtype"], name="x", shape=parameters["input_shape"])
[x_norm, _, _] = tf.compat.v1.nn.fused_batch_norm(
x,
scale,
@@ -52,19 +63,22 @@
variance,
parameters["epsilon"],
data_format="NHWC",
- is_training=False)
+ is_training=parameters["is_training"])
input_tensor = tf.compat.v1.placeholder(
dtype=parameters["dtype"],
name="input",
shape=parameters["input_shape"])
out = tf.add(input_tensor, x_norm)
- return [input_tensor], [out]
+ return [x, input_tensor], [out]
def build_inputs(parameters, sess, inputs, outputs):
- input_value = create_tensor_data(parameters["dtype"],
- parameters["input_shape"])
- return [input_value], sess.run(
- outputs, feed_dict=dict(zip(inputs, [input_value])))
+ input_values = [
+ create_tensor_data(parameters["dtype"], parameters["input_shape"]),
+ create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ ]
+
+ return input_values, sess.run(
+ outputs, feed_dict=dict(zip(inputs, input_values)))
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
diff --git a/tensorflow/lite/testing/selective_build_test.cc b/tensorflow/lite/testing/selective_build_test.cc
index f614d2e..83b1fa6 100644
--- a/tensorflow/lite/testing/selective_build_test.cc
+++ b/tensorflow/lite/testing/selective_build_test.cc
@@ -64,12 +64,12 @@
}
TEST(SelectiveBuiltTest, AddModel) {
- std::string model = "third_party/tensorflow/lite/testdata/add.bin";
+ std::string model = "tensorflow/lite/testdata/add.bin";
EXPECT_THAT(RunWithRandomInputs(model), true);
}
TEST(SelectiveBuiltTest, LSTMModel) {
- std::string model = "third_party/tensorflow/lite/testdata/lstm.bin";
+ std::string model = "tensorflow/lite/testdata/lstm.bin";
EXPECT_THAT(RunWithRandomInputs(model), true);
}
} // namespace tflite
diff --git a/tensorflow/lite/testing/split.h b/tensorflow/lite/testing/split.h
index c23f6f9..d70ed28 100644
--- a/tensorflow/lite/testing/split.h
+++ b/tensorflow/lite/testing/split.h
@@ -59,6 +59,16 @@
}
template <>
+inline std::vector<uint32_t> Split(const string& s, const string& delimiter) {
+ std::vector<uint32_t> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ // NOLINTNEXTLINE(runtime/deprecated_fn)
+ fields.push_back(strtol(s.data() + p.first, nullptr, 10));
+ }
+ return fields;
+}
+
+template <>
inline std::vector<int64_t> Split(const string& s, const string& delimiter) {
std::vector<int64_t> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
diff --git a/tensorflow/lite/testing/tf_driver.cc b/tensorflow/lite/testing/tf_driver.cc
index b63aecc..481030c 100644
--- a/tensorflow/lite/testing/tf_driver.cc
+++ b/tensorflow/lite/testing/tf_driver.cc
@@ -162,6 +162,10 @@
num_values_available =
FillTensorWithData<int32_t>(tensor, values_as_string);
break;
+ case tensorflow::DT_UINT32:
+ num_values_available =
+ FillTensorWithData<uint32_t>(tensor, values_as_string);
+ break;
case tensorflow::DT_UINT8:
num_values_available =
FillTensorWithData<uint8_t>(tensor, values_as_string);
@@ -224,6 +228,8 @@
return TensorDataToCsvString<float>(tensor);
case tensorflow::DT_INT32:
return TensorDataToCsvString<int32_t>(tensor);
+ case tensorflow::DT_UINT32:
+ return TensorDataToCsvString<uint32_t>(tensor);
case tensorflow::DT_INT64:
return TensorDataToCsvString<tensorflow::int64>(tensor);
case tensorflow::DT_UINT8:
diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc
index c858bf0..7ac7e07 100644
--- a/tensorflow/lite/testing/tflite_driver.cc
+++ b/tensorflow/lite/testing/tflite_driver.cc
@@ -333,6 +333,8 @@
return TypedCheck<float, float>(verbose, tensor);
case kTfLiteInt32:
return TypedCheck<int32_t, float>(verbose, tensor);
+ case kTfLiteUInt32:
+ return TypedCheck<uint32_t, float>(verbose, tensor);
case kTfLiteInt64:
return TypedCheck<int64_t, float>(verbose, tensor);
case kTfLiteUInt64:
@@ -485,6 +487,12 @@
SetTensorData(values, tensor->data.raw);
break;
}
+ case kTfLiteUInt32: {
+ const auto& values = testing::Split<uint32_t>(csv_values, ",");
+ if (!CheckSizes<uint32_t>(tensor->bytes, values.size())) return;
+ SetTensorData(values, tensor->data.raw);
+ break;
+ }
case kTfLiteInt64: {
const auto& values = testing::Split<int64_t>(csv_values, ",");
if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return;
@@ -586,6 +594,9 @@
case kTfLiteInt32:
expected_output_[id]->SetData<int32_t>(csv_values);
break;
+ case kTfLiteUInt32:
+ expected_output_[id]->SetData<uint32_t>(csv_values);
+ break;
case kTfLiteInt64:
expected_output_[id]->SetData<int64_t>(csv_values);
break;
@@ -692,6 +703,8 @@
return JoinDefault(tensor->data.f, num_elements, ",");
case kTfLiteInt32:
return JoinDefault(tensor->data.i32, num_elements, ",");
+ case kTfLiteUInt32:
+ return JoinDefault(tensor->data.u32, num_elements, ",");
case kTfLiteInt64:
return JoinDefault(tensor->data.i64, num_elements, ",");
case kTfLiteUInt64:
diff --git a/tensorflow/lite/toco/export_tensorflow.cc b/tensorflow/lite/toco/export_tensorflow.cc
index 7ecf6cc..f3bdc9b 100644
--- a/tensorflow/lite/toco/export_tensorflow.cc
+++ b/tensorflow/lite/toco/export_tensorflow.cc
@@ -41,6 +41,7 @@
using tensorflow::DT_INT16;
using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
+using tensorflow::DT_UINT32;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::TensorProto;
@@ -59,6 +60,8 @@
return tensorflow::DT_UINT8;
case ArrayDataType::kInt32:
return tensorflow::DT_INT32;
+ case ArrayDataType::kUint32:
+ return tensorflow::DT_UINT32;
case ArrayDataType::kInt64:
return tensorflow::DT_INT64;
case ArrayDataType::kString:
@@ -2438,6 +2441,9 @@
case ArrayDataType::kInt32:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
break;
+ case ArrayDataType::kUint32:
+ (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT32);
+ break;
case ArrayDataType::kInt64:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
break;
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index 2adfe83..27e0047 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -57,6 +57,7 @@
using tensorflow::DT_INT64;
using tensorflow::DT_QUINT8;
using tensorflow::DT_STRING;
+using tensorflow::DT_UINT32;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
@@ -185,6 +186,8 @@
return ArrayDataType::kBool;
else if (dtype == DT_INT32)
return ArrayDataType::kInt32;
+ else if (dtype == DT_UINT32)
+ return ArrayDataType::kUint32;
else if (dtype == DT_INT64)
return ArrayDataType::kInt64;
else if (dtype == DT_STRING)
@@ -296,6 +299,18 @@
};
template <>
+struct TensorTraits<uint32> {
+ static int size(const TensorProto& p) { return p.uint32_val_size(); }
+ static int32 get(const TensorProto& p, int i) { return p.uint32_val(i); }
+ static std::string accessor_name() { return "uint32_val"; }
+ static std::string type_name() { return "uint32"; }
+ static void CopyFromContent(const TensorProto& p, std::vector<uint32>* data) {
+ toco::port::CopyToBuffer(p.tensor_content(),
+ reinterpret_cast<char*>(data->data()));
+ }
+};
+
+template <>
struct TensorTraits<int64> {
static int size(const TensorProto& p) { return p.int64_val_size(); }
static int64 get(const TensorProto& p, int i) { return p.int64_val(i); }
@@ -432,6 +447,23 @@
&output_int_data);
}
+tensorflow::Status ImportUint32Array(const TensorProto& input_tensor,
+ Array* output_array) {
+ CHECK_EQ(input_tensor.dtype(), DT_UINT32);
+ const auto& input_shape = input_tensor.tensor_shape();
+ CHECK_LE(input_shape.dim_size(), 6);
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
+ auto& output_int_data =
+ output_array->GetMutableBuffer<ArrayDataType::kUint32>().data;
+ output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
+ return ImportTensorData<uint32>(input_tensor, input_flat_size,
+ &output_int_data);
+}
+
tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT64);
@@ -757,6 +789,10 @@
array.data_type = ArrayDataType::kInt32;
status = ImportInt32Array(tensor, &array);
break;
+ case DT_UINT32:
+ array.data_type = ArrayDataType::kUint32;
+ status = ImportUint32Array(tensor, &array);
+ break;
case DT_QUINT8:
array.data_type = ArrayDataType::kUint8;
status = ImportQuint8Array(tensor, &array);
@@ -1473,7 +1509,6 @@
model);
}
}
-
switch (GetDataTypeAttr(node, "dtype")) {
case DT_FLOAT:
case DT_INT32:
diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc
index 98ce18b..ef5a077 100644
--- a/tensorflow/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/lite/toco/import_tensorflow_test.cc
@@ -37,6 +37,7 @@
using tensorflow::DT_INVALID;
using tensorflow::DT_QUINT8;
using tensorflow::DT_STRING;
+using tensorflow::DT_UINT32;
using tensorflow::NodeDef;
using tensorflow::Status;
using ::testing::ElementsAre;
@@ -127,6 +128,11 @@
t.add_int_val(i % std::numeric_limits<int>::max() + 1);
}
break;
+ case DT_UINT32:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_int_val(i % std::numeric_limits<uint32_t>::max() + 1);
+ }
+ break;
case DT_QUINT8:
for (int64_t i = 0; i < num_elements; ++i) {
t.add_int_val(i % std::numeric_limits<uint8_t>::max() + 1);
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index 077206d..cf76e62 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -48,6 +48,7 @@
{ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
{ArrayDataType::kInt16, ::tflite::TensorType_INT16},
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
+ {ArrayDataType::kUint32, ::tflite::TensorType_UINT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kUint64, ::tflite::TensorType_UINT64},
{ArrayDataType::kString, ::tflite::TensorType_STRING},
diff --git a/tensorflow/lite/toco/tflite/types.cc b/tensorflow/lite/toco/tflite/types.cc
index 9d4ab84..d241b56 100644
--- a/tensorflow/lite/toco/tflite/types.cc
+++ b/tensorflow/lite/toco/tflite/types.cc
@@ -92,6 +92,8 @@
return ::tflite::TensorType_INT16;
case ArrayDataType::kInt32:
return ::tflite::TensorType_INT32;
+ case ArrayDataType::kUint32:
+ return ::tflite::TensorType_UINT32;
case ArrayDataType::kInt64:
return ::tflite::TensorType_INT64;
case ArrayDataType::kUint8:
@@ -117,6 +119,8 @@
return ArrayDataType::kInt16;
case ::tflite::TensorType_INT32:
return ArrayDataType::kInt32;
+ case ::tflite::TensorType_UINT32:
+ return ArrayDataType::kUint32;
case ::tflite::TensorType_INT64:
return ArrayDataType::kInt64;
case ::tflite::TensorType_STRING:
@@ -143,6 +147,8 @@
return CopyBuffer<ArrayDataType::kInt16>(array, builder);
case ArrayDataType::kInt32:
return CopyBuffer<ArrayDataType::kInt32>(array, builder);
+ case ArrayDataType::kUint32:
+ return CopyBuffer<ArrayDataType::kUint32>(array, builder);
case ArrayDataType::kInt64:
return CopyBuffer<ArrayDataType::kInt64>(array, builder);
case ArrayDataType::kString:
@@ -170,6 +176,8 @@
return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
case ::tflite::TensorType_INT32:
return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
+ case ::tflite::TensorType_UINT32:
+ return CopyBuffer<ArrayDataType::kUint32>(buffer, array);
case ::tflite::TensorType_INT64:
return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
case ::tflite::TensorType_STRING:
diff --git a/tensorflow/lite/toco/tflite/types_test.cc b/tensorflow/lite/toco/tflite/types_test.cc
index efa2911..e1f4a65 100644
--- a/tensorflow/lite/toco/tflite/types_test.cc
+++ b/tensorflow/lite/toco/tflite/types_test.cc
@@ -71,6 +71,7 @@
std::vector<std::pair<ArrayDataType, ::tflite::TensorType>> testdata = {
{ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
+ {ArrayDataType::kUint32, ::tflite::TensorType_UINT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
{ArrayDataType::kBool, ::tflite::TensorType_BOOL},
@@ -154,6 +155,12 @@
::testing::ElementsAre(1, 1 << 30));
}
+TEST(DataBuffer, Uint32) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kUint32>({1, 1U << 31});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kUint32>().data,
+ ::testing::ElementsAre(1, 1U << 31));
+}
+
TEST(DataBuffer, Int16) {
Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt16>({1, 1 << 14});
EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt16>().data,
diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc
index b34f492..35a4229 100644
--- a/tensorflow/lite/toco/tooling_util.cc
+++ b/tensorflow/lite/toco/tooling_util.cc
@@ -2307,6 +2307,8 @@
return ArrayDataType::kInt16;
case INT32:
return ArrayDataType::kInt32;
+ case UINT32:
+ return ArrayDataType::kUint32;
case INT64:
return ArrayDataType::kInt64;
case UINT64:
diff --git a/tensorflow/lite/toco/types.proto b/tensorflow/lite/toco/types.proto
index 4548998..7e886b4 100644
--- a/tensorflow/lite/toco/types.proto
+++ b/tensorflow/lite/toco/types.proto
@@ -64,4 +64,7 @@
// Variant type
VARIANT = 15;
+
+ // Uint32
+ UINT32 = 16;
}
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 036af73..cb08bf3 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -485,6 +485,12 @@
return CreateInputTensorData<int32_t>(
num_elements, std::uniform_int_distribution<int32_t>(low, high));
}
+ case kTfLiteUInt32: {
+ int low = has_value_range ? low_range : 0;
+ int high = has_value_range ? high_range : 99;
+ return CreateInputTensorData<uint32_t>(
+ num_elements, std::uniform_int_distribution<uint32_t>(low, high));
+ }
case kTfLiteInt16: {
int low = has_value_range ? low_range : 0;
int high = has_value_range ? high_range : 99;
diff --git a/tensorflow/lite/tools/serialization/enum_mapping.h b/tensorflow/lite/tools/serialization/enum_mapping.h
index 721ce3b..a21271a 100644
--- a/tensorflow/lite/tools/serialization/enum_mapping.h
+++ b/tensorflow/lite/tools/serialization/enum_mapping.h
@@ -68,6 +68,8 @@
return TensorType_FLOAT64;
case kTfLiteInt32:
return TensorType_INT32;
+ case kTfLiteUInt32:
+ return TensorType_UINT32;
case kTfLiteUInt8:
return TensorType_UINT8;
case kTfLiteInt8:
diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc
index dcb154a..e23b36e 100644
--- a/tensorflow/lite/tools/verifier.cc
+++ b/tensorflow/lite/tools/verifier.cc
@@ -422,6 +422,9 @@
case TensorType_INT32:
bytes_required *= sizeof(int32_t);
break;
+ case TensorType_UINT32:
+ bytes_required *= sizeof(uint32_t);
+ break;
case TensorType_UINT8:
bytes_required *= sizeof(uint8_t);
break;
diff --git a/tensorflow/lite/type_to_tflitetype_test.cc b/tensorflow/lite/type_to_tflitetype_test.cc
index da6d7a6..30bc2e5 100644
--- a/tensorflow/lite/type_to_tflitetype_test.cc
+++ b/tensorflow/lite/type_to_tflitetype_test.cc
@@ -30,6 +30,8 @@
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt16>::Type>());
EXPECT_EQ(kTfLiteInt32,
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt32>::Type>());
+ EXPECT_EQ(kTfLiteUInt32,
+ typeToTfLiteType<TfLiteTypeToType<kTfLiteUInt32>::Type>());
EXPECT_EQ(kTfLiteFloat32,
typeToTfLiteType<TfLiteTypeToType<kTfLiteFloat32>::Type>());
EXPECT_EQ(kTfLiteUInt8,
diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc
index 1395a4e..995d52b 100644
--- a/tensorflow/lite/util.cc
+++ b/tensorflow/lite/util.cc
@@ -96,7 +96,10 @@
*bytes = sizeof(float);
break;
case kTfLiteInt32:
- *bytes = sizeof(int);
+ *bytes = sizeof(int32_t);
+ break;
+ case kTfLiteUInt32:
+ *bytes = sizeof(uint32_t);
break;
case kTfLiteUInt8:
*bytes = sizeof(uint8_t);
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index 1d7072f..2fdb894 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -23,6 +23,7 @@
tensorflow/lite/delegates/gpu/cl/serialization_generated.h
tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h
+tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
tensorflow/lite/micro/build_def.bzl
tensorflow/lite/schema/schema_generated.h
tensorflow/opensource_only/BUILD
@@ -72,7 +73,6 @@
tensorflow/third_party/eigen3/Eigen/SparseCholesky
tensorflow/third_party/eigen3/Eigen/SparseCore
tensorflow/third_party/eigen3/LICENSE
-tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
@@ -278,10 +278,7 @@
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
-tensorflow/third_party/toolchains/remote_config/BUILD
tensorflow/third_party/toolchains/remote_config/configs.bzl
-tensorflow/third_party/toolchains/remote_config/containers.bzl
-tensorflow/third_party/toolchains/remote_config/rbe_config.bzl
tensorflow/third_party/typing_extensions.BUILD
tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/zlib.BUILD
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 0e20815..8f041e6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -237,6 +237,7 @@
"//tensorflow/python/util",
"//tensorflow/python/util:_pywrap_checkpoint_reader",
"//tensorflow/python/util:_pywrap_kernel_registry",
+ "//tensorflow/python/util:_pywrap_nest",
"//tensorflow/python/util:_pywrap_stat_summarizer",
"//tensorflow/python/util:_pywrap_tfprof",
"//tensorflow/python/util:_pywrap_transform_graph",
@@ -751,6 +752,7 @@
deps = [
":_pywrap_debug_events_writer",
":_pywrap_events_writer",
+ "//tensorflow/python/util:_pywrap_nest",
"//tensorflow/python/util:_pywrap_kernel_registry",
":_pywrap_py_exception_registry",
"//tensorflow/python/lib/core:_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
@@ -5212,6 +5214,7 @@
":model_analyzer_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
+ "//tensorflow/python/util:cpp_nest",
"//tensorflow/python/util:cpp_python_util",
"//tensorflow/python/util:function_parameter_canonicalizer",
"//tensorflow/python/util:kernel_registry",
@@ -5289,6 +5292,7 @@
srcs = [
":bfloat16_lib", # bfloat16
":cost_analyzer_lib", # cost_analyzer
+ "//tensorflow/python/util:cpp_nest",
"//tensorflow/python/util:cpp_python_util",
"//tensorflow/python/util:kernel_registry",
":model_analyzer_lib", # model_analyzer
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index a480dac..61b2e2b 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -33,7 +33,7 @@
# This value changes every day with an automatic CL. It can be modified in code
# via `forward_compatibility_horizon()` or with the environment variable
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 2, 10)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 2, 17)
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
diff --git a/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py b/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py
index 9753c00..b2c9227 100644
--- a/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py
+++ b/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py
@@ -41,9 +41,9 @@
if autotune_buffers else "parallelism_only")
wall_time = self.run_and_report_benchmark(
dataset=dataset,
- num_elements=1,
+ num_elements=benchmark_iters,
warmup=True,
- iters=benchmark_iters,
+ iters=1,
name=benchmark_label + (autotune_string if autotune else ""))
return wall_time
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index be15d40..821e644 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -157,6 +157,9 @@
srcs = ["data_service_ops_test.py"],
shard_count = 10,
srcs_version = "PY3",
+ tags = [
+ "notsan", # TODO(b/180454113)
+ ],
deps = [
":data_service_test_base",
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
index e0e9678..e456877 100644
--- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
@@ -238,6 +238,21 @@
self.evaluate(next_fn())
@combinations.generate(test_base.default_test_combinations())
+ def testRoundtripEmptySnapshot(self):
+ dataset = dataset_ops.Dataset.range(0)
+ dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ self.assertDatasetProduces(dataset, [])
+ self.assertSnapshotDirectoryContains(
+ self._snapshot_dir,
+ num_fingerprints=1,
+ num_runs_per_fingerprint=1,
+ num_snapshot_shards_per_run=0)
+
+ dataset2 = dataset_ops.Dataset.range(0)
+ dataset2 = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ self.assertDatasetProduces(dataset2, [])
+
+ @combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSimple(self):
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py
index 4c16d35..f272942 100644
--- a/tensorflow/python/data/experimental/ops/interleave_ops.py
+++ b/tensorflow/python/data/experimental/ops/interleave_ops.py
@@ -111,11 +111,17 @@
first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0])
first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0])
- for data_input in data_inputs[1:]:
+ for i, data_input in enumerate(data_inputs[1:]):
if (dataset_ops.get_legacy_output_types(data_input) != first_output_types
or dataset_ops.get_legacy_output_classes(data_input)
!= first_output_classes):
- raise TypeError("All datasets must have the same type and class.")
+ raise TypeError("All datasets must have the same type and class.\n"
+ "dataset 0 vs dataset %s types: %s ; %s\n"
+ "classes: %s ; %s" %
+ (i + 1, first_output_types,
+ dataset_ops.get_legacy_output_types(data_input),
+ first_output_classes,
+ dataset_ops.get_legacy_output_classes(data_input)))
output_shapes = dataset_ops.get_legacy_output_shapes(self._data_inputs[0])
for data_input in self._data_inputs[1:]:
diff --git a/tensorflow/python/data/experimental/ops/io.py b/tensorflow/python/data/experimental/ops/io.py
index b91899a..d102f17 100644
--- a/tensorflow/python/data/experimental/ops/io.py
+++ b/tensorflow/python/data/experimental/ops/io.py
@@ -25,6 +25,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.platform import gfile
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.tf_export import tf_export
@@ -98,8 +99,8 @@
coder = nested_structure_coder.StructureCoder()
encoded = coder.encode_structure(dataset.element_spec)
- os.makedirs(path, exist_ok=True)
- with open(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f:
+ gfile.MakeDirs(path)
+ with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f:
f.write(encoded.SerializeToString())
path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
@@ -131,7 +132,7 @@
self._path = path
if element_spec is None:
- with open(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f:
+ with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f:
encoded_spec = f.read()
struct_pb = nested_structure_coder.struct_pb2.StructuredValue()
struct_pb.ParseFromString(encoded_spec)
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
index 23df8a0..8dc6c84 100644
--- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -33,6 +33,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
@@ -56,11 +57,17 @@
for _ in range(num_inits):
self.evaluate(multi_device_iterator.initializer)
- @combinations.generate(test_base.v1_only_combinations())
- def testBasic(self):
+ @combinations.generate(
+ combinations.times(
+ test_base.v1_only_combinations(),
+ combinations.combine(
+ max_buffer_size=[0, 1, 10], prefetch_buffer_size=[0, 1, 10])))
+ def testBasic(self, prefetch_buffer_size, max_buffer_size):
dataset = dataset_ops.Dataset.range(10)
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
+ dataset, ["/cpu:1", "/cpu:2"],
+ max_buffer_size=max_buffer_size,
+ prefetch_buffer_size=prefetch_buffer_size)
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config):
@@ -346,8 +353,12 @@
class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
parameterized.TestCase):
- @combinations.generate(test_base.v2_eager_only_combinations())
- def testBasic(self):
+ @combinations.generate(
+ combinations.times(
+ test_base.v2_eager_only_combinations(),
+ combinations.combine(
+ max_buffer_size=[0, 1, 10], prefetch_buffer_size=[0, 1, 10])))
+ def testBasic(self, max_buffer_size, prefetch_buffer_size):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@@ -355,7 +366,9 @@
dataset = dataset_ops.Dataset.range(1000)
mdi = multi_device_iterator_ops.OwnedMultiDeviceIterator(
- dataset, ["/cpu:0", "/gpu:0"])
+ dataset, ["/cpu:0", "/gpu:0"],
+ max_buffer_size=max_buffer_size,
+ prefetch_buffer_size=prefetch_buffer_size)
for i, el in enumerate(mdi):
self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()])
@@ -407,7 +420,12 @@
@def_function.function
def fn():
- dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn)
+ dataset = dataset_ops._GeneratorDataset(
+ 1,
+ init_fn,
+ next_fn,
+ finalize_fn,
+ output_signature=tensor_spec.TensorSpec([], dtypes.int64))
iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
dataset, ["/cpu:0", "/gpu:0"])
next(iterator)
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6497cb2..fea5a0e 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1808,16 +1808,20 @@
If not specified, elements will be processed sequentially. If the value
`tf.data.AUTOTUNE` is used, then the number of parallel
calls is set dynamically based on available CPU.
- deterministic: (Optional.) A boolean controlling whether determinism
- should be traded for performance by allowing elements to be yielded out
- of order. If `deterministic` is `None`, the
- `tf.data.Options.experimental_deterministic` dataset option (`True` by
- default) is used to decide whether to run deterministically.
+ deterministic: (Optional.) When `num_parallel_calls` is specified, this
+ boolean controls the order in which the transformation produces
+ elements. If set to `False`, the transformation is allowed to yield
+ elements out of order to trade determinism for performance. If not
+ specified, the `tf.data.Options.experimental_deterministic` option
+ (`True` by default) controls the behavior.
Returns:
Dataset: A `Dataset`.
"""
if num_parallel_calls is None:
+ if deterministic is not None:
+ warnings.warn("The `deterministic` argument has no effect unless the "
+ "`num_parallel_calls` argument is specified.")
return MapDataset(self, map_func, preserve_cardinality=True)
else:
return ParallelMapDataset(
@@ -1937,11 +1941,12 @@
from cycle elements synchronously with no parallelism. If the value
`tf.data.AUTOTUNE` is used, then the number of parallel
calls is set dynamically based on available CPU.
- deterministic: (Optional.) A boolean controlling whether determinism
- should be traded for performance by allowing elements to be produced out
- of order. If `deterministic` is `None`, the
- `tf.data.Options.experimental_deterministic` dataset option (`True` by
- default) is used to decide whether to run deterministically.
+ deterministic: (Optional.) When `num_parallel_calls` is specified, this
+ boolean controls the order in which the transformation produces
+ elements. If set to `False`, the transformation is allowed to yield
+ elements out of order to trade determinism for performance. If not
+ specified, the `tf.data.Options.experimental_deterministic` option
+ (`True` by default) controls the behavior.
Returns:
Dataset: A `Dataset`.
@@ -1953,6 +1958,9 @@
cycle_length = AUTOTUNE
if num_parallel_calls is None:
+ if deterministic is not None:
+ warnings.warn("The `deterministic` argument has no effect unless the "
+ "`num_parallel_calls` argument is specified.")
return InterleaveDataset(self, map_func, cycle_length, block_length)
else:
return ParallelInterleaveDataset(
@@ -2673,17 +2681,20 @@
If not specified, elements will be processed sequentially. If the value
`tf.data.AUTOTUNE` is used, then the number of parallel
calls is set dynamically based on available CPU.
- deterministic: (Optional.) A boolean controlling whether determinism
- should be traded for performance by allowing elements to be produced out
- of order. If `deterministic` is `None`, the
- `tf.data.Options.experimental_deterministic` dataset option (`True` by
- default) is used to decide whether to produce elements
- deterministically.
+ deterministic: (Optional.) When `num_parallel_calls` is specified, this
+ boolean controls the order in which the transformation produces
+ elements. If set to `False`, the transformation is allowed to yield
+ elements out of order to trade determinism for performance. If not
+ specified, the `tf.data.Options.experimental_deterministic` option
+ (`True` by default) controls the behavior.
Returns:
Dataset: A `Dataset`.
"""
if num_parallel_calls is None:
+ if deterministic is not None:
+ warnings.warn("The `deterministic` argument has no effect unless the "
+ "`num_parallel_calls` argument is specified.")
return DatasetV1Adapter(
MapDataset(
self,
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 6d858a8..4347225 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -450,7 +450,6 @@
],
python_version = "PY3",
tags = [
- "notap", # TODO(b/171355671)
"notsan", # TODO(b/151841995)
],
deps = [
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index 07adc63..655d472 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -1201,6 +1201,9 @@
assert remote_value.fetch() == 3
```
+ NOTE: A known limitation is `tf.data.Options` is ignored in dataset created
+ by `create_per_worker_dataset`.
+
Args:
dataset_fn: The dataset function that returns a dataset. This is to be
executed on the workers.
diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py
index d70f2c8..df97a01 100644
--- a/tensorflow/python/distribute/strategy_common_test.py
+++ b/tensorflow/python/distribute/strategy_common_test.py
@@ -131,7 +131,7 @@
def fn():
def replica_fn():
- value = constant_op.constant(1.0)
+ value = array_ops.identity(1.0)
reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
return reduced
@@ -152,7 +152,7 @@
def fn():
def replica_fn():
- value = (constant_op.constant(1.0), constant_op.constant(2.0))
+ value = (array_ops.identity(1.0), array_ops.identity(1.0))
reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
return reduced
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 66b0771..6e8d436 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -1046,6 +1046,7 @@
"//tensorflow/python:functional_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_spec",
+ "//tensorflow/python:test_ops",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 4681b9c..788ce6e 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -873,6 +873,18 @@
tape.pop_tape(self._tape)
self._recording = False
+ @tf_contextlib.contextmanager
+ def _ensure_recording(self):
+ """Ensures that this tape is recording."""
+ if not self._recording:
+ try:
+ self._push_tape()
+ yield
+ finally:
+ self._pop_tape()
+ else:
+ yield
+
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
@@ -1144,14 +1156,12 @@
target_shape = array_ops.shape(target)
# Note that we push and pop the tape here and below. This is needed since we
# need gradients through the enclosed operations.
- self._push_tape()
- target = array_ops.reshape(target, [-1])
- self._pop_tape()
+ with self._ensure_recording():
+ target = array_ops.reshape(target, [-1])
def loop_fn(i):
- self._push_tape()
- y = array_ops.gather(target, i)
- self._pop_tape()
+ with self._ensure_recording():
+ y = array_ops.gather(target, i)
return self.gradient(y, flat_sources,
unconnected_gradients=unconnected_gradients)
@@ -1285,16 +1295,14 @@
# Flatten target to 2-D.
# Note that we push and pop the tape here and below. This is needed since we
# need gradients through the enclosed operations.
- self._push_tape()
- with ops.control_dependencies(
- [check_ops.assert_equal(batch_size, source_shape[0])]):
- target = array_ops.reshape(target, [batch_size, target_row_size])
- self._pop_tape()
+ with self._ensure_recording():
+ with ops.control_dependencies(
+ [check_ops.assert_equal(batch_size, source_shape[0])]):
+ target = array_ops.reshape(target, [batch_size, target_row_size])
def loop_fn(i):
- self._push_tape()
- y = array_ops.gather(target, i, axis=1)
- self._pop_tape()
+ with self._ensure_recording():
+ y = array_ops.gather(target, i, axis=1)
return self.gradient(y, source,
unconnected_gradients=unconnected_gradients)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index bdc2bae..9b6d001 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -855,6 +855,24 @@
g.jacobian(y, [x])
@test_util.assert_no_new_tensors
+ def testJacobianInsideGradientTapeScope(self):
+ with backprop.GradientTape() as g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = x * x
+ z = y * y
+ self.assertAllClose(4. * 3. ** 3., g.jacobian(z, x))
+
+ @test_util.assert_no_new_tensors
+ def testBatchJacobianInsideGradientTapeScope(self):
+ with backprop.GradientTape(persistent=True) as g:
+ x = constant_op.constant([[3.0]])
+ g.watch(x)
+ y = x * x
+ z = y * y
+ self.assertAllClose([[[4. * 3. ** 3.]]], g.batch_jacobian(z, x))
+
+ @test_util.assert_no_new_tensors
def testGradientTapeBatchJacobianCalledMultipleTimes(self):
with backprop.GradientTape() as g:
x = constant_op.constant([[3.0]])
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 3287a15..1976026 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1451,6 +1451,14 @@
self._run(fn, 100000)
+ def benchmark_tf_flatten_dict_items(self):
+ nested = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
+
+ def fn():
+ nest.flatten_dict_items(nested)
+
+ self._run(fn, 100000)
+
def benchmark_tf_nn_convolution_overhead(self):
inputs = array_ops.ones((1, 1, 1, 1))
filters = array_ops.ones((1, 1, 1, 1))
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 5c8ee14..92e9c17 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1796,6 +1796,18 @@
self.assertIn('CPU', has_device.v.device)
@test_util.run_in_graph_and_eager_modes
+ def testMultipleDeviceCheck(self):
+
+ def f():
+ with ops.device('cpu'):
+ return test_ops.device_placement_op()
+
+ func = function.defun(f)
+ with ops.device('cpu:0'):
+ output = self.evaluate(func())
+ self.assertIn(compat.as_bytes('CPU:0'), output)
+
+ @test_util.run_in_graph_and_eager_modes
def testDeviceAnnotationsRespected(self):
def multi_device_fn():
diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py
index ebe70d9..2b23a15 100644
--- a/tensorflow/python/eager/remote_test.py
+++ b/tensorflow/python/eager/remote_test.py
@@ -38,6 +38,7 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -45,9 +46,11 @@
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import server_lib
from tensorflow.python.training.server_lib import ClusterSpec
+from tensorflow.python.util import compat
class SingleWorkerTest(test.TestCase, parameterized.TestCase):
@@ -112,20 +115,6 @@
self.assertEqual(rets[1].backing_device,
'/job:worker/replica:0/task:0/device:CPU:0')
- def testMultiDeviceFunctionAmbiguousDevice(self):
-
- @def_function.function
- def ambiguous_device(i):
- with ops.device('cpu:0'):
- return i + constant_op.constant([2])
-
- with self.assertRaises(errors.InvalidArgumentError) as cm:
- with ops.device('/job:worker/replica:0/task:0/cpu:0'):
- ambiguous_device(constant_op.constant([2])).numpy()
-
- self.assertIn('the output node must match exactly one device',
- cm.exception.message)
-
def testStreaming(self):
"""A mini stress test for streaming - issuing many RPCs back to back."""
with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
@@ -318,6 +307,21 @@
with ops.device('/job:worker/replica:0/task:1'):
self.assertAllEqual(local_func(x), [2, 1])
+ def testMultiDeviceFunctionAmbiguousDevice(self):
+
+ @def_function.function
+ def ambiguous_device(i):
+ with ops.device('/job:worker'):
+ # Multiple worker tasks, thus ambiguous device found error will be
+ # raised.
+ return i + constant_op.constant([2])
+
+ with self.assertRaises(errors.InvalidArgumentError) as cm:
+ ambiguous_device(constant_op.constant([2])).numpy()
+
+ self.assertIn('the output node must match exactly one device',
+ cm.exception.message)
+
# Note that the following tests for remote function cancellation only works
# when non-streaming RPC. We need to disable streaming explicitly and restore
# this config to its initial value at the end of each test case.
@@ -579,6 +583,32 @@
# Reset the context to avoid polluting other test cases.
context._reset_context()
+ def testMultipleDeviceFoundCheck(self):
+ remote.connect_to_cluster(self._cluster)
+
+ @def_function.function
+ def func():
+ with ops.device('cpu:0'):
+ # Multiple CPU:0 devices match would be found, but the CPU:0 from the
+ # parent device scope should be picked.
+ x = test_ops.device_placement_op()
+ y = string_ops.string_upper(x)
+ packed_var_0 = array_ops.stack([x, y], 0)
+ return packed_var_0
+
+ with ops.device('/job:my_worker/task:1'):
+ output = self.evaluate(func())
+ self.assertEqual(
+ compat.as_bytes('/job:my_worker/replica:0/task:1/device:CPU:0'),
+ output[0])
+ self.assertIn(compat.as_bytes('/JOB:MY_WORKER'), output[1])
+ with ops.device('/job:my_ps/task:1'):
+ output = self.evaluate(func())
+ self.assertEqual(
+ compat.as_bytes('/job:my_ps/replica:0/task:1/device:CPU:0'),
+ output[0])
+ self.assertIn(compat.as_bytes('/JOB:MY_PS'), output[1])
+
def testSimpleParameterServer(self):
remote.connect_to_cluster(self._cluster)
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 2b3e753..1bef551 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -462,7 +462,7 @@
@test_util.run_deprecated_v1
def testWhileLoopCallsFunc(self):
- with self.session(use_gpu=True) as sess:
+ with self.session():
@function.Defun(dtypes.float32)
def Times2(x):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 893351b..e56064b 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -2289,7 +2289,7 @@
``` python
class MyOperatorTest(test_util.TensorFlowTestCase):
def testMyOperator(self):
- with self.session(use_gpu=True):
+ with self.session():
valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
result = MyOperator(valid_input).eval()
self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
@@ -2339,7 +2339,7 @@
```python
class MyOperatorTest(test_util.TensorFlowTestCase):
def testMyOperator(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
result = MyOperator(valid_input).eval()
self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
diff --git a/tensorflow/python/keras/benchmarks/BUILD b/tensorflow/python/keras/benchmarks/BUILD
index dbdeb60..7486c7c 100644
--- a/tensorflow/python/keras/benchmarks/BUILD
+++ b/tensorflow/python/keras/benchmarks/BUILD
@@ -58,13 +58,11 @@
size = "large",
srcs = ["keras_cpu_benchmark_test.py"],
python_version = "PY3",
- tags = COMMON_TAGS + [
- "noguitar", # b/179813008
- ],
+ tags = COMMON_TAGS,
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
],
)
@@ -79,7 +77,7 @@
],
deps = [
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/utils:tf_inspect",
],
)
@@ -90,7 +88,7 @@
python_version = "PY3",
deps = [
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -100,7 +98,7 @@
srcs_version = "PY3",
deps = [
":distribution_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
],
)
@@ -112,7 +110,7 @@
tags = COMMON_TAGS,
deps = [
":benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -124,7 +122,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -136,7 +134,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -148,7 +146,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -160,7 +158,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
],
)
@@ -173,7 +171,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -185,7 +183,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -197,7 +195,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
],
)
@@ -210,7 +208,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -223,7 +221,7 @@
":benchmark_util",
":distribution_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -232,7 +230,7 @@
srcs = ["distribution_util.py"],
srcs_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -246,7 +244,7 @@
deps = [
":benchmark_util",
":profiler_lib",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/optimizer_v2",
],
)
@@ -263,5 +261,5 @@
srcs = ["model_memory_profile.py"],
python_version = "PY3",
tags = ["no_oss"],
- deps = ["//tensorflow:tensorflow_py"],
+ deps = ["//tensorflow:tensorflow_py_no_contrib"],
)
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD b/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
index 54d4474..f76eaa0 100644
--- a/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
@@ -51,7 +51,7 @@
visibility = ["//tensorflow/python/keras:__subpackages__"],
deps = [
":run_xprof",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -63,7 +63,7 @@
tags = BECHMARK_TAGS,
deps = [
":layer_benchmarks_test_base",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:benchmark_util",
],
)
diff --git a/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD b/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD
index ade9d1d..a71d256 100644
--- a/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD
+++ b/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD
@@ -32,7 +32,7 @@
srcs = ["saved_model_benchmark_util.py"],
srcs_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -46,7 +46,7 @@
],
deps = [
":saved_model_benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -60,7 +60,7 @@
],
deps = [
":saved_model_benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -74,7 +74,7 @@
],
deps = [
":saved_model_benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -88,7 +88,7 @@
],
deps = [
":saved_model_benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -102,7 +102,7 @@
],
deps = [
":saved_model_benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -116,7 +116,7 @@
],
deps = [
":saved_model_benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -130,7 +130,7 @@
],
deps = [
":saved_model_benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
@@ -144,7 +144,7 @@
],
deps = [
":saved_model_benchmark_util",
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/keras/benchmarks:profiler_lib",
],
)
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index b68ff7e..92f8ea4 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -860,6 +860,9 @@
srcs = ["parameter_server_training_test.py"],
python_version = "PY3",
shard_count = 1,
+ tags = [
+ "no_tfrt", # TODO(b/180537361): Reenable TFRT after the issue is resolved.
+ ],
deps = [
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@@ -878,6 +881,7 @@
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras",
+ "//tensorflow/python/keras/utils:dataset_creator",
],
)
diff --git a/tensorflow/python/keras/distribute/parameter_server_training_test.py b/tensorflow/python/keras/distribute/parameter_server_training_test.py
index 4feb748..eedf808 100644
--- a/tensorflow/python/keras/distribute/parameter_server_training_test.py
+++ b/tensorflow/python/keras/distribute/parameter_server_training_test.py
@@ -21,7 +21,9 @@
import random
import tempfile
+from absl import logging
from absl.testing import parameterized
+import numpy as np
from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
@@ -36,12 +38,18 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras import callbacks as callbacks_lib
+from tensorflow.python.keras.engine import sequential
+from tensorflow.python.keras.layers import core as core_layers
from tensorflow.python.keras.layers.preprocessing import string_lookup
+from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.keras.utils import dataset_creator
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
from tensorflow.python.training.server_lib import ClusterSpec
@@ -54,16 +62,19 @@
LABEL_VOCAB = ["yes", "no"]
-def make_coordinator(num_workers, num_ps):
+def make_cluster(num_workers, num_ps):
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
cluster_def["chief"] = [
"localhost:%d" % multi_worker_test_base.pick_unused_port()
]
- cluster_resolver = SimpleClusterResolver(
- ClusterSpec(cluster_def), rpc_layer="grpc")
+ return SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc")
+
+
+def make_coordinator(num_workers, num_ps):
return coordinator_lib.ClusterCoordinator(
- parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
+ parameter_server_strategy_v2.ParameterServerStrategyV2(
+ make_cluster(num_workers, num_ps)))
# TODO(yuefengz): move this to keras/integration_tests.
@@ -178,7 +189,7 @@
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
accuracy.update_state(labels, actual_pred)
- self.coordinator._strategy.run(replica_fn, args=(iterator,))
+ self.coordinator.strategy.run(replica_fn, args=(iterator,))
distributed_dataset = self.coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
@@ -230,6 +241,124 @@
self.assertIn(prediction1, ("yes", "no"))
+class ModelFitTest(test.TestCase, parameterized.TestCase):
+
+ def _model_compile(self, steps_per_execution=1, run_eagerly=False):
+
+ class ResultAssertingCallback(callbacks_lib.Callback):
+
+ def __init__(self):
+ self._prev_epoch = -1
+
+ def on_epoch_end(self, epoch, logs=None):
+ logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs)
+ if epoch <= self._prev_epoch:
+ raise RuntimeError("Epoch is supposed to be larger than previous.")
+ self._prev_epoch = epoch
+ if (logs.get("loss", None) is None or
+ not isinstance(logs["loss"], np.floating)):
+ raise RuntimeError("loss is supposed to be in the logs and float.")
+
+ strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+ make_cluster(3, 2))
+ with strategy.scope():
+ model = sequential.Sequential([core_layers.Dense(10)])
+ model.compile(
+ gradient_descent.SGD(),
+ loss="mse",
+ steps_per_execution=steps_per_execution,
+ run_eagerly=run_eagerly)
+ return model, [ResultAssertingCallback()]
+
+ def _model_fit(self,
+ steps_per_execution=1,
+ validation_data=None,
+ x=None,
+ steps_per_epoch=10,
+ run_eagerly=False):
+ model, callbacks = self._model_compile(steps_per_execution, run_eagerly)
+
+ def dataset_fn(input_context):
+ del input_context
+ x = random_ops.random_uniform((10, 10))
+ y = random_ops.random_uniform((10,))
+ return dataset_ops.DatasetV2.from_tensor_slices(
+ (x, y)).shuffle(10).repeat().batch(2)
+
+ x = x or dataset_creator.DatasetCreator(dataset_fn)
+
+ model.fit(
+ x,
+ epochs=10,
+ steps_per_epoch=steps_per_epoch,
+ verbose=0,
+ callbacks=callbacks,
+ validation_data=validation_data)
+ return model
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testModelFit(self):
+ model = self._model_fit()
+ self.assertEqual(model.optimizer.iterations, 100)
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testModelFitWithStepsPerExecution(self):
+ model = self._model_fit(steps_per_execution=10)
+ self.assertEqual(model.optimizer.iterations, 100)
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testModelFitWithNoStepsPerEpoch(self):
+ with self.assertRaisesRegex(
+ ValueError, "`steps_per_epoch` must be specified with "
+ "`ParameterServerStrategy`."):
+ self._model_fit(steps_per_epoch=None)
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testModelFitWithRunEagerly(self):
+ with self.assertRaisesRegex(
+ ValueError, "When using `Model` with `ParameterServerStrategy`, "
+ "`run_eagerly` is not supported."):
+ self._model_fit(run_eagerly=True)
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testModelFitWithValidationData(self):
+ with self.assertRaisesRegex(
+ NotImplementedError, "Evaluation in `model.fit` with "
+ "`ParameterServerStrategy` is not yet supported."):
+ self._model_fit(
+ validation_data=dataset_ops.DatasetV2.from_tensor_slices([1, 1]))
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testModelFitWithDatasetInstance(self):
+ with self.assertRaisesRegex(
+ NotImplementedError, "Only `DatasetCreator` input is supported in "
+ "`ParameterServerStrategy` at this time."):
+ self._model_fit(x=dataset_ops.DatasetV2.from_tensor_slices([1, 1]))
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testModelEvaluate(self):
+ model, _ = self._model_compile()
+ with self.assertRaisesRegex(
+ NotImplementedError, "`model.evaluate` is not yet supported with "
+ "`ParameterServerStrategy`."):
+ model.evaluate(x=dataset_ops.DatasetV2.from_tensor_slices([1, 1]))
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testModelPredict(self):
+ model, _ = self._model_compile()
+ with self.assertRaisesRegex(
+ NotImplementedError, "`model.predict` is not yet supported with "
+ "`ParameterServerStrategy`."):
+ model.predict(x=dataset_ops.DatasetV2.from_tensor_slices([1, 1]))
+
+ @combinations.generate(combinations.combine(mode=["eager"]))
+ def testClusterCoordinatorSingleInstance(self):
+ model = self._model_fit()
+ strategy = model.distribute_strategy
+ self.assertIs(strategy._cluster_coordinator,
+ coordinator_lib.ClusterCoordinator(strategy))
+
+
if __name__ == "__main__":
v2_compat.enable_v2_behavior()
test.main()
diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD
index 7937382..ade3cfa 100644
--- a/tensorflow/python/keras/engine/BUILD
+++ b/tensorflow/python/keras/engine/BUILD
@@ -58,6 +58,7 @@
"//tensorflow/python/distribute:parameter_server_strategy",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:reduce_util",
+ "//tensorflow/python/distribute/coordinator:cluster_coordinator",
"//tensorflow/python/eager:monitoring",
"//tensorflow/python/keras:activations",
"//tensorflow/python/keras:backend",
@@ -184,6 +185,7 @@
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/keras/utils:dataset_creator",
"//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/keras/utils:tf_utils",
],
diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py
index a098cdb..c9c0499 100644
--- a/tensorflow/python/keras/engine/base_layer_utils.py
+++ b/tensorflow/python/keras/engine/base_layer_utils.py
@@ -802,35 +802,46 @@
# TODO(b/141682913): Figure out why this is private and fix it.
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
- if len(saveables) != 1:
- raise ValueError('Only Trackables with one Saveable are supported.')
- saveable = list(saveables)[0]
+ # 'Saveables' won't exist when we're passed a legacy TF1 table like
+ # a StaticHashTable.
+ if not saveables:
+ self._num_tensors = 0
+ self._setter = lambda weights: None
+ self._getter = lambda: []
- if ops.executing_eagerly_outside_functions():
- # If we're in eager mode, we need to defer calling the Trackable's
- # saveable() callable until data export time.
- # However, it is safe to call the saveable as many times as we want, so
- # we will call it now to figure out how many tensors this Trackable will
- # produce.
- self._saveable = saveable
- self._num_tensors = len(self._saveable().specs)
- self._setter = lambda weights: self._saveable().restore(weights, None)
- self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
+ elif len(saveables) == 1:
+ saveable = list(saveables)[0]
+
+ if ops.executing_eagerly_outside_functions():
+ # If we're in eager mode, we need to defer calling the Trackable's
+ # saveable() callable until data export time.
+ # However, it is safe to call the saveable as many times as we want, so
+ # we will call it now to figure out how many tensors this Trackable will
+ # produce.
+ self._saveable = saveable
+ self._num_tensors = len(self._saveable().specs)
+ self._setter = lambda weights: self._saveable().restore(weights, None)
+ self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
+ else:
+ # If we're in Graph mode, we need to evaluate the Saveable only once and
+ # cache the resulting restore graph. Failing to do this will result in
+ # new assignment ops being added to the graph each time set_weights() is
+ # called.
+ self._placeholder_tensors = []
+ self._saveable = saveable()
+ self._num_tensors = len(self._saveable.specs)
+ for spec in self._saveable.specs:
+ tensor = spec.tensor
+ self._placeholder_tensors.append(
+ array_ops.placeholder(tensor.dtype, tensor.shape))
+ self._assign_op = self._saveable.restore(self._placeholder_tensors,
+ None)
+ self._setter = self._set_weights_v1
+ self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
else:
- # If we're in Graph mode, we need to evaluate the Saveable only once and
- # cache the resulting restore graph. Failing to do this will result in
- # new assignment ops being added to the graph each time set_weights() is
- # called.
- self._placeholder_tensors = []
- self._saveable = saveable()
- self._num_tensors = len(self._saveable.specs)
- for spec in self._saveable.specs:
- tensor = spec.tensor
- self._placeholder_tensors.append(
- array_ops.placeholder(tensor.dtype, tensor.shape))
- self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
- self._setter = self._set_weights_v1
- self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
+ raise ValueError('Only Trackables with one Saveable are supported. '
+ 'The Trackable %s has %d Saveables.' %
+ (trackable, len(saveables)))
@property
def num_tensors(self):
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index 2eccee8..eef4d1f 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -45,6 +45,7 @@
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import data_utils
+from tensorflow.python.keras.utils import dataset_creator
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -514,6 +515,40 @@
return dataset
+class DatasetCreatorAdapter(DataAdapter):
+ """Adapter that handles dataset functions."""
+
+ def __init__(self, *args, **kwargs):
+ super(DatasetCreatorAdapter, self).__init__(*args, **kwargs)
+
+ @staticmethod
+ def can_handle(x, y=None):
+ if isinstance(x, dataset_creator.DatasetCreator):
+ assert y is None
+ return True
+
+ def should_recreate_iterator(self):
+ # We expect users to shuffle the dataset in their `dataset_fn` supplied to
+ # `DatasetCreator`. Since that is a buffered shuffle, we intend to not reset
+ # the dataset so the batches that are not shuffled can still be pulled.
+ return False
+
+ def get_size(self):
+ raise NotImplementedError()
+
+ def get_dataset(self):
+ raise NotImplementedError()
+
+ def batch_size(self):
+ raise NotImplementedError()
+
+ def has_partial_batch(self):
+ raise NotImplementedError()
+
+ def partial_batch_size(self):
+ raise NotImplementedError()
+
+
class CompositeTensorDataAdapter(DataAdapter):
"""Adapter that handles composite tensor."""
@@ -948,8 +983,8 @@
ALL_ADAPTER_CLS = [
ListsOfScalarsDataAdapter, TensorLikeDataAdapter,
- GenericArrayLikeDataAdapter, DatasetAdapter,
- GeneratorDataAdapter, KerasSequenceAdapter, CompositeTensorDataAdapter,
+ GenericArrayLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter,
+ KerasSequenceAdapter, CompositeTensorDataAdapter, DatasetCreatorAdapter
]
@@ -1120,6 +1155,7 @@
self._steps_per_execution_value = steps_per_execution.numpy().item()
adapter_cls = select_data_adapter(x, y)
+ self._verify_data_adapter_compatibility(adapter_cls)
self._adapter = adapter_cls(
x,
y,
@@ -1135,6 +1171,23 @@
model=model)
strategy = ds_context.get_strategy()
+
+ self._current_step = 0
+ self._step_increment = self._steps_per_execution_value - 1
+ self._insufficient_data = False
+
+ self._configure_dataset_and_inferred_steps(strategy, x, steps_per_epoch,
+ class_weight, distribute)
+
+ def _verify_data_adapter_compatibility(self, adapter_cls):
+ if adapter_cls == DatasetCreatorAdapter:
+ raise NotImplementedError("`DatasetCreator` input is only supported in "
+ "`ParameterServerStrategy` at this time.")
+
+ def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
+ class_weight, distribute):
+ """Configure the `_dataset` and `_inferred_steps` attributes."""
+ del x
dataset = self._adapter.get_dataset()
if class_weight:
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
@@ -1145,11 +1198,6 @@
if distribute and not _is_distributed_dataset(dataset):
dataset = strategy.experimental_distribute_dataset(dataset)
self._dataset = dataset
-
- self._current_step = 0
- self._step_increment = self._steps_per_execution_value - 1
- self._insufficient_data = False
-
self._validate_data_handler()
def enumerate_epochs(self):
@@ -1181,12 +1229,15 @@
self._steps_per_execution.assign(original_value)
self._steps_per_execution_value = original_value
+ def sync(self):
+ context.async_wait()
+
@contextlib.contextmanager
def catch_stop_iteration(self):
"""Catches errors when an iterator runs out of data."""
try:
yield
- context.async_wait()
+ self.sync()
except (StopIteration, errors.OutOfRangeError):
if self._inferred_steps is None:
self._inferred_steps = self._current_step
@@ -1285,6 +1336,46 @@
"`steps_per_execution > 1`, you must specify the number of steps "
"to run.")
+ def resolve_logs(self, logs):
+ return logs
+
+
+class _ClusterCoordinatorDataHandler(DataHandler):
+ """A `DataHandler` that is compatible with `ClusterCoordinator`."""
+
+ def _verify_data_adapter_compatibility(self, adapter_cls):
+ if adapter_cls != DatasetCreatorAdapter:
+ raise NotImplementedError("Only `DatasetCreator` input is supported in "
+ "`ParameterServerStrategy` at this time.")
+
+ def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
+ class_weight, distribute):
+ if not isinstance(x, dataset_creator.DatasetCreator):
+ raise TypeError("When using `ParameterServerStrategy`, `x` must be a "
+ "`DatasetCreator`.")
+
+ def per_worker_dataset_fn():
+ return strategy.distribute_datasets_from_function(x)
+
+ self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access
+ per_worker_dataset_fn)
+ if steps_per_epoch is None:
+ raise ValueError(
+ "`steps_per_epoch` must be specified with `ParameterServerStrategy`.")
+ self._inferred_steps = steps_per_epoch
+
+ def sync(self):
+ self._model._cluster_coordinator.join() # pylint: disable=protected-access
+
+ def resolve_logs(self, logs):
+ return logs.fetch()
+
+
+def get_data_handler(*args, **kwargs):
+ if getattr(kwargs["model"], "_cluster_coordinator", None):
+ return _ClusterCoordinatorDataHandler(*args, **kwargs)
+ return DataHandler(*args, **kwargs)
+
def _make_class_weight_map_fn(class_weight):
"""Applies class weighting to a `Dataset`.
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 9d1bbaa..c5f9b79 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -32,6 +32,7 @@
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values as ds_values
+from tensorflow.python.distribute.coordinator import cluster_coordinator
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@@ -305,6 +306,9 @@
self._distribution_strategy = ds_context.get_strategy()
else:
self._distribution_strategy = None
+
+ self._cluster_coordinator = None
+
# Defaults to value of `tf.config.experimental_functions_run_eagerly`.
self._run_eagerly = None
# Initialize cache attrs.
@@ -743,6 +747,10 @@
'constructed with `dynamic=True`). '
'You cannot set `run_eagerly=False`.')
+ if self._cluster_coordinator and self._run_eagerly:
+ raise ValueError('When using `Model` with `ParameterServerStrategy`, '
+ '`run_eagerly` is not supported.')
+
# Run eagerly logic, by priority:
# (1) Dynamic models must be run eagerly.
# (2) Explicitly setting run_eagerly causes a Model to be run eagerly.
@@ -851,6 +859,11 @@
train_function, experimental_relax_shapes=True)
self.train_function = train_function
+
+ if self._cluster_coordinator:
+ self.train_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
+ train_function, args=(iterator,))
+
return self.train_function
def fit(self,
@@ -1079,10 +1092,14 @@
val_x, val_y, val_sample_weight = (
data_adapter.unpack_x_y_sample_weight(validation_data))
+ if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access
+ self._cluster_coordinator = cluster_coordinator.ClusterCoordinator(
+ self.distribute_strategy)
+
with self.distribute_strategy.scope(), \
training_utils.RespectCompiledTrainableState(self):
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
- data_handler = data_adapter.DataHandler(
+ data_handler = data_adapter.get_data_handler(
x=x,
y=y,
sample_weight=sample_weight,
@@ -1141,6 +1158,7 @@
if self.stop_training:
break
+ logs = data_handler.resolve_logs(logs)
if logs is None:
raise ValueError('Expect x to be a non-empty array or dataset.')
epoch_logs = copy.copy(logs)
@@ -1150,7 +1168,7 @@
# Create data_handler for evaluation and cache it.
if getattr(self, '_eval_data_handler', None) is None:
self._fit_frame = tf_inspect.currentframe()
- self._eval_data_handler = data_adapter.DataHandler(
+ self._eval_data_handler = data_adapter.get_data_handler(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
@@ -1378,6 +1396,10 @@
self._check_call_args('evaluate')
_disallow_inside_tf_function('evaluate')
+ if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access
+ raise NotImplementedError('`model.evaluate` is not yet supported with '
+ '`ParameterServerStrategy`.')
+
with self.distribute_strategy.scope():
# Use cached evaluation data only when it's called in `Model.fit`
if (getattr(self, '_fit_frame', None) is not None
@@ -1386,7 +1408,7 @@
data_handler = self._eval_data_handler
else:
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
- data_handler = data_adapter.DataHandler(
+ data_handler = data_adapter.get_data_handler(
x=x,
y=y,
sample_weight=sample_weight,
@@ -1613,6 +1635,10 @@
self._check_call_args('predict')
_disallow_inside_tf_function('predict')
+ if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access
+ raise NotImplementedError('`model.predict` is not yet supported with '
+ '`ParameterServerStrategy`.')
+
outputs = None
with self.distribute_strategy.scope():
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
@@ -1630,7 +1656,7 @@
'AutoShardPolicy.FILE might lead to out-of-order result'
'. Consider setting it to AutoShardPolicy.DATA.')
- data_handler = data_adapter.DataHandler(
+ data_handler = data_adapter.get_data_handler(
x=x,
batch_size=batch_size,
steps_per_epoch=steps,
@@ -2648,6 +2674,10 @@
return functions
def _should_eval(self, epoch, validation_freq):
+ if self._cluster_coordinator:
+ raise NotImplementedError(
+ 'Evaluation in `model.fit` with '
+ '`ParameterServerStrategy` is not yet supported.')
epoch = epoch + 1 # one-index the user-facing epoch.
if isinstance(validation_freq, int):
return epoch % validation_freq == 0
diff --git a/tensorflow/python/keras/integration_test/BUILD b/tensorflow/python/keras/integration_test/BUILD
index 9d3cd16..2df01f3 100644
--- a/tensorflow/python/keras/integration_test/BUILD
+++ b/tensorflow/python/keras/integration_test/BUILD
@@ -17,7 +17,7 @@
srcs = ["forwardprop_test.py"],
python_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
@@ -27,7 +27,7 @@
srcs = ["function_test.py"],
python_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -36,7 +36,7 @@
srcs = ["gradients_test.py"],
python_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -45,7 +45,7 @@
srcs = ["saved_model_test.py"],
python_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
@@ -55,7 +55,7 @@
srcs = ["legacy_rnn_test.py"],
python_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -63,7 +63,7 @@
name = "module_test",
srcs = ["module_test.py"],
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -72,7 +72,7 @@
srcs = ["vectorized_map_test.py"],
python_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -81,7 +81,7 @@
srcs = ["gradient_checkpoint_test.py"],
python_version = "PY3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -93,7 +93,7 @@
python_version = "PY3",
tags = ["no_oss"],
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
@@ -110,7 +110,7 @@
"notsan", # TODO(b/156029134)
],
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/keras/keras_parameterized.py b/tensorflow/python/keras/keras_parameterized.py
index bc153dc..a5392e3 100644
--- a/tensorflow/python/keras/keras_parameterized.py
+++ b/tensorflow/python/keras/keras_parameterized.py
@@ -420,7 +420,7 @@
def _v1_session_test(f, test_or_class, config, *args, **kwargs):
with ops.get_default_graph().as_default():
with testing_utils.run_eagerly_scope(False):
- with test_or_class.test_session(use_gpu=True, config=config):
+ with test_or_class.test_session(config=config):
f(test_or_class, *args, **kwargs)
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index b07773a..a8bc647 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -44,9 +44,6 @@
# Preprocessing layers.
if tf2.enabled():
- from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding
- from tensorflow.python.keras.layers.preprocessing.category_encoding_v1 import CategoryEncoding as CategoryEncodingV1
- CategoryEncodingV2 = CategoryEncoding
from tensorflow.python.keras.layers.preprocessing.integer_lookup import IntegerLookup
from tensorflow.python.keras.layers.preprocessing.integer_lookup_v1 import IntegerLookup as IntegerLookupV1
IntegerLookupV2 = IntegerLookup
@@ -63,9 +60,6 @@
from tensorflow.python.keras.layers.preprocessing.integer_lookup_v1 import IntegerLookup
from tensorflow.python.keras.layers.preprocessing.integer_lookup import IntegerLookup as IntegerLookupV2
IntegerLookupV1 = IntegerLookup
- from tensorflow.python.keras.layers.preprocessing.category_encoding_v1 import CategoryEncoding
- from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding as CategoryEncodingV2
- CategoryEncodingV1 = CategoryEncoding
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2
NormalizationV1 = Normalization
@@ -76,6 +70,7 @@
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
TextVectorizationV1 = TextVectorization
from tensorflow.python.keras.layers.preprocessing.category_crossing import CategoryCrossing
+from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding
from tensorflow.python.keras.layers.preprocessing.discretization import Discretization
from tensorflow.python.keras.layers.preprocessing.hashing import Hashing
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index 99edff6..bcfb8b0 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -93,9 +93,10 @@
activation: Activation function to use.
If you don't specify anything, no activation is applied.
use_bias: Boolean, whether the layer uses a bias.
- kernel_initializer: An initializer for the convolution kernel.
+ kernel_initializer: An initializer for the convolution kernel. If None, the
+ default initializer (glorot_uniform) will be used.
bias_initializer: An initializer for the bias vector. If None, the default
- initializer will be used.
+ initializer (zeros) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
@@ -450,9 +451,9 @@
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix (
- see `keras.initializers`).
+ see `keras.initializers`). Defaults to 'glorot_uniform'.
bias_initializer: Initializer for the bias vector (
- see `keras.initializers`).
+ see `keras.initializers`). Defaults to 'zeros'.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (
@@ -606,13 +607,13 @@
activation is applied (see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix (see
- `keras.initializers`).
+ `keras.initializers`). Defaults to 'glorot_uniform'.
bias_initializer: Initializer for the bias vector (see
- `keras.initializers`).
+ `keras.initializers`). Defaults to 'zeros'.
kernel_regularizer: Regularizer function applied to the `kernel` weights
- matrix (see `keras.regularizers`).
+ matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (see
- `keras.regularizers`).
+ `keras.regularizers`).
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation") (see `keras.regularizers`).
kernel_constraint: Constraint function applied to the kernel matrix (see
@@ -751,9 +752,9 @@
activation is applied (see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix (see
- `keras.initializers`).
+ `keras.initializers`). Defaults to 'glorot_uniform'.
bias_initializer: Initializer for the bias vector (see
- `keras.initializers`).
+ `keras.initializers`). Defaults to 'zeros'.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (see
@@ -872,9 +873,9 @@
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix (
- see `keras.initializers`).
+ see `keras.initializers`). Defaults to 'glorot_uniform'.
bias_initializer: Initializer for the bias vector (
- see `keras.initializers`).
+ see `keras.initializers`). Defaults to 'zeros'.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (
@@ -1136,9 +1137,9 @@
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix (
- see `keras.initializers`).
+ see `keras.initializers`). Defaults to 'glorot_uniform'.
bias_initializer: Initializer for the bias vector (
- see `keras.initializers`).
+ see `keras.initializers`). Defaults to 'zeros'.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (
@@ -1439,8 +1440,10 @@
If you don't specify anything, no activation is applied (
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
- kernel_initializer: Initializer for the `kernel` weights matrix.
- bias_initializer: Initializer for the bias vector.
+ kernel_initializer: Initializer for the `kernel` weights matrix (
+ see `keras.initializers`). Defaults to 'glorot_uniform'.
+ bias_initializer: Initializer for the bias vector (
+ see `keras.initializers`). Defaults to 'zeros'.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix (
see `keras.regularizers`).
@@ -1729,10 +1732,14 @@
If you don't specify anything, no activation is applied (
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias.
- depthwise_initializer: An initializer for the depthwise convolution kernel.
- pointwise_initializer: An initializer for the pointwise convolution kernel.
+ depthwise_initializer: An initializer for the depthwise convolution kernel (
+ see `keras.initializers`). If None, then the default initializer (
+ 'glorot_uniform') will be used.
+ pointwise_initializer: An initializer for the pointwise convolution kernel (
+ see `keras.initializers`). If None, then the default initializer
+ ('glorot_uniform') will be used.
bias_initializer: An initializer for the bias vector. If None, the default
- initializer will be used.
+ initializer ('zeros') will be used (see `keras.initializers`).
depthwise_regularizer: Optional regularizer for the depthwise
convolution kernel.
pointwise_regularizer: Optional regularizer for the pointwise
@@ -1935,11 +1942,13 @@
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias.
depthwise_initializer: An initializer for the depthwise convolution kernel (
- see `keras.initializers`).
+ see `keras.initializers`). If None, then the default initializer (
+ 'glorot_uniform') will be used.
pointwise_initializer: An initializer for the pointwise convolution kernel (
- see `keras.initializers`).
+ see `keras.initializers`). If None, then the default initializer
+ ('glorot_uniform') will be used.
bias_initializer: An initializer for the bias vector. If None, the default
- initializer will be used (see `keras.initializers`).
+ initializer ('zeros') will be used (see `keras.initializers`).
depthwise_regularizer: Optional regularizer for the depthwise
convolution kernel (see `keras.regularizers`).
pointwise_regularizer: Optional regularizer for the pointwise
@@ -2127,12 +2136,14 @@
If you don't specify anything, no activation is applied (
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
- depthwise_initializer: Initializer for the depthwise kernel matrix (
- see `keras.initializers`).
- pointwise_initializer: Initializer for the pointwise kernel matrix (
- see `keras.initializers`).
- bias_initializer: Initializer for the bias vector (
- see `keras.initializers`).
+ depthwise_initializer: An initializer for the depthwise convolution kernel (
+ see `keras.initializers`). If None, then the default initializer (
+ 'glorot_uniform') will be used.
+ pointwise_initializer: An initializer for the pointwise convolution kernel (
+ see `keras.initializers`). If None, then the default initializer
+ ('glorot_uniform') will be used.
+ bias_initializer: An initializer for the bias vector. If None, the default
+ initializer ('zeros') will be used (see `keras.initializers`).
depthwise_regularizer: Regularizer function applied to
the depthwise kernel matrix (see `keras.regularizers`).
pointwise_regularizer: Regularizer function applied to
@@ -2291,9 +2302,11 @@
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
depthwise_initializer: Initializer for the depthwise kernel matrix (
- see `keras.initializers`).
+ see `keras.initializers`). If None, the default initializer (
+ 'glorot_uniform') will be used.
bias_initializer: Initializer for the bias vector (
- see `keras.initializers`).
+ see `keras.initializers`). If None, the default initializer (
+ 'zeros') will bs used.
depthwise_regularizer: Regularizer function applied to
the depthwise kernel matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index 3c09963..0496f51 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -42,7 +42,7 @@
stack_size = 3
length = 7
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Conv1D,
kwargs=kwargs,
@@ -54,7 +54,7 @@
stack_size = 3
length = 7
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
@@ -112,7 +112,7 @@
'activity_regularizer': 'l2',
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv1D(**kwargs)
layer.build((None, 5, 2))
self.assertEqual(len(layer.losses), 2)
@@ -131,14 +131,14 @@
'bias_constraint': b_constraint,
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv1D(**kwargs)
layer.build((None, 5, 2))
self.assertEqual(layer.kernel.constraint, k_constraint)
self.assertEqual(layer.bias.constraint, b_constraint)
def test_conv1d_recreate_conv(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv1D(filters=1,
kernel_size=3,
strides=1,
@@ -151,7 +151,7 @@
self.assertEqual(outp1_shape, layer(inpt1).shape)
def test_conv1d_recreate_conv_unknown_dims(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv1D(filters=1,
kernel_size=3,
strides=1,
@@ -184,7 +184,7 @@
input_data_shape = (num_samples, num_row or 7, num_col or 6, stack_size)
input_data = 10 * np.random.random(input_data_shape).astype(np.float32)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Conv2D,
kwargs=kwargs,
@@ -205,7 +205,7 @@
input_data_shape = batch_shape + (num_row or 7, num_col or 6, stack_size)
input_data = 10 * np.random.random(input_data_shape).astype(np.float32)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
testing_utils.layer_test(
@@ -272,7 +272,7 @@
'activity_regularizer': 'l2',
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv2D(**kwargs)
layer.build((None, 5, 5, 2))
self.assertEqual(len(layer.losses), 2)
@@ -291,7 +291,7 @@
'bias_constraint': b_constraint,
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv2D(**kwargs)
layer.build((None, 5, 5, 2))
self.assertEqual(layer.kernel.constraint, k_constraint)
@@ -313,7 +313,7 @@
num_col = 6
depth = 5
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Conv3D,
kwargs=kwargs,
@@ -331,7 +331,7 @@
num_col = 6
depth = 5
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
@@ -387,7 +387,7 @@
'activity_regularizer': 'l2',
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv3D(**kwargs)
layer.build((None, 5, 5, 5, 2))
self.assertEqual(len(layer.losses), 2)
@@ -407,7 +407,7 @@
'bias_constraint': b_constraint,
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv3D(**kwargs)
layer.build((None, 5, 5, 5, 2))
self.assertEqual(layer.kernel.constraint, k_constraint)
@@ -415,7 +415,7 @@
def test_conv3d_dynamic_shape(self):
input_data = np.random.random((1, 3, 3, 3, 3)).astype(np.float32)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Won't raise error here.
testing_utils.layer_test(
keras.layers.Conv3D,
@@ -564,7 +564,7 @@
kwargs['filters'] = 1
kwargs['kernel_size'] = 3
kwargs['dilation_rate'] = 2
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = conv_layer_cls(**kwargs)
output1 = layer(np.zeros(input_shape1))
self.assertEqual(output1.shape, expected_output_shape1)
@@ -607,7 +607,7 @@
expected_output_shape1, expected_output_shape2)
def test_dynamic_shape(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv3D(2, 3)
input_shape = (5, None, None, 2)
inputs = keras.Input(shape=input_shape)
@@ -626,7 +626,7 @@
shape = (num_samples, num_steps, input_dim)
inputs = np.ones(shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# basic test
testing_utils.layer_test(
keras.layers.ZeroPadding1D,
@@ -682,7 +682,7 @@
inputs = np.ones((num_samples, input_num_row, input_num_col, stack_size))
# basic test
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.ZeroPadding2D,
kwargs={
@@ -699,7 +699,7 @@
input_shape=inputs.shape)
# correctness test
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.ZeroPadding2D(
padding=(2, 2), data_format=data_format)
layer.build(inputs.shape)
@@ -770,7 +770,7 @@
inputs = np.ones((num_samples, input_len_dim1, input_len_dim2,
input_len_dim3, stack_size))
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# basic test
testing_utils.layer_test(
keras.layers.ZeroPadding3D,
@@ -787,7 +787,7 @@
},
input_shape=inputs.shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# correctness test
layer = keras.layers.ZeroPadding3D(
padding=(2, 2, 2), data_format=data_format)
@@ -856,7 +856,7 @@
class UpSamplingTest(keras_parameterized.TestCase):
def test_upsampling_1d(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.UpSampling1D, kwargs={'size': 2}, input_shape=(3, 5, 4))
@@ -875,7 +875,7 @@
stack_size)
# basic test
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.UpSampling2D,
kwargs={'size': (2, 2),
@@ -960,7 +960,7 @@
input_len_dim3, stack_size)
# basic test
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.UpSampling3D,
kwargs={'size': (2, 2, 2),
@@ -1010,7 +1010,7 @@
input_len_dim1 = 2
inputs = np.random.rand(num_samples, time_length, input_len_dim1)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Cropping1D,
kwargs={'cropping': (2, 2)},
@@ -1036,7 +1036,7 @@
else:
inputs = np.random.rand(num_samples, input_len_dim1, input_len_dim2,
stack_size)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# basic test
testing_utils.layer_test(
keras.layers.Cropping2D,
@@ -1069,7 +1069,7 @@
inputs = np.random.rand(num_samples, input_len_dim1, input_len_dim2,
stack_size)
# another correctness test (no cropping)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
cropping = ((0, 0), (0, 0))
layer = keras.layers.Cropping2D(
cropping=cropping, data_format=data_format)
@@ -1105,7 +1105,7 @@
inputs = np.random.rand(num_samples, input_len_dim1, input_len_dim2,
input_len_dim3, stack_size)
# basic test
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Cropping3D,
kwargs={'cropping': cropping,
@@ -1114,7 +1114,7 @@
if len(croppings) == 3 and len(croppings[0]) == 2:
# correctness test
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Cropping3D(
cropping=cropping, data_format=data_format)
layer.build(inputs.shape)
@@ -1152,7 +1152,7 @@
num_row = 7
num_col = 6
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.DepthwiseConv2D,
kwargs=kwargs,
diff --git a/tensorflow/python/keras/layers/convolutional_transpose_test.py b/tensorflow/python/keras/layers/convolutional_transpose_test.py
index 4326044..e9adef5 100644
--- a/tensorflow/python/keras/layers/convolutional_transpose_test.py
+++ b/tensorflow/python/keras/layers/convolutional_transpose_test.py
@@ -36,7 +36,7 @@
num_row = 7
num_col = 6
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Conv2DTranspose,
kwargs=kwargs,
@@ -67,7 +67,7 @@
'activity_regularizer': 'l2',
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv2DTranspose(**kwargs)
layer.build((None, 5, 5, 2))
self.assertEqual(len(layer.losses), 2)
@@ -86,7 +86,7 @@
'bias_constraint': b_constraint,
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv2DTranspose(**kwargs)
layer.build((None, 5, 5, 2))
self.assertEqual(layer.kernel.constraint, k_constraint)
@@ -127,7 +127,7 @@
num_col = 6
depth = 5
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Conv3DTranspose,
kwargs=kwargs,
@@ -159,7 +159,7 @@
'activity_regularizer': 'l2',
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv3DTranspose(**kwargs)
layer.build((None, 5, 5, 5, 2))
self.assertEqual(len(layer.losses), 2)
@@ -178,7 +178,7 @@
'bias_constraint': b_constraint,
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.Conv3DTranspose(**kwargs)
layer.build((None, 5, 5, 5, 2))
self.assertEqual(layer.kernel.constraint, k_constraint)
@@ -186,7 +186,7 @@
def test_conv3d_transpose_dynamic_shape(self):
input_data = np.random.random((1, 3, 3, 3, 3)).astype(np.float32)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Won't raise error here.
testing_utils.layer_test(
keras.layers.Conv3DTranspose,
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index 3bb392c..fcc9dd1 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -205,7 +205,7 @@
units = 2
num_samples = 32
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index d468e5d..a9f6856 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -104,7 +104,7 @@
@keras_parameterized.run_all_keras_modes
def test_batchnorm_convnet(self):
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with self.session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(
axis=1, input_shape=(3, 4, 4), momentum=0.8)
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index dcf6dd8..0a77a29 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -366,26 +366,22 @@
... [9., 10., 11., 12.]])
>>> x = tf.reshape(x, [1, 3, 4, 1])
>>> max_pool_2d = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
- ... strides=(1, 1), padding='valid')
+ ... strides=(2, 2), padding='valid')
>>> max_pool_2d(x)
- <tf.Tensor: shape=(1, 2, 3, 1), dtype=float32, numpy=
- array([[[[ 6.],
- [ 7.],
- [ 8.]],
- [[10.],
- [11.],
- [12.]]]], dtype=float32)>
-
+ <tf.Tensor: shape=(1, 1, 2, 1), dtype=float32, numpy=
+ array([[[[6.],
+ [8.]]]], dtype=float32)>
+
Usage Example:
-
+
>>> input_image = tf.constant([[[[1.], [1.], [2.], [4.]],
... [[2.], [2.], [3.], [2.]],
... [[4.], [1.], [1.], [1.]],
- ... [[2.], [2.], [1.], [4.]]]])
+ ... [[2.], [2.], [1.], [4.]]]])
>>> output = tf.constant([[[[1], [0]],
- ... [[0], [1]]]])
+ ... [[0], [1]]]])
>>> model = tf.keras.models.Sequential()
- >>> model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
+ >>> model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
... input_shape=(4,4,1)))
>>> model.compile('adam', 'mean_squared_error')
>>> model.predict(input_image, steps=1)
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index 4abc6a7..4e5bb1b 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -259,7 +259,6 @@
name = "category_encoding",
srcs = [
"category_encoding.py",
- "category_encoding_v1.py",
],
srcs_version = "PY3",
deps = [
diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
index f46c06d..1177660 100644
--- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
@@ -81,6 +81,21 @@
)
tf_py_test(
+ name = "index_lookup_forward_benchmark",
+ srcs = ["index_lookup_forward_benchmark.py"],
+ python_version = "PY3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:extra_py_tests_deps",
+ "//tensorflow/python:platform_benchmark",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/compat:v2_compat",
+ "//tensorflow/python/keras/layers/preprocessing:index_lookup",
+ ],
+)
+
+tf_py_test(
name = "normalization_adapt_benchmark",
srcs = ["normalization_adapt_benchmark.py"],
python_version = "PY3",
diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py
new file mode 100644
index 0000000..0e264fb
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py
@@ -0,0 +1,146 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for Keras text vectorization preprocessing layer's adapt method."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+import string
+import time
+
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras.layers.preprocessing import index_lookup
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.platform import benchmark
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+
+v2_compat.enable_v2_behavior()
+
+
+# word_gen creates random sequences of ASCII letters (both lowercase and upper).
+# The number of unique strings is ~2,700.
+def tensor_gen(batch, num_elements):
+ data = []
+ for _ in range(batch):
+ batch_element = []
+ for _ in range(num_elements - 1):
+ tok = "".join(random.choice(string.ascii_letters) for i in range(2))
+ batch_element.append(tok)
+ batch_element.append("") # Explicitly test the empty string.
+ data.append(batch_element)
+ return constant_op.constant(data)
+
+
+def get_vocab():
+ vocab = list(
+ set([a + b for a in string.ascii_letters for b in string.ascii_letters])) # pylint:disable=g-complex-comprehension
+ vocab.sort()
+ return vocab
+
+
+# This class uses TestCase for get_temp_dir().
+class BenchmarkLookup(benchmark.TensorFlowBenchmark):
+ """Benchmark the index lookup layer's forward pass."""
+
+ def _write_to_temp_file(self, file_name, vocab_list):
+ vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
+ with gfile.GFile(vocab_path, "w") as writer:
+ for vocab in vocab_list:
+ writer.write(vocab + "\n")
+ writer.flush()
+ writer.close()
+ return vocab_path
+
+ def run_numpy_implementation(self, data, vocab):
+ """Test the python implementation."""
+ input_t = keras.Input(shape=(), dtype=dtypes.string)
+ layer = index_lookup.IndexLookup(
+ vocabulary=vocab,
+ max_tokens=None,
+ num_oov_indices=1,
+ mask_token="",
+ oov_token="OOV",
+ dtype=dtypes.string)
+ out_t = layer(input_t)
+ model = keras.Model(input_t, out_t)
+ num_repeats = 5
+ starts = []
+ ends = []
+ _ = model(data)
+ for _ in range(num_repeats):
+ starts.append(time.time())
+ out = model(data)
+ ends.append(time.time())
+ avg_time = np.mean(np.array(ends) - np.array(starts))
+ return avg_time, out
+
+ def bm_adapt_implementation(self, num_elements, batch_size):
+ """Test the KPL adapt implementation."""
+ vocab = get_vocab()
+ vocab_file = self._write_to_temp_file("vocab", vocab)
+ vocabulary_initializer = lookup_ops.TextFileInitializer(
+ filename=vocab_file,
+ key_dtype=dtypes.string,
+ key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
+ value_dtype=dtypes.int64,
+ value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
+ value_index_offset=2)
+ input_t = keras.Input(shape=(), dtype=dtypes.string)
+ layer = index_lookup.IndexLookup(
+ vocabulary=vocabulary_initializer,
+ max_tokens=None,
+ num_oov_indices=1,
+ mask_token="",
+ oov_token="OOV",
+ dtype=dtypes.string)
+ out_t = layer(input_t)
+ model = keras.Model(input_t, out_t)
+ num_repeats = 5
+ starts = []
+ ends = []
+ data = tensor_gen(batch_size, num_elements)
+ _ = model(data)
+ for _ in range(num_repeats):
+ starts.append(time.time())
+ _ = model(data)
+ ends.append(time.time())
+ avg_time = np.mean(np.array(ends) - np.array(starts))
+ baseline, _ = self.run_numpy_implementation(data, vocab)
+ extras = {
+ "numpy implementation baseline": baseline,
+ "delta seconds": (baseline - avg_time),
+ "delta percent": ((baseline - avg_time) / baseline) * 100
+ }
+ name = "index_lookup_forward|%s_elements|batch_%s" % (num_elements,
+ batch_size)
+ self.report_benchmark(
+ iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
+
+ def benchmark_vocab_size_by_batch(self):
+ for tensor_size in [100, 1000, 10000]:
+ for batch in [1, 16, 2048]:
+ self.bm_adapt_implementation(tensor_size, batch)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding.py b/tensorflow/python/keras/layers/preprocessing/category_encoding.py
index 4d7573d..cba28c5 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_encoding.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_encoding.py
@@ -18,10 +18,6 @@
from __future__ import division
from __future__ import print_function
-import collections
-import json
-import numbers
-
import numpy as np
from tensorflow.python.framework import dtypes
@@ -31,16 +27,12 @@
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_preprocessing_layer
-from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bincount_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops.ragged import ragged_tensor
-from tensorflow.python.util import compat
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
TFIDF = "tf-idf"
@@ -48,29 +40,27 @@
BINARY = "binary"
COUNT = "count"
-# The string tokens in the extracted vocabulary
-_NUM_ELEMENTS_NAME = "num_elements"
# The inverse-document-frequency weights
_IDF_NAME = "idf"
-@keras_export("keras.layers.experimental.preprocessing.CategoryEncoding", v1=[])
-class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
+@keras_export("keras.layers.experimental.preprocessing.CategoryEncoding")
+class CategoryEncoding(base_preprocessing_layer.PreprocessingLayer):
"""Category encoding layer.
- This layer provides options for condensing data into a categorical encoding.
- It accepts integer values as inputs and outputs a dense representation
- (one sample = 1-index tensor of float values representing data about the
- sample's tokens) of those inputs.
+ This layer provides options for condensing data into a categorical encoding
+ when the total number of tokens are known in advance. It accepts integer
+ values as inputs and outputs a dense representation (one sample = 1-index
+ tensor of float values representing data about the sample's tokens) of those
+ inputs. For integer inputs where the total number of tokens is not known, see
+ `tf.keras.layers.experimental.preprocessing.IntegerLookup`.
Examples:
- **Multi-hot encoding data if you know in advance the number of tokens**
-
- In this case, you can pass the `max_tokens` argument to the constructor.
+ **Multi-hot encoding data**
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
- ... max_tokens=4, output_mode="binary")
+ ... num_tokens=4, output_mode="binary")
>>> layer([[0, 1], [0, 0], [1, 2], [3, 1]])
<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[1., 1., 0., 0.],
@@ -78,20 +68,10 @@
[0., 1., 1., 0.],
[0., 1., 0., 1.]], dtype=float32)>
- **Multi-hot encoding data where the number of tokens is unknown**
-
- In this case, you should `adapt()` the layer on a sample dataset.
-
- ```python
- layer = CategoryEncoding(output_mode="binary")
- layer.adapt(sample_dataset) # Indexes the vocabulary of the data
- outputs = layer(inputs)
- ```
-
**Using weighted inputs in `count` mode**
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
- ... max_tokens=4, output_mode="count")
+ ... num_tokens=4, output_mode="count")
>>> count_weights = np.array([[.1, .2], [.1, .1], [.2, .3], [.4, .2]])
>>> layer([[0, 1], [0, 0], [1, 2], [3, 1]], count_weights=count_weights)
<tf.Tensor: shape=(4, 4), dtype=float64, numpy=
@@ -101,18 +81,17 @@
[0. , 0.2, 0. , 0.4]])>
Args:
- max_tokens: The maximum size of the vocabulary for this layer. If None,
- there is no cap on the size of the vocabulary.
+ num_tokens: The total number of tokens the layer should support. All inputs
+ to the layer must integers in the range 0 <= value < num_tokens or an
+ error will be thrown.
output_mode: Specification for the output of the layer.
Defaults to "binary". Values can
be "binary", "count" or "tf-idf", configuring the layer as follows:
"binary": Outputs a single int array per batch, of either vocab_size or
- max_tokens size, containing 1s in all elements where the token mapped
+ num_tokens size, containing 1s in all elements where the token mapped
to that index exists at least once in the batch item.
"count": As "binary", but the int array contains a count of the number
of times the token at that index appeared in the batch item.
- "tf-idf": As "binary", but the TF-IDF algorithm is applied to find the
- value in each token slot.
sparse: Boolean. If true, returns a `SparseTensor` instead of a dense
`Tensor`. Defaults to `False`.
@@ -124,59 +103,40 @@
"""
def __init__(self,
- max_tokens=None,
+ num_tokens=None,
output_mode=BINARY,
sparse=False,
**kwargs):
- # 'output_mode' must be one of (COUNT, BINARY, TFIDF)
+ # max_tokens is an old name for the num_tokens arg we continue to support
+ # because of usage.
+ if "max_tokens" in kwargs:
+ logging.warning(
+ "max_tokens is deprecated, please use num_tokens instead.")
+ num_tokens = kwargs["max_tokens"]
+ del kwargs["max_tokens"]
+
+ super(CategoryEncoding, self).__init__(**kwargs)
+
+ # 'output_mode' must be one of (COUNT, BINARY)
layer_utils.validate_string_arg(
output_mode,
- allowable_strings=(COUNT, BINARY, TFIDF),
+ allowable_strings=(COUNT, BINARY),
layer_name="CategoryEncoding",
arg_name="output_mode")
- # If max_tokens is set, the value must be greater than 1 - otherwise we
- # are creating a 0-element vocab, which doesn't make sense.
- if max_tokens is not None and max_tokens < 1:
- raise ValueError("max_tokens must be > 1.")
+ if num_tokens is None:
+ raise ValueError("num_tokens must be set to use this layer. If the "
+ "number of tokens is not known beforehand, use the "
+ "IntegerLookup layer instead.")
+ if num_tokens < 1:
+ raise ValueError("num_tokens must be >= 1.")
- # We need to call super() before we call _add_state_variable().
- combiner = _CategoryEncodingCombiner(
- max_tokens=max_tokens,
- compute_idf=output_mode == TFIDF)
- super(CategoryEncoding, self).__init__(combiner=combiner, **kwargs)
- base_preprocessing_layer.keras_kpl_gauge.get_cell(
- "CategoryEncoding").set(True)
-
- self.max_tokens = max_tokens
+ self.num_tokens = num_tokens
self.output_mode = output_mode
self.sparse = sparse
- self._called = False
-
- if self.output_mode == TFIDF:
- # The TF-IDF weight may have a (None,) tensorshape. This creates
- # a 1D variable with arbitrary shape, which we can assign any weight to
- # so long as it has 1 dimension. In order to properly initialize this
- # weight in Keras, we need to provide a custom callable initializer which
- # does not depend on the shape of the weight (as all other initializers
- # do) since the weight is not known. Hence the lambda shape, dtype: [0].
- if max_tokens is None:
- initializer = lambda shape, dtype: [0]
- else:
- initializer = init_ops.zeros_initializer
-
- # We are adding these here instead of in build() since they do not depend
- # on the input shape at all.
- self.tf_idf_weights = self._add_state_variable(
- name=_IDF_NAME,
- shape=tensor_shape.TensorShape((max_tokens,)),
- dtype=K.floatx(),
- initializer=initializer)
-
- self.input_spec = InputSpec(ndim=2)
def compute_output_shape(self, input_shape):
- return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
+ return tensor_shape.TensorShape([input_shape[0], self.num_tokens])
def compute_output_signature(self, input_spec):
output_shape = self.compute_output_shape(input_spec.shape.as_list())
@@ -187,95 +147,15 @@
else:
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
- def adapt(self, data, reset_state=True):
- """Fits the state of the preprocessing layer to the dataset.
-
- Overrides the default adapt method to apply relevant preprocessing to the
- inputs before passing to the combiner.
-
- Args:
- data: The data to train on. It can be passed either as a tf.data Dataset,
- or as a numpy array.
- reset_state: Optional argument specifying whether to clear the state of
- the layer at the start of the call to `adapt`. This must be True for
- this layer, which does not support repeated calls to `adapt`.
-
- Raises:
- RuntimeError: if the layer cannot be adapted at this time.
- """
- if not reset_state:
- raise ValueError("CategoryEncoding does not support streaming adapts.")
-
- super(CategoryEncoding, self).adapt(data, reset_state)
-
- def _set_state_variables(self, updates):
- if not self.built:
- raise RuntimeError("_set_state_variables() must be called after build().")
- if _NUM_ELEMENTS_NAME in updates:
- if self.max_tokens is None:
- self.set_num_elements(updates[_NUM_ELEMENTS_NAME])
- elif self.max_tokens != updates[_NUM_ELEMENTS_NAME]:
- raise RuntimeError("Cannot update states if you construct the layer "
- "with `max_tokens`={}".format(self.max_tokens))
- if self.output_mode == TFIDF:
- self.set_tfidf_data(updates[_IDF_NAME])
-
def get_config(self):
config = {
- "max_tokens": self.max_tokens,
+ "num_tokens": self.num_tokens,
"output_mode": self.output_mode,
"sparse": self.sparse,
}
base_config = super(CategoryEncoding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- def _convert_to_ndarray(self, x):
- if isinstance(x, ops.Tensor):
- return x
- else:
- return np.array(x)
-
- def _convert_to_sparse_inputs(self, inputs):
- if isinstance(inputs, sparse_tensor.SparseTensor):
- return inputs
- elif isinstance(inputs, ragged_tensor.RaggedTensor):
- return inputs.to_sparse()
- else:
- indices = array_ops.where_v2(
- math_ops.greater_equal(inputs, array_ops.constant(0, inputs.dtype)))
- values = array_ops.gather_nd(inputs, indices)
- shape = array_ops.shape(inputs, out_type=dtypes.int64)
- return sparse_tensor.SparseTensor(indices, values, shape)
-
- def set_num_elements(self, num_elements):
- if self.max_tokens is not None:
- raise RuntimeError(
- "In order to dynamically set the number of elements, the "
- "layer's 'max_tokens' arg must be set to None.")
- if not isinstance(num_elements, numbers.Integral):
- raise ValueError("num_elements must be a scalar integer.")
- if self._called:
- raise RuntimeError("num_elements cannot be changed after the layer is "
- "called.")
- self.max_tokens = num_elements
-
- def set_tfidf_data(self, tfidf_data):
- tfidf_data = self._convert_to_ndarray(tfidf_data)
- if self.output_mode != TFIDF:
- raise RuntimeError(
- "In order to set TF-IDF data, the output mode must be 'tf-idf'.")
- if tfidf_data.ndim != 1:
- raise ValueError("TF-IDF data must be a 1-index array.")
- if self.max_tokens is not None:
- input_data_length = tfidf_data.shape[0]
- if input_data_length > self.max_tokens:
- raise ValueError("The array provided has %d elements. This layer is "
- "configured to only allow %d elements." %
- (input_data_length, self.max_tokens))
- if input_data_length < self.max_tokens:
- tfidf_data = np.resize(tfidf_data, (self.max_tokens,))
- K.set_value(self.tf_idf_weights, tfidf_data)
-
def call(self, inputs, count_weights=None):
if isinstance(inputs, (list, np.ndarray)):
inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
@@ -285,30 +165,8 @@
if count_weights is not None and self.output_mode != COUNT:
raise ValueError("count_weights is not used in `output_mode='tf-idf'`, "
"or `output_mode='binary'`. Please pass a single input.")
- self._called = True
- if self.max_tokens is None:
- raise RuntimeError(
- "If you construct a `CategoryEncoding` layer with "
- "`max_tokens=None`, you need to call `adapt()` "
- "on it before using it")
- else:
- out_depth = self.max_tokens
- if self.output_mode == TFIDF:
- # If the input is a sparse tensor, we densify it with the default value of
- # -1. Because -1 is ignored by one_hot, this effectively drops the non-set
- # positions from the output encoding.
- if self.sparse:
- raise ValueError("`sparse=True` with `output_mode=tfidf` "
- "is not supported.")
- if isinstance(inputs, sparse_tensor.SparseTensor):
- inputs = sparse_ops.sparse_tensor_to_dense(inputs, default_value=-1)
- one_hot_data = array_ops.one_hot(inputs, depth=out_depth)
- counts = math_ops.reduce_sum(one_hot_data, axis=1)
- tf_idf_data = math_ops.multiply(counts, self.tf_idf_weights)
- tf_idf_data.set_shape(tensor_shape.TensorShape((None, out_depth)))
- return tf_idf_data
-
+ out_depth = self.num_tokens
binary_output = (self.output_mode == BINARY)
if isinstance(inputs, sparse_tensor.SparseTensor):
max_value = math_ops.reduce_max(inputs.values)
@@ -321,191 +179,16 @@
math_ops.cast(out_depth, max_value.dtype), max_value),
math_ops.greater_equal(
min_value, math_ops.cast(0, min_value.dtype)))
- control_flow_ops.Assert(
- condition, ["Input values must be in the range 0 <= values < max_tokens"
- " with max_tokens={}".format(out_depth)])
+ control_flow_ops.Assert(condition, [
+ "Input values must be in the range 0 <= values < num_tokens"
+ " with num_tokens={}".format(out_depth)
+ ])
if self.sparse:
return sparse_bincount(inputs, out_depth, binary_output, count_weights)
else:
return dense_bincount(inputs, out_depth, binary_output, count_weights)
-class _CategoryEncodingAccumulator(
- collections.namedtuple("Accumulator", ["data", "per_doc_count_dict"])):
- pass
-
-
-class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner):
- """Combiner for the CategoryEncoding preprocessing layer.
-
- This class encapsulates the logic for computing the number of elements in the
- input dataset and the document frequency for each element.
-
- Attributes:
- compute_max_element: (Optional) If set, this combiner will return the
- maximum element in this set as part of its `extract()` call.
- compute_idf: (Optional) If set, the inverse document frequency will be
- computed for each value.
- """
- # These are indices into the accumulator's `data` array.
- MAX_VALUE_IDX = 0
- DOC_ID_IDX = 1
-
- def __init__(self, max_tokens=None, compute_idf=False):
- self.max_tokens = max_tokens
- self._compute_idf = compute_idf
-
- def compute(self, values, accumulator=None):
- """Computes a step in this computation, returning a new accumulator."""
- values = base_preprocessing_layer.convert_to_list(values)
-
- if accumulator is None:
- accumulator = self._create_accumulator()
-
- # TODO(momernick): Benchmark improvements to this algorithm.
- for element in values:
- if not isinstance(element, list):
- element = [element]
- current_doc_id = accumulator.data[self.DOC_ID_IDX]
- for value in element:
- if self.max_tokens is None:
- current_max_value = accumulator.data[self.MAX_VALUE_IDX]
- if value > current_max_value:
- accumulator.data[self.MAX_VALUE_IDX] = value
- if self._compute_idf:
- doc_count = accumulator.per_doc_count_dict[value]
- if doc_count["last_doc_id"] != current_doc_id:
- doc_count["count"] += 1
- doc_count["last_doc_id"] = current_doc_id
- accumulator.data[self.DOC_ID_IDX] += 1
-
- return accumulator
-
- def merge(self, accumulators):
- """Merges several accumulators to a single accumulator."""
- if not accumulators:
- return accumulators
-
- base_accumulator = accumulators[0]
-
- for accumulator in accumulators[1:]:
- base_accumulator.data[self.DOC_ID_IDX] += accumulator.data[
- self.DOC_ID_IDX]
- if self.max_tokens is None:
- base_accumulator.data[self.MAX_VALUE_IDX] = max(
- base_accumulator.data[self.MAX_VALUE_IDX],
- accumulator.data[self.MAX_VALUE_IDX])
- if self._compute_idf:
- for token, value in accumulator.per_doc_count_dict.items():
- # Any newly created token counts in 'base_accumulator''s
- # per_doc_count_dict will have a last_doc_id of -1. This is always
- # less than the next doc id (which are strictly positive), so any
- # future occurrences are guaranteed to be counted.
- base_accumulator.per_doc_count_dict[token]["count"] += value["count"]
-
- return base_accumulator
-
- def _inverse_document_frequency(self, document_counts, num_documents):
- """Computes the inverse-document-frequency (IDF) component of TFIDF.
-
- Uses the default weighting scheme described in
- https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
-
- Args:
- document_counts: An array of the # of documents each token appears in.
- num_documents: An int representing the total number of documents
-
- Returns:
- An array of "inverse document frequency" weights.
- """
- return np.log(1 + num_documents / (1 + np.array(document_counts)))
-
- def extract(self, accumulator):
- """Converts an accumulator into a dict of output values.
-
- Args:
- accumulator: An accumulator aggregating over the full dataset.
-
- Returns:
- A dict of:
- "num_elements": The number of unique elements in the data set. Only
- returned if `compute_max_element` is True.
- "idf": The inverse-document-frequency for each index, where idf[i] is
- the IDF value for index i. Only returned if `compute_idf` is True.
- """
- data, document_counts = accumulator
- if data[self.MAX_VALUE_IDX] is not None:
- max_element = data[self.MAX_VALUE_IDX] + 1
- else:
- max_element = self.max_tokens
- output_dict = {}
- if self.max_tokens is None:
- output_dict[_NUM_ELEMENTS_NAME] = max_element
-
- if self._compute_idf:
- num_documents = data[self.DOC_ID_IDX]
- # Here, we need to get the doc_counts for every token value, including
- # values we have not yet seen (and are not in the document_counts dict).
- # However, because document_counts is a defaultdict (see below), querying
- # the dict directly for those values gives us meaningful counts (of 0).
- # However, this also means we can't just extract the values in
- # document_counts - we need to do a deliberate indexing using range().
- doc_counts = [document_counts[i]["count"] for i in range(max_element)]
- idf = self._inverse_document_frequency(doc_counts, num_documents)
- output_dict[_IDF_NAME] = idf
-
- return output_dict
-
- def restore(self, output):
- """Creates an accumulator based on 'output'."""
- raise NotImplementedError(
- "CategoryEncoding does not restore or support streaming updates.")
-
- def serialize(self, accumulator):
- """Serializes an accumulator for a remote call."""
- output_dict = {}
- output_dict["data"] = accumulator.data
- if self._compute_idf:
- output_dict["idf_vocab"] = list(accumulator.per_doc_count_dict.keys())
- output_dict["idf_counts"] = [
- counter["count"]
- for counter in accumulator.per_doc_count_dict.values()
- ]
- return compat.as_bytes(json.dumps(output_dict))
-
- def deserialize(self, encoded_accumulator):
- """Deserializes an accumulator received from 'serialize()'."""
- accumulator_dict = json.loads(compat.as_text(encoded_accumulator))
-
- accumulator = self._create_accumulator()
- for i, value in enumerate(accumulator_dict["data"]):
- accumulator.data[i] = value
-
- if self._compute_idf:
- create_dict = lambda x: {"count": x, "last_doc_id": -1}
- idf_count_dicts = [
- create_dict(count) for count in accumulator_dict["idf_counts"]
- ]
- idf_dict = dict(zip(accumulator_dict["idf_vocab"], idf_count_dicts))
- accumulator.per_doc_count_dict.update(idf_dict)
-
- return accumulator
-
- def _create_accumulator(self):
- """Accumulates a sorted array of vocab tokens and corresponding counts."""
-
- if self._compute_idf:
- create_default_dict = lambda: {"count": 0, "last_doc_id": -1}
- per_doc_count_dict = collections.defaultdict(create_default_dict)
- else:
- per_doc_count_dict = None
- if self.max_tokens is None:
- data = [0, 0]
- else:
- data = [None, 0]
- return _CategoryEncodingAccumulator(data, per_doc_count_dict)
-
-
def sparse_bincount(inputs, out_depth, binary_output, count_weights=None):
"""Apply binary or count encoding to an input and return a sparse tensor."""
result = bincount_ops.sparse_bincount(
diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
index c95be0a..603bb0e 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py
@@ -64,13 +64,13 @@
expected_output = [[0, 1, 1, 1, 0, 0],
[1, 1, 0, 1, 0, 0]]
# pyformat: enable
- max_tokens = 6
+ num_tokens = 6
config.set_soft_device_placement(True)
with distribution.scope():
input_data = keras.Input(shape=(4,), dtype=dtypes.int32)
layer = category_encoding.CategoryEncoding(
- max_tokens=max_tokens, output_mode=category_encoding.BINARY)
+ num_tokens=num_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(inp_dataset)
diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py
index 8c7516b..2e704d9 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py
@@ -24,8 +24,6 @@
from tensorflow.python import keras
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -34,7 +32,6 @@
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers.preprocessing import category_encoding
-from tensorflow.python.keras.layers.preprocessing import category_encoding_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
@@ -42,10 +39,7 @@
def get_layer_class():
- if context.executing_eagerly():
- return category_encoding.CategoryEncoding
- else:
- return category_encoding_v1.CategoryEncoding
+ return category_encoding.CategoryEncoding
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@@ -62,11 +56,11 @@
# [X, X, X, 2]]
expected_indices = [[0, 1], [0, 2], [0, 3], [1, 0], [1, 3]]
expected_values = [1, 1, 1, 1, 2]
- max_tokens = 6
+ num_tokens = 6
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
+ num_tokens=num_tokens, output_mode=category_encoding.COUNT, sparse=True)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
@@ -76,7 +70,7 @@
# Assert sparse output is same as dense output.
layer = get_layer_class()(
- max_tokens=max_tokens,
+ num_tokens=num_tokens,
output_mode=category_encoding.COUNT,
sparse=False)
int_data = layer(input_data)
@@ -94,13 +88,13 @@
expected_output = [[0, 1, 1, 1, 0, 0],
[0, 1, 0, 1, 0, 0]]
# pyformat: enable
- max_tokens = 6
- expected_output_shape = [None, max_tokens]
+ num_tokens = 6
+ expected_output_shape = [None, num_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.BINARY)
+ num_tokens=num_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@@ -118,14 +112,14 @@
expected_output = [[0, .1, .2, .3, .4, 0],
[0, .4, 0, .1, .5, 0]]
# pyformat: enable
- max_tokens = 6
- expected_output_shape = [None, max_tokens]
+ num_tokens = 6
+ expected_output_shape = [None, num_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
weight_data = keras.Input(shape=(None,), dtype=dtypes.float32, sparse=True)
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.COUNT)
+ num_tokens=num_tokens, output_mode=category_encoding.COUNT)
int_data = layer(input_data, count_weights=weight_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@@ -148,10 +142,10 @@
# [1, X, X, X]]
expected_indices = [[0, 0], [1, 2], [2, 1], [3, 0]]
expected_values = [1, 1, 2, 1]
- max_tokens = 6
+ num_tokens = 6
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
+ num_tokens=num_tokens, output_mode=category_encoding.COUNT, sparse=True)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
@@ -161,7 +155,7 @@
# Assert sparse output is same as dense output.
layer = get_layer_class()(
- max_tokens=max_tokens,
+ num_tokens=num_tokens,
output_mode=category_encoding.COUNT,
sparse=False)
int_data = layer(input_data)
@@ -187,10 +181,10 @@
# [1, X, X, X]]
expected_indices = [[0, 0], [1, 2], [2, 1], [3, 0]]
expected_values = [.1, .2, .7, .2]
- max_tokens = 6
+ num_tokens = 6
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
+ num_tokens=num_tokens, output_mode=category_encoding.COUNT, sparse=True)
int_data = layer(input_data, count_weights=weight_data)
model = keras.Model(inputs=[input_data, weight_data], outputs=int_data)
@@ -205,13 +199,13 @@
expected_output = [[0, 1, 1, 1, 0, 0],
[0, 1, 0, 1, 0, 0]]
# pyformat: enable
- max_tokens = 6
- expected_output_shape = [None, max_tokens]
+ num_tokens = 6
+ expected_output_shape = [None, num_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.BINARY)
+ num_tokens=num_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@@ -228,11 +222,11 @@
# [X, X, X, 2]]
expected_indices = [[0, 1], [0, 2], [0, 3], [1, 3]]
expected_values = [1, 1, 1, 2]
- max_tokens = 6
+ num_tokens = 6
input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
+ num_tokens=num_tokens, output_mode=category_encoding.COUNT, sparse=True)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
@@ -242,7 +236,7 @@
# Assert sparse output is same as dense output.
layer = get_layer_class()(
- max_tokens=max_tokens,
+ num_tokens=num_tokens,
output_mode=category_encoding.COUNT,
sparse=False)
int_data = layer(input_data)
@@ -255,12 +249,11 @@
def test_sparse_output_and_dense_layer(self):
input_array = constant_op.constant([[1, 2, 3], [3, 3, 0]])
- max_tokens = 4
+ num_tokens = 4
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
encoding_layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.COUNT,
- sparse=True)
+ num_tokens=num_tokens, output_mode=category_encoding.COUNT, sparse=True)
int_data = encoding_layer(input_data)
dense_layer = keras.layers.Dense(units=1)
output_data = dense_layer(int_data)
@@ -270,126 +263,41 @@
def test_dense_oov_input(self):
input_array = constant_op.constant([[0, 1, 2], [2, 3, 1]])
- max_tokens = 3
- expected_output_shape = [None, max_tokens]
- encoder_layer = get_layer_class()(max_tokens)
+ num_tokens = 3
+ expected_output_shape = [None, num_tokens]
+ encoder_layer = get_layer_class()(num_tokens)
input_data = keras.Input(shape=(3,), dtype=dtypes.int32)
int_data = encoder_layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
model = keras.Model(inputs=input_data, outputs=int_data)
with self.assertRaisesRegex(
errors.InvalidArgumentError,
- ".*must be in the range 0 <= values < max_tokens.*"):
+ ".*must be in the range 0 <= values < num_tokens.*"):
_ = model.predict(input_array, steps=1)
def test_dense_negative(self):
input_array = constant_op.constant([[1, 2, 0], [2, 2, -1]])
- max_tokens = 3
- expected_output_shape = [None, max_tokens]
- encoder_layer = get_layer_class()(max_tokens)
+ num_tokens = 3
+ expected_output_shape = [None, num_tokens]
+ encoder_layer = get_layer_class()(num_tokens)
input_data = keras.Input(shape=(3,), dtype=dtypes.int32)
int_data = encoder_layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
model = keras.Model(inputs=input_data, outputs=int_data)
with self.assertRaisesRegex(
errors.InvalidArgumentError,
- ".*must be in the range 0 <= values < max_tokens.*"):
+ ".*must be in the range 0 <= values < num_tokens.*"):
_ = model.predict(input_array, steps=1)
-
-@keras_parameterized.run_all_keras_modes
-class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
- preprocessing_test_utils.PreprocessingLayerTest
- ):
-
- def test_sparse_adapt(self):
- vocab_data = sparse_ops.from_dense(
- np.array([[1, 1, 0, 1, 1, 2, 2, 0, 2, 3, 3, 0, 4]], dtype=np.int64))
- vocab_dataset = dataset_ops.Dataset.from_tensors(vocab_data)
- input_array = sparse_ops.from_dense(
- np.array([[1, 2, 3, 0], [0, 3, 1, 0]], dtype=np.int64))
-
- # pyformat: disable
- expected_output = [[0, 1, 1, 1, 0],
- [0, 1, 0, 1, 0]]
- # pyformat: enable
- max_tokens = 5
- expected_output_shape = [None, max_tokens]
-
- input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
- layer = get_layer_class()(
- max_tokens=None, output_mode=category_encoding.BINARY)
- layer.adapt(vocab_dataset)
- int_data = layer(input_data)
- self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
-
- model = keras.Model(inputs=input_data, outputs=int_data)
- output_dataset = model.predict(input_array, steps=1)
- self.assertAllEqual(expected_output, output_dataset)
-
- def test_ragged_adapt(self):
- vocab_data = ragged_factory_ops.constant(
- np.array([[1, 1, 0, 1, 1], [2, 2], [0, 2, 3], [0, 4]]))
- vocab_dataset = dataset_ops.Dataset.from_tensors(vocab_data)
- input_array = ragged_factory_ops.constant([[1, 2, 3], [3, 1]])
-
- # pyformat: disable
- expected_output = [[0, 1, 1, 1, 0],
- [0, 1, 0, 1, 0]]
- # pyformat: enable
- max_tokens = 5
- expected_output_shape = [None, max_tokens]
-
- input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
-
- layer = get_layer_class()(
- max_tokens=None, output_mode=category_encoding.BINARY)
- layer.adapt(vocab_dataset)
- int_data = layer(input_data)
-
- self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
-
- model = keras.Model(inputs=input_data, outputs=int_data)
- output_dataset = model.predict(input_array, steps=1)
- self.assertAllEqual(expected_output, output_dataset)
-
- def test_hard_maximum_set_state_variables_after_build(self):
- state_variables = {category_encoding._NUM_ELEMENTS_NAME: 5}
- input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
-
- # pyformat: disable
- expected_output = [[0, 1, 1, 1, 0],
- [1, 1, 0, 1, 0]]
- # pyformat: enable
- max_tokens = 5
- expected_output_shape = [None, max_tokens]
+ def test_legacy_max_tokens_arg(self):
+ input_array = np.array([[1, 2, 3, 1]])
+ expected_output = [[0, 1, 1, 1, 0, 0]]
+ num_tokens = 6
+ expected_output_shape = [None, num_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.BINARY)
- int_data = layer(input_data)
- layer._set_state_variables(state_variables)
- self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
-
- model = keras.Model(inputs=input_data, outputs=int_data)
- output_dataset = model.predict(input_array)
- self.assertAllEqual(expected_output, output_dataset)
-
- def test_soft_maximum_set_state_after_build(self):
- input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
-
- # pyformat: disable
- expected_output = [[0, 1, 1, 1, 0],
- [1, 1, 0, 1, 0]]
- # pyformat: enable
- max_tokens = 5
- expected_output_shape = [None, max_tokens]
-
- input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(
- max_tokens=None, output_mode=category_encoding.BINARY)
- layer.build(input_data.shape)
- layer.set_num_elements(max_tokens)
+ max_tokens=num_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@@ -397,34 +305,6 @@
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
- def test_set_weights_fails_on_wrong_size_weights(self):
- tfidf_data = [.05, .5, .25, .2, .125]
- layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.TFIDF)
-
- with self.assertRaisesRegex(ValueError, ".*Layer weight shape.*"):
- layer.set_weights([np.array(tfidf_data)])
-
- def test_set_num_elements_after_call_fails(self):
- input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(
- max_tokens=None, output_mode=category_encoding.BINARY)
- layer.adapt([1, 2])
- _ = layer(input_data)
- with self.assertRaisesRegex(
- RuntimeError, ".*'max_tokens' arg must be set to None."):
- layer.set_num_elements(5)
-
- def test_set_state_variables_after_call_fails(self):
- state_variables = {category_encoding._NUM_ELEMENTS_NAME: 5}
-
- input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(
- max_tokens=None, output_mode=category_encoding.BINARY)
- layer.adapt([1, 2])
- _ = layer(input_data)
- with self.assertRaisesRegex(RuntimeError, "Cannot update states.*"):
- layer._set_state_variables(state_variables)
-
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes
@@ -432,19 +312,19 @@
preprocessing_test_utils.PreprocessingLayerTest
):
- def test_binary_output_hard_maximum(self):
+ def test_binary_output(self):
input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
# pyformat: disable
expected_output = [[0, 1, 1, 1, 0, 0],
[1, 1, 0, 1, 0, 0]]
# pyformat: enable
- max_tokens = 6
- expected_output_shape = [None, max_tokens]
+ num_tokens = 6
+ expected_output_shape = [None, num_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()(
- max_tokens=max_tokens, output_mode=category_encoding.BINARY)
+ num_tokens=num_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@@ -452,39 +332,18 @@
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
- def test_binary_output_soft_maximum(self):
- input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
-
- # pyformat: disable
- expected_output = [[0, 1, 1, 1, 0],
- [1, 1, 0, 1, 0]]
- # pyformat: enable
- max_tokens = 5
- expected_output_shape = [None, max_tokens]
-
- input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(
- max_tokens=None, output_mode=category_encoding.BINARY)
- layer.set_num_elements(max_tokens)
- int_data = layer(input_data)
- self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
-
- model = keras.Model(inputs=input_data, outputs=int_data)
- output_dataset = model.predict(input_array)
- self.assertAllEqual(expected_output, output_dataset)
-
- def test_count_output_hard_maximum(self):
+ def test_count_output(self):
input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
# pyformat: disable
expected_output = [[0, 2, 1, 1, 0, 0],
[2, 1, 0, 1, 0, 0]]
# pyformat: enable
- max_tokens = 6
- expected_output_shape = [None, max_tokens]
+ num_tokens = 6
+ expected_output_shape = [None, num_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.COUNT)
+ layer = get_layer_class()(num_tokens=6, output_mode=category_encoding.COUNT)
int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@@ -492,75 +351,6 @@
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
- def test_count_output_soft_maximum(self):
- input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
-
- # pyformat: disable
- expected_output = [[0, 2, 1, 1, 0],
- [2, 1, 0, 1, 0]]
- # pyformat: enable
- max_tokens = 5
- expected_output_shape = [None, max_tokens]
-
- input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(
- max_tokens=None, output_mode=category_encoding.COUNT)
- layer.set_num_elements(max_tokens)
- int_data = layer(input_data)
- self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
-
- model = keras.Model(inputs=input_data, outputs=int_data)
- output_dataset = model.predict(input_array)
- self.assertAllEqual(expected_output, output_dataset)
-
- def test_tfidf_output_hard_maximum(self):
- tfidf_data = [.05, .5, .25, .2, .125]
- input_array = np.array([[1, 2, 3, 1], [0, 4, 1, 0]])
-
- # pyformat: disable
- # pylint: disable=bad-whitespace
- expected_output = [[ 0, 1, .25, .2, 0, 0],
- [.1, .5, 0, 0, .125, 0]]
- # pylint: enable=bad-whitespace
- # pyformat: enable
- max_tokens = 6
- expected_output_shape = [None, max_tokens]
-
- input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.TFIDF)
- layer.set_tfidf_data(tfidf_data)
- int_data = layer(input_data)
- self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
-
- model = keras.Model(inputs=input_data, outputs=int_data)
- output_dataset = model.predict(input_array)
- self.assertAllClose(expected_output, output_dataset)
-
- def test_tfidf_output_soft_maximum(self):
- tfidf_data = [.05, .5, .25, .2, .125]
- input_array = np.array([[1, 2, 3, 1], [0, 4, 1, 0]])
-
- # pyformat: disable
- # pylint: disable=bad-whitespace
- expected_output = [[ 0, 1, .25, .2, 0],
- [.1, .5, 0, 0, .125]]
- # pylint: enable=bad-whitespace
- # pyformat: enable
- max_tokens = 5
- expected_output_shape = [None, max_tokens]
-
- input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(
- max_tokens=None, output_mode=category_encoding.TFIDF)
- layer.set_num_elements(max_tokens)
- layer.set_tfidf_data(tfidf_data)
- int_data = layer(input_data)
- self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
-
- model = keras.Model(inputs=input_data, outputs=int_data)
- output_dataset = model.predict(input_array)
- self.assertAllClose(expected_output, output_dataset)
-
class CategoryEncodingModelBuildingTest(
keras_parameterized.TestCase,
@@ -568,43 +358,23 @@
@parameterized.named_parameters(
{
- "testcase_name": "count_hard_max",
- "max_tokens": 5,
+ "testcase_name": "count_output",
+ "num_tokens": 5,
"output_mode": category_encoding.COUNT
}, {
- "testcase_name": "count_soft_max",
- "max_tokens": None,
- "output_mode": category_encoding.COUNT
- }, {
- "testcase_name": "binary_hard_max",
- "max_tokens": 5,
+ "testcase_name": "binary_output",
+ "num_tokens": 5,
"output_mode": category_encoding.BINARY
- }, {
- "testcase_name": "binary_soft_max",
- "max_tokens": None,
- "output_mode": category_encoding.BINARY
- }, {
- "testcase_name": "tfidf_hard_max",
- "max_tokens": 5,
- "output_mode": category_encoding.TFIDF
- }, {
- "testcase_name": "tfidf_soft_max",
- "max_tokens": None,
- "output_mode": category_encoding.TFIDF
})
- def test_end_to_end_bagged_modeling(self, output_mode, max_tokens):
- tfidf_data = np.array([.03, .5, .25, .2, .125])
+ def test_end_to_end_bagged_modeling(self, output_mode, num_tokens):
input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
- layer = get_layer_class()(max_tokens=max_tokens, output_mode=output_mode)
+ layer = get_layer_class()(num_tokens=num_tokens, output_mode=output_mode)
weights = []
- if max_tokens is None:
+ if num_tokens is None:
layer.set_num_elements(5)
- if output_mode == category_encoding.TFIDF:
- weights.append(tfidf_data)
-
layer.set_weights(weights)
int_data = layer(input_data)
@@ -614,160 +384,5 @@
_ = model.predict(input_array)
-@keras_parameterized.run_all_keras_modes
-class CategoryEncodingCombinerTest(
- keras_parameterized.TestCase,
- preprocessing_test_utils.PreprocessingLayerTest):
-
- def compare_idf_accumulators(self, a, b, msg=None):
- if a is None or b is None:
- self.assertAllEqual(a, b, msg=msg)
-
- self.assertAllEqual(a.data, b.data, msg=msg)
-
- if a.per_doc_count_dict is not None:
-
- def per_doc_counts(accumulator):
- count_values = [
- count_dict["count"]
- for count_dict in accumulator.per_doc_count_dict.values()
- ]
- return dict(zip(accumulator.per_doc_count_dict.keys(), count_values))
-
- self.assertAllEqual(per_doc_counts(a), per_doc_counts(b), msg=msg)
-
- compare_accumulators = compare_idf_accumulators
-
- def update_accumulator(self, accumulator, data):
- accumulator.data[1] = data["num_documents"]
- accumulator.data[0] = data["max_element"]
-
- if "document_counts" in data:
- create_dict = lambda x: {"count": x, "last_doc_id": -1}
- idf_dict = {}
- for i, count in enumerate(data["document_counts"]):
- if count > 0:
- idf_dict[i] = create_dict(count)
-
- accumulator.per_doc_count_dict.update(idf_dict)
-
- return accumulator
-
- def test_combiner_api_compatibility_int_mode(self):
- data = np.array([[1, 2, 3, 4], [1, 2, 3, 0]])
- combiner = category_encoding._CategoryEncodingCombiner(compute_idf=False)
- expected_accumulator_output = {
- "max_element": np.array(4),
- "num_documents": np.array(2),
- }
- expected_extract_output = {
- "num_elements": np.array(5),
- }
- expected_accumulator = combiner._create_accumulator()
- expected_accumulator = self.update_accumulator(expected_accumulator,
- expected_accumulator_output)
- self.validate_accumulator_serialize_and_deserialize(combiner, data,
- expected_accumulator)
- self.validate_accumulator_uniqueness(combiner, data)
- self.validate_accumulator_extract(combiner, data, expected_extract_output)
-
- def test_combiner_api_compatibility_tfidf_mode(self):
- data = np.array([[1, 2, 3, 4], [1, 2, 3, 0]])
- combiner = category_encoding._CategoryEncodingCombiner(compute_idf=True)
- expected_accumulator_output = {
- "max_element": np.array(4),
- "document_counts": np.array([1, 2, 2, 2, 1]),
- "num_documents": np.array(2),
- }
- expected_extract_output = {
- "num_elements": np.array(5),
- "idf": np.array([0.693147, 0.510826, 0.510826, 0.510826, 0.693147]),
- }
-
- expected_accumulator = combiner._create_accumulator()
- expected_accumulator = self.update_accumulator(expected_accumulator,
- expected_accumulator_output)
- self.validate_accumulator_serialize_and_deserialize(combiner, data,
- expected_accumulator)
- self.validate_accumulator_uniqueness(combiner, data)
- self.validate_accumulator_extract(combiner, data, expected_extract_output)
-
- # TODO(askerryryan): Add tests confirming equivalence to behavior of
- # existing tf.keras.preprocessing.text.Tokenizer.
- @parameterized.named_parameters(
- {
- "testcase_name": "no_top_k",
- "data": np.array([[1, 2], [4, 2], [3], [4, 2]]),
- "expected_accumulator_output": {
- "max_element": np.array(4),
- "document_counts": np.array([0, 1, 3, 1, 2]),
- "num_documents": np.array(4),
- },
- "expected_extract_output": {
- "num_elements":
- np.array(5),
- "idf":
- np.array([1.609438, 1.098612, 0.693147, 1.098612, 0.847298]),
- },
- }, {
- "testcase_name": "single_element_per_row",
- "data": np.array([[1], [2], [4], [2], [3]]),
- "expected_accumulator_output": {
- "max_element": np.array(4),
- "document_counts": np.array([0, 1, 2, 1, 1]),
- "num_documents": np.array(5),
- },
- "expected_extract_output": {
- "num_elements":
- np.array(5),
- "idf":
- np.array([1.791759, 1.252763, 0.980829, 1.252763, 1.252763]),
- },
- })
- def test_combiner_computation(self,
- data,
- expected_accumulator_output,
- expected_extract_output,
- compute_idf=True):
- combiner = category_encoding._CategoryEncodingCombiner(
- compute_idf=compute_idf)
- expected_accumulator = combiner._create_accumulator()
- expected_accumulator = self.update_accumulator(expected_accumulator,
- expected_accumulator_output)
- self.validate_accumulator_computation(combiner, data, expected_accumulator)
- self.validate_accumulator_extract(combiner, data, expected_extract_output)
-
- def test_1d_data(self):
- data = [1, 2, 3]
- cls = get_layer_class()
- layer = cls()
- layer.adapt(data)
- output = layer(data)
- self.assertListEqual(output.shape.as_list(), [3, 4])
-
- def test_no_adapt_exception(self):
- cls = get_layer_class()
- layer = cls()
- with self.assertRaisesRegex(
- RuntimeError, r".*you need to call.*"):
- _ = layer([1, 2, 3])
-
- def test_saving_loading(self):
- cls = get_layer_class()
- encoder = cls()
- encoder.adapt([1, 2, 3])
- model = keras.Sequential([encoder])
- model.save("/tmp/model", save_format="tf")
- loaded_model = keras.models.load_model("/tmp/model")
- self.assertAllClose(model.predict([[1]]), loaded_model.predict([[1]]))
-
- def test_serialize(self):
- cls = get_layer_class()
- encoder = cls()
- encoder.adapt([1, 2, 3])
- model = keras.Sequential([encoder])
- _ = keras.models.clone_model(model)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding_v1.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_v1.py
deleted file mode 100644
index 3afb86b..0000000
--- a/tensorflow/python/keras/layers/preprocessing/category_encoding_v1.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tensorflow V1 version of the text category_encoding preprocessing layer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.keras.engine import base_preprocessing_layer_v1
-from tensorflow.python.keras.layers.preprocessing import category_encoding
-from tensorflow.python.util.tf_export import keras_export
-
-
-@keras_export(v1=["keras.layers.experimental.preprocessing.CategoryEncoding"])
-class CategoryEncoding(category_encoding.CategoryEncoding,
- base_preprocessing_layer_v1.CombinerPreprocessingLayer):
- """CategoryEncoding layer.
-
- This layer provides options for condensing input data into denser
- representations. It accepts either integer values or strings as inputs,
- allows users to map those inputs into a contiguous integer space, and
- outputs either those integer values (one sample = 1D tensor of integer token
- indices) or a dense representation (one sample = 1D tensor of float values
- representing data about the sample's tokens).
-
- If desired, the user can call this layer's adapt() method on a dataset.
- When this layer is adapted, it will analyze the dataset, determine the
- frequency of individual integer or string values, and create a 'vocabulary'
- from them. This vocabulary can have unlimited size or be capped, depending
- on the configuration options for this layer; if there are more unique
- values in the input than the maximum vocabulary size, the most frequent
- terms will be used to create the vocabulary.
-
- Attributes:
- max_elements: The maximum size of the vocabulary for this layer. If None,
- there is no cap on the size of the vocabulary.
- output_mode: Optional specification for the output of the layer. Values can
- be "int", "binary", "count" or "tf-idf", configuring the layer as follows:
- "int": Outputs integer indices, one integer index per split string
- token.
- "binary": Outputs a single int array per batch, of either vocab_size or
- max_elements size, containing 1s in all elements where the token
- mapped to that index exists at least once in the batch item.
- "count": As "binary", but the int array contains a count of the number
- of times the token at that index appeared in the batch item.
- "tf-idf": As "binary", but the TF-IDF algorithm is applied to find the
- value in each token slot.
- output_sequence_length: Only valid in INT mode. If set, the output will have
- its time dimension padded or truncated to exactly `output_sequence_length`
- values, resulting in a tensor of shape [batch_size,
- output_sequence_length] regardless of the input shape.
- pad_to_max_elements: Only valid in "binary", "count", and "tf-idf" modes.
- If True, the output will have its feature axis padded to `max_elements`
- even if the number of unique values in the vocabulary is less than
- max_elements, resulting in a tensor of shape [batch_size, max_elements]
- regardless of vocabulary size. Defaults to False.
- """
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py
index 925e1ca..762b77c 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing.py
@@ -24,16 +24,12 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
-from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util.tf_export import keras_export
# Default key from tf.sparse.cross_hashed
@@ -84,20 +80,6 @@
[2],
[2]])>
-
- Example (FarmHash64) with list of inputs:
- >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
- >>> inp_1 = [['A'], ['B'], ['C'], ['D'], ['E']]
- >>> inp_2 = np.asarray([[5], [4], [3], [2], [1]])
- >>> layer([inp_1, inp_2])
- <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
- array([[1],
- [1],
- [0],
- [2],
- [0]])>
-
-
Example (SipHash64):
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
@@ -187,47 +169,13 @@
def call(self, inputs):
inputs = self._preprocess_inputs(inputs)
- if isinstance(inputs, (tuple, list)):
- return self._process_input_list(inputs)
- elif isinstance(inputs, sparse_tensor.SparseTensor):
+ if isinstance(inputs, sparse_tensor.SparseTensor):
return sparse_tensor.SparseTensor(
indices=inputs.indices,
values=self._hash_values_to_bins(inputs.values),
dense_shape=inputs.dense_shape)
return self._hash_values_to_bins(inputs)
- def _process_input_list(self, inputs):
- # TODO(momernick): support ragged_cross_hashed with corrected fingerprint
- # and siphash.
- if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs):
- raise ValueError('Hashing with ragged input is not supported yet.')
- if self.mask_value is not None:
- raise ValueError(
- 'Cross hashing with a mask_value is not supported yet, mask_value is '
- '{}.'.format(self.mask_value))
- sparse_inputs = [
- inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor)
- ]
- dense_inputs = [
- inp for inp in inputs if not isinstance(inp, sparse_tensor.SparseTensor)
- ]
- all_dense = True if not sparse_inputs else False
- indices = [sp_inp.indices for sp_inp in sparse_inputs]
- values = [sp_inp.values for sp_inp in sparse_inputs]
- shapes = [sp_inp.dense_shape for sp_inp in sparse_inputs]
- indices_out, values_out, shapes_out = gen_sparse_ops.SparseCrossHashed(
- indices=indices,
- values=values,
- shapes=shapes,
- dense_inputs=dense_inputs,
- num_buckets=self.num_bins,
- strong_hash=self.strong_hash,
- salt=self.salt)
- sparse_out = sparse_tensor.SparseTensor(indices_out, values_out, shapes_out)
- if all_dense:
- return sparse_ops.sparse_tensor_to_dense(sparse_out)
- return sparse_out
-
def _hash_values_to_bins(self, values):
"""Converts a non-sparse tensor of values to bin indices."""
str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
@@ -257,41 +205,16 @@
string_ops.string_to_hash_bucket_strong, key=self.salt)
def compute_output_shape(self, input_shape):
- if not isinstance(input_shape, (tuple, list)):
- return input_shape
- input_shapes = input_shape
- batch_size = None
- for inp_shape in input_shapes:
- inp_tensor_shape = tensor_shape.TensorShape(inp_shape).as_list()
- if len(inp_tensor_shape) != 2:
- raise ValueError('Inputs must be rank 2, get {}'.format(input_shapes))
- if batch_size is None:
- batch_size = inp_tensor_shape[0]
- # The second dimension is dynamic based on inputs.
- output_shape = [batch_size, None]
- return tensor_shape.TensorShape(output_shape)
+ return input_shape
def compute_output_signature(self, input_spec):
- if not isinstance(input_spec, (tuple, list)):
- output_shape = self.compute_output_shape(input_spec.shape)
- output_dtype = dtypes.int64
- if isinstance(input_spec, sparse_tensor.SparseTensorSpec):
- return sparse_tensor.SparseTensorSpec(
- shape=output_shape, dtype=output_dtype)
- else:
- return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
- input_shapes = [x.shape for x in input_spec]
- output_shape = self.compute_output_shape(input_shapes)
- if any(
- isinstance(inp_spec, ragged_tensor.RaggedTensorSpec)
- for inp_spec in input_spec):
- return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
- elif any(
- isinstance(inp_spec, sparse_tensor.SparseTensorSpec)
- for inp_spec in input_spec):
+ output_shape = self.compute_output_shape(input_spec.shape)
+ output_dtype = dtypes.int64
+ if isinstance(input_spec, sparse_tensor.SparseTensorSpec):
return sparse_tensor.SparseTensorSpec(
- shape=output_shape, dtype=dtypes.int64)
- return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
+ shape=output_shape, dtype=output_dtype)
+ else:
+ return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
def get_config(self):
config = {
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_test.py
index 712a78e..0844f71 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing_test.py
@@ -64,23 +64,6 @@
# 'omar' should map to 0.
self.assertAllClose([[0], [1], [2], [1], [1]], omar_mask_output)
- def test_hash_dense_multi_inputs_mask_value_farmhash(self):
- layer = hashing.Hashing(num_bins=3, mask_value='omar')
- inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
- ['skywalker']])
- inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
- with self.assertRaisesRegex(ValueError, 'not supported yet'):
- _ = layer([inp_1, inp_2])
-
- def test_hash_dense_multi_inputs_farmhash(self):
- layer = hashing.Hashing(num_bins=2)
- inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
- ['skywalker']])
- inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
- output = layer([inp_1, inp_2])
- # Assert equal for hashed output that should be true on all platforms.
- self.assertAllClose([[0], [0], [1], [1], [0]], output)
-
def test_hash_dense_list_input_farmhash(self):
layer = hashing.Hashing(num_bins=2)
inp = [['omar'], ['stringer'], ['marlo'], ['wire'], ['skywalker']]
@@ -93,15 +76,6 @@
# Assert equal for hashed output that should be true on all platforms.
self.assertAllClose([0, 0, 1, 0, 0], output)
- def test_hash_dense_list_inputs_mixed_int_string_farmhash(self):
- layer = hashing.Hashing(num_bins=2)
- inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
- ['skywalker']])
- inp_2 = np.asarray([[1], [2], [3], [4], [5]]).astype(np.int64)
- output = layer([inp_1, inp_2])
- # Assert equal for hashed output that should be true on all platforms.
- self.assertAllClose([[0], [1], [1], [1], [0]], output)
-
def test_hash_dense_int_input_farmhash(self):
layer = hashing.Hashing(num_bins=3)
inp = np.asarray([[0], [1], [2], [3], [4]])
@@ -123,21 +97,6 @@
# Note the result is different from (133, 137).
self.assertAllClose([[1], [0], [1], [0], [1]], output_2)
- def test_hash_dense_multi_inputs_siphash(self):
- layer = hashing.Hashing(num_bins=2, salt=[133, 137])
- inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
- ['skywalker']])
- inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
- output = layer([inp_1, inp_2])
- # Assert equal for hashed output that should be true on all platforms.
- # Note the result is different from FarmHash.
- self.assertAllClose([[0], [1], [0], [0], [1]], output)
-
- layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
- output_2 = layer_2([inp_1, inp_2])
- # Note the result is different from (133, 137).
- self.assertAllClose([[1], [1], [1], [0], [1]], output_2)
-
def test_hash_dense_int_input_siphash(self):
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
inp = np.asarray([[0], [1], [2], [3], [4]])
@@ -174,19 +133,6 @@
# 'omar' should map to 0.
self.assertAllClose([0, 1, 2, 1, 1], omar_mask_output.values)
- def test_hash_sparse_multi_inputs_farmhash(self):
- layer = hashing.Hashing(num_bins=2)
- indices = [[0, 0], [1, 0], [2, 0]]
- inp_1 = sparse_tensor.SparseTensor(
- indices=indices,
- values=['omar', 'stringer', 'marlo'],
- dense_shape=[3, 1])
- inp_2 = sparse_tensor.SparseTensor(
- indices=indices, values=['A', 'B', 'C'], dense_shape=[3, 1])
- output = layer([inp_1, inp_2])
- self.assertAllClose(indices, output.indices)
- self.assertAllClose([0, 0, 1], output.values)
-
def test_hash_sparse_int_input_farmhash(self):
layer = hashing.Hashing(num_bins=3)
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
@@ -213,25 +159,6 @@
# The result should be same with test_hash_dense_input_siphash.
self.assertAllClose([1, 0, 1, 0, 1], output.values)
- def test_hash_sparse_multi_inputs_siphash(self):
- layer = hashing.Hashing(num_bins=2, salt=[133, 137])
- indices = [[0, 0], [1, 0], [2, 0]]
- inp_1 = sparse_tensor.SparseTensor(
- indices=indices,
- values=['omar', 'stringer', 'marlo'],
- dense_shape=[3, 1])
- inp_2 = sparse_tensor.SparseTensor(
- indices=indices, values=['A', 'B', 'C'], dense_shape=[3, 1])
- output = layer([inp_1, inp_2])
- # The result should be same with test_hash_dense_input_siphash.
- self.assertAllClose(indices, output.indices)
- self.assertAllClose([0, 1, 0], output.values)
-
- layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
- output = layer_2([inp_1, inp_2])
- # The result should be same with test_hash_dense_input_siphash.
- self.assertAllClose([1, 1, 1], output.values)
-
def test_hash_sparse_int_input_siphash(self):
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
@@ -272,17 +199,6 @@
expected_output = [[0, 1, 2, 1], [2, 1, 1]]
self.assertAllClose(expected_output, omar_mask_output)
- def test_hash_ragged_string_multi_inputs_farmhash(self):
- layer = hashing.Hashing(num_bins=2)
- inp_data_1 = ragged_factory_ops.constant(
- [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
- dtype=dtypes.string)
- inp_data_2 = ragged_factory_ops.constant(
- [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
- dtype=dtypes.string)
- with self.assertRaisesRegex(ValueError, 'not supported yet'):
- _ = layer([inp_data_1, inp_data_2])
-
def test_hash_ragged_int_input_farmhash(self):
layer = hashing.Hashing(num_bins=3)
inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]],
@@ -321,17 +237,6 @@
model = training.Model(inputs=inp_t, outputs=out_t)
self.assertAllClose(out_data, model.predict(inp_data))
- def test_hash_ragged_string_multi_inputs_siphash(self):
- layer = hashing.Hashing(num_bins=2, salt=[133, 137])
- inp_data_1 = ragged_factory_ops.constant(
- [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
- dtype=dtypes.string)
- inp_data_2 = ragged_factory_ops.constant(
- [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
- dtype=dtypes.string)
- with self.assertRaisesRegex(ValueError, 'not supported yet'):
- _ = layer([inp_data_1, inp_data_2])
-
def test_hash_ragged_int_input_siphash(self):
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]],
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
index 6a44692..1e47462 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
@@ -412,7 +412,7 @@
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
with test.mock.patch.object(
random_ops, 'random_uniform', return_value=mock_random):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = image_preprocessing.RandomFlip()
actual_output = layer(input_images, training=1)
self.assertAllClose(expected_output, actual_output)
@@ -698,7 +698,7 @@
fill_value=0.0,
interpolation='bilinear'):
inp = np.arange(15).reshape((1, 5, 3, 1)).astype(np.float32)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
output = image_preprocessing.transform(
inp,
transform_matrix,
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py
index f985769..46c37db 100644
--- a/tensorflow/python/keras/layers/preprocessing/index_lookup.py
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py
@@ -184,18 +184,6 @@
self._value_dtype = dtypes.int64
oov_value = self._oov_value
- self._table = lookup_ops.MutableHashTable(
- key_dtype=self._key_dtype,
- value_dtype=self._value_dtype,
- default_value=oov_value,
- name=(self._name + "_index_table"))
- tracked_table = self._add_trackable(self._table, trainable=False)
- # This is a workaround for summary() on this layer. Because the table is
- # not mutable during training, the effective number of parameters (and so
- # the weight shape) is 0; we add this as an attr so that the parameter
- # counting code in the Model object doesn't throw an attribute error.
- tracked_table.shape = tensor_shape.TensorShape((0,))
-
if self.num_oov_indices <= 1:
oov_indices = None
else:
@@ -203,13 +191,30 @@
oov_end = oov_start + num_oov_indices
oov_indices = list(range(oov_start, oov_end))
- self._table_handler = table_utils.TableHandler(
- table=self._table,
- oov_tokens=oov_indices,
- use_v1_apis=self._use_v1_apis())
-
- if vocabulary is not None:
- self.set_vocabulary(vocabulary)
+ if vocabulary is not None and isinstance(vocabulary,
+ lookup_ops.TextFileInitializer):
+ self._table = self._static_table_class()(
+ vocabulary, default_value=oov_value)
+ self._table_handler = table_utils.TableHandler(
+ table=self._table,
+ mask_token=mask_token,
+ oov_tokens=oov_indices,
+ use_v1_apis=self._use_v1_apis())
+ self.max_tokens = (
+ self._table_handler.vocab_size() + self.num_oov_indices +
+ (0 if mask_token is None else 1))
+ else:
+ self._table = lookup_ops.MutableHashTable(
+ key_dtype=self._key_dtype,
+ value_dtype=self._value_dtype,
+ default_value=oov_value,
+ name=(self._name + "_index_table"))
+ self._table_handler = table_utils.TableHandler(
+ table=self._table,
+ oov_tokens=oov_indices,
+ use_v1_apis=self._use_v1_apis())
+ if vocabulary is not None:
+ self.set_vocabulary(vocabulary)
if self.output_mode == TFIDF:
# The TF-IDF weight may have a (None,) tensorshape. This creates
@@ -232,6 +237,13 @@
dtype=K.floatx(),
initializer=initializer)
+ tracked_table = self._add_trackable(self._table, trainable=False)
+ # This is a workaround for summary() on this layer. Because the table is
+ # not mutable during training, the effective number of parameters (and so
+ # the weight shape) is 0; we add this as an attr so that the parameter
+ # counting code in the Model object doesn't throw an attribute error.
+ tracked_table.shape = tensor_shape.TensorShape((0,))
+
def compute_output_shape(self, input_shape):
if self.output_mode != INT:
return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
@@ -538,6 +550,9 @@
def _use_v1_apis(self):
return False
+ def _static_table_class(self):
+ return lookup_ops.StaticHashTable
+
class _IndexLookupAccumulator(
collections.namedtuple("Accumulator",
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py
index b845dd4..eb60963 100644
--- a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py
@@ -41,7 +41,9 @@
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -703,6 +705,15 @@
class IndexLookupOutputTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
+ def _write_to_temp_file(self, file_name, vocab_list):
+ vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
+ with gfile.GFile(vocab_path, "w") as writer:
+ for vocab in vocab_list:
+ writer.write(vocab + "\n")
+ writer.flush()
+ writer.close()
+ return vocab_path
+
def test_int_output(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
@@ -958,7 +969,60 @@
layer_output = layer(input_data)
self.assertAllEqual(layer_output.shape.as_list(), [16, 2])
+ def test_int_output_file_vocab(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = np.array([["earth", "wind", "and", "fire"],
+ ["fire", "", "earth", "michigan"]])
+ expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]]
+ vocab_file = self._write_to_temp_file("temp", vocab_data)
+ vocabulary_initializer = lookup_ops.TextFileInitializer(
+ filename=vocab_file,
+ key_dtype=dtypes.string,
+ key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
+ value_dtype=dtypes.int64,
+ value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
+ value_index_offset=2)
+
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(
+ vocabulary=vocabulary_initializer,
+ max_tokens=None,
+ num_oov_indices=1,
+ mask_token="",
+ oov_token="[OOV]",
+ dtype=dtypes.string)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+ output_dataset = model.predict(input_array)
+ self.assertAllEqual(expected_output, output_dataset)
+
+ def test_int_output_int_file_vocab(self):
+ vocab_data = ["10", "20", "30", "40"]
+ input_array = np.array([[10, 20, 30, 40], [40, 0, 10, 42]])
+ expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]]
+
+ vocab_file = self._write_to_temp_file("temp", vocab_data)
+ vocabulary_initializer = lookup_ops.TextFileInitializer(
+ filename=vocab_file,
+ key_dtype=dtypes.int64,
+ key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
+ value_dtype=dtypes.int64,
+ value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
+ value_index_offset=2)
+
+ input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
+ layer = get_layer_class()(
+ vocabulary=vocabulary_initializer,
+ max_tokens=None,
+ num_oov_indices=1,
+ mask_token=0,
+ oov_token=-1,
+ dtype=dtypes.int64)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+ output_dataset = model.predict(input_array)
+ self.assertAllEqual(expected_output, output_dataset)
@keras_parameterized.run_all_keras_modes
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py
index 47fea11..c710108 100644
--- a/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py
@@ -21,6 +21,7 @@
from tensorflow.python.keras.engine import base_preprocessing_layer_v1
from tensorflow.python.keras.layers.preprocessing import index_lookup
+from tensorflow.python.ops import lookup_ops
class IndexLookup(index_lookup.IndexLookup,
@@ -58,3 +59,6 @@
def _use_v1_apis(self):
return True
+
+ def _static_table_class(self):
+ return lookup_ops.StaticHashTableV1
diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils.py b/tensorflow/python/keras/layers/preprocessing/table_utils.py
index 56b07fb..e7fe917 100644
--- a/tensorflow/python/keras/layers/preprocessing/table_utils.py
+++ b/tensorflow/python/keras/layers/preprocessing/table_utils.py
@@ -27,6 +27,7 @@
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
@@ -38,8 +39,24 @@
class TableHandler(object):
"""Wrapper object that holds a lookup table and provides accessors."""
- def __init__(self, table, oov_tokens=None, use_v1_apis=False):
+ def __init__(self,
+ table,
+ oov_tokens=None,
+ mask_token=None,
+ use_v1_apis=False):
self.table = table
+
+ # If we are using V1 APIs, and the table has an initializer, we need to run
+ # it. However, not all tables have initializers, so we try-except here.
+ if use_v1_apis:
+ try:
+ K.get_session().run(self.table.initializer)
+ except AttributeError:
+ pass
+
+ self.mutable = isinstance(table, lookup_ops.MutableHashTable)
+ self.mask_token = mask_token
+
self.use_v1_apis = use_v1_apis
if oov_tokens is None:
self.oov_tokens = oov_tokens
@@ -56,10 +73,17 @@
return self._eval(self.table.size())
def clear(self):
+ if not self.mutable:
+ return RuntimeError("Unable to clear a statically-backed table.")
+
keys, _ = self.table.export()
self._run(self.table.remove(keys))
def insert(self, keys, values):
+ """Insert values into the backed table."""
+ if not self.mutable:
+ raise RuntimeError("Unable to insert into a statically-backed table.")
+
if len(values) != len(keys):
raise RuntimeError("Size mismatch between values and key arrays. "
"Keys had size %s, values had size %s." %
@@ -90,12 +114,35 @@
return array_ops.where(oov_locations, oov_values, lookups)
+ def _lookup_and_mask(self, inputs):
+ """Return a lookup with any location with the mask_token masked to 0."""
+ lookups = self.table.lookup(inputs)
+ # If we don't need to handle masking, return the lookup values directly.
+ if self.mask_token is None:
+ return lookups
+
+ # If we do need to handle masking, increment all the lookup values by 1
+ # to account for the mask value at location 0. This also increments the
+ # OOV value, so replace that. (This is inefficient, but we can't adjust
+ # the table safely, so we don't have a choice.)
+ oov_locations = math_ops.equal(lookups, self.table._default_value) # pylint: disable=protected-access
+ oov_values = array_ops.ones_like(
+ lookups, dtype=self.table._value_dtype) * self.table._default_value # pylint: disable=protected-access
+ adjusted_lookups = array_ops.where(oov_locations, oov_values, lookups)
+
+ # Inject 0s wherever the mask token was in the inputs.
+ mask_locations = math_ops.equal(inputs, self.mask_token)
+ return array_ops.where(
+ mask_locations,
+ array_ops.zeros_like(lookups, dtype=self.table._value_dtype), # pylint: disable=protected-access
+ adjusted_lookups) # pylint: disable=protected-access
+
def _ragged_lookup(self, inputs):
"""Perform a table lookup on a ragged tensor."""
# The table lookup ops don't natively support ragged tensors, so if we have
# a RT we need to use map_flat_values to look up every element.
indexed_data = ragged_functional_ops.map_flat_values(
- self.table.lookup, inputs)
+ self._lookup_and_mask, inputs)
indexed_data = ragged_functional_ops.map_flat_values(
self._replace_oov_buckets, inputs, indexed_data)
# table.lookup is not shape-preserving, so we need to set the shape here.
@@ -107,7 +154,7 @@
def _sparse_lookup(self, inputs):
"""Perform a table lookup on a sparse tensor."""
- values = self.table.lookup(inputs.values)
+ values = self._lookup_and_mask(inputs.values)
values = self._replace_oov_buckets(inputs.values, values)
indexed_data = sparse_tensor.SparseTensor(inputs.indices, values,
inputs.dense_shape)
@@ -118,7 +165,7 @@
def _tensor_lookup(self, inputs):
"""Perform a table lookup on a tf.tensor."""
- values = self.table.lookup(inputs)
+ values = self._lookup_and_mask(inputs)
indexed_data = self._replace_oov_buckets(inputs, values)
# (b/149446477): output does not preserve input shape.
indexed_data.set_shape(inputs.shape)
diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py
index 05b18d1..d23eb97 100644
--- a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py
@@ -45,6 +45,41 @@
table, oov_tokens, use_v1_apis=(not context.executing_eagerly()))
+def get_static_table(tmpdir,
+ vocab_list,
+ mask_token=None,
+ dtype=dtypes.string,
+ oov_tokens=None):
+ vocabulary_file = os.path.join(tmpdir, "tmp_vocab.txt")
+
+ if dtype == dtypes.string:
+ with open(vocabulary_file, "w") as f:
+ f.write("\n".join(vocab_list) + "\n")
+ else:
+ with open(vocabulary_file, "w") as f:
+ f.write("\n".join([str(v) for v in vocab_list]) + "\n")
+
+ offset = ((0 if mask_token is None else 1) +
+ (len(oov_tokens) if oov_tokens is not None else 0))
+ init = lookup_ops.TextFileInitializer(
+ vocabulary_file,
+ dtype,
+ lookup_ops.TextFileIndex.WHOLE_LINE,
+ dtypes.int64,
+ lookup_ops.TextFileIndex.LINE_NUMBER,
+ value_index_offset=offset)
+ if context.executing_eagerly():
+ table = lookup_ops.StaticHashTable(init, default_value=-7)
+ else:
+ table = lookup_ops.StaticHashTableV1(init, default_value=-7)
+
+ return table_utils.TableHandler(
+ table,
+ oov_tokens,
+ mask_token=mask_token,
+ use_v1_apis=(not context.executing_eagerly()))
+
+
@keras_parameterized.run_all_keras_modes
class CategoricalEncodingInputTest(
keras_parameterized.TestCase,
@@ -252,6 +287,132 @@
self.assertAllEqual(expected_output, output_data)
+@keras_parameterized.run_all_keras_modes
+class StaticIndexLookupOutputTest(
+ keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest):
+
+ def test_int_output_default_lookup_value(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = np.array([["earth", "wind", "and", "fire"],
+ ["fire", "and", "earth", "michigan"]])
+ expected_output = [[1, 2, 3, 4], [4, 3, 1, -7]]
+
+ table = get_static_table(
+ tmpdir=self.get_temp_dir(),
+ vocab_list=vocab_data,
+ mask_token="",
+ oov_tokens=None)
+ output_data = table.lookup(input_array)
+
+ self.assertAllEqual(expected_output, output_data)
+
+ def test_output_shape(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = np.array([["earth", "wind", "and", "fire"],
+ ["fire", "and", "earth", "michigan"]])
+
+ table = get_static_table(
+ tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None)
+ output_data = table.lookup(input_array)
+
+ self.assertAllEqual(input_array.shape[1:], output_data.shape[1:])
+
+ def test_int_output_no_reserved_zero_default_lookup_value(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = np.array([["earth", "wind", "and", "fire"],
+ ["fire", "and", "earth", "michigan"]])
+ expected_output = [[0, 1, 2, 3], [3, 2, 0, -7]]
+
+ table = get_static_table(
+ tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None)
+ output_data = table.lookup(input_array)
+
+ self.assertAllEqual(expected_output, output_data)
+
+
+@keras_parameterized.run_all_keras_modes
+class CategoricalEncodingStaticInputTest(
+ keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest):
+
+ def test_sparse_string_input(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 2]],
+ values=["fire", "michigan"],
+ dense_shape=[3, 4])
+
+ expected_indices = [[0, 0], [1, 2]]
+ expected_values = [5, 1]
+ expected_dense_shape = [3, 4]
+
+ table = get_static_table(
+ tmpdir=self.get_temp_dir(),
+ vocab_list=vocab_data,
+ mask_token="",
+ oov_tokens=[1])
+ output_data = table.lookup(input_array)
+
+ self.assertAllEqual(expected_indices, output_data.indices)
+ self.assertAllEqual(expected_values, output_data.values)
+ self.assertAllEqual(expected_dense_shape, output_data.dense_shape)
+
+ def test_sparse_int_input(self):
+ vocab_data = np.array([10, 11, 12, 13], dtype=np.int64)
+ input_array = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 2]],
+ values=np.array([13, 32], dtype=np.int64),
+ dense_shape=[3, 4])
+
+ expected_indices = [[0, 0], [1, 2]]
+ expected_values = [5, 1]
+ expected_dense_shape = [3, 4]
+
+ table = get_static_table(
+ tmpdir=self.get_temp_dir(),
+ vocab_list=vocab_data,
+ dtype=dtypes.int64,
+ mask_token=0,
+ oov_tokens=[1])
+ output_data = table.lookup(input_array)
+
+ self.assertAllEqual(expected_indices, output_data.indices)
+ self.assertAllEqual(expected_values, output_data.values)
+ self.assertAllEqual(expected_dense_shape, output_data.dense_shape)
+
+ def test_ragged_string_input(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = ragged_factory_ops.constant(
+ [["earth", "wind", "fire"], ["fire", "and", "earth", "michigan"]])
+ expected_output = [[2, 3, 5], [5, 4, 2, 1]]
+
+ table = get_static_table(
+ tmpdir=self.get_temp_dir(),
+ vocab_list=vocab_data,
+ mask_token="",
+ oov_tokens=[1])
+ output_data = table.lookup(input_array)
+
+ self.assertAllEqual(expected_output, output_data)
+
+ def test_ragged_int_input(self):
+ vocab_data = np.array([10, 11, 12, 13], dtype=np.int64)
+ input_array = ragged_factory_ops.constant([[10, 11, 13], [13, 12, 10, 42]],
+ dtype=np.int64)
+ expected_output = [[2, 3, 5], [5, 4, 2, 1]]
+
+ table = get_static_table(
+ tmpdir=self.get_temp_dir(),
+ vocab_list=vocab_data,
+ dtype=dtypes.int64,
+ mask_token=0,
+ oov_tokens=[1])
+ output_data = table.lookup(input_array)
+
+ self.assertAllEqual(expected_output, output_data)
+
+
class GetVocabularyFromFileTest(test.TestCase):
def setUp(self):
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index cc20fe8..f0e7fe7 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -236,7 +236,7 @@
`batch_size` is a scalar tensor that represents the batch size
of the inputs. `dtype` is `tf.DType` that represents the dtype of
the inputs.
- For backward compatible reason, if this method is not implemented
+ For backward compatibility, if this method is not implemented
by the cell, the RNN layer will create a zero filled tensor with the
size of [batch_size, cell.state_size].
In the case that `cell` is a list of RNN cell instances, the cells
diff --git a/tensorflow/python/keras/layers/separable_convolutional_test.py b/tensorflow/python/keras/layers/separable_convolutional_test.py
index 8234bfe..8fdaccc 100644
--- a/tensorflow/python/keras/layers/separable_convolutional_test.py
+++ b/tensorflow/python/keras/layers/separable_convolutional_test.py
@@ -35,7 +35,7 @@
stack_size = 3
length = 7
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.SeparableConv1D,
kwargs=kwargs,
@@ -66,7 +66,7 @@
'activity_regularizer': 'l2',
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.SeparableConv1D(**kwargs)
layer.build((None, 5, 2))
self.assertEqual(len(layer.losses), 3)
@@ -87,7 +87,7 @@
'bias_constraint': b_constraint,
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.SeparableConv1D(**kwargs)
layer.build((None, 5, 2))
self.assertEqual(layer.depthwise_kernel.constraint, d_constraint)
@@ -104,7 +104,7 @@
num_row = 7
num_col = 6
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.SeparableConv2D,
kwargs=kwargs,
@@ -138,7 +138,7 @@
'activity_regularizer': 'l2',
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.SeparableConv2D(**kwargs)
layer.build((None, 5, 5, 2))
self.assertEqual(len(layer.losses), 3)
@@ -159,7 +159,7 @@
'bias_constraint': b_constraint,
'strides': 1
}
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
layer = keras.layers.SeparableConv2D(**kwargs)
layer.build((None, 5, 5, 2))
self.assertEqual(layer.depthwise_kernel.constraint, d_constraint)
diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py
index b748b3b..aa74302 100644
--- a/tensorflow/python/keras/layers/serialization.py
+++ b/tensorflow/python/keras/layers/serialization.py
@@ -48,7 +48,6 @@
from tensorflow.python.keras.layers import wrappers
from tensorflow.python.keras.layers.preprocessing import category_crossing
from tensorflow.python.keras.layers.preprocessing import category_encoding
-from tensorflow.python.keras.layers.preprocessing import category_encoding_v1
from tensorflow.python.keras.layers.preprocessing import discretization
from tensorflow.python.keras.layers.preprocessing import hashing
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
@@ -71,12 +70,11 @@
pooling, image_preprocessing, preprocessing_integer_lookup_v1,
preprocessing_normalization_v1, preprocessing_string_lookup_v1,
preprocessing_text_vectorization_v1, recurrent, wrappers,
- hashing, category_crossing, category_encoding_v1, discretization,
+ hashing, category_crossing, category_encoding, discretization,
multi_head_attention)
ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2,
preprocessing_integer_lookup, preprocessing_normalization,
- preprocessing_string_lookup, preprocessing_text_vectorization,
- category_encoding)
+ preprocessing_string_lookup, preprocessing_text_vectorization)
# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
# thread-local to avoid concurrent mutations.
LOCAL = threading.local()
diff --git a/tensorflow/python/keras/legacy_tf_layers/normalization_test.py b/tensorflow/python/keras/legacy_tf_layers/normalization_test.py
index 6b8d4ca..0386e1e 100644
--- a/tensorflow/python/keras/legacy_tf_layers/normalization_test.py
+++ b/tensorflow/python/keras/legacy_tf_layers/normalization_test.py
@@ -407,7 +407,7 @@
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
@@ -898,7 +898,7 @@
moving_stddev = 1.
renorm_mean = 0.
renorm_stddev = 1.
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(5):
x = np.random.random(shape)
@@ -948,7 +948,7 @@
moving_stddev = 1.
renorm_mean = 0.
renorm_stddev = 1.
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
for step in range(6):
x = np.random.random(shape)
@@ -1002,7 +1002,7 @@
moving_mean = 0.
moving_variance = 1.
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(5):
x = np.random.random(shape)
@@ -1055,7 +1055,7 @@
moving_stddev = 1.
renorm_mean = 0.
renorm_stddev = 1.
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(5):
x = np.random.random(shape)
@@ -1101,7 +1101,7 @@
self.assertListEqual(
out1.shape.as_list(), out2.shape.as_list())
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
x = np.random.random(shape)
@@ -1123,7 +1123,7 @@
out = normalization_layers.batch_normalization(
inp, virtual_batch_size=2)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
x = np.random.random(np_shape)
@@ -1154,7 +1154,7 @@
shape[0] // virtual_batch_size,
shape[1]])
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(5):
x = np.random.random(shape)
@@ -1207,7 +1207,7 @@
ghost_shape = ([virtual_batch_size, shape[0] // virtual_batch_size] +
shape[1:])
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(5):
x = np.random.random(shape)
@@ -1261,7 +1261,7 @@
ghost_shape = ([virtual_batch_size, shape[0] // virtual_batch_size] +
shape[1:])
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(5):
x = np.random.random(shape)
@@ -1413,7 +1413,7 @@
ghost_shape = ([virtual_batch_size, shape[0] // virtual_batch_size] +
shape[1:])
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(5):
x = np.random.random(shape)
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index 1ad806d..137933d 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -1728,10 +1728,12 @@
return K.mean(
K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
-@dispatch.dispatch_for_types(
- binary_crossentropy, ragged_tensor.RaggedTensor)
-def _ragged_tensor_binary_crossentropy(y_true, y_pred,
- from_logits=False, label_smoothing=0):
+
+@dispatch.dispatch_for_types(binary_crossentropy, ragged_tensor.RaggedTensor)
+def _ragged_tensor_binary_crossentropy(y_true,
+ y_pred,
+ from_logits=False,
+ label_smoothing=0):
""" Implements support for handling RaggedTensors.
Expected shape: (batch, sequence_len) with sequence_len being variable
@@ -1742,8 +1744,10 @@
(SUM_OVER_BATCH_SIZE), the reduction averages the per batch losses over
the number of batches.
"""
- fn = functools.partial(binary_crossentropy, from_logits=from_logits,
- label_smoothing=label_smoothing)
+ fn = functools.partial(
+ binary_crossentropy,
+ from_logits=from_logits,
+ label_smoothing=label_smoothing)
return _ragged_tensor_apply_loss(fn, y_true, y_pred)
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index 0933673..44c890f 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -897,8 +897,7 @@
def test_ragged_tensors(self):
bce_obj = losses.BinaryCrossentropy()
y_true = ragged_factory_ops.constant([[1, 0, 1], [0]])
- y_pred = ragged_factory_ops.constant([[1, 1, 1], [0]],
- dtype=dtypes.float32)
+ y_pred = ragged_factory_ops.constant([[1, 1, 1], [0]], dtype=dtypes.float32)
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
loss = bce_obj(y_true, y_pred, sample_weight=sample_weight)
@@ -910,8 +909,8 @@
# Test with logits.
y_true = ragged_factory_ops.constant([[1, 0, 1], [0, 1]])
- logits = ragged_factory_ops.constant(
- [[100.0, -100.0, 100.0], [100.0, 100.0]])
+ logits = ragged_factory_ops.constant([[100.0, -100.0, 100.0],
+ [100.0, 100.0]])
weights = constant_op.constant([4, 3])
bce_obj = losses.BinaryCrossentropy(from_logits=True)
loss = bce_obj(y_true, logits, sample_weight=weights)
diff --git a/tensorflow/python/keras/mixed_precision/BUILD b/tensorflow/python/keras/mixed_precision/BUILD
index 6f2bc97..baf126e 100644
--- a/tensorflow/python/keras/mixed_precision/BUILD
+++ b/tensorflow/python/keras/mixed_precision/BUILD
@@ -289,7 +289,10 @@
srcs = ["layer_correctness_test.py"],
python_version = "PY3",
shard_count = 10,
- tags = ["no_rocm"],
+ tags = [
+ "no_rocm",
+ "no_tfrt", # TODO(b/179863362)
+ ],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/compat:v2_compat",
diff --git a/tensorflow/python/keras/optimizer_v2/adam_test.py b/tensorflow/python/keras/optimizer_v2/adam_test.py
index 9cf5817..85958bb 100644
--- a/tensorflow/python/keras/optimizer_v2/adam_test.py
+++ b/tensorflow/python/keras/optimizer_v2/adam_test.py
@@ -113,7 +113,7 @@
def testSparse(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -203,7 +203,7 @@
def doTestBasic(self, use_callable_params=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -261,7 +261,7 @@
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
def testBasicWithAmsgrad(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, v0hat, m1, v1, v1hat = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -353,7 +353,7 @@
def testBasicWithLearningRateDecay(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -398,7 +398,7 @@
def testBasicWithLearningRateInverseTimeDecay(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -445,7 +445,7 @@
def testTensorLearningRate(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -484,7 +484,7 @@
def testSharing(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -565,7 +565,7 @@
def testSparse(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -655,7 +655,7 @@
def doTestBasic(self, use_callable_params=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -715,7 +715,7 @@
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
def testBasicWithAmsgrad(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, v0hat, m1, v1, v1hat = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -809,7 +809,7 @@
def testBasicWithLearningRateDecay(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -854,7 +854,7 @@
def testBasicWithLearningRateInverseTimeDecay(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -901,7 +901,7 @@
def testTensorLearningRate(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -940,7 +940,7 @@
def testSharing(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/python/keras/optimizer_v2/adamax_test.py b/tensorflow/python/keras/optimizer_v2/adamax_test.py
index f955df8..9a73fad 100644
--- a/tensorflow/python/keras/optimizer_v2/adamax_test.py
+++ b/tensorflow/python/keras/optimizer_v2/adamax_test.py
@@ -81,7 +81,7 @@
def testResourceSparse(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype) # pylint: disable=cell-var-from-loop
m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots()
@@ -275,7 +275,7 @@
def testTensorLearningRate(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -312,7 +312,7 @@
def testSharing(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/python/keras/optimizer_v2/ftrl_test.py b/tensorflow/python/keras/optimizer_v2/ftrl_test.py
index 6627fc0..9e17462 100644
--- a/tensorflow/python/keras/optimizer_v2/ftrl_test.py
+++ b/tensorflow/python/keras/optimizer_v2/ftrl_test.py
@@ -37,7 +37,7 @@
def doTestFtrlwithoutRegularization(self, use_resource=False):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
if use_resource:
var0 = variables.Variable([0.0, 0.0], dtype=dtype)
var1 = variables.Variable([0.0, 0.0], dtype=dtype)
@@ -77,7 +77,7 @@
def testFtrlwithoutRegularization2(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -107,7 +107,7 @@
def testMinimizeSparseResourceVariable(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -129,7 +129,7 @@
def testFtrlWithL1(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -159,7 +159,7 @@
def testFtrlWithBeta(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -185,7 +185,7 @@
def testFtrlWithL2_Beta(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -216,7 +216,7 @@
def testFtrlWithL1_L2(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -253,7 +253,7 @@
"""
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -286,7 +286,7 @@
"""Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -321,7 +321,7 @@
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True) as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([1.0, 2.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -404,7 +404,7 @@
def testEquivAdagradwithoutRegularization(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.Ftrl(
3.0,
@@ -415,7 +415,7 @@
l2_regularization_strength=0.0),
dtype)
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype)
@@ -449,7 +449,7 @@
def testEquivSparseGradientDescentwithoutRegularization(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.Ftrl(
3.0,
@@ -461,7 +461,7 @@
dtype,
is_sparse=True)
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0),
dtype,
@@ -473,7 +473,7 @@
def testEquivGradientDescentwithoutRegularization(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
for dtype in [dtypes.half, dtypes.float32]:
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.Ftrl(
3.0,
@@ -484,7 +484,7 @@
l2_regularization_strength=0.0),
dtype)
- with ops.Graph().as_default(), self.cached_session(use_gpu=True):
+ with ops.Graph().as_default(), self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0), dtype)
diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD
index cd25e47..73a25b3 100644
--- a/tensorflow/python/keras/utils/BUILD
+++ b/tensorflow/python/keras/utils/BUILD
@@ -9,6 +9,7 @@
default_visibility = [
"//tensorflow/python/feature_column:__pkg__",
"//tensorflow/python/keras:__subpackages__",
+ "//tensorflow/tools/pip_package:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
@@ -40,7 +41,6 @@
":control_flow_util",
":engine_utils",
":generic_utils",
- ":kpl_test_utils",
":layer_utils",
":multi_gpu_utils",
":np_utils",
@@ -59,7 +59,10 @@
name = "kpl_test_utils",
srcs = ["kpl_test_utils.py"],
srcs_version = "PY3",
- deps = [],
+ deps = [
+ "//tensorflow/python/keras",
+ "//tensorflow/python/keras/layers/preprocessing:string_lookup",
+ ],
)
py_library(
@@ -262,7 +265,6 @@
"dataset_creator.py",
],
srcs_version = "PY3",
- visibility = ["//tensorflow/tools/pip_package:__pkg__"],
deps = [
"//tensorflow/python:util",
],
@@ -272,6 +274,9 @@
name = "dataset_creator_test",
srcs = ["dataset_creator_test.py"],
python_version = "PY3",
+ tags = [
+ "no_tfrt", # TODO(b/180537361): Reenable TFRT after the issue is resolved.
+ ],
deps = [
":dataset_creator",
"//tensorflow/python/distribute:multi_worker_test_base",
diff --git a/tensorflow/python/keras/utils/dataset_creator_test.py b/tensorflow/python/keras/utils/dataset_creator_test.py
index fa7df62..ebb544e 100644
--- a/tensorflow/python/keras/utils/dataset_creator_test.py
+++ b/tensorflow/python/keras/utils/dataset_creator_test.py
@@ -50,8 +50,6 @@
next(iter(dataset_ops.DatasetV2.from_tensor_slices([1, 1]))))
def test_dataset_creator_usage_in_parameter_server_model_fit(self):
- self.skipTest("TODO(rchao): Enable this test once training API changes for "
- "DatasetFactory is submitted.")
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=2, num_ps=1, rpc_layer="grpc")
cluster_def["chief"] = [
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 2f77828..a828959 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2915,7 +2915,10 @@
name = "atrous_convolution_test",
size = "medium",
srcs = ["atrous_convolution_test.py"],
- tags = ["manual"],
+ tags = [
+ "manual",
+ "no_cuda_asan",
+ ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -3337,7 +3340,6 @@
shard_count = 30,
tags = [
"no_cuda11",
- "no_cuda_asan", # b/179030928
],
deps = [
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/aggregate_ops_test.py b/tensorflow/python/kernel_tests/aggregate_ops_test.py
index adb4f3a..9221f83 100644
--- a/tensorflow/python/kernel_tests/aggregate_ops_test.py
+++ b/tensorflow/python/kernel_tests/aggregate_ops_test.py
@@ -58,7 +58,7 @@
def testAddN(self):
np.random.seed(12345)
- with self.session(use_gpu=True) as sess:
+ with self.session():
for dtype in self._supported_types():
for count in range(1, self._MAX_N + 1):
data = [self._buildData((2, 2), dtype) for _ in range(count)]
@@ -71,7 +71,7 @@
@test_util.run_deprecated_v1
def testUnknownShapes(self):
np.random.seed(12345)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
for dtype in self._supported_types():
data = self._buildData((2, 2), dtype)
for count in range(1, self._MAX_N + 1):
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py
index 8a6ac74..2b4431a 100644
--- a/tensorflow/python/kernel_tests/argmax_op_test.py
+++ b/tensorflow/python/kernel_tests/argmax_op_test.py
@@ -97,7 +97,7 @@
def testFloatInt32Output(self):
x = np.asarray(100 * np.random.randn(200), dtype=np.float32)
expected_values = x.argmax()
- with self.session(use_gpu=True):
+ with self.session():
ans = math_ops.argmax(x, axis=0, output_type=dtypes.int32)
tf_ans = self.evaluate(ans)
self.assertEqual(np.int32, tf_ans.dtype)
@@ -105,7 +105,7 @@
# the values don't have a range that exceeds 32-bit integers.
self.assertAllEqual(tf_ans, expected_values)
expected_values = x.argmin()
- with self.session(use_gpu=True):
+ with self.session():
ans = math_ops.argmin(x, axis=0, output_type=dtypes.int32)
tf_ans = self.evaluate(ans)
self.assertEqual(np.int32, tf_ans.dtype)
diff --git a/tensorflow/python/kernel_tests/array_ops/batch_gather_op_test.py b/tensorflow/python/kernel_tests/array_ops/batch_gather_op_test.py
index e41053b..16ac476 100644
--- a/tensorflow/python/kernel_tests/array_ops/batch_gather_op_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/batch_gather_op_test.py
@@ -46,7 +46,7 @@
def testSimpleGather(self, indices_dtype):
data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13])
indices = [3, 4]
- with self.session(use_gpu=True):
+ with self.session():
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
@@ -62,7 +62,7 @@
def test2DArray(self, indices_dtype):
data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]])
indices = [[3], [4]]
- with self.session(use_gpu=True):
+ with self.session():
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
@@ -77,7 +77,7 @@
def testHigherRank(self):
data = np.array([[[0, 1, 2], [3, 7, 5]], [[8, 9, 10], [11, 15, 13]]])
indices = [[[2, 0], [1, 2]], [[2, 0], [0, 1]]]
- with self.session(use_gpu=True):
+ with self.session():
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
@@ -113,7 +113,7 @@
self.evaluate(array_ops.batch_gather(params, [7]))
def testEmptySlices(self):
- with self.session(use_gpu=True):
+ with self.session():
for dtype in _TEST_TYPES:
for itype in np.int32, np.int64:
params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/python/kernel_tests/array_ops/gather_op_test.py b/tensorflow/python/kernel_tests/array_ops/gather_op_test.py
index f0c762e..f8050b7 100644
--- a/tensorflow/python/kernel_tests/array_ops/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/gather_op_test.py
@@ -59,7 +59,7 @@
return data
def testScalar1D(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
data = np.array([0, 1, 2, 3, 7, 5])
for dtype in _TEST_TYPES:
for indices in 4, [1, 2, 2, 4, 5]:
@@ -74,7 +74,7 @@
self.assertEqual(np_val.shape, gather_t.get_shape())
def testScalar2D(self):
- with self.session(use_gpu=True):
+ with self.session():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
[9, 10, 11], [12, 13, 14]])
for dtype in _TEST_TYPES:
@@ -90,7 +90,7 @@
self.assertEqual(expected_shape, gather_t.get_shape())
def testSimpleTwoD32(self):
- with self.session(use_gpu=True):
+ with self.session():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
[9, 10, 11], [12, 13, 14]])
for dtype in _TEST_TYPES:
@@ -304,7 +304,7 @@
# On GPU the bad indices do not raise error but fetch 0 values
if not test.is_gpu_available():
return
- with self.session(use_gpu=True):
+ with self.session():
params = [[0, 1, 2], [3, 4, 5]]
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
array_ops.gather(params, [[7]], axis=0).eval()
diff --git a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
index b3c566b..e19db9b 100644
--- a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
@@ -211,7 +211,7 @@
scatter = state_ops.scatter_nd_update(ref, indices, updates)
init = variables.global_variables_initializer()
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.evaluate(init)
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
@@ -225,7 +225,7 @@
scatter = state_ops.scatter_nd_update(ref, indices, updates)
init = variables.global_variables_initializer()
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.evaluate(init)
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
diff --git a/tensorflow/python/kernel_tests/array_ops/slice_op_test.py b/tensorflow/python/kernel_tests/array_ops/slice_op_test.py
index d8097ad..55cb164 100644
--- a/tensorflow/python/kernel_tests/array_ops/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/slice_op_test.py
@@ -40,7 +40,7 @@
def testEmpty(self):
inp = np.random.rand(4, 4).astype("f")
for k in xrange(4):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
a = constant_op.constant(inp, shape=[4, 4], dtype=dtypes.float32)
slice_t = a[2, k:k]
slice_val = self.evaluate(slice_t)
@@ -49,7 +49,7 @@
def testInt32(self):
inp = np.random.rand(4, 4).astype("i")
for k in xrange(4):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
a = constant_op.constant(inp, shape=[4, 4], dtype=dtypes.int32)
slice_t = a[2, k:k]
slice_val = self.evaluate(slice_t)
@@ -119,7 +119,7 @@
def testSelectAll(self):
for _ in range(10):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inp = np.random.rand(4, 4, 4, 4).astype("f")
a = constant_op.constant(inp, shape=[4, 4, 4, 4], dtype=dtypes.float32)
@@ -133,7 +133,7 @@
def testSingleDimension(self):
for _ in range(10):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inp = np.random.rand(10).astype("f")
a = constant_op.constant(inp, shape=[10], dtype=dtypes.float32)
@@ -229,7 +229,7 @@
def testSingleElementAll(self):
for _ in range(10):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inp = np.random.rand(4, 4).astype("f")
a = constant_op.constant(inp, shape=[4, 4], dtype=dtypes.float32)
@@ -312,7 +312,7 @@
self.assertAllEqual(m2.get_shape().as_list(), [1, 2, 3])
def _testGradientSlice(self, input_shape, slice_begin, slice_size):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
num_inputs = np.prod(input_shape)
num_grads = np.prod(slice_size)
inp = np.random.rand(num_inputs).astype("f").reshape(input_shape)
@@ -362,7 +362,7 @@
self.assertAllClose(np_ans, result)
def _testGradientVariableSize(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inp = constant_op.constant([1.0, 2.0, 3.0], name="in")
out = array_ops.slice(inp, [1], [-1])
grad_actual = self.evaluate(gradients_impl.gradients(out, inp)[0])
@@ -380,7 +380,7 @@
# Regression test for bug in slice. A low-level bug in Eigen was causing
# incorrect results for negative indices in multi-dimensional tensors.
# See b/114318298.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 7]])
loss1 = math_ops.reduce_sum(x[:-1, :-1] * 1.0)
loss2 = math_ops.reduce_sum(x[:-1][:, :-1])
@@ -477,7 +477,7 @@
self.assertEqual([None, 2], c.get_shape().as_list())
def testSliceOfSlice(self):
- with self.session(use_gpu=True):
+ with self.session():
a = constant_op.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
b = a[1:, :]
c = b[:-1, :]
diff --git a/tensorflow/python/kernel_tests/array_ops/stack_op_test.py b/tensorflow/python/kernel_tests/array_ops/stack_op_test.py
index ab1dd1d..f0e7db4 100644
--- a/tensorflow/python/kernel_tests/array_ops/stack_op_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/stack_op_test.py
@@ -52,7 +52,7 @@
@test_util.run_deprecated_v1
def testSimple(self):
np.random.seed(7)
- with self.session(use_gpu=True):
+ with self.session():
for shape in (2,), (3,), (2, 3), (3, 2), (8, 2, 10):
rank = len(shape)
for axis in range(-rank, rank):
@@ -90,7 +90,7 @@
@test_util.run_deprecated_v1
def testConst(self):
np.random.seed(7)
- with self.session(use_gpu=True):
+ with self.session():
# Verify that shape induction works with shapes produced via const stack
a = constant_op.constant([1, 2, 3, 4, 5, 6])
b = array_ops.reshape(a, array_ops.stack([2, 3]))
@@ -155,7 +155,7 @@
data = np.random.randn(*shape)
shapes = [shape[1:]] * shape[0]
with self.subTest(shape=shape):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# TODO(irving): Remove list() once we handle maps correctly
xs = list(map(constant_op.constant, data))
c = array_ops.stack(xs)
@@ -171,7 +171,7 @@
out_shape = list(shape[1:])
out_shape.insert(1, shape[0])
with self.subTest(shape=shape):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# TODO(irving): Remove list() once we handle maps correctly
xs = list(map(constant_op.constant, data))
c = array_ops.stack(xs, axis=1)
@@ -241,7 +241,7 @@
for axis in range(-rank, rank):
test_arrays = np_split_squeeze(expected, axis)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.subTest(shape=shape, dtype=dtype, axis=axis):
actual_pack = array_ops.stack(test_arrays, axis=axis)
self.assertEqual(expected.shape, actual_pack.get_shape())
@@ -265,7 +265,7 @@
def testComplex(self):
np.random.seed(7)
- with self.session(use_gpu=True):
+ with self.session():
for shape in (2,), (3,), (2, 3), (3, 2), (8, 2, 10):
for dtype in [np.complex64, np.complex128]:
with self.subTest(shape=shape, dtype=dtype):
@@ -279,7 +279,7 @@
@test_util.run_deprecated_v1
def testSimple(self):
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(
[1, 0, 2],
ops.convert_to_tensor([1, constant_op.constant(0), 2]).eval())
@@ -299,7 +299,7 @@
]).eval())
def testWithNDArray(self):
- with self.session(use_gpu=True):
+ with self.session():
result = ops.convert_to_tensor([[[0., 0.],
constant_op.constant([1., 1.])],
np.array(
@@ -310,7 +310,7 @@
@test_util.run_deprecated_v1
def testVariable(self):
- with self.session(use_gpu=True):
+ with self.session():
v = variables.Variable(17)
result = ops.convert_to_tensor([[0, 0, 0], [0, v, 0], [0, 0, 0]])
self.evaluate(v.initializer)
@@ -364,7 +364,7 @@
@test_util.run_deprecated_v1
def testPlaceholder(self):
- with self.session(use_gpu=True):
+ with self.session():
# Test using placeholder with a defined shape.
ph_0 = array_ops.placeholder(dtypes.int32, shape=[])
result_0 = ops.convert_to_tensor([[0, 0, 0], [0, ph_0, 0], [0, 0, 0]])
@@ -391,7 +391,7 @@
# Dynamic shape error.
ph_1 = array_ops.placeholder(dtypes.int32)
result_1 = ops.convert_to_tensor([[0, 0, 0], [0, ph_1, 0], [0, 0, 0]])
- with self.session(use_gpu=True):
+ with self.session():
with self.assertRaises(errors_impl.InvalidArgumentError):
result_1.eval(feed_dict={ph_1: [1]})
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 006737f..e4219f1 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -474,7 +474,7 @@
@test_util.run_deprecated_v1
def testReverseRowsOf3Channels(self):
"""Tests optimized code for reversing rows with last dim size = 3."""
- with self.session(use_gpu=True):
+ with self.session():
for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
for outer_size in (1, 2):
for middle_size in list(range(50)) + [100000]:
@@ -491,7 +491,7 @@
@test_util.run_deprecated_v1
def testReverseRowsOf4Channels(self):
- with self.session(use_gpu=True):
+ with self.session():
for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
for outer_size in (1, 2):
for middle_size in list(range(50)) + [100000]:
@@ -508,7 +508,7 @@
@test_util.run_deprecated_v1
def testReverseColumnsOf3Channels(self):
- with self.session(use_gpu=True):
+ with self.session():
for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
for outer_size in list(range(50)) + [100000]:
for middle_size in (1, 2):
@@ -641,7 +641,7 @@
def test_basic_slice(self):
for tensor_type in STRIDED_SLICE_TYPES:
with self.subTest(tensor_type=tensor_type):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
checker = StridedSliceChecker(
self, StridedSliceChecker.REF_TENSOR, tensor_type=tensor_type)
_ = checker[:, :, :]
@@ -696,7 +696,7 @@
@test_util.run_deprecated_v1
def testDegenerateSlices(self):
- with self.session(use_gpu=True):
+ with self.session():
checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
# degenerate by offering a forward interval with a negative stride
_ = checker[0:-1:-1, :, :]
@@ -717,7 +717,7 @@
@test_util.run_deprecated_v1
def testEllipsis(self):
- with self.session(use_gpu=True):
+ with self.session():
raw = [[[[[1, 2], [3, 4], [5, 6]]], [[[7, 8], [9, 10], [11, 12]]]]]
checker = StridedSliceChecker(self, raw)
@@ -738,7 +738,7 @@
@test_util.run_deprecated_v1
def testShrink(self):
- with self.session(use_gpu=True):
+ with self.session():
raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
[[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
checker = StridedSliceChecker(self, raw)
@@ -749,7 +749,7 @@
@test_util.run_deprecated_v1
def testBothNewAxisAndShrink(self):
- with self.session(use_gpu=True):
+ with self.session():
ones = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int16)
self.assertAllEqual(
ones[array_ops.newaxis, :,
@@ -757,7 +757,7 @@
@test_util.run_deprecated_v1
def testTensorIndexing(self):
- with self.session(use_gpu=True):
+ with self.session():
raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
[[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
checker = StridedSliceChecker(self, raw, check_type_infer=False)
@@ -769,7 +769,7 @@
_ = checker[..., 2**64 // 2**63] # Test longs in Python 2
def testTensorIndexingTypeError(self):
- with self.session(use_gpu=True):
+ with self.session():
checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
expected = re.escape(array_ops._SLICE_TYPE_ERROR)
with self.assertRaisesRegex(TypeError, expected):
@@ -787,7 +787,7 @@
@test_util.run_deprecated_v1
def testExpand(self):
- with self.session(use_gpu=True):
+ with self.session():
raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
[[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
checker = StridedSliceChecker(self, raw)
@@ -805,7 +805,7 @@
@test_util.run_deprecated_v1
def testExpandVariable(self):
- with self.session(use_gpu=True):
+ with self.session():
x = variables.Variable(7, dtype=dtypes.int32)
self.evaluate(x.initializer)
y = x[None].eval()
@@ -814,7 +814,7 @@
@test_util.run_deprecated_v1
def testOptimizedCases(self):
- with self.session(use_gpu=True):
+ with self.session():
checker = StridedSliceChecker(self,
StridedSliceChecker.REF_TENSOR_ALIGNED)
# Identity
@@ -830,7 +830,7 @@
@test_util.run_v1_only("currently failing on v2")
def testMasks(self):
- with self.session(use_gpu=True):
+ with self.session():
scalar = np.array(0)
# Test tensor type mask
checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
@@ -870,7 +870,7 @@
@test_util.run_deprecated_v1
def testUnknown(self):
- with self.session(use_gpu=True):
+ with self.session():
uncertain_tensor = array_ops.placeholder(dtypes.float32)
a = StridedSliceShapeChecker(uncertain_tensor)
a_slice_shape = a[...]
@@ -882,7 +882,7 @@
@test_util.run_deprecated_v1
def testTensorShapeUncertain(self):
- with self.session(use_gpu=True):
+ with self.session():
uncertain_tensor = array_ops.placeholder(
dtypes.float32, shape=(5, None, 7))
a = StridedSliceShapeChecker(uncertain_tensor)
@@ -906,7 +906,7 @@
@test_util.run_deprecated_v1
def testTensorValuedIndexShape(self):
- with self.session(use_gpu=True):
+ with self.session():
defined_shape_tensor = array_ops.placeholder(
dtypes.float32, shape=(5, 3, 7))
index_value = array_ops.placeholder(dtypes.int32, shape=())
@@ -965,7 +965,7 @@
@test_util.run_v1_only("b/120545219")
def testGradient(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
var = variables.Variable(
array_ops.reshape(
math_ops.range(1, 97, 1, dtype=dtypes.float32), shape=(6, 4, 4)))
@@ -992,7 +992,7 @@
@test_util.run_v1_only("b/120545219")
def testGradientZero(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
var = variables.Variable(8.)
init = variables.global_variables_initializer()
sess.run(init)
@@ -1001,7 +1001,7 @@
@test_util.run_deprecated_v1
def testInt64Indices(self):
- with self.session(use_gpu=True) as sess:
+ with self.session():
a = math_ops.range(3, dtype=dtypes.float32)
index = constant_op.constant(1, dtype=dtypes.int64)
b = 2. * a[index]
@@ -1014,7 +1014,7 @@
@test_util.run_deprecated_v1
def testHostVsDevice(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
var2 = variables.Variable(
array_ops.reshape(
math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
@@ -1029,7 +1029,7 @@
@test_util.run_deprecated_v1
def testInt64Shape(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
original_dy = array_ops.reshape(
math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
shape=(4, 1, 1))
@@ -1044,7 +1044,7 @@
@test_util.run_deprecated_v1
def testMixedIndexTypes(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
original_dy = array_ops.reshape(
math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
shape=(4, 1, 1))
@@ -1133,7 +1133,7 @@
if self.tensor_type.is_complex:
value -= 1j * value
- with self.test.test_session(use_gpu=True) as sess:
+ with self.test.test_session() as sess:
if self._use_resource:
var = resource_variable_ops.ResourceVariable(self.x)
else:
@@ -1514,7 +1514,7 @@
def testInvertPermutation(self):
for dtype in [dtypes.int32, dtypes.int64]:
with self.subTest(dtype=dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype)
y = array_ops.invert_permutation(x)
self.assertAllEqual(y.get_shape(), [5])
@@ -1597,7 +1597,7 @@
def testInvertPermutation(self):
for dtype in [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]:
with self.subTest(dtype=dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant([0, 1, 2, 3], dtype=dtype)
y = gen_array_ops.snapshot(x)
self.assertAllEqual(y, [0, 1, 2, 3])
diff --git a/tensorflow/python/kernel_tests/atrous_conv2d_test.py b/tensorflow/python/kernel_tests/atrous_conv2d_test.py
index e0cf7c2..1aa0b03 100644
--- a/tensorflow/python/kernel_tests/atrous_conv2d_test.py
+++ b/tensorflow/python/kernel_tests/atrous_conv2d_test.py
@@ -61,7 +61,7 @@
@test_util.run_deprecated_v1
def testAtrousConv2DForward(self):
- with self.session(use_gpu=True):
+ with self.session():
# Input: [batch, height, width, input_depth]
height = 9
for width in [9, 10]: # Test both odd and even width.
@@ -108,7 +108,7 @@
padding = "SAME" # The padding needs to be "SAME"
np.random.seed(1) # Make it reproducible.
- with self.session(use_gpu=True):
+ with self.session():
# Input: [batch, height, width, input_depth]
for height in range(15, 17):
for width in range(15, 17):
@@ -138,7 +138,7 @@
@test_util.run_deprecated_v1
def testGradient(self):
- with self.session(use_gpu=True):
+ with self.session():
# Input: [batch, height, width, input_depth]
x_shape = [2, 5, 6, 2]
# Filter: [kernel_height, kernel_width, input_depth, output_depth]
@@ -166,7 +166,7 @@
@test_util.run_deprecated_v1
def testAtrousConv2DTransposeForward(self):
- with self.session(use_gpu=True):
+ with self.session():
# Input: [batch, height, width, input_depth]
height = 9
for width in [9, 10]: # Test both odd and even width.
@@ -206,7 +206,7 @@
@test_util.run_deprecated_v1
def testAtrousDepthwiseConv2DForward(self):
strides = [1, 1, 1, 1]
- with self.session(use_gpu=True):
+ with self.session():
# Input: [batch, height, width, input_depth]
height = 9
for width in [9, 10]: # Test both odd and even width.
diff --git a/tensorflow/python/kernel_tests/banded_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/banded_triangular_solve_op_test.py
index bd0fdae..4545a2a 100644
--- a/tensorflow/python/kernel_tests/banded_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/banded_triangular_solve_op_test.py
@@ -86,7 +86,7 @@
a_np = np.tile(a_np, batch_dims + [1, 1])
b = np.tile(b, batch_dims + [1, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
a_tf = a
b_tf = b
if use_placeholder:
@@ -199,7 +199,7 @@
# right-hand sides.
matrix = np.array([[1., 1.], [1., 1.]])
rhs = np.array([[1., 0.]])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
@@ -208,7 +208,7 @@
# Number of bands exceeds the dimension of the matrix.
matrix = np.ones((6, 4))
rhs = np.ones((4, 2))
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py
index a64032e..73f0209 100644
--- a/tensorflow/python/kernel_tests/basic_gpu_test.py
+++ b/tensorflow/python/kernel_tests/basic_gpu_test.py
@@ -40,13 +40,13 @@
class GPUBinaryOpsTest(test.TestCase):
def _compareGPU(self, x, y, np_func, tf_func):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = tf_func(inx, iny)
tf_gpu = self.evaluate(out)
- with self.cached_session(use_gpu=False) as sess:
+ with self.cached_session(use_gpu=False):
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = tf_func(inx, iny)
@@ -143,7 +143,7 @@
np_out = np.floor_divide(x, y + 0.1)
- with self.session(use_gpu=True) as sess:
+ with self.session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y + 0.1)
ofunc = inx / iny
@@ -167,7 +167,7 @@
def _compareGpu(self, x, y, np_func, tf_func):
np_ans = np_func(x, y)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = tf_func(inx, iny)
diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
index ac82a32..331fded 100644
--- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
@@ -166,7 +166,7 @@
def Loss(x, y):
return math_ops.reduce_sum(math_ops.matmul(x, y, adjoint_a, adjoint_b))
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
((x_jacob_t, y_jacob_t),
(x_jacob_n, y_jacob_n)) = gradient_checker_v2.compute_gradient(
Loss, [x, y], delta=delta)
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py
index 133d339..4ca8133 100644
--- a/tensorflow/python/kernel_tests/bincount_op_test.py
+++ b/tensorflow/python/kernel_tests/bincount_op_test.py
@@ -36,7 +36,7 @@
class BincountTest(test_util.TensorFlowTestCase):
def test_empty(self):
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(
self.evaluate(bincount_ops.bincount([], minlength=5)),
[0, 0, 0, 0, 0])
@@ -54,7 +54,7 @@
np.float64)
def test_values(self):
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(
self.evaluate(bincount_ops.bincount([1, 1, 1, 2, 2, 3])),
[0, 3, 2, 1])
@@ -74,7 +74,7 @@
np.ones(10000))
def test_maxlength(self):
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(
self.evaluate(bincount_ops.bincount([5], maxlength=3)), [0, 0, 0])
self.assertAllEqual(
@@ -84,7 +84,7 @@
def test_random_with_weights(self):
num_samples = 10000
- with self.session(use_gpu=True):
+ with self.session():
np.random.seed(42)
for dtype in [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]:
arr = np.random.randint(0, 1000, num_samples)
@@ -98,7 +98,7 @@
def test_random_without_weights(self):
num_samples = 10000
- with self.session(use_gpu=True):
+ with self.session():
np.random.seed(42)
for dtype in [np.int32, np.float32]:
arr = np.random.randint(0, 1000, num_samples)
@@ -108,7 +108,7 @@
np.bincount(arr, weights))
def test_zero_weights(self):
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(
self.evaluate(bincount_ops.bincount(np.arange(1000), np.zeros(1000))),
np.zeros(1000))
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index f1e1ff1..fd177c6 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -33,21 +33,21 @@
def testBroadcastToBasic(self):
for dtype in [np.uint8, np.uint16, np.int8, np.int16, np.int32, np.int64]:
- with self.session(use_gpu=True):
+ with self.session():
x = np.array([1, 2, 3], dtype=dtype)
v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
v_np = np.broadcast_to(x, [3, 3])
self.assertAllEqual(v_tf, v_np)
def testBroadcastToString(self):
- with self.session(use_gpu=True):
+ with self.session():
x = np.array([b"1", b"2", b"3"])
v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
v_np = np.broadcast_to(x, [3, 3])
self.assertAllEqual(v_tf, v_np)
def testBroadcastToBool(self):
- with self.session(use_gpu=True):
+ with self.session():
x = np.array([True, False, True], dtype=np.bool)
v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
v_np = np.broadcast_to(x, [3, 3])
@@ -56,7 +56,7 @@
def testBroadcastToShape(self):
for input_dim in range(1, 6):
for output_dim in range(input_dim, 6):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
input_shape = [2] * input_dim
output_shape = [2] * output_dim
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
@@ -67,7 +67,7 @@
def testBroadcastToShapeInnerDim(self):
input_shape = [2, 1, 3]
output_shape = [2, 5, 3]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
v_np = np.broadcast_to(x, output_shape)
@@ -76,7 +76,7 @@
def testBroadcastToShapeLargerDim(self):
input_shape = [2, 1, 3, 2, 2, 2]
output_shape = [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 15, 3, 2, 2, 2]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
v_np = np.broadcast_to(x, output_shape)
@@ -85,21 +85,21 @@
def testBroadcastToShapeLargerDim2(self):
input_shape = [2, 1, 3, 2, 2, 2, 1, 1, 1]
output_shape = [1, 1, 1, 2, 5, 3, 2, 2, 2, 3, 3, 3]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
v_np = np.broadcast_to(x, output_shape)
self.assertAllEqual(v_tf, v_np)
def testBroadcastToScalar(self):
- with self.session(use_gpu=True):
+ with self.session():
x = np.array(1, dtype=np.int32)
v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
v_np = np.broadcast_to(x, [3, 3])
self.assertAllEqual(v_tf, v_np)
def testBroadcastScalarToNonScalar(self):
- with self.session(use_gpu=True):
+ with self.session():
x = np.array(1.0, dtype=np.float)
v_tf = array_ops.broadcast_to(constant_op.constant(1.0), [2, 3, 4,
1, 1, 1])
@@ -108,7 +108,7 @@
def testBroadcastToShapeTypeAndInference(self):
for dtype in [dtypes.int32, dtypes.int64]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = np.array([1, 2, 3])
v_tf = array_ops.broadcast_to(
constant_op.constant(x),
diff --git a/tensorflow/python/kernel_tests/bucketize_op_test.py b/tensorflow/python/kernel_tests/bucketize_op_test.py
index 59c30d8..de73948 100644
--- a/tensorflow/python/kernel_tests/bucketize_op_test.py
+++ b/tensorflow/python/kernel_tests/bucketize_op_test.py
@@ -36,14 +36,14 @@
constant_op.constant([-5, 0, 2, 3, 5, 8, 10, 11, 12]),
boundaries=[0, 3, 8, 11])
expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4]
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.assertAllEqual(expected_out, self.evaluate(op))
def testEmptyFloat(self):
op = math_ops._bucketize(
array_ops.zeros([0, 3], dtype=dtypes.float32), boundaries=[])
expected_out = np.zeros([0, 3], dtype=np.float32)
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(expected_out, self.evaluate(op))
def testFloat(self):
@@ -51,7 +51,7 @@
constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]),
boundaries=[0., 3., 8., 11.])
expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4]
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.assertAllEqual(expected_out, self.evaluate(op))
def test2DInput(self):
@@ -59,14 +59,14 @@
constant_op.constant([[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]),
boundaries=[0, 3, 8, 11])
expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]]
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.assertAllEqual(expected_out, self.evaluate(op))
@test_util.run_deprecated_v1
def testInvalidBoundariesOrder(self):
op = math_ops._bucketize(
constant_op.constant([-5, 0]), boundaries=[0, 8, 3, 11])
- with self.session(use_gpu=True) as sess:
+ with self.session():
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"Expected sorted boundaries"):
self.evaluate(op)
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index 7b79415..c1f8cc3 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -108,7 +108,7 @@
with self.cached_session(use_gpu=False):
b = math_ops.cast(math_ops.cast(a, dtypes.bfloat16), dtypes.float32)
self.assertAllClose(a, self.evaluate(b), rtol=1 / 128.)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
b = math_ops.cast(math_ops.cast(a, dtypes.bfloat16), dtypes.float32)
self.assertAllClose(a, self.evaluate(b), rtol=1 / 128.)
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index 0697f7d..cc03e60 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -166,7 +166,7 @@
@test_util.disable_xla("b/123337890")
def testNotInvertibleCPU(self):
# The input should be invertible.
- with self.session(use_gpu=True):
+ with self.session():
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
"Cholesky decomposition was not successful. The"
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index d0c805f..d85e393 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -52,7 +52,7 @@
# ClipByValue test
def testClipByValue(self):
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-5.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
np_ans = [[-4.4, 2.0, 3.0], [4.0, 4.4, 4.4]]
clip_value = 4.4
@@ -73,7 +73,7 @@
dtypes.int64,
dtypes.uint8,
]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
np_ans = [[2, 2, 3], [4, 4, 4]]
clip_value_min = 2
@@ -95,7 +95,7 @@
dtypes.int64,
dtypes.uint8,
]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
np_ans = [[2, 2, 3], [4, 4, 4]]
clip_value_min = constant_op.constant(
@@ -118,7 +118,7 @@
dtypes.int64,
dtypes.uint8,
]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
np_ans = [[4, 4, 4], [4, 5, 6]]
clip_value_min = 4
@@ -141,7 +141,7 @@
dtypes.int64,
dtypes.uint8,
]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
np_ans = [[2, 2, 3], [5, 5, 6]]
clip_value_min = constant_op.constant(
@@ -154,7 +154,7 @@
self.assertAllClose(np_ans, tf_ans)
def testClipByValueBadShape(self):
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-5.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3, 1])
# Use a nonsensical shape.
clip = constant_op.constant([1.0, 2.0])
@@ -176,7 +176,7 @@
def _testClipIndexedSlicesByValue(self, values, indices, shape,
clip_value_min, clip_value_max, expected):
- with self.session(use_gpu=True) as sess:
+ with self.session():
values = constant_op.constant(values)
indices = constant_op.constant(indices)
shape = constant_op.constant(shape)
@@ -211,7 +211,7 @@
# ClipByNorm tests
def testClipByNormClipped(self):
# Norm clipping when clip_norm < 5
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
# Norm of x = sqrt(3^2 + 4^2) = 5
np_ans = [[-2.4, 0.0, 0.0], [3.2, 0.0, 0.0]]
@@ -227,14 +227,14 @@
@test_util.run_deprecated_v1
def testClipByNormGradientZeros(self):
- with self.session(use_gpu=True):
+ with self.session():
x = array_ops.zeros([3])
b = clip_ops.clip_by_norm(x, 1.)
grad, = gradients_impl.gradients(b, x)
self.assertAllEqual(grad, [1., 1., 1.])
def testClipByNormBadShape(self):
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3, 1])
# Use a nonsensical shape.
clip = constant_op.constant([1.0, 2.0])
@@ -243,7 +243,7 @@
def testClipByNormNotClipped(self):
# No norm clipping when clip_norm >= 5
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
# Norm of x = sqrt(3^2 + 4^2) = 5
np_ans = [[-3.0, 0.0, 0.0], [4.0, 0.0, 0.0]]
@@ -255,7 +255,7 @@
def testClipByNormZero(self):
# No norm clipping when norm = 0
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
# Norm = 0, no changes
np_ans = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
@@ -267,7 +267,7 @@
def testClipByNormClippedWithDim0(self):
# Norm clipping when clip_norm < 5
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 3.0], shape=[2, 3])
# Norm of x[:, 0] = sqrt(3^2 + 4^2) = 5, x[:, 2] = 3
np_ans = [[-2.4, 0.0, 0.0], [3.2, 0.0, 3.0]]
@@ -279,7 +279,7 @@
def testClipByNormClippedWithDim1(self):
# Norm clipping when clip_norm < 5
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 3.0], shape=[2, 3])
# Norm of x[0, :] = 3, x[1, :] = sqrt(3^2 + 4^2) = 5
np_ans = [[-3.0, 0.0, 0.0], [3.2, 0.0, 2.4]]
@@ -291,7 +291,7 @@
def testClipByNormNotClippedWithAxes(self):
# No norm clipping when clip_norm >= 5
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 3.0], shape=[2, 3])
# Norm of x[0, :] = 3, x[1, :] = sqrt(3^2 + 4^2) = 5
np_ans = [[-3.0, 0.0, 0.0], [4.0, 0.0, 3.0]]
@@ -305,7 +305,7 @@
@test_util.run_deprecated_v1
def testClipByGlobalNormClipped(self):
# Norm clipping when clip_norm < 5
- with self.session(use_gpu=True):
+ with self.session():
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
x1 = constant_op.constant([1.0, -2.0])
# Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
@@ -327,7 +327,7 @@
@test_util.run_deprecated_v1
def testClipByGlobalNormClippedTensor(self):
# Norm clipping when clip_norm < 5
- with self.session(use_gpu=True):
+ with self.session():
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
x1 = constant_op.constant([1.0, -2.0])
# Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
@@ -349,7 +349,7 @@
@test_util.run_deprecated_v1
def testClipByGlobalNormSupportsNone(self):
# Norm clipping when clip_norm < 5
- with self.session(use_gpu=True):
+ with self.session():
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
x1 = constant_op.constant([1.0, -2.0])
# Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
@@ -373,7 +373,7 @@
@test_util.run_deprecated_v1
def testClipByGlobalNormWithIndexedSlicesClipped(self):
# Norm clipping when clip_norm < 5
- with self.session(use_gpu=True):
+ with self.session():
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
x1 = ops.IndexedSlices(
constant_op.constant([1.0, -2.0]), constant_op.constant([3, 4]))
@@ -407,7 +407,7 @@
@test_util.run_deprecated_v1
def testClipByGlobalNormNotClipped(self):
# No norm clipping when clip_norm >= 5
- with self.session(use_gpu=True):
+ with self.session():
x0 = constant_op.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
x1 = constant_op.constant([1.0, -2.0])
# Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
@@ -427,7 +427,7 @@
@test_util.run_deprecated_v1
def testClipByGlobalNormZero(self):
# No norm clipping when norm = 0
- with self.session(use_gpu=True):
+ with self.session():
x0 = constant_op.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
x1 = constant_op.constant([0.0, 0.0])
# Norm = 0, no changes
@@ -447,7 +447,7 @@
@test_util.run_deprecated_v1
def testClipByGlobalNormInf(self):
# Expect all NaNs when global norm is inf.
- with self.session(use_gpu=True):
+ with self.session():
x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0],
shape=[2, 3])
x1 = constant_op.constant([1.0, -2.0])
@@ -463,7 +463,7 @@
def testClipByAverageNormClipped(self):
# Norm clipping when average clip_norm < 0.83333333
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
# Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
np_ans = [[-2.88, 0.0, 0.0], [3.84, 0.0, 0.0]]
@@ -475,7 +475,7 @@
def testClipByAverageNormClippedTensor(self):
# Norm clipping when average clip_norm < 0.83333333
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
# Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
np_ans = [[-2.88, 0.0, 0.0], [3.84, 0.0, 0.0]]
@@ -487,7 +487,7 @@
def testClipByAverageNormNotClipped(self):
# No norm clipping when average clip_norm >= 0.83333333
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
# Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
np_ans = [[-3.0, 0.0, 0.0], [4.0, 0.0, 0.0]]
@@ -499,7 +499,7 @@
def testClipByAverageNormZero(self):
# No norm clipping when average clip_norm = 0
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
# Average norm = 0, no changes
np_ans = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
@@ -512,7 +512,7 @@
def testClipByAverageNormReplacedWithClipByNorm(self):
# Check clip_by_average_norm(t) is the same as
# clip_by_norm(t, clip_norm * tf.compat.v1.to_float(tf.size(t)))
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
# Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
# expected answer [[-2.88, 0.0, 0.0], [3.84, 0.0, 0.0]]
@@ -532,7 +532,7 @@
y = clip_ops.clip_by_value(zero, 1.0, 1.0)
z = clip_ops.clip_by_value(zero, zero, 1.0)
w = clip_ops.clip_by_value(zero, 1.0, zero)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sess.run([x, y, z, w], feed_dict={zero: np.zeros((7, 0))})
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index bcc31872..da4f4f8 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -38,7 +38,7 @@
@test_util.run_deprecated_v1
def testHStack(self):
- with self.session(use_gpu=True):
+ with self.session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
c = array_ops.concat([p1, p2], 0)
@@ -54,7 +54,7 @@
@test_util.run_deprecated_v1
def testVStack(self):
- with self.session(use_gpu=True):
+ with self.session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
c = array_ops.concat([p1, p2], 1)
@@ -70,7 +70,7 @@
@test_util.run_deprecated_v1
def test4DStack(self):
- with self.session(use_gpu=True):
+ with self.session():
p1 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 1, 1])
p2 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 4, 1])
c = array_ops.concat([p1, p2], 2)
@@ -121,7 +121,7 @@
dtype_feed = dtypes.float32
else:
dtype_feed = dtype
- with self.session(use_gpu=True):
+ with self.session():
p = []
for i in np.arange(num_tensors):
input_shape = shape
@@ -315,7 +315,7 @@
@test_util.run_deprecated_v1
def testGradientWithUnknownInputDim(self):
- with self.session(use_gpu=True):
+ with self.session():
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
c = array_ops.concat([x, y], 2)
@@ -526,7 +526,7 @@
# shared memory is not large for all the inputs
@test_util.run_deprecated_v1
def testConcatLargeNumberOfTensors(self):
- with self.session(use_gpu=True):
+ with self.session():
for concat_dim in range(2):
params = {}
p = []
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index cb014fc..68d6cad 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -54,7 +54,7 @@
def _testGpu(self, x):
np_ans = np.array(x)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_ans = ops.convert_to_tensor(x).eval()
dtype = dtypes_lib.as_dtype(np_ans.dtype)
if dtype.is_floating or dtype.is_complex:
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 9026100..8d70dc9 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -557,7 +557,7 @@
@test_util.run_v1_only("b/120545219")
def testCondColocation(self):
- with self.session(use_gpu=True):
+ with self.session():
with ops.device("/cpu:0"):
v = variables.Variable(7.0)
@@ -1224,7 +1224,7 @@
def testCondGradMultiDevice(self):
config = config_pb2.ConfigProto(device_count={"CPU": 2},
allow_soft_placement=True)
- with self.cached_session(use_gpu=True, config=config) as sess:
+ with self.cached_session(config=config) as sess:
pred = array_ops.placeholder(dtypes.bool, [])
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
@@ -2621,7 +2621,7 @@
def testWhileCondGradMultiDevice(self):
config = config_pb2.ConfigProto(device_count={"CPU": 2},
allow_soft_placement=True)
- with self.cached_session(use_gpu=True, config=config) as sess:
+ with self.cached_session(config=config) as sess:
pred = array_ops.placeholder(dtypes.bool, [])
x_init = constant_op.constant(1.0)
@@ -4911,7 +4911,7 @@
if test_util.is_gpu_available():
self.skipTest("b/128646478 fails in opensource")
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
with ops.device(test.gpu_device_name()):
value = constant_op.constant(1.0)
with ops.device("/cpu:0"):
diff --git a/tensorflow/python/kernel_tests/conv1d_transpose_test.py b/tensorflow/python/kernel_tests/conv1d_transpose_test.py
index 02ac5af..f068239 100644
--- a/tensorflow/python/kernel_tests/conv1d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv1d_transpose_test.py
@@ -153,7 +153,7 @@
def testConv1DTransposeSingleStrideNCW(self):
# `NCW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with self.session():
strides = [1, 1, 1]
# Input, output: [batch, depth, width]
@@ -184,7 +184,7 @@
def testConv1DTransposeSameNCW(self):
# `NCW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with self.session():
strides = [1, 1, 2]
# Input, output: [batch, depth, width]
@@ -216,7 +216,7 @@
def testConv1DTransposeValidNCW(self):
# `NCW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with self.session():
strides = [1, 1, 2]
# Input, output: [batch, depth, width]
diff --git a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
index e14a719..2a57d68 100644
--- a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
@@ -77,7 +77,7 @@
@test_util.run_deprecated_v1
def testGradientDilatedConv(self):
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with self.session():
for padding in [
"SAME",
"VALID",
diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
index 96f1c05..60f1650 100644
--- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
@@ -186,7 +186,7 @@
def testConv2DTransposeSingleStrideNCHW(self):
# `NCHW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with self.session():
strides = [1, 1, 1, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -221,7 +221,7 @@
def testConv2DTransposeSameNCHW(self):
# `NCHW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with self.session():
strides = [1, 1, 2, 2]
# Input, output: [batch, depth, height, width]
@@ -257,7 +257,7 @@
def testConv2DTransposeValidNCHW(self):
# `NCHW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with self.session():
strides = [1, 1, 2, 2]
# Input, output: [batch, depth, height, width]
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index d0c4fea..5a7fa64 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -211,7 +211,7 @@
x2, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
- self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
+ self.assertAllClose(conv1, self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionClass3DExpandedBatch(self):
@@ -237,7 +237,7 @@
conv2 = convolver2(x2, filter_in)
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
- self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
+ self.assertAllClose(conv1, self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
@@ -253,7 +253,7 @@
x2, filter_in, strides=[1, 1, 1], padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
- self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
+ self.assertAllClose(conv1, self.evaluate(conv2).reshape(conv1.shape))
def testConv3D1x1x1Filter(self):
expected_output = [
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index dd03312..44a67cc 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -2787,7 +2787,7 @@
expected: An array containing the expected operation outputs.
data_format: string data format for input tensor.
"""
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
t1 = self._InitValues(tensor_in_sizes)
f1 = self._InitValues(depthwise_filter_in_sizes)
f1.set_shape(depthwise_filter_in_sizes)
@@ -2899,7 +2899,7 @@
depthwise_filter_in_sizes = [2, 2, 2, 3]
pointwise_filter_in_sizes = [1, 1, 6, 7]
padding = [[0, 0], [1, 2], [3, 4], [0, 0]]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Compute the 'expected' values by manually padding before calling
# separable_conv2d
t1 = self._InitValues(tensor_in_sizes)
diff --git a/tensorflow/python/kernel_tests/decode_image_op_test.py b/tensorflow/python/kernel_tests/decode_image_op_test.py
index a2c0c7f..8c5aade 100644
--- a/tensorflow/python/kernel_tests/decode_image_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_image_op_test.py
@@ -37,7 +37,7 @@
def testBmp(self):
# Read a real bmp and verify shape
path = os.path.join(prefix_path, "bmp", "testdata", "lena.bmp")
- with self.session(use_gpu=True) as sess:
+ with self.session():
bmp0 = io_ops.read_file(path)
image0 = image_ops.decode_image(bmp0)
image1 = image_ops.decode_bmp(bmp0)
@@ -53,7 +53,7 @@
stride = 5
shape = (12, height, width, 3)
- with self.session(use_gpu=True) as sess:
+ with self.session():
gif0 = io_ops.read_file(path)
image0 = image_ops.decode_image(gif0)
image1 = image_ops.decode_gif(gif0)
@@ -82,7 +82,7 @@
def testJpeg(self):
# Read a real jpeg and verify shape
path = os.path.join(prefix_path, "jpeg", "testdata", "jpeg_merge_test1.jpg")
- with self.session(use_gpu=True) as sess:
+ with self.session():
jpeg0 = io_ops.read_file(path)
image0 = image_ops.decode_image(jpeg0)
image1 = image_ops.decode_jpeg(jpeg0)
@@ -100,7 +100,7 @@
inputs = [(1, "lena_gray.png")]
for channels_in, filename in inputs:
for channels in 0, 1, 3, 4:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
path = os.path.join(prefix_path, "png", "testdata", filename)
png0 = io_ops.read_file(path)
image0 = image_ops.decode_image(png0, channels=channels)
diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py
index 27461ac..04564d8 100644
--- a/tensorflow/python/kernel_tests/depthtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py
@@ -56,7 +56,7 @@
self.evaluate(output_nhwc)
if test.is_gpu_available():
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# test NHWC (default) on GPU
x_tf = array_ops.depth_to_space(input_nhwc, block_size)
self.assertAllEqual(x_tf, outputs)
@@ -126,7 +126,7 @@
self.assertAllEqual(x_tf.shape, x_out.shape)
self.evaluate(x_tf)
if test.is_gpu_available():
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# test NHWC (default) on GPU
x_tf = array_ops.depth_to_space(input_nhwc, block_size)
self.assertAllEqual(x_tf.shape, x_out.shape)
@@ -343,7 +343,7 @@
return
assert 4 == x.ndim
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_x = ops.convert_to_tensor(x)
tf_y = array_ops.depth_to_space(tf_x, block_size, data_format=data_format)
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 266a0f8..e26de9b 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -425,7 +425,7 @@
# GitHub issue 22110.
if not test.is_gpu_available():
return
- with self.session(use_gpu=True):
+ with self.session():
x = array_ops.placeholder(dtypes.float32)
f = np.ones([1, 1, 1, 1], np.float32)
v = nn_impl.depthwise_conv2d(
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
index 4eb2be0..d8154be 100644
--- a/tensorflow/python/kernel_tests/determinant_op_test.py
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -154,7 +154,7 @@
@test_util.run_v1_only("b/120545219")
def testConcurrentExecutesWithoutError(self):
- with self.session(use_gpu=True) as sess:
+ with self.session():
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
det1 = linalg_ops.matrix_determinant(matrix1)
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
index 8e8586b..99b4133 100644
--- a/tensorflow/python/kernel_tests/diag_op_test.py
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -374,7 +374,7 @@
@test_util.run_deprecated_v1
def testVector(self):
- with self.session(use_gpu=True):
+ with self.session():
v = np.array([1.0, 2.0, 3.0])
mat = np.diag(v)
v_diag = array_ops.matrix_diag(v)
@@ -397,7 +397,7 @@
self.assertAllEqual(v_diags, solution[0])
def _testVectorBatch(self, dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
[[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
@@ -441,7 +441,7 @@
@test_util.run_deprecated_v1
def testRectangularBatch(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Stores expected num_rows and num_cols (when the other is given).
# expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
test_list = list()
@@ -542,7 +542,7 @@
@test_util.run_deprecated_v1
def testInvalidShapeAtEval(self):
- with self.session(use_gpu=True):
+ with self.session():
v = array_ops.placeholder(dtype=dtypes_lib.float32)
with self.assertRaisesOpError("diagonal must be at least 1-dim"):
array_ops.matrix_diag(v).eval(feed_dict={v: 0.0})
@@ -550,7 +550,7 @@
@test_util.run_deprecated_v1
def testGrad(self):
shapes = ((3,), (7, 4))
- with self.session(use_gpu=True):
+ with self.session():
for shape in shapes:
x = constant_op.constant(np.random.rand(*shape), np.float32)
y = array_ops.matrix_diag(x)
@@ -564,7 +564,7 @@
tests = dict() # tests[shape] = (d_lower, d_upper)
tests[(3,)] = (-1, -1)
tests[(7, 3, 4)] = (-1, 1)
- with self.session(use_gpu=True):
+ with self.session():
for shape, diags in tests.items():
x = constant_op.constant(np.random.rand(*shape), np.float32)
for align in alignment_list:
@@ -580,7 +580,7 @@
@test_util.run_deprecated_v1
def testSquare(self):
- with self.session(use_gpu=True):
+ with self.session():
v = np.array([1.0, 2.0, 3.0])
mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0],
@@ -603,7 +603,7 @@
@test_util.run_deprecated_v1
def testRectangular(self):
- with self.session(use_gpu=True):
+ with self.session():
v = np.array([3.0, 4.0])
mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
expected = np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]])
@@ -631,7 +631,7 @@
self.assertAllEqual(output, solution)
def _testSquareBatch(self, dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
v_batch = np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]]).astype(dtype)
mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0],
@@ -668,7 +668,7 @@
@test_util.run_deprecated_v1
def testRectangularBatch(self):
- with self.session(use_gpu=True):
+ with self.session():
v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]])
mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]])
@@ -701,7 +701,7 @@
@test_util.run_deprecated_v1
def testInvalidShapeAtEval(self):
- with self.session(use_gpu=True):
+ with self.session():
v = array_ops.placeholder(dtype=dtypes_lib.float32)
with self.assertRaisesOpError("input must be at least 2-dim"):
array_ops.matrix_set_diag(v, [v]).eval(feed_dict={v: 0.0})
@@ -717,7 +717,7 @@
})
def _testGrad(self, input_shape, diag_shape, diags, align):
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant(
np.random.rand(*input_shape), dtype=dtypes_lib.float32)
x_diag = constant_op.constant(
@@ -751,7 +751,7 @@
@test_util.run_deprecated_v1
def testGradWithNoShapeInformation(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
v = array_ops.placeholder(dtype=dtypes_lib.float32)
mat = array_ops.placeholder(dtype=dtypes_lib.float32)
grad_input = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -774,7 +774,7 @@
@test_util.run_deprecated_v1
def testSquare(self):
- with self.session(use_gpu=True):
+ with self.session():
v = np.array([1.0, 2.0, 3.0])
mat = np.diag(v)
mat_diag = array_ops.matrix_diag_part(mat)
@@ -798,7 +798,7 @@
@test_util.run_deprecated_v1
def testRectangular(self):
- with self.session(use_gpu=True):
+ with self.session():
mat = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mat_diag = array_ops.matrix_diag_part(mat)
self.assertAllEqual(mat_diag, np.array([1.0, 5.0]))
@@ -817,7 +817,7 @@
self.assertAllEqual(mat_diag, solution[0])
def _testSquareBatch(self, dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
[[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
@@ -853,7 +853,7 @@
@test_util.run_deprecated_v1
def testRectangularBatch(self):
- with self.session(use_gpu=True):
+ with self.session():
v_batch = np.array([[1.0, 2.0], [4.0, 5.0]])
mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]],
[[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]])
@@ -880,7 +880,7 @@
matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None])
result = array_ops.matrix_diag_part(matrix, k=-1)
input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
- with self.session(use_gpu=True):
+ with self.session():
result_eval = result.eval(feed_dict={matrix: input_matrix})
self.assertAllEqual([4, 8], result_eval)
@@ -891,7 +891,7 @@
@test_util.run_deprecated_v1
def testInvalidShapeAtEval(self):
- with self.session(use_gpu=True):
+ with self.session():
v = array_ops.placeholder(dtype=dtypes_lib.float32)
with self.assertRaisesOpError("input must be at least 2-dim"):
array_ops.matrix_diag_part(v).eval(feed_dict={v: 0.0})
@@ -899,7 +899,7 @@
@test_util.run_deprecated_v1
def testGrad(self):
shapes = ((3, 3), (2, 3), (3, 2), (5, 3, 3))
- with self.session(use_gpu=True):
+ with self.session():
for shape in shapes:
x = constant_op.constant(np.random.rand(*shape), dtype=np.float32)
y = array_ops.matrix_diag_part(x)
@@ -913,7 +913,7 @@
tests = dict() # tests[shape] = (d_lower, d_upper)
tests[(3, 3)] = (-1, -1)
tests[(7, 3, 4)] = (-1, 1)
- with self.session(use_gpu=True):
+ with self.session():
for align in alignment_list:
for shape, diags in tests.items():
x = constant_op.constant(np.random.rand(*shape), np.float32)
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 2858f11..10c1567 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -39,7 +39,7 @@
@test_util.run_deprecated_v1
def testSimpleOneDimensional(self):
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant([0, 13, 2, 39, 4, 17], dtype=dtypes.float32)
indices = constant_op.constant([0, 0, 2, 3, 2, 1])
partitions = data_flow_ops.dynamic_partition(
@@ -60,7 +60,7 @@
@test_util.run_deprecated_v1
def testSimpleTwoDimensional(self):
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14], [15, 16, 17]],
dtype=dtypes.float32)
@@ -87,7 +87,7 @@
indices_list = [x % 2 for x in range(num)]
part1 = [x for x in range(num) if x % 2 == 0]
part2 = [x for x in range(num) if x % 2 == 1]
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float32)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -109,7 +109,7 @@
parts = [[] for _ in range(num_partitions)]
for i in range(rows):
parts[(i ** 2) % num_partitions].append(data_list[i])
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float32)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -125,7 +125,7 @@
def testSimpleComplex(self):
data_list = [1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j]
indices_list = [1, 0, 1, 0]
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.complex64)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -138,7 +138,7 @@
def testScalarPartitions(self):
data_list = [10, 13, 12, 11]
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float64)
indices = 3
partitions = data_flow_ops.dynamic_partition(
@@ -159,7 +159,7 @@
@test_util.run_deprecated_v1
def testHigherRank(self):
np.random.seed(7)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
for n in 2, 3:
for shape in (4,), (4, 5), (4, 5, 2):
partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape)
@@ -184,7 +184,7 @@
def testEmptyParts(self):
data_list = [1, 2, 3, 4]
indices_list = [1, 3, 1, 3]
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float32)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -200,7 +200,7 @@
def testEmptyDataTwoDimensional(self):
data_list = [[], []]
indices_list = [0, 1]
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float32)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -216,7 +216,7 @@
def testEmptyPartitions(self):
data_list = []
indices_list = []
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float32)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -237,7 +237,7 @@
data_list = [1, 2, 3, 4, 5, 6]
indices_list = [6, 5, 4, 3, 1, 0]
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float32)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -258,7 +258,7 @@
data_list = [1, 2, 3, 4, 5, 6]
indices_list = [10, 11, 2, 12, 0, 1000]
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float32)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -282,7 +282,7 @@
data_list = [1.1, 2.1, 3.1, 4.1, 5.1, 6.1]
indices_list = [90, 70, 60, 100, 110, 40]
- with self.session(use_gpu=True) as sess:
+ with self.session():
data = constant_op.constant(data_list, dtype=dtypes.float32)
indices = constant_op.constant(indices_list, dtype=dtypes.int32)
partitions = data_flow_ops.dynamic_partition(
@@ -295,7 +295,7 @@
@test_util.run_deprecated_v1
def testErrorIndexOutOfRange(self):
- with self.cached_session() as sess:
+ with self.cached_session():
data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
indices = constant_op.constant([0, 2, 99, 2, 2])
@@ -346,7 +346,7 @@
inds += [13]*194 + [14]*194 + [15]*192
self.assertEqual(len(inds), x.shape[0])
partitioned = data_flow_ops.dynamic_partition(x, inds, 16)
- with self.cached_session() as sess:
+ with self.cached_session():
res = self.evaluate(partitioned)
self.assertEqual(res[-1].shape[0], 192)
diff --git a/tensorflow/python/kernel_tests/eig_op_test.py b/tensorflow/python/kernel_tests/eig_op_test.py
index b1c8395..e9e311b 100644
--- a/tensorflow/python/kernel_tests/eig_op_test.py
+++ b/tensorflow/python/kernel_tests/eig_op_test.py
@@ -55,7 +55,7 @@
@test_util.run_deprecated_v1
def testConcurrentExecutesWithoutError(self):
all_ops = []
- with self.session(use_gpu=True) as sess:
+ with self.session():
for compute_v_ in True, False:
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
@@ -84,7 +84,7 @@
"self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32)
self.assertEqual(matrix.shape, (32, 32))
matrix_tensor = constant_op.constant(matrix)
- with self.session(use_gpu=True) as _:
+ with self.session() as _:
(e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor))
self.assertEqual(e.size, 32)
self.assertAllClose(
@@ -166,7 +166,7 @@
a = RandomInput()
np_e, np_v = np.linalg.eig(a)
- with self.session(use_gpu=True):
+ with self.session():
if compute_v_:
tf_e, tf_v = linalg_ops.eig(constant_op.constant(a))
@@ -222,7 +222,7 @@
tol = 1e-2
else:
tol = 1e-7
- with self.session(use_gpu=True):
+ with self.session():
def Compute(x):
e, v = linalg_ops.eig(x)
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index e1a5086..917d7ae 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -1048,7 +1048,7 @@
@test_util.run_deprecated_v1
def testCint32Gpu(self):
- with self.session(use_gpu=True):
+ with self.session():
indices = [
ops.convert_to_tensor([0, 1, 2]),
ops.convert_to_tensor([2, 3])
@@ -1076,7 +1076,7 @@
@test_util.run_deprecated_v1
def testInt32Gpu(self):
- with self.session(use_gpu=True):
+ with self.session():
indices = [
ops.convert_to_tensor([0, 1, 2]),
ops.convert_to_tensor([2, 3])
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 84a95934..aa2cf66 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -340,7 +340,7 @@
lambda elem_, input_: (a, b), elems, initializer=(0., 0.))
loss = l0 + array_ops.stop_gradient(l1)
grad = gradients_impl.gradients(ys=[loss], xs=[a, b])
- with self.test_session(use_gpu=True) as sess:
+ with self.test_session():
self.evaluate(variables.global_variables_initializer())
self.evaluate(grad)
@@ -933,7 +933,7 @@
def ReturnsTooManyArgs(unused_i, v):
return v, v
- with self.test_session(use_gpu=True):
+ with self.test_session():
with self.assertRaisesRegex(errors.InvalidArgumentError,
"must be a scalar"):
functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval()
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index 026683d..15b1e21 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -39,7 +39,7 @@
class GatherNdTest(test.TestCase):
def _testSimpleDtype(self, dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
params = constant_op.constant(np.array([8, 1, 2, 3, 7, 5], dtype=dtype))
indices = constant_op.constant([[4], [4], [0]])
gather_nd_t = array_ops.gather_nd(params, indices)
@@ -60,7 +60,7 @@
@test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
- with self.session(use_gpu=True):
+ with self.session():
params = np.ones((3, 3), dtype=np.float32)
indices_empty = np.empty((0, 2), dtype=np.int32)
@@ -91,7 +91,7 @@
self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
def testIndexScalar(self):
- with self.session(use_gpu=True):
+ with self.session():
params = np.array(
[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T
indices = constant_op.constant([4, 1])
@@ -101,7 +101,7 @@
self.assertAllEqual(np.array(7), gather_nd_val)
def testParamsRankLargerThanIndexIndexScalarSlices(self):
- with self.session(use_gpu=True):
+ with self.session():
params = np.array(
[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T
indices = constant_op.constant([4])
@@ -111,7 +111,7 @@
self.assertAllEqual(np.array([-7, 7]), gather_nd_val)
def testParamsRankLargerThanIndexSlices(self):
- with self.session(use_gpu=True):
+ with self.session():
params = np.array(
[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T
indices = constant_op.constant([[4], [4], [0]])
@@ -122,7 +122,7 @@
self.assertAllEqual(np.array([[-7, 7], [-7, 7], [-8, 8]]), gather_nd_val)
def testHigherRankParamsLargerThanIndexSlices(self):
- with self.session(use_gpu=True):
+ with self.session():
params = np.array(
[[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]],
[[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]],
@@ -136,7 +136,7 @@
self.assertAllEqual(params[[4, 4, 0]], gather_nd_val)
def testEmptyIndicesLastRankMeansCopyEntireTensor(self):
- with self.session(use_gpu=True):
+ with self.session():
params = np.array(
[[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]],
[[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]],
@@ -153,7 +153,7 @@
gather_nd_val)
def testHigherRankParamsAndIndicesLargerThanIndexSlices(self):
- with self.session(use_gpu=True):
+ with self.session():
params = np.array(
[[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]],
[[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]],
@@ -168,7 +168,7 @@
gather_nd_val)
def testHigherRankParams(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = (10, 20, 5, 1, 17)
params = np.random.rand(*shape)
indices = np.vstack([np.random.randint(0, s, size=2000) for s in shape]).T
@@ -180,7 +180,7 @@
self.assertEqual([2000], gather_nd_t.get_shape())
def testHigherRankParamsAndIndices(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = (10, 20, 5, 1, 17)
params = np.random.rand(*shape)
indices = np.vstack([np.random.randint(0, s, size=2000) for s in shape]).T
@@ -220,7 +220,7 @@
# On GPU the bad indices do not raise error but fetch 0 values
if not test.is_gpu_available():
return
- with self.session(use_gpu=True):
+ with self.session():
params = [0, 1, 2]
indices = [[[0], [7]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
@@ -244,7 +244,7 @@
# On GPU the bad indices do not raise error but fetch 0 values
if not test.is_gpu_available():
return
- with self.session(use_gpu=True):
+ with self.session():
params = [[0, 1, 2]]
indices = [[[0], [0], [1]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
@@ -261,7 +261,7 @@
grad_vals = constant_op.constant([1, 2], dtype=dtypes.float64)
grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
expected_grads = np.array([[1, 0], [0, 2]], dtype=np.float64)
- with self.session(use_gpu=True):
+ with self.session():
assert np.array_equal(expected_grads, self.evaluate(grads))
@test_util.run_deprecated_v1
@@ -273,7 +273,7 @@
grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64)
- with self.session(use_gpu=True):
+ with self.session():
self.assertIndexedSlices(grads)
self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads))
@@ -290,7 +290,7 @@
grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
expected_grads = np.array(
[[[5, 6], [1, 2]], [[3, 4], [7, 8]]], dtype=np.float64)
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(expected_grads, self.evaluate(grads))
@test_util.run_deprecated_v1
@@ -320,7 +320,7 @@
[[[[5, 6], [1, 2]]]],
[[[[3, 4], [7, 8]]]]
]]], dtype=np.float64)
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(expected_grads, self.evaluate(grads))
@test_util.run_deprecated_v1
@@ -336,7 +336,7 @@
grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
expected_grads = np.array(
[[[5, 6], [1, 2]], [[3, 4], [7, 8]]], dtype=np.float64)
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(expected_grads, self.evaluate(grads))
@test_util.run_deprecated_v1
@@ -358,7 +358,7 @@
[1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]],
dtype=np.float64)
- with self.session(use_gpu=True):
+ with self.session():
self.assertIndexedSlices(grads)
self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads))
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
index c636cee..be3fee3 100644
--- a/tensorflow/python/kernel_tests/in_topk_op_test.py
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -29,7 +29,7 @@
def _validateInTopK(self, predictions, target, k, expected):
np_ans = np.array(expected, np.bool)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
precision = nn_ops.in_top_k(predictions, target, k)
out = self.evaluate(precision)
self.assertAllClose(np_ans, out)
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index f2348c6..898d6f3 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -102,7 +102,7 @@
"""
def func():
- with tc.test_session(use_gpu=True):
+ with tc.test_session():
return init([num]).eval()
return func
@@ -112,7 +112,7 @@
@test_util.run_deprecated_v1
def testZerosInitializer(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = [2, 3]
x = variable_scope.get_variable(
"x", shape=shape, initializer=init_ops.zeros_initializer())
@@ -121,7 +121,7 @@
@test_util.run_deprecated_v1
def testOnesInitializer(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = [2, 3]
x = variable_scope.get_variable(
"x", shape=shape, initializer=init_ops.ones_initializer())
@@ -130,7 +130,7 @@
@test_util.run_deprecated_v1
def testConstantZeroInitializer(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = [2, 3]
x = variable_scope.get_variable(
"x", shape=shape, initializer=init_ops.constant_initializer(0.0))
@@ -139,7 +139,7 @@
@test_util.run_deprecated_v1
def testConstantOneInitializer(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = [2, 3]
x = variable_scope.get_variable(
"x", shape=shape, initializer=init_ops.constant_initializer(1.0))
@@ -148,7 +148,7 @@
@test_util.run_deprecated_v1
def testConstantIntInitializer(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = [2, 3]
x = variable_scope.get_variable(
"x",
@@ -161,7 +161,7 @@
@test_util.run_deprecated_v1
def testConstantTupleInitializer(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = [3]
x = variable_scope.get_variable(
"x",
@@ -173,7 +173,7 @@
self.assertAllEqual(x, [10, 20, 30])
def _testNDimConstantInitializer(self, name, value, shape, expected):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
init = init_ops.constant_initializer(value, dtype=dtypes.int32)
x = variable_scope.get_variable(name, shape=shape, initializer=init)
self.evaluate(x.initializer)
@@ -198,7 +198,7 @@
def _testNDimConstantInitializerLessValues(self, name, value, shape,
expected):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
init = init_ops.constant_initializer(value, dtype=dtypes.int32)
x = variable_scope.get_variable(name, shape=shape, initializer=init)
self.evaluate(x.initializer)
@@ -225,7 +225,7 @@
def _testNDimConstantInitializerMoreValues(self, value, shape):
ops.reset_default_graph()
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
init = init_ops.constant_initializer(value, dtype=dtypes.int32)
self.assertRaises(
ValueError,
@@ -398,7 +398,7 @@
init = init_ops.variance_scaling_initializer(
distribution="truncated_normal")
- with self.session(use_gpu=True), \
+ with self.session(), \
test.mock.patch.object(
random_ops, "truncated_normal", wraps=random_ops.truncated_normal) \
as mock_truncated_normal:
@@ -415,7 +415,7 @@
expect_var = 1. / shape[0]
init = init_ops.variance_scaling_initializer(distribution="normal")
- with self.session(use_gpu=True), \
+ with self.session(), \
test.mock.patch.object(
random_ops, "truncated_normal", wraps=random_ops.truncated_normal) \
as mock_truncated_normal:
@@ -433,7 +433,7 @@
init = init_ops.variance_scaling_initializer(
distribution="untruncated_normal")
- with self.session(use_gpu=True), \
+ with self.session(), \
test.mock.patch.object(
random_ops, "random_normal", wraps=random_ops.random_normal) \
as mock_random_normal:
@@ -450,7 +450,7 @@
expect_var = 1. / shape[0]
init = init_ops.variance_scaling_initializer(distribution="uniform")
- with self.session(use_gpu=True):
+ with self.session():
x = init(shape).eval()
self.assertNear(np.mean(x), expect_mean, err=1e-2)
@@ -461,7 +461,7 @@
class RangeTest(test.TestCase):
def _Range(self, start, limit, delta):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_ans = math_ops.range(start, limit, delta, name="range")
self.assertEqual([len(np.arange(start, limit, delta))],
tf_ans.get_shape())
@@ -481,7 +481,7 @@
@test_util.run_deprecated_v1
def testLimitOnly(self):
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(np.arange(5), math_ops.range(5))
def testEmpty(self):
@@ -910,7 +910,7 @@
outputs_2norm = linalg_ops.norm(outputs)
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.evaluate(my_ops)
# Check the shape of the outputs
t = self.evaluate(outputs)
@@ -925,7 +925,7 @@
shape = [3, 3, 10, 10]
count = 70
tol = 1e-5
- with self.session(use_gpu=True):
+ with self.session():
for i in range(count):
x = variable_scope.get_variable(
"{}".format(i),
@@ -996,7 +996,7 @@
shape = [3, 10, 10]
count = 70
tol = 1e-5
- with self.session(use_gpu=True):
+ with self.session():
for i in range(count):
x = variable_scope.get_variable(
"{}".format(i),
@@ -1063,7 +1063,7 @@
outputs_2norm = linalg_ops.norm(outputs)
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.evaluate(my_ops)
# Check the shape of the outputs
t = self.evaluate(outputs)
@@ -1167,7 +1167,7 @@
outputs_2norm = linalg_ops.norm(outputs)
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.evaluate(my_ops)
# Check the shape of the outputs
t = self.evaluate(outputs)
@@ -1227,7 +1227,7 @@
shape = [3, 3, 3, 5, 5]
count = 20
tol = 1e-5
- with self.session(use_gpu=True):
+ with self.session():
for i in range(count):
x = variable_scope.get_variable(
"{}".format(i),
@@ -1302,7 +1302,7 @@
outputs_2norm = linalg_ops.norm(outputs)
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
self.evaluate(my_ops)
# Check the shape of the outputs
t = self.evaluate(outputs)
diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py
index 4841c18..d39f2e9 100644
--- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py
@@ -78,7 +78,7 @@
b_mats_val = np.transpose(b_mats_val, (0, 2, 1))
if adjoint_b:
b_mats_val = np.conj(b_mats_val)
- with self.test_session(use_gpu=True):
+ with self.test_session():
a_mats = ops.convert_to_tensor(a_mats_val, dtype=datatype)
b_mats = ops.convert_to_tensor(b_mats_val, dtype=datatype)
a_sm = dense_to_csr_sparse_matrix(a_mats)
diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_grad_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_grad_test.py
index 0cda66a..c548ced 100644
--- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_grad_test.py
@@ -64,7 +64,7 @@
sparsify = lambda m: m * (m > 0)
for dense_shape in ([53, 65, 127], [127, 65]):
mats_val = sparsify(np.random.randn(*dense_shape))
- with self.test_session(use_gpu=True) as sess:
+ with self.test_session() as sess:
mats = math_ops.cast(mats_val, dtype=dtypes.float32)
sparse_mats = dense_to_csr_sparse_matrix(mats)
dense_mats = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
@@ -96,7 +96,7 @@
grad_vals = np.random.randn(*dense_shape).astype(np.float32)
expected_a_grad = alpha * grad_vals
expected_b_grad = beta * grad_vals
- with self.test_session(use_gpu=True) as sess:
+ with self.test_session() as sess:
a_mats = math_ops.cast(a_mats_val, dtype=dtypes.float32)
b_mats = math_ops.cast(b_mats_val, dtype=dtypes.float32)
a_sm = dense_to_csr_sparse_matrix(a_mats)
diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_sparse_mat_mul_grad_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_sparse_mat_mul_grad_test.py
index 07d1e6a..27bedc0 100644
--- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_sparse_mat_mul_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_sparse_mat_mul_grad_test.py
@@ -79,7 +79,7 @@
b_mats_val = np.transpose(b_mats_val, (0, 2, 1))
if adjoint_b:
b_mats_val = np.conj(b_mats_val)
- with self.test_session(use_gpu=True):
+ with self.test_session():
a_mats = ops.convert_to_tensor(a_mats_val, dtype=datatype)
b_mats = ops.convert_to_tensor(b_mats_val, dtype=datatype)
a_sm = dense_to_csr_sparse_matrix(a_mats)
diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py
index 2cdddda..eb42917 100644
--- a/tensorflow/python/kernel_tests/linalg_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg_ops_test.py
@@ -59,7 +59,7 @@
def test_works_with_five_different_random_pos_def_matrices(self):
for n in range(1, 6):
for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
- with self.session(use_gpu=True):
+ with self.session():
# Create 2 x n x n matrix
array = np.array(
[_RandomPDMatrix(n, self.rng),
@@ -85,7 +85,7 @@
with self.subTest(n=n, np_dtype=np_dtype, atol=atol):
matrix = _RandomPDMatrix(n, self.rng, np_dtype)
_, logdet_np = np.linalg.slogdet(matrix)
- with self.session(use_gpu=True):
+ with self.session():
# Create 2 x n x n matrix
# matrix = np.array(
# [_RandomPDMatrix(n, self.rng, np_dtype),
@@ -99,7 +99,7 @@
with self.subTest(np_dtype=np_dtype, atol=atol):
matrix = (np.eye(20) * 1e-6).astype(np_dtype)
_, logdet_np = np.linalg.slogdet(matrix)
- with self.session(use_gpu=True):
+ with self.session():
logdet_tf = linalg.logdet(matrix)
self.assertAllClose(logdet_np, self.evaluate(logdet_tf), atol=atol)
@@ -117,7 +117,7 @@
with self.subTest(n=n, np_dtype=np_dtype, atol=atol):
matrix = _RandomPDMatrix(n, self.rng, np_dtype)
sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
- with self.session(use_gpu=True):
+ with self.session():
sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
self.assertAllClose(
log_abs_det_np, self.evaluate(log_abs_det_tf), atol=atol)
@@ -129,7 +129,7 @@
with self.subTest(np_dtype=np_dtype, atol=atol):
matrix = (np.eye(20) * 1e-6).astype(np_dtype)
sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
- with self.session(use_gpu=True):
+ with self.session():
sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
self.assertAllClose(
log_abs_det_np, self.evaluate(log_abs_det_tf), atol=atol)
@@ -259,7 +259,7 @@
num_columns=num_columns_placeholder,
batch_shape=batch_shape_placeholder,
dtype=dtype)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
eye_tf = sess.run(
eye,
feed_dict={
diff --git a/tensorflow/python/kernel_tests/lrn_op_test.py b/tensorflow/python/kernel_tests/lrn_op_test.py
index fbe628c..f548804 100644
--- a/tensorflow/python/kernel_tests/lrn_op_test.py
+++ b/tensorflow/python/kernel_tests/lrn_op_test.py
@@ -55,7 +55,7 @@
return output
def _RunAndVerify(self, dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# random shape
shape = np.random.randint(1, 16, size=4)
# Make depth at least 2 to make it meaningful
@@ -103,7 +103,7 @@
@test_util.run_deprecated_v1
def testGradientsZeroInput(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = [4, 4, 4, 4]
p = array_ops.placeholder(dtypes.float32, shape=shape)
inp_array = np.zeros(shape).astype("f")
@@ -116,7 +116,7 @@
self.assertShapeEqual(expected, grad)
def _RunAndVerifyGradients(self, dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# random shape
shape = np.random.randint(1, 5, size=4)
# Make depth at least 2 to make it meaningful
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
index 2e43d4a..1b8319a 100644
--- a/tensorflow/python/kernel_tests/manip_ops_test.py
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -42,12 +42,12 @@
def _testRoll(self, np_input, shift, axis):
expected_roll = np.roll(np_input, shift, axis)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
roll = manip_ops.roll(np_input, shift, axis)
self.assertAllEqual(roll, expected_roll)
def _testGradient(self, np_input, shift, axis):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = constant_op.constant(np_input.tolist())
xs = list(np_input.shape)
y = manip_ops.roll(inx, shift, axis)
@@ -98,7 +98,7 @@
self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
# Make sure negative axis should be 0 <= axis + dims < dims
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
@@ -122,7 +122,7 @@
tensor = array_ops.placeholder(dtype=dtypes.int32)
shift = 1
axis = 0
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"input must be 1-D or higher"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
@@ -140,7 +140,7 @@
tensor = [[1, 2], [3, 4]]
shift = 1
axis = array_ops.placeholder(dtype=dtypes.int32)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"axis must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
@@ -158,7 +158,7 @@
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = 1
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"shift must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
@@ -175,7 +175,7 @@
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = [0, 1]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"shift and axis must have the same size"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
@@ -184,7 +184,7 @@
tensor = [1, 2]
shift = 1
axis = 1
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(tensor, shift, axis).eval()
diff --git a/tensorflow/python/kernel_tests/map_stage_op_test.py b/tensorflow/python/kernel_tests/map_stage_op_test.py
index dd16fad..516fc37 100644
--- a/tensorflow/python/kernel_tests/map_stage_op_test.py
+++ b/tensorflow/python/kernel_tests/map_stage_op_test.py
@@ -46,7 +46,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
sess.run(stage, feed_dict={x: -1, pi: 0})
for i in range(10):
_, yval = sess.run([stage, y], feed_dict={x: i, pi: i + 1, gi: i})
@@ -68,7 +68,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
sess.run(stage, feed_dict={x: -1, pi: 0})
for i in range(10):
_, yval = sess.run([stage, y], feed_dict={x: i, pi: i + 1, gi: i})
@@ -96,7 +96,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
sess.run(stage, feed_dict={x: -1, pi: 0})
for i in range(10):
_, yval = sess.run([stage, y], feed_dict={x: i, pi: i + 1, gi: i})
@@ -146,7 +146,7 @@
n = 10
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
for i in range(n):
sess.run(stage, feed_dict={x: i, pi: i})
@@ -174,7 +174,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
sess.run(stage, feed_dict={x: -1, pi: 3})
self.assertEqual(sess.run(size), 1)
sess.run(stage, feed_dict={x: -1, pi: 1})
@@ -209,7 +209,7 @@
queue = Queue.Queue()
n = 8
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# Stage data in a separate thread which will block
# when it hits the staging area's capacity and thus
# not fill the queue with n tokens
@@ -273,7 +273,7 @@
queue = Queue.Queue()
n = 8
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# Stage data in a separate thread which will block
# when it hits the staging area's capacity and thus
# not fill the queue with n tokens
@@ -334,7 +334,7 @@
n = 10
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# Keys n-1..0
keys = list(reversed(six.moves.range(n)))
@@ -372,7 +372,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# 0 complete and incomplete entries
self.assertTrue(sess.run([size, isize]) == [0, 0])
# Stage key 0, x and f tuple entries
@@ -430,7 +430,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# 0 complete and incomplete entries
self.assertTrue(sess.run([size, isize]) == [0, 0])
# Stage key 0, x and f tuple entries
@@ -482,7 +482,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# 0 complete and incomplete entries
self.assertTrue(sess.run([size, isize]) == [0, 0])
# Stage key 0, x and f tuple entries
@@ -574,7 +574,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# Stage complete tuple
sess.run(stage_xvf, feed_dict={pi: 0, x: 1, f: 2, v: 3})
diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
index 61e2610..19091a7 100644
--- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
@@ -149,7 +149,7 @@
@test_util.run_deprecated_v1
def testDynamic(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
inp = array_ops.placeholder(ops.dtypes.float32)
expm = linalg_impl.matrix_exponential(inp)
matrix = np.array([[1., 2.], [3., 4.]])
@@ -157,7 +157,7 @@
@test_util.run_deprecated_v1
def testConcurrentExecutesWithoutError(self):
- with self.session(use_gpu=True) as sess:
+ with self.session():
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
expm1 = linalg_impl.matrix_exponential(matrix1)
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index 9a5a467..eebd568 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -37,7 +37,7 @@
def _verifyInverse(self, x, np_type):
for adjoint in False, True:
y = x.astype(np_type)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Verify that x^{-1} * x == Identity matrix.
inv = linalg_ops.matrix_inverse(y, adjoint=adjoint)
tf_ans = test_util.matmul_without_tf32(inv, y, adjoint_b=adjoint)
@@ -139,7 +139,7 @@
@test_util.deprecated_graph_mode_only
def testConcurrentExecutesWithoutError(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
all_ops = []
for adjoint_ in True, False:
matrix1 = random_ops.random_normal([5, 5], seed=42)
diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
index d2e9c7c..e75d0df 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
@@ -124,7 +124,7 @@
feed_dict = None
self.assertEqual(np_ans.shape, tf_ans.get_shape())
if feed_dict:
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
tf_ans_val = sess.run(tf_ans, feed_dict=feed_dict)
else:
tf_ans_val = self.evaluate(tf_ans)
@@ -137,7 +137,7 @@
tf_r = math_ops.matmul(a, tf_r, adjoint_a=True)
tf_r_norm = linalg_ops.norm(tf_r, ord="fro", axis=[-2, -1])
if feed_dict:
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
tf_ans_val, tf_r_norm_val = sess.run([tf_ans, tf_r_norm],
feed_dict=feed_dict)
else:
@@ -147,7 +147,7 @@
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testWrongDimensions(self):
# The matrix and right-hand sides should have the same number of rows.
- with self.session(use_gpu=True):
+ with self.session():
matrix = constant_op.constant([[1., 0.], [0., 1.]])
rhs = constant_op.constant([[1., 0.]])
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
index 209e604..0d149de 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
@@ -63,7 +63,7 @@
a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
out = sess.run(tf_ans, {a_ph: a, b_ph: b})
else:
tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index a497a0d..2c85b1d 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -195,7 +195,7 @@
def testNonSquareMatrix(self):
# A non-square matrix should cause an error.
matrix = np.array([[1., 2., 3.], [3., 4., 5.]])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaises(ValueError):
self._verifySolve(matrix, matrix)
with self.assertRaises(ValueError):
@@ -207,7 +207,7 @@
# right-hand sides.
matrix = np.array([[1., 0.], [0., 1.]])
rhs = np.array([[1., 0.]])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
diff --git a/tensorflow/python/kernel_tests/norm_op_test.py b/tensorflow/python/kernel_tests/norm_op_test.py
index f378719..ff32a58 100644
--- a/tensorflow/python/kernel_tests/norm_op_test.py
+++ b/tensorflow/python/kernel_tests/norm_op_test.py
@@ -68,7 +68,7 @@
def _CompareNorm(self, matrix):
np_norm = np.linalg.norm(matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
tf_matrix = constant_op.constant(matrix)
tf_norm = linalg_ops.norm(
diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py
index 30abf9a..6372188 100644
--- a/tensorflow/python/kernel_tests/pad_op_test.py
+++ b/tensorflow/python/kernel_tests/pad_op_test.py
@@ -372,7 +372,7 @@
for dtype in [dtypes.int32, dtypes.int64]:
paddings = np.zeros((0, 2))
inp = np.asarray(7)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_val = array_ops.pad(inp, constant_op.constant(paddings, dtype=dtype))
out = self.evaluate(tf_val)
self.assertAllEqual(inp, out)
@@ -397,7 +397,7 @@
padded,
[paddings_value[i][0] + inp.shape.dims[i].value for i in range(4)],
[-1, -1, -1, -1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllEqual(inp, self.evaluate(middle))
self.assertAllEqual(
np.zeros([row[0] for row in paddings_value]), self.evaluate(left))
diff --git a/tensorflow/python/kernel_tests/pool_test.py b/tensorflow/python/kernel_tests/pool_test.py
index 0e6bbeb..cb408ae 100644
--- a/tensorflow/python/kernel_tests/pool_test.py
+++ b/tensorflow/python/kernel_tests/pool_test.py
@@ -248,7 +248,7 @@
def testPoolNC(self):
if test.is_gpu_available(cuda_only=True):
# "NC*" format is currently only supported on CUDA.
- with self.session(use_gpu=True):
+ with self.session():
for padding in ["SAME", "VALID"]:
self._test(
input_shape=[2, 2, 9],
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 98e043e..73dca4f 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -906,7 +906,7 @@
self._testDepthwiseMaxPoolInvalidConfig([1, 2, 2, 4], [1, 1, 1, 3],
[1, 1, 1, 3], "evenly divide")
if test.is_gpu_available():
- with self.session(use_gpu=True):
+ with self.session():
t = variables.Variable(np.ones([1, 2, 2, 4]))
self.evaluate(variables.global_variables_initializer())
with self.assertRaisesOpError("for CPU devices"):
@@ -922,7 +922,7 @@
for dtype in [np.float32, np.float16] \
+ [np.float64] if not test.is_built_with_rocm() else []:
tensor_input = np.random.rand(*input_shape).astype(dtype)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
t = constant_op.constant(tensor_input, shape=input_shape)
out_op, _ = nn_ops.max_pool_with_argmax(t, ksize, strides, padding)
gpu_val = self.evaluate(out_op)
@@ -942,7 +942,7 @@
# in the input.
tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)
tensor_output = np.random.rand(*output_shape).astype(dtype)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
t = constant_op.constant(tensor_input, shape=input_shape)
_, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding)
argmax = self.evaluate(argmax_op)
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index b374119..d924e65 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -597,6 +597,7 @@
self.assertIsNone(ret)
@test_util.run_in_graph_and_eager_modes
+ @test_util.disable_tfrt("b/180469928")
def testEagerPyFuncInDefun(self):
with test_util.device(use_gpu=True):
def wrapper():
@@ -755,7 +756,7 @@
y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32)
z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
output = sess.run(z, feed_dict={x: 3.0})
self.assertEqual(output, 18.0)
diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index 7804aa7..720a4d7 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -145,7 +145,7 @@
if use_static_shape_:
q_tf_val, r_tf_val = self.evaluate([q_tf, r_tf])
else:
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
q_dims = q_tf_val.shape
diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_big_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_big_test.py
index 5767205..2bf15db 100644
--- a/tensorflow/python/kernel_tests/random/multinomial_op_big_test.py
+++ b/tensorflow/python/kernel_tests/random/multinomial_op_big_test.py
@@ -34,7 +34,7 @@
def testLargeDynamicRange(self):
random_seed.set_random_seed(10)
counts_by_indices = {}
- with self.test_session(use_gpu=True) as sess:
+ with self.test_session():
samples = random_ops.multinomial(
constant_op.constant([[-30, 0]], dtype=dtypes.float32),
num_samples=1000000,
@@ -52,7 +52,7 @@
def testLargeDynamicRange2(self):
random_seed.set_random_seed(10)
counts_by_indices = {}
- with self.test_session(use_gpu=True) as sess:
+ with self.test_session():
samples = random_ops.multinomial(
constant_op.constant([[0, -30]], dtype=dtypes.float32),
num_samples=1000000,
@@ -72,7 +72,7 @@
random_seed.set_random_seed(10)
counts_by_indices = {}
# here the cpu undersamples and won't pass this test either
- with self.test_session(use_gpu=True) as sess:
+ with self.test_session():
samples = random_ops.multinomial(
constant_op.constant([[0, -17]], dtype=dtypes.float32),
num_samples=1000000,
diff --git a/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py b/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py
index 309c3e4..5ec054f 100644
--- a/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py
+++ b/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py
@@ -129,7 +129,7 @@
# TruncatedNormalMoments requires scipy.stats.
# Give up early if we are unable to import it.
random_seed.set_random_seed(seed)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
if use_stateless:
# Generate a seed that stateless ops can use.
new_seed = random_ops.random_uniform([2],
@@ -163,7 +163,7 @@
try:
import scipy.stats # pylint: disable=g-import-not-at-top
random_seed.set_random_seed(seed)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
if use_stateless:
new_seed = random_ops.random_uniform([2],
seed=seed,
@@ -298,7 +298,7 @@
minvals=-1.,
maxvals=1.)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
samples, samples_stateless = sess.run([sample_op, sample_op_stateless])
# 0. is more than 16 standard deviations from the mean, and
# should have a likelihood < 1e-57.
@@ -313,7 +313,7 @@
minval = variables.Variable(-1.)
maxval = variables.Variable(1.)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
with backprop.GradientTape(persistent=True) as tape:
samples = stateless.stateless_parameterized_truncated_normal(
[1], [1, 2], mean, stddev, minval, maxval)
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index 135e440..0063c7f 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -230,7 +230,7 @@
@test_util.run_deprecated_v1
def testLargeShape(self):
- with self.session(use_gpu=True):
+ with self.session():
v = variables.Variable(
array_ops.zeros(dtype=dtypes.float32, shape=[2**33, 1]))
n = random_ops.truncated_normal(v.shape)
@@ -238,7 +238,7 @@
@test_util.run_deprecated_v1
def testNoCSE(self):
- with self.session(use_gpu=True):
+ with self.session():
shape = [2, 3, 4]
rnd1 = random_ops.truncated_normal(shape, 0.0, 1.0, dtypes.float32)
rnd2 = random_ops.truncated_normal(shape, 0.0, 1.0, dtypes.float32)
@@ -371,7 +371,7 @@
def testNoCSE(self):
shape = [2, 3, 4]
for dtype in dtypes.float16, dtypes.float32, dtypes.int32:
- with self.session(use_gpu=True):
+ with self.session():
rnd1 = random_ops.random_uniform(shape, 0, 17, dtype=dtype)
rnd2 = random_ops.random_uniform(shape, 0, 17, dtype=dtype)
diff = (rnd2 - rnd1).eval()
diff --git a/tensorflow/python/kernel_tests/random/random_poisson_test.py b/tensorflow/python/kernel_tests/random/random_poisson_test.py
index eafa1d9..2d94533 100644
--- a/tensorflow/python/kernel_tests/random/random_poisson_test.py
+++ b/tensorflow/python/kernel_tests/random/random_poisson_test.py
@@ -104,7 +104,7 @@
merged.
"""
for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
rnd1 = random_ops.random_poisson(2.0, [24], dtype=dtype)
rnd2 = random_ops.random_poisson(2.0, [24], dtype=dtype)
diff = rnd2 - rnd1
diff --git a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
index 24b5a36..f60f5c4 100644
--- a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
@@ -240,7 +240,7 @@
def _test_determinism(self, case, seed_type):
# Stateless values should be equal iff the seeds are equal (roughly)
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
- with self.test_session(use_gpu=True), ops.device(get_device().name):
+ with self.test_session(), ops.device(get_device().name):
_, stateless_op, _ = case
if context.executing_eagerly():
values = [
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index 601c542..e51acf3 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -156,7 +156,7 @@
def _compare(self, x, reduction_axes, keepdims, feed_dict=None):
np_ans = self._np_reduce(x, reduction_axes, keepdims)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
tf_ans = self._tf_reduce(x, reduction_axes, keepdims)
out = sess.run(tf_ans, feed_dict)
self.assertAllClose(np_ans, out)
@@ -178,7 +178,7 @@
if reduction_axes is not None and np.shape(reduction_axes) == (1,):
# Test scalar reduction_axes argument
self._compareGradient(x, reduction_axes[0], rtol=rtol, atol=atol)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = self._tf_reduce(t, reduction_axes, False)
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -208,7 +208,7 @@
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
v = math_ops.reduce_sum([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = self.evaluate(v)
self.assertAllEqual(tf_v, 0)
@@ -403,7 +403,7 @@
@test_util.run_deprecated_v1
def testEmptyGradients(self):
- with self.session(use_gpu=True):
+ with self.session():
x = array_ops.zeros([0, 3])
y = math_ops.reduce_sum(x, [1])
error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -411,7 +411,7 @@
@test_util.run_deprecated_v1
def testDegenerate(self):
- with self.session(use_gpu=True):
+ with self.session():
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128):
# A large number is needed to get Eigen to die
@@ -446,7 +446,7 @@
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
v = math_ops.reduce_mean([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = self.evaluate(v)
self.assertAllEqual(tf_v, 0)
@@ -525,7 +525,7 @@
@test_util.run_deprecated_v1
def testEmptyGradients(self):
- with self.session(use_gpu=True):
+ with self.session():
x = array_ops.zeros([0, 3])
y = math_ops.reduce_mean(x, [1])
error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -533,7 +533,7 @@
@test_util.run_deprecated_v1
def testDegenerate(self):
- with self.session(use_gpu=True):
+ with self.session():
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
# A large number is needed to get Eigen to die
x = array_ops.zeros((0, 9938), dtype=dtype)
@@ -560,7 +560,7 @@
@test_util.run_deprecated_v1
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
v = math_ops.reduce_mean([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = self.evaluate(v)
self.assertAllEqual(tf_v, 0)
@@ -609,7 +609,7 @@
np_arr = self._makeIncremental((2,) * rank, dtypes.complex128)
self._compareAllAxes(np_arr)
- with self.session(use_gpu=True):
+ with self.session():
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
# A large number is needed to get Eigen to die
x = array_ops.zeros((0, 9938), dtype=dtype)
@@ -640,7 +640,7 @@
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
v = math_ops.reduce_prod([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = self.evaluate(v)
self.assertAllEqual(tf_v, 0)
@@ -711,7 +711,7 @@
@test_util.run_deprecated_v1
def testEmptyGradients(self):
- with self.session(use_gpu=True):
+ with self.session():
x = array_ops.zeros([0, 3])
y = math_ops.reduce_prod(x, [1])
error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -719,7 +719,7 @@
@test_util.run_deprecated_v1
def testDegenerate(self):
- with self.session(use_gpu=True):
+ with self.session():
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
# A large number is needed to get Eigen to die
x = array_ops.zeros((0, 9938), dtype=dtype)
@@ -750,7 +750,7 @@
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
v = math_ops.reduce_min([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = self.evaluate(v)
self.assertAllEqual(tf_v, 0)
@@ -866,7 +866,7 @@
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
v = math_ops.reduce_max([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = self.evaluate(v)
self.assertAllEqual(tf_v, 0)
@@ -998,7 +998,7 @@
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
- with self.session(use_gpu=True) as sess:
+ with self.session():
v = math_ops.reduce_all([True, True],
constant_op.constant(0, dtype=dtype))
tf_v = self.evaluate(v)
@@ -1047,7 +1047,7 @@
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
- with self.session(use_gpu=True) as sess:
+ with self.session():
v = math_ops.reduce_any([True, True],
constant_op.constant(0, dtype=dtype))
tf_v = self.evaluate(v)
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 81f1058..c091c09 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -104,6 +104,11 @@
def testNoElement(self):
self._testRelu(np.array([[], []], dtype=np.float32))
+ @test_util.disable_xla("b/157978028: Does not yet pass with XLA")
+ def testNaNPropagation(self):
+ for t in [np.float16, np.float32, np.float64]:
+ self._testRelu(np.array([-1, np.nan, 1, np.nan]).astype(t))
+
# The gradient test for ReLU is a bit tricky as the derivative is not well
# defined at around zero and we want to avoid that in terms of input values.
def testGradientFloat32(self):
@@ -234,6 +239,11 @@
self._testRelu6(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t))
+ @test_util.disable_xla("b/157978028: Does not yet pass with XLA")
+ def testNaNPropagation(self):
+ for t in [np.float16, np.float32, np.float64]:
+ self._testRelu6(np.array([-1, np.nan, 1, 7, np.nan]).astype(t))
+
# The gradient test for ReLU6 is a bit tricky as the derivative is
# not well defined at around zero and six and we want to avoid that
# in terms of input values.
@@ -294,6 +304,11 @@
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
alpha=0.1)
+ def testNaNPropagation(self):
+ for t in [np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(np.array([-1, np.nan, 1, np.nan]).astype(t),
+ alpha=0.2)
+
# The gradient test for Leaky ReLU is a bit tricky as the derivative is not
# well defined at around zero and we want to avoid that in terms of input
# values.
@@ -411,6 +426,10 @@
for t in [np.float16, np.float32, np.float64]:
self._testElu(np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t))
+ def testNaNPropagation(self):
+ for t in [np.float16, np.float32, np.float64]:
+ self._testElu(np.array([-1, np.nan, 1, np.nan]).astype(t))
+
def testGradientFloat32(self):
with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index bb47d60..c096357 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -223,7 +223,7 @@
self.assertEqual(out.get_shape(), inp.get_shape())
self.assertEqual(out.dtype, inp.dtype)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
input_value = np.random.randn(batch_size, input_size)
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
@@ -260,7 +260,7 @@
self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list())
self.assertEqual(out.dtype, inp.dtype)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
input_value = np.random.randn(batch_size, input_size)
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
full_dropout_values = sess.run(
@@ -288,7 +288,7 @@
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
self.assertEqual(len(dynamic_outputs), len(inputs))
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
input_value = np.random.randn(batch_size, input_size)
dynamic_values = sess.run(
dynamic_outputs,
@@ -324,7 +324,7 @@
1.0 * (2 + 1) * np.ones((input_size)))))
def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
- with self.session(use_gpu=True, graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
if use_outer_scope:
with variable_scope.variable_scope(prefix) as scope:
factory(scope)
@@ -388,7 +388,7 @@
input_size = 5
batch_size = 2
max_length = 8
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell(
@@ -411,7 +411,7 @@
input_size = 5
batch_size = 2
max_length = 8
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell(
@@ -442,7 +442,7 @@
input_size = 5
batch_size = 2
max_length = 8
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, 2 * num_units)
@@ -583,7 +583,7 @@
batch_size = 2
num_proj = 4
max_length = 8
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
inputs = max_length * [
@@ -681,7 +681,7 @@
num_proj_shards = 3
num_unit_shards = 2
max_length = 8
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
@@ -715,7 +715,7 @@
num_proj_shards = 3
num_unit_shards = 2
max_length = 8
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed)
inputs = max_length * [
array_ops.placeholder(dtypes.float64, shape=(None, input_size))
@@ -752,7 +752,7 @@
num_proj_shards = 3
num_unit_shards = 2
max_length = 8
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
inputs = max_length * [
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
]
@@ -809,7 +809,7 @@
num_proj_shards = 3
num_unit_shards = 2
max_length = 8
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sequence_length = array_ops.placeholder(dtypes.int64)
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
@@ -1151,7 +1151,7 @@
state_is_tuple=False)
########### Step 1: Run static graph and generate readouts
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
if in_graph_mode:
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
@@ -1211,7 +1211,7 @@
static_individual_variable_gradients, feed_dict=feeds)
########## Step 2: Run dynamic graph and generate readouts
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
if in_graph_mode:
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
@@ -1372,7 +1372,7 @@
return input_value, inputs, outputs, state_fw, state_bw, sequence_length
def _testBidirectionalRNN(self, use_shape):
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
input_value, inputs, outputs, state_fw, state_bw, sequence_length = (
self._createBidirectionalRNN(use_shape, True))
variables_lib.global_variables_initializer().run()
@@ -1419,7 +1419,7 @@
self.assertAllClose(s_fw, s_bw)
def _testBidirectionalRNNWithoutSequenceLength(self, use_shape):
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
input_value, inputs, outputs, state_fw, state_bw, _ = (
self._createBidirectionalRNN(use_shape, False))
variables_lib.global_variables_initializer().run()
@@ -1504,7 +1504,7 @@
def _testBidirectionalDynamicRNN(self, use_shape, use_state_tuple,
use_time_major, use_sequence_length):
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
input_value, inputs, outputs, state_fw, state_bw, sequence_length = (
self._createBidirectionalDynamicRNN(
use_shape, use_state_tuple, use_time_major, use_sequence_length))
@@ -1582,7 +1582,7 @@
# REMARKS: factory(scope) is a function accepting a scope
# as an argument, such scope can be None, a string
# or a VariableScope instance.
- with self.session(use_gpu=True, graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
if use_outer_scope:
with variable_scope.variable_scope(prefix) as scope:
factory(scope)
@@ -1905,7 +1905,7 @@
batch_size = 2
state_saver = TestStateSaver(batch_size, 2 * num_units)
- with self.session(use_gpu=True, graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
if use_outer_scope:
with variable_scope.variable_scope(prefix) as scope:
self._factory(scope=scope, state_saver=state_saver)
@@ -1984,7 +1984,7 @@
sequence_length = np.random.randint(0, time_steps, size=batch_size)
- with self.session(use_gpu=True, graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
@@ -2006,7 +2006,7 @@
sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds)
def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
- with self.session(use_gpu=True, graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
if use_outer_scope:
with variable_scope.variable_scope(prefix) as scope:
factory(scope)
@@ -2298,7 +2298,7 @@
np.ones((max_time, batch_size, 1), np.int64), output_vals[1])
def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
- with self.session(use_gpu=True, graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
if use_outer_scope:
with variable_scope.variable_scope(prefix) as scope:
factory(scope)
@@ -2416,7 +2416,7 @@
sequence_length=sequence_length,
dtype=dtypes.float32)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
variables_lib.global_variables_initializer().run()
@@ -2903,7 +2903,7 @@
return
gpu_dev = test.gpu_device_name()
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 1, 3])
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 27732de..7bf9940 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -212,7 +212,7 @@
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
if not in_eager_mode:
@@ -232,7 +232,7 @@
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
if not in_eager_mode:
@@ -262,7 +262,7 @@
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
state = (state[0], state[1].stack())
diff --git a/tensorflow/python/kernel_tests/scan_ops_test.py b/tensorflow/python/kernel_tests/scan_ops_test.py
index b0161b8..e802d5b 100644
--- a/tensorflow/python/kernel_tests/scan_ops_test.py
+++ b/tensorflow/python/kernel_tests/scan_ops_test.py
@@ -79,7 +79,7 @@
def _compare(self, x, axis, exclusive, reverse):
np_out = handle_options(np.cumsum, x, axis, exclusive, reverse)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_out = math_ops.cumsum(x, axis, exclusive, reverse).eval()
self.assertAllClose(np_out, tf_out)
@@ -101,7 +101,7 @@
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in [dtypes.int64, dtypes.int32]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
axis = constant_op.constant(0, axis_dtype)
tf_out = math_ops.cumsum(x, axis).eval()
@@ -152,7 +152,7 @@
def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
input_tensor = ops.convert_to_tensor(x)
- with self.session(use_gpu=True):
+ with self.session():
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
@@ -168,7 +168,7 @@
def _compareGradient(self, shape, axis, exclusive, reverse):
x = np.arange(0, 50).reshape(shape).astype(np.float64)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
t = ops.convert_to_tensor(x)
result = math_ops.cumsum(t, axis, exclusive, reverse)
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -212,7 +212,7 @@
def _compare(self, x, axis, exclusive, reverse):
np_out = handle_options(np.cumprod, x, axis, exclusive, reverse)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_out = math_ops.cumprod(x, axis, exclusive, reverse).eval()
self.assertAllClose(np_out, tf_out)
@@ -234,7 +234,7 @@
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in [dtypes.int64, dtypes.int32]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
axis = constant_op.constant(0, axis_dtype)
tf_out = math_ops.cumprod(x, axis).eval()
@@ -278,7 +278,7 @@
def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
input_tensor = ops.convert_to_tensor(x)
- with self.session(use_gpu=True):
+ with self.session():
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
@@ -294,7 +294,7 @@
def _compareGradient(self, shape, axis, exclusive, reverse):
x = np.arange(1, 9).reshape(shape).astype(np.float64)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
t = ops.convert_to_tensor(x)
result = math_ops.cumprod(t, axis, exclusive, reverse)
jacob_t, jacob_n = gradient_checker.compute_gradient(
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index b9206bf..5787098 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -134,7 +134,7 @@
repeat_indices=False,
updates_are_scalar=False):
np.random.seed(8)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
for indices_shape in (), (2,), (3, 7), (3, 4, 7):
for extra_shape in (), (5,), (5, 9):
# Generate random indices with no duplicates for easy numpy comparison
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 6a9350b..d4ff43b 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -307,7 +307,7 @@
ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
tf_x, np_x = self._input(shape, dtype=dtype)
for use_gpu in [True, False]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
for np_op1, np_op2, tf_op, init_op in ops_list:
# sqrt_n doesn't support integers
if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer):
@@ -333,7 +333,7 @@
for indices in indices_flat, indices_flat.reshape(5, 2):
shape = indices.shape + (2,)
for dtype in dtypes:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_x, np_x = self._input(shape)
num_segments_constant = constant_op.constant(
num_segments, dtype=dtype)
@@ -433,7 +433,7 @@
shape = [n, num_cols]
num_segments = max(indices) + 1
for dtype in self.differentiable_dtypes:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_x, np_x = self._input(shape, dtype=dtype)
# Results from UnsortedSegmentSum
unsorted_s = math_ops.unsorted_segment_sum(
@@ -470,7 +470,7 @@
def testEmptySecondDimension(self):
dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32,
np.complex64, np.complex128]
- with self.session(use_gpu=True):
+ with self.session():
for dtype in dtypes:
for itype in (np.int32, np.int64):
data = np.zeros((2, 0), dtype=dtype)
@@ -486,7 +486,7 @@
for indices in indices_flat, indices_flat.reshape(5, 2):
shape = indices.shape + (2,)
for dtype in self.all_dtypes:
- with self.session(use_gpu=True):
+ with self.session():
tf_x, np_x = self._input(shape, dtype=dtype)
np_ans = self._segmentReduce(
indices, np_x, np.add, op2=None, num_segments=num_segments)
diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
index 40f8b31..64a8bc1 100644
--- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
+++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
@@ -55,7 +55,7 @@
@test_util.run_deprecated_v1
def testConcurrentExecutesWithoutError(self):
all_ops = []
- with self.session(use_gpu=True) as sess:
+ with self.session():
for compute_v_ in True, False:
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
@@ -84,7 +84,7 @@
"self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32)
self.assertEqual(matrix.shape, (32, 32))
matrix_tensor = constant_op.constant(matrix)
- with self.session(use_gpu=True) as sess:
+ with self.session():
(e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor))
self.assertEqual(e.size, 32)
self.assertAllClose(
@@ -156,7 +156,7 @@
else:
atol = 1e-12
np_e, np_v = np.linalg.eigh(a)
- with self.session(use_gpu=True):
+ with self.session():
if compute_v_:
tf_e, tf_v = linalg_ops.self_adjoint_eig(constant_op.constant(a))
@@ -211,7 +211,8 @@
tol = 1e-2
else:
tol = 1e-7
- with self.session(use_gpu=True):
+ with self.session():
+
def Compute(x):
e, v = linalg_ops.self_adjoint_eig(x)
# (complex) Eigenvectors are only unique up to an arbitrary phase
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 5a165c9..c5f6d02 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -267,7 +267,7 @@
for dtype in [dtypes.int32, dtypes.int64]:
x = np.zeros([2])
np_ans = np.expand_dims(x, axis=0)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tensor = array_ops.expand_dims(x, constant_op.constant(0, dtype))
tf_ans = self.evaluate(tensor)
self.assertShapeEqual(np_ans, tensor)
@@ -433,7 +433,7 @@
def testSimple(self):
# multiples could be int32 or int64
for dtype in [dtypes.int32, dtypes.int64]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inp = np.random.rand(4, 1).astype(np.float32)
a = constant_op.constant(inp)
tiled = array_ops.tile(a, constant_op.constant([1, 4], dtype=dtype))
@@ -505,7 +505,7 @@
bytes: (dtypes.string, bytes)
}
for dtype_np, (dtype_tf, cast) in types_to_test.items():
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inp = np.random.rand(4, 1).astype(dtype_np)
a = constant_op.constant(
[cast(x) for x in inp.ravel(order="C")],
@@ -601,7 +601,7 @@
@test_util.run_deprecated_v1
def testGradientSimpleReductionOnGPU(self):
- with self.session(use_gpu=True):
+ with self.session():
inp = np.random.rand(4, 1).astype("f")
a = constant_op.constant(
[float(x) for x in inp.flatten()], shape=[4, 1], dtype=dtypes.float32)
@@ -616,7 +616,7 @@
@test_util.run_deprecated_v1
def testGradientStridedReductionOnGPU(self):
- with self.session(use_gpu=True):
+ with self.session():
inp = np.random.rand(4, 2).astype("f")
a = constant_op.constant(
[float(x) for x in inp.flatten()], shape=[4, 2], dtype=dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/signal/dct_ops_test.py b/tensorflow/python/kernel_tests/signal/dct_ops_test.py
index d4f9e39..7379526 100644
--- a/tensorflow/python/kernel_tests/signal/dct_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/dct_ops_test.py
@@ -190,7 +190,7 @@
# "ortho" normalization is not implemented for type I.
if dct_type == 1 and norm == "ortho":
return
- with self.session(use_gpu=True):
+ with self.session():
tol = 5e-4 if dtype == np.float32 else 1e-7
signals = np.random.rand(*shape).astype(dtype)
n = np.random.randint(1, 2 * signals.shape[-1])
diff --git a/tensorflow/python/kernel_tests/signal/fft_ops_test.py b/tensorflow/python/kernel_tests/signal/fft_ops_test.py
index 762bdc5..7563a40 100644
--- a/tensorflow/python/kernel_tests/signal/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/fft_ops_test.py
@@ -87,7 +87,8 @@
if test.is_built_with_rocm():
self.skipTest("Complex datatype not yet supported in ROCm.")
return
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
+
def f(inx, iny):
inx.set_shape(x.shape)
iny.set_shape(y.shape)
@@ -123,12 +124,12 @@
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
# fft_length unused for complex FFTs.
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
return sess.run(self._tf_fft_for_rank(rank)(x), feed_dict=feed_dict)
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
# fft_length unused for complex FFTs.
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
return sess.run(self._tf_ifft_for_rank(rank)(x), feed_dict=feed_dict)
def _np_fft(self, x, rank, fft_length=None):
@@ -299,12 +300,12 @@
class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
return sess.run(
self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
return sess.run(
self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
diff --git a/tensorflow/python/kernel_tests/signal/shape_ops_test.py b/tensorflow/python/kernel_tests/signal/shape_ops_test.py
index 6d9c77a..dc99390 100644
--- a/tensorflow/python/kernel_tests/signal/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/shape_ops_test.py
@@ -327,7 +327,7 @@
def test_gradient_numerical(self):
if context.executing_eagerly():
return
- with self.session(use_gpu=True):
+ with self.session():
signal_shape = (2, 128)
signal = array_ops.ones(signal_shape)
frame_length = 33
diff --git a/tensorflow/python/kernel_tests/signal/spectral_ops_test.py b/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
index f7844c6..920d775 100644
--- a/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
@@ -266,7 +266,7 @@
# TODO(rjryan): Update gradient tests for Eager.
if context.executing_eagerly():
return
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
signal_length = 512
# An all-zero signal has all zero gradients with respect to the sum of the
diff --git a/tensorflow/python/kernel_tests/spacetobatch_op_test.py b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
index 0147f2b..97b23b8 100644
--- a/tensorflow/python/kernel_tests/spacetobatch_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
@@ -101,7 +101,7 @@
"""
def _testPad(self, inputs, paddings, block_size, outputs):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# outputs = space_to_batch(inputs)
x_tf = self.space_to_batch(
math_ops.cast(inputs, dtypes.float32),
@@ -327,7 +327,7 @@
array_ops.space_to_depth(
array_ops.transpose(x, [3, 1, 2, 0]), block_size=block_size),
[3, 1, 2, 0])
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllEqual(y1, y2)
@@ -526,7 +526,7 @@
# Check the gradients.
def _checkGrad(self, x, paddings, block_size):
assert 4 == x.ndim
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_x = ops.convert_to_tensor(x)
tf_y = self.space_to_batch(tf_x, paddings, block_size)
epsilon = 1e-5
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
index 8e2115f..0e000c1 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
@@ -73,7 +73,7 @@
matmul = sparse_ops.sparse_tensor_dense_matmul(
sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
dense_t_shape = [m, k] if adjoint_b else [k, m]
sp_t_val_shape = [nnz]
err = gradient_checker.compute_gradient_error(
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index 8ec1756..2abc4e2 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -66,7 +66,7 @@
x_values = x[np.where(x)]
x_shape = x.shape
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
sp_x_value = sparse_tensor.SparseTensorValue(
indices=x_indices, values=x_values, dense_shape=x_shape)
tf_value_ans = sparse_ops.sparse_tensor_dense_matmul(
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index c53f196..f709acd 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -64,7 +64,7 @@
def _testXent(self, np_features, np_labels):
np_loss, np_backprop = self._npXent(np_features, np_labels)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
np_features, np_labels)
tf_loss, tf_backprop = self.evaluate([loss, backprop])
@@ -73,7 +73,7 @@
def testSingleClass(self):
for label_dtype in np.int32, np.int64:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
np.array([[1.], [-1.], [0.]]).astype(np.float32),
np.array([0, 0, 0]).astype(label_dtype))
@@ -145,19 +145,19 @@
np.array([1.3862, 3.4420]), np_loss, rtol=1.e-3, atol=1.e-3)
def testShapeMismatch(self):
- with self.session(use_gpu=True):
+ with self.session():
with self.assertRaisesRegex(ValueError, ".*Rank mismatch:*"):
nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=[[0, 2]], logits=[[0., 1.], [2., 3.], [2., 3.]])
def testScalar(self):
- with self.session(use_gpu=True):
+ with self.session():
with self.assertRaisesRegex(ValueError, ".*Logits cannot be scalars*"):
nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=constant_op.constant(0), logits=constant_op.constant(1.0))
def testLabelsPlaceholderScalar(self):
- with ops_lib.Graph().as_default(), self.session(use_gpu=True):
+ with ops_lib.Graph().as_default(), self.session():
labels = array_ops.placeholder(np.int32)
y = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=[[7.]])
@@ -165,7 +165,7 @@
y.eval(feed_dict={labels: 0})
def testVector(self):
- with self.session(use_gpu=True):
+ with self.session():
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=constant_op.constant(0), logits=constant_op.constant([1.0]))
self.assertAllClose(0.0, self.evaluate(loss))
@@ -193,7 +193,7 @@
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testGradient(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
l = constant_op.constant([3, 0, 1], name="l")
f = constant_op.constant(
[0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
index 16f92db..58674ab 100644
--- a/tensorflow/python/kernel_tests/split_op_test.py
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -55,13 +55,13 @@
model_input = array_ops.placeholder(dtypes.float32)
inp = np.zeros((1, 10))
# check that we still fail at runtime if the shapes were unknown
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(array_ops.split(model_input, [4]), {model_input: inp})
# scalar Tensors are not permitted as num_splits
for axis in [0, -2]:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
with self.assertRaises(ValueError):
# pylint: disable=expression-not-assigned
sess.run(
@@ -83,7 +83,7 @@
model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
result = array_ops.split(model_input2, [2, 2], axis=0)[0]
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
sess.run(result, feed_dict={model_input2: np.ones([4, 2])})
@test_util.run_deprecated_v1
@@ -92,7 +92,7 @@
value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
with self.assertRaises(ValueError) as context:
sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]})
self.assertTrue("Cannot infer num from shape" in str(context.exception))
@@ -214,7 +214,7 @@
@test_util.run_deprecated_v1
def testOutputShape(self):
for axis in [1, -1]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
size_splits = [3, 7, 2]
outputs = array_ops.split(tensor, size_splits, axis)
@@ -315,7 +315,7 @@
def _testGradientsSimple(self, dtype):
inp = self._makeData((4, 4), dtype)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inp_tensor = ops.convert_to_tensor(inp)
s = array_ops.split(value=inp_tensor, num_or_size_splits=4, axis=1)
inp_grads = [self._makeData((4, 1), dtype)for _ in range(4)]
@@ -382,7 +382,7 @@
splits = array_ops.placeholder(dtypes.int32, [3])
y = array_ops.split(values, splits, axis=x)
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"must have exactly one element"):
sess.run(y, {x: np.array([], dtype=np.int32), splits: [4, 11, 15]})
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py
index 29cd00b..8ea4c5d 100644
--- a/tensorflow/python/kernel_tests/stage_op_test.py
+++ b/tensorflow/python/kernel_tests/stage_op_test.py
@@ -43,7 +43,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
sess.run(stage, feed_dict={x: -1})
for i in range(10):
_, yval = sess.run([stage, y], feed_dict={x: i})
@@ -63,7 +63,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
sess.run(stage, feed_dict={x: -1})
for i in range(10):
_, yval = sess.run([stage, y], feed_dict={x: i})
@@ -89,7 +89,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
sess.run(stage, feed_dict={x: -1})
for i in range(10):
_, yval = sess.run([stage, y], feed_dict={x: i})
@@ -131,7 +131,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
for i in range(10):
sess.run(stage, feed_dict={x: i})
@@ -156,7 +156,7 @@
G.finalize()
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
sess.run(stage, feed_dict={x: -1})
self.assertEqual(sess.run(size), 1)
sess.run(stage, feed_dict={x: -1})
@@ -189,7 +189,7 @@
queue = Queue.Queue()
n = 8
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# Stage data in a separate thread which will block
# when it hits the staging area's capacity and thus
# not fill the queue with n tokens
@@ -254,7 +254,7 @@
queue = Queue.Queue()
n = 8
- with self.session(use_gpu=True, graph=G) as sess:
+ with self.session(graph=G) as sess:
# Stage data in a separate thread which will block
# when it hits the staging area's capacity and thus
# not fill the queue with n tokens
diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py
index 8bbfc51..d64697b 100644
--- a/tensorflow/python/kernel_tests/svd_op_test.py
+++ b/tensorflow/python/kernel_tests/svd_op_test.py
@@ -163,7 +163,7 @@
if use_static_shape_:
s_tf_val, u_tf_val, v_tf_val = self.evaluate([s_tf, u_tf, v_tf])
else:
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
s_tf_val, u_tf_val, v_tf_val = sess.run(
[s_tf, u_tf, v_tf], feed_dict={x_tf: x_np})
else:
@@ -172,7 +172,7 @@
if use_static_shape_:
s_tf_val = self.evaluate(s_tf)
else:
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
s_tf_val = sess.run(s_tf, feed_dict={x_tf: x_np})
if compute_uv_:
@@ -284,7 +284,7 @@
epsilon = np.finfo(dtype_).eps
delta = 0.1 * epsilon**(1.0 / 3.0)
tol = 1e-5
- with self.session(use_gpu=True):
+ with self.session():
tf_a = constant_op.constant(a)
if compute_uv_:
tf_s, tf_u, tf_v = _NormalizingSvd(tf_a, full_matrices_)
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 892c585..a642d33 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -83,7 +83,7 @@
@test_util.run_in_graph_and_eager_modes
def testTensorArrayWriteRead(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -104,7 +104,7 @@
self.assertAllEqual(-3.0, d2)
def _testTensorArrayWritePack(self, tf_dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -133,7 +133,7 @@
self._testTensorArrayWritePackMaybeLegacy()
def testEmptyTensorArrayPack(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
@@ -148,7 +148,7 @@
self.assertAllEqual([3, 0, 1], c0.shape)
def testTensorArrayWriteConcatInParallel(self):
- with self.session(use_gpu=True):
+ with self.session():
def _concat_1():
ta = tensor_array_ops.TensorArray(
@@ -189,7 +189,7 @@
self.assertAllEqual([1, 1, 1, 8, 9, 8, 9, 8, 9], c0)
def _testTensorArrayWriteConcat(self, tf_dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3, infer_shape=False)
@@ -217,7 +217,7 @@
self._testTensorArrayWriteConcat(dtypes.string)
def _testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -251,7 +251,7 @@
@test_util.run_v1_only("Uses placeholders")
def testSkipEagerTensorArrayReadUninitializedInferShapeFillsZeros(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -261,7 +261,7 @@
[[0.0, 0.0]], sess.run(ta.write(1, val).read(0), {val: [[4.0, 5.0]]}))
def _testTensorArrayUnpackRead(self, tf_dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
convert = _make_converter(tf_dtype)
ta = _make_ta(3, "foo", dtype=tf_dtype)
@@ -311,7 +311,7 @@
self._testTensorArrayUnpackReadMaybeLegacy()
def _testTensorArraySplitRead(self, tf_dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
convert = _make_converter(tf_dtype)
# Split an empty vector
@@ -365,7 +365,7 @@
@test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
@test_util.run_v1_only("v2 does not support TensorArray.grad.")
def testSkipEagerTensorGradArrayWriteRead(self):
- with self.session(use_gpu=True) as session:
+ with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -401,7 +401,7 @@
def testSkipEagerTensorArrayGradGrad(self):
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.skipTest("Legacy TensorArray does not support double derivatives.")
- with self.test_session(use_gpu=True) as session:
+ with self.test_session() as session:
x = constant_op.constant(4.0)
ta = tensor_array_ops.TensorArray(
@@ -420,7 +420,7 @@
@test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
@test_util.run_v1_only("v2 does not support TensorArray.grad.")
def testSkipEagerTensorGradArrayDynamicWriteRead(self):
- with self.session(use_gpu=True) as session:
+ with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -463,7 +463,7 @@
@test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
@test_util.run_v1_only("v2 does not support TensorArray.grad.")
def testSkipEagerTensorGradAccessTwiceReceiveSameObject(self):
- with self.session(use_gpu=True) as session:
+ with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
g_ta_0 = ta.grad("grad")
@@ -479,7 +479,7 @@
self.assertAllEqual([[4.0, 5.0]], d_r1_0)
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = _make_ta(3, "foo", dtype=dtypes.float32)
# TODO(b/129870929): Remove the last 2 checks (runtime checks) after
# back back from preferred_dtype= to dtype= in convert_to_tensor. Also
@@ -518,7 +518,7 @@
self.evaluate(ta.write(3, 3.0).flow)
def testTensorArrayReadWrongIndexOrDataTypeFails(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = _make_ta(3, "foo", dtype=dtypes.float32)
w0 = ta.write(0, [[4.0, 5.0]])
@@ -553,7 +553,7 @@
@test_util.disable_control_flow_v2("v2 allows multiple writes.")
@test_util.run_v1_only("v2 allows multiple writes.")
def testSkipEagerTensorArrayWriteMultipleFails(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
@@ -563,7 +563,7 @@
self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
def testTensorArrayConcatIncompatibleShapesFails(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -597,7 +597,7 @@
self.evaluate(w3.concat())
def testTensorArraySplitIncompatibleShapesFails(self):
- with self.session(use_gpu=True):
+ with self.session():
in_eager_mode = context.executing_eagerly()
ta = _make_ta(3, "foo")
with self.assertRaisesOpError(
@@ -636,7 +636,7 @@
self.evaluate(ta.split([1.0], [1]).flow)
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)
ta_grad = ta.grad("grad")
@@ -679,7 +679,7 @@
@test_util.disable_control_flow_v2("Low level legacy TA op test.")
@test_util.run_v1_only("Low level legacy TA op test.")
def testSkipEagerTensorArrayGradWithShapeKnownElementShape(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
ta = tensor_array_ops.TensorArray(
size=3,
dtype=dtypes.float32,
@@ -710,7 +710,7 @@
@test_util.disable_control_flow_v2("Low level legacy TA op test.")
@test_util.run_v1_only("Low level legacy TA op test.")
def testSkipEagerTensorArrayGradWithShapeUnknownElementShape(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
ta = tensor_array_ops.TensorArray(
size=3, dtype=dtypes.float32,
element_shape=None) # Note that element_shape is unknown
@@ -733,7 +733,7 @@
sess.run(read_value, feed_dict={value: fed_value}))
def testMultiTensorArray(self):
- with self.session(use_gpu=True):
+ with self.session():
h1 = tensor_array_ops.TensorArray(
size=1, dtype=dtypes.float32, tensor_array_name="foo")
w1 = h1.write(0, 4.0)
@@ -749,7 +749,7 @@
self.assertAllClose(9.0, val)
def _testTensorArrayGradientWriteReadType(self, dtype):
- with self.cached_session(use_gpu=True) as session:
+ with self.cached_session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.as_dtype(dtype),
tensor_array_name="foo",
@@ -801,7 +801,7 @@
self._testTensorArrayGradientWriteReadType(dtype)
def _testTensorArrayGradientWritePackConcatAndRead(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -839,7 +839,7 @@
@test_util.disable_control_flow_v2("v2 does not support clear_after_read.")
@test_util.run_v1_only("v2 does not support clear_after_read.")
def testTensorArrayReadTwice(self):
- with self.session(use_gpu=True):
+ with self.session():
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
ta_readonce = tensor_array_ops.TensorArray(
@@ -867,7 +867,7 @@
self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice))
def _testTensorArrayGradientUnpackRead(self):
- with self.cached_session(use_gpu=True) as session:
+ with self.cached_session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -897,7 +897,7 @@
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayGradientSplitConcat(self):
- with self.session(use_gpu=True) as session:
+ with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=2,
infer_shape=False)
@@ -920,7 +920,7 @@
grad_vals[0])
def _testTensorArrayGradientDynamicUnpackRead(self):
- with self.cached_session(use_gpu=True) as session:
+ with self.cached_session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -946,20 +946,20 @@
self._testTensorArrayGradientDynamicUnpackRead()
def testCloseTensorArray(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
self.evaluate(ta.close())
def testSizeTensorArray(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
s = ta.size()
self.assertAllEqual(3, self.evaluate(s))
def testWriteCloseTensorArray(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -971,7 +971,8 @@
def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
np_dtype = dtype.as_numpy_dtype
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
+
def func(v0, state0, var):
ta = tensor_array_ops.TensorArray(
dtype=dtype,
@@ -1068,7 +1069,8 @@
dynamic_size=True, dtype=dtypes.float32)
def testGradSerialTwoLoops(self):
- with self.session(use_gpu=True):
+ with self.session():
+
def loop(x):
num_steps = 100
acc = tensor_array_ops.TensorArray(
@@ -1117,7 +1119,7 @@
@test_util.deprecated_graph_mode_only
def testSkipEagerSumOfTwoReadVariablesWithoutRepeatGrad(self):
- with self.session(use_gpu=True) as session:
+ with self.session() as session:
a = array_ops.identity(
np.arange(
3 * 5, dtype=np.float32).reshape(3, 5) + 1)
@@ -1195,7 +1197,7 @@
@test_util.deprecated_graph_mode_only
def testSkipEagerWriteShape(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c0 = constant_op.constant([4.0, 5.0])
@@ -1220,7 +1222,7 @@
@test_util.deprecated_graph_mode_only
def testSkipEagerPartlyUnknownShape(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=6)
@@ -1260,7 +1262,7 @@
self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list())
def _testUnpackShape(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -1297,7 +1299,7 @@
@test_util.deprecated_graph_mode_only
def testSplitShape(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -1329,7 +1331,7 @@
@test_util.deprecated_graph_mode_only
def testSkipEagerWriteUnknownShape(self):
- with self.session(use_gpu=True):
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -1341,7 +1343,7 @@
self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
def _testGradientWhenNotAllComponentsRead(self):
- with self.cached_session(use_gpu=True) as session:
+ with self.cached_session() as session:
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
x = constant_op.constant([2.0, 3.0])
w = ta.unstack(x)
@@ -1357,7 +1359,7 @@
@test_util.deprecated_graph_mode_only
def testSkipEagerWriteButNotAllComponentsReadGrad(self):
- with self.cached_session(use_gpu=True) as session:
+ with self.cached_session() as session:
x0 = constant_op.constant(5.0)
x1 = constant_op.constant(10.0)
ta = tensor_array_ops.TensorArray(
@@ -1369,7 +1371,7 @@
self.assertAllEqual(grad_r0_x1_vals, [1.0, 0.0])
def _testTensorArrayUnpackDynamic(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=3, dynamic_size=True)
x = constant_op.constant([1.0, 2.0, 3.0])
@@ -1386,7 +1388,7 @@
@test_util.run_deprecated_v1
def testSkipEagerTensorArraySplitDynamic(self):
- with self.session(use_gpu=True) as sess:
+ with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=3, dynamic_size=True)
x = constant_op.constant([1.0, 2.0, 3.0])
@@ -1449,7 +1451,7 @@
ta_gather_with_unknown_indices_shape([0])
def _testTensorArrayEvalEmpty(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=False)
v2_msg = ("Tried to stack elements of an empty list with "
@@ -1469,7 +1471,7 @@
# this test is ill-defined for Eager mode --- unpacking an empty tensor
# gives an empty list / there is not equivalent of "mark_used" in Eager
def _testTensorArrayEvalEmptyWithDefault(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=True)
self.assertEqual(0, ta.size().eval())
@@ -1491,7 +1493,7 @@
@test_util.run_deprecated_v1
def testSkipEagerTensorArrayScatterReadAndGradients(self):
- with self.session(use_gpu=True) as session:
+ with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -1518,7 +1520,7 @@
@test_util.run_deprecated_v1
def testSkipEagerTensorArrayScatterPartialReadAndGradients(self):
- with self.session(use_gpu=True) as session:
+ with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -1554,7 +1556,7 @@
@test_util.run_v1_only("b/118890905")
def testTensorArrayWriteGatherAndGradients(self):
- with self.session(use_gpu=True) as session:
+ with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -1703,7 +1705,7 @@
[s for s in dev_stats[d] if "TensorArray" == s.node_name])
def testTensorArrayIdentity(self):
- with self.session(use_gpu=True):
+ with self.session():
ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
infer_shape=False)
ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4,
@@ -1769,7 +1771,7 @@
# dy is outside of the gradients name scope; tf.gradients must
# wrap it in the correct name scope.
dx, = gradients_impl.gradients(ys=[y], xs=[x], grad_ys=[dy])
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
vdx, vdy = self.evaluate([dx, dy])
self.assertAllClose(vdx, vdy)
@@ -1777,7 +1779,7 @@
def testSkipEagerTensorArrayInt64GPU(self):
if not test.is_gpu_available():
return
- with self.session(use_gpu=True, force_gpu=True) as sess:
+ with self.session(force_gpu=True) as sess:
value = array_ops.placeholder(dtypes.int64)
ta = tensor_array_ops.TensorArray(dtype=dtypes.int64, size=2)
ta = ta.scatter([0, 1], value)
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index 268f689..845b634 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -179,7 +179,7 @@
for _ in range(num_trials):
a_np, b_np, a_dims_np, b_dims_np = _generate_random_tensors_and_dims()
np_ans = np.tensordot(a_np, b_np, axes=(a_dims_np, b_dims_np))
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
if dynamic_shape_:
a = array_ops.placeholder(dtype_)
b = array_ops.placeholder(dtype_)
@@ -219,7 +219,7 @@
all_axes.append(a_np.ndim - 1)
for axes in all_axes:
np_ans = np.tensordot(a_np, b_np, axes=axes)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
if dynamic_shape_:
a = array_ops.placeholder(dtype_)
b = array_ops.placeholder(dtype_)
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
index b17a8f0..e5c8e17 100644
--- a/tensorflow/python/kernel_tests/topk_op_test.py
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -47,7 +47,7 @@
sorted=True): # pylint: disable=redefined-builtin
np_expected_values = np.array(expected_values)
np_expected_indices = np.array(expected_indices)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
values_op, indices_op = nn_ops.top_k(inputs, k, sorted=sorted)
values, indices = self.evaluate([values_op, indices_op])
@@ -196,7 +196,7 @@
@test_util.run_deprecated_v1
def testKNegative(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
- with self.session(use_gpu=True):
+ with self.session():
k = array_ops.placeholder(dtypes.int32)
values, _ = nn_ops.top_k(inputs, k)
with self.assertRaisesOpError("Need k >= 0, got -7"):
@@ -211,7 +211,7 @@
@test_util.run_deprecated_v1
def testTopKGradients(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
inputs = array_ops.placeholder(dtypes.float32, shape=[2, 5])
values, _ = nn_ops.top_k(inputs, 3)
grad = sess.run(
diff --git a/tensorflow/python/kernel_tests/trace_op_test.py b/tensorflow/python/kernel_tests/trace_op_test.py
index 52640c0..6812030 100644
--- a/tensorflow/python/kernel_tests/trace_op_test.py
+++ b/tensorflow/python/kernel_tests/trace_op_test.py
@@ -31,7 +31,7 @@
def compare(self, x):
np_ans = np.trace(x, axis1=-2, axis2=-1)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_ans = math_ops.trace(x).eval()
self.assertAllClose(tf_ans, np_ans)
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 8709621..2c6f5ea 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -79,7 +79,7 @@
np_ans = self._np_transpose(x, perm)
if conjugate:
np_ans = np.conj(np_ans)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
y = array_ops.transpose(inx, p, conjugate=conjugate)
tf_ans = self.evaluate(y)
@@ -170,7 +170,7 @@
inp = np.arange(
1, total_size + 1, dtype=datatype).reshape(input_shape)
np_ans = self._np_transpose(inp, perm)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = self.evaluate(y)
@@ -193,7 +193,7 @@
inp = np.arange(
1, total_size + 1, dtype=np.float32).reshape(input_shape)
np_ans = self._np_transpose(inp, perm)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = self.evaluate(y)
@@ -230,7 +230,7 @@
inp = np.arange(
1, total_size + 1, dtype=np.float32).reshape(input_shape)
np_ans = self._np_transpose(inp, perm)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = self.evaluate(y)
@@ -255,7 +255,7 @@
inp = np.arange(
1, total_size + 1, dtype=datatype).reshape(input_shape)
np_ans = self._np_transpose(inp, perm)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = self.evaluate(y)
@@ -278,7 +278,7 @@
inp = np.arange(
1, total_size + 1, dtype=np.float32).reshape(input_shape)
np_ans = self._np_transpose(inp, perm)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = self.evaluate(y)
@@ -331,7 +331,7 @@
with self.subTest(input_shape=input_shape, perm=perm):
inp = np.random.randint(10, size=input_shape)
np_ans = self._np_transpose(inp, perm)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = self.evaluate(y)
@@ -355,7 +355,7 @@
x = np.arange(0, 8).reshape([2, 4]).astype(np.float32)
p = np.array([1, 0]).astype(perm_dtype)
np_ans = np.copy(x).transpose(p)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
inp = constant_op.constant(p)
y = array_ops.transpose(inx, inp)
diff --git a/tensorflow/python/kernel_tests/tridiagonal_matmul_op_test.py b/tensorflow/python/kernel_tests/tridiagonal_matmul_op_test.py
index 456f13e..3854400 100644
--- a/tensorflow/python/kernel_tests/tridiagonal_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/tridiagonal_matmul_op_test.py
@@ -80,7 +80,7 @@
diags_matrix_batch, rhs_batch, diagonals_format='matrix')
]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
results = self.evaluate(results)
results_batch = self.evaluate(results_batch)
@@ -114,7 +114,7 @@
diags = constant_op.constant(diags, dtype=dtype)
rhs = constant_op.constant(rhs, dtype=dtype)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
grad_reference, _ = gradient_checker_v2.compute_gradient(
reference_matmul, [diags, rhs])
grad_theoretical, grad_numerical = gradient_checker_v2.compute_gradient(
@@ -155,7 +155,7 @@
constant_op.constant(rhs, dtype=dtypes.complex128),
diagonals_format='matrix')
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
result = self.evaluate(result)
self.assertAllClose(result, expected_result)
diff --git a/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py b/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py
index 3045461..c278fed 100644
--- a/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py
@@ -77,7 +77,7 @@
diags_format="compact",
transpose_rhs=False,
conjugate_rhs=False):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
pivoting = True
if hasattr(self, "pivoting"):
pivoting = self.pivoting
@@ -412,7 +412,7 @@
transpose_rhs=transpose_rhs,
conjugate_rhs=conjugate_rhs)
res = math_ops.reduce_sum(x * y)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
actual_grad_diags = sess.run(
tape_diags.gradient(res, diags), feed_dict=feed_dict)
actual_rhs_diags = sess.run(
@@ -563,7 +563,7 @@
return
x = linalg_impl.tridiagonal_solve(
diags, rhs, diags_format, partial_pivoting=self.pivoting)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
result = sess.run(x, feed_dict={diags: diags_feed, rhs: rhs_feed})
self.assertAllClose(result, expected)
@@ -648,7 +648,7 @@
rhs,
diagonals_format="sequence",
partial_pivoting=self.pivoting)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
result = sess.run(
x,
feed_dict={
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index e04bf97..6396695 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -150,7 +150,7 @@
@test_util.run_deprecated_v1
def testResourceAssignments(self):
- with self.session(use_gpu=True):
+ with self.session():
var = resource_variable_ops.ResourceVariable(0.0)
plus_one = var.assign_add(1.0)
minus_one = var.assign_sub(2.0)
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index c16d016..b54cb7e 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -38,7 +38,7 @@
class WhereOpTest(test.TestCase):
def _testWhere(self, x, truth, expected_err_re=None, fn=array_ops.where):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
ans = fn(x)
self.assertTrue(ans.get_shape().is_compatible_with([None, x.ndim]))
if expected_err_re is None:
@@ -49,7 +49,7 @@
self.evaluate(ans)
def _testWrongNumbers(self, fn=array_ops.where):
- with self.session(use_gpu=True):
+ with self.session():
with self.assertRaises(ValueError):
fn([False, True], [1, 2], None)
with self.assertRaises(ValueError):
@@ -103,7 +103,7 @@
def _testThreeArgument(self, fn=array_ops.where):
x = np.array([[-2, 3, -1], [1, -3, -3]])
np_val = np.where(x > 0, x * x, -x)
- with self.test_session(use_gpu=True):
+ with self.test_session():
tf_val = self.evaluate(fn(constant_op.constant(x) > 0, x * x, -x))
self.assertAllEqual(tf_val, np_val)
@@ -223,7 +223,7 @@
x = np.zeros((7, 11))
y = np.ones((7, 11))
np_val = np.where(f < 0, x, y)
- with self.test_session(use_gpu=True):
+ with self.test_session():
tf_val = self.evaluate(
array_ops.where_v2(constant_op.constant(f) < 0, x, y))
self.assertAllEqual(tf_val, np_val)
@@ -232,7 +232,7 @@
x = np.zeros((7, 11))
y = np.ones((7, 11))
np_val = np.where(True, x, y)
- with self.test_session(use_gpu=True):
+ with self.test_session():
tf_val = self.evaluate(
array_ops.where_v2(
constant_op.constant(True, dtype=dtypes.bool), x, y))
@@ -242,7 +242,7 @@
x = np.zeros(7)
y = np.ones(7)
np_val = np.where([True], x, y)
- with self.test_session(use_gpu=True):
+ with self.test_session():
tf_val = self.evaluate(
array_ops.where_v2(
constant_op.constant([True], dtype=dtypes.bool), x, y))
@@ -253,7 +253,7 @@
x = np.random.randn(3, 4)
y = np.random.randn(3, 4)
np_val = np.where(pred, x, y)
- with self.test_session(use_gpu=True):
+ with self.test_session():
tf_val = self.evaluate(array_ops.where_v2(pred, x, y))
self.assertAllClose(tf_val, np_val)
@@ -263,7 +263,7 @@
c_mat = np.array([[False] * 192, [True] * 192] * 8192) # [16384, 192]
c_vec = np.array([False, True] * 8192) # [16384]
np_val = np.where(c_mat, x * x, -x)
- with self.session(use_gpu=True):
+ with self.session():
tf_val = array_ops.where(c_vec, x * x, -x).eval()
self.assertAllEqual(tf_val, np_val)
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 6e60a93..b1adbd3 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -319,7 +319,7 @@
features = np.zeros([0, 2, 4]).astype(np.float32)
labels = np.zeros([0, 2, 4]).astype(np.float32)
np_loss, _ = self._npXent(features, labels)
- with self.session(use_gpu=True) as sess:
+ with self.session():
loss = nn_ops.softmax_cross_entropy_with_logits(
labels=labels, logits=features)
tf_loss = self.evaluate(loss)
diff --git a/tensorflow/python/ops/batch_ops_test.py b/tensorflow/python/ops/batch_ops_test.py
index fb8746e..b63c4c8 100644
--- a/tensorflow/python/ops/batch_ops_test.py
+++ b/tensorflow/python/ops/batch_ops_test.py
@@ -56,7 +56,7 @@
"""Tests that a single batched tensor executes together and only once."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -98,7 +98,7 @@
"""Test that batching with padding up to an allowed batch size works."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -130,7 +130,7 @@
"""Tests that multiple batched tensors execute together."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
@@ -171,7 +171,7 @@
"""Tests illegally feeding tensors with different dim0 sizes."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
@@ -187,7 +187,7 @@
"""Tests that batch and unbatch work together."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -213,7 +213,7 @@
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@@ -241,7 +241,7 @@
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2., shape=[])
captured_inp1 = resource_variable_ops.ResourceVariable(3.)
with ops.device("/cpu:0"):
@@ -270,7 +270,7 @@
def testBatchDecoratedGpu(self):
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
@@ -292,6 +292,31 @@
self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()])
self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()])
+ def testParallelRunsWithCpuAndGpu(self):
+ # Run multiple instances of a batch function in parallel. This is a
+ # regression test: this used to fail because _Send nodes for one call would
+ # send the tensor to the _Recv node for a different call.
+ if context.executing_eagerly():
+ return
+ @batch_ops.batch_function(1, 2, 1)
+ def f(x):
+ with ops.device("/GPU:0"):
+ x = x + 1.
+ with ops.device("/CPU:0"):
+ return x + 1
+ num_calls = 10
+ placeholders = [array_ops.placeholder(dtypes.float32, shape=(1,))
+ for _ in range(num_calls)]
+ results = []
+ for p in placeholders:
+ (result,) = f(p)
+ results.append(result)
+ inputs = [[float(i)] for i in range(num_calls)]
+ expected = [[float(i + 2)] for i in range(num_calls)]
+ with self.session() as sess:
+ outputs = sess.run(results, feed_dict=dict(zip(placeholders, inputs)))
+ self.assertAllEqual(outputs, expected)
+
def testSoftPlacement(self):
if context.executing_eagerly():
return
@@ -324,7 +349,7 @@
"""Tests that the batch_function op works."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@@ -355,7 +380,7 @@
"""Tests that batch_function op works with captured input."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -391,7 +416,7 @@
"""Tests that batch_function op works with error in the inputs."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
@@ -415,13 +440,10 @@
def testBatchFunctionOpWithLargeBatchSplitted(self):
"""Tests that the batch_function op works with large batch splitted."""
- if test_util.is_xla_enabled():
- self.skipTest("b/178649404")
-
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@@ -475,7 +497,7 @@
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
@@ -499,7 +521,7 @@
"""Tests that the unbatch timeout works."""
if context.executing_eagerly():
return
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py
index d154b67..2a24db8 100644
--- a/tensorflow/python/ops/bitwise_ops_test.py
+++ b/tensorflow/python/ops/bitwise_ops_test.py
@@ -39,7 +39,7 @@
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
for dtype in dtype_list:
lhs = constant_op.constant([0, 5, 3, 14], dtype=dtype)
rhs = constant_op.constant([5, 0, 7, 11], dtype=dtype)
@@ -62,7 +62,7 @@
def count_bits(x):
return sum(bin(z).count("1") for z in six.iterbytes(x.tobytes()))
for dtype in dtype_list:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
print("PopulationCount test: ", dtype)
inputs = np.array(raw_inputs, dtype=dtype.as_numpy_dtype)
truth = [count_bits(x) for x in inputs]
@@ -76,7 +76,7 @@
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
inputs = [0, 5, 3, 14]
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
for dtype in dtype_list:
# Because of issues with negative numbers, let's test this indirectly.
# 1. invert(a) and a = 0
@@ -101,7 +101,7 @@
dtype_list = [np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64]
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
for dtype in dtype_list:
lhs = np.array([0, 5, 3, 14], dtype=dtype)
rhs = np.array([5, 0, 7, 3], dtype=dtype)
@@ -115,7 +115,7 @@
def testShiftsWithNegativeLHS(self):
dtype_list = [np.int8, np.int16, np.int32, np.int64]
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
for dtype in dtype_list:
lhs = np.array([-1, -5, -3, -14], dtype=dtype)
rhs = np.array([5, 0, 7, 11], dtype=dtype)
@@ -129,7 +129,7 @@
def testImplementationDefinedShiftsDoNotCrash(self):
dtype_list = [np.int8, np.int16, np.int32, np.int64]
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
for dtype in dtype_list:
lhs = np.array([-1, -5, -3, -14], dtype=dtype)
rhs = np.array([-2, 64, 101, 32], dtype=dtype)
@@ -146,7 +146,7 @@
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16]
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
for dtype in dtype_list:
lhs = constant_op.constant([[0], [3], [5]], dtype=dtype)
rhs = constant_op.constant([[1, 2, 4]], dtype=dtype)
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 5bd31aa..0d49bb7 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -749,7 +749,7 @@
mat_value = rng.randn(m, m).astype("float32")
x_value = rng.randn(m).astype("float32")
hess_value = mat_value + mat_value.T
- with self.session(use_gpu=True):
+ with self.session():
mat = constant_op.constant(mat_value)
x = constant_op.constant(x_value)
x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
@@ -766,7 +766,7 @@
mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
x_values = [rng.randn(m).astype("float32") for _ in range(n)]
hess_values = [mat_value + mat_value.T for mat_value in mat_values]
- with self.session(use_gpu=True):
+ with self.session():
mats = [constant_op.constant(mat_value) for mat_value in mat_values]
xs = [constant_op.constant(x_value) for x_value in x_values]
xs_mats_xs = [
@@ -781,7 +781,7 @@
@test_util.run_v1_only("b/120545219")
def testHessianInvalidDimension(self):
for shape in [(10, 10), None]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape)
# Expect a ValueError because the dimensions are wrong
with self.assertRaises(ValueError):
@@ -795,7 +795,7 @@
m = 3
rng = np.random.RandomState([1, 2, 3])
x_value = rng.randn(m, m).astype("float32")
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant(x_value)
x_square = math_ops.reduce_sum(
math_ops.matmul(array_ops.transpose(x), x) * 0.5
@@ -815,7 +815,7 @@
n = 4
rng = np.random.RandomState([1, 2, 3])
x_value = rng.randn(m, n).astype("float32")
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant(x_value)
x_square = math_ops.reduce_sum(
math_ops.matmul(array_ops.transpose(x), x) * 0.5
diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py
index 94217d9..da72e3b 100644
--- a/tensorflow/python/ops/histogram_ops_test.py
+++ b/tensorflow/python/ops/histogram_ops_test.py
@@ -109,7 +109,7 @@
value_range = [0.0, 5.0]
values = []
expected_bin_counts = [0, 0, 0, 0, 0]
- with self.session(use_gpu=True):
+ with self.session():
hist = histogram_ops.histogram_fixed_width(values, value_range, nbins=5)
self.assertEqual(dtypes.int32, hist.dtype)
self.assertAllClose(expected_bin_counts, self.evaluate(hist))
@@ -120,7 +120,7 @@
value_range = [0.0, 5.0]
values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
expected_bin_counts = [2, 1, 1, 0, 2]
- with self.session(use_gpu=True):
+ with self.session():
hist = histogram_ops.histogram_fixed_width(
values, value_range, nbins=5, dtype=dtypes.int64)
self.assertEqual(dtypes.int64, hist.dtype)
@@ -132,7 +132,7 @@
value_range = np.float64([0.0, 5.0])
values = np.float64([-1.0, 0.0, 1.5, 2.0, 5.0, 15])
expected_bin_counts = [2, 1, 1, 0, 2]
- with self.session(use_gpu=True):
+ with self.session():
hist = histogram_ops.histogram_fixed_width(values, value_range, nbins=5)
self.assertEqual(dtypes.int32, hist.dtype)
self.assertAllClose(expected_bin_counts, self.evaluate(hist))
@@ -143,7 +143,7 @@
value_range = [0.0, 5.0]
values = [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]]
expected_bin_counts = [2, 1, 1, 0, 2]
- with self.session(use_gpu=True):
+ with self.session():
hist = histogram_ops.histogram_fixed_width(values, value_range, nbins=5)
self.assertEqual(dtypes.int32, hist.dtype)
self.assertAllClose(expected_bin_counts, self.evaluate(hist))
@@ -154,7 +154,7 @@
values = [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]]
expected_bin_counts = [2, 1, 1, 0, 2]
placeholder = array_ops.placeholder(dtypes.int32)
- with self.session(use_gpu=True):
+ with self.session():
hist = histogram_ops.histogram_fixed_width(values, value_range, nbins=5)
self.assertAllEqual(hist.shape.as_list(), (5,))
self.assertEqual(dtypes.int32, hist.dtype)
diff --git a/tensorflow/python/ops/image_grad_test_base.py b/tensorflow/python/ops/image_grad_test_base.py
index 92cdf5d..5982176 100644
--- a/tensorflow/python/ops/image_grad_test_base.py
+++ b/tensorflow/python/ops/image_grad_test_base.py
@@ -50,7 +50,7 @@
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_nearest_neighbor(input_tensor,
out_shape[1:3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = self.evaluate(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
@@ -65,7 +65,7 @@
def resize_nn(t, shape=out_shape):
return image_ops.resize_nearest_neighbor(t, shape[1:3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(
@@ -82,7 +82,7 @@
def resize_nn(t, shape=out_shape):
return image_ops.resize_nearest_neighbor(t, shape[1:3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(
@@ -106,7 +106,7 @@
grad_cpu = gradient_checker_v2.compute_gradient(
resize_nn, [input_tensor], delta=1 / 8)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
grad_gpu = gradient_checker_v2.compute_gradient(
resize_nn, [input_tensor], delta=1 / 8)
@@ -444,7 +444,7 @@
constant_op.constant(boxes, shape=[num_boxes, 4]),
constant_op.constant(box_ind, shape=[num_boxes]),
constant_op.constant(crop_size, shape=[2]))
- with self.session(use_gpu=True) as sess:
+ with self.session():
self.assertEqual(crops_shape, list(crops.get_shape()))
crops = self.evaluate(crops)
self.assertEqual(crops_shape, list(crops.shape))
@@ -561,7 +561,7 @@
x = np.random.randint(0, high=255, size=[2, 20, 30, 3]).astype(nptype)
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
hsv_out = gen_image_ops.rgb_to_hsv(rgb_input_tensor)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertEqual(out_shape, list(hsv_out.get_shape()))
hsv_out = self.evaluate(hsv_out)
self.assertEqual(out_shape, list(hsv_out.shape))
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 11720b2..aec150c 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -901,10 +901,17 @@
"""
with ops.name_scope(None, 'central_crop', [image]):
image = ops.convert_to_tensor(image, name='image')
- if central_fraction <= 0.0 or central_fraction > 1.0:
- raise ValueError('central_fraction must be within (0, 1]')
- if central_fraction == 1.0:
- return image
+ central_fraction_static = tensor_util.constant_value(central_fraction)
+ if central_fraction_static is not None:
+ if central_fraction_static <= 0.0 or central_fraction_static > 1.0:
+ raise ValueError('central_fraction must be within (0, 1]')
+ if central_fraction_static == 1.0:
+ return image
+ else:
+ assert_ops = _assert(
+ math_ops.logical_or(central_fraction > 0.0, central_fraction <= 1.0),
+ ValueError, 'central_fraction must be within (0, 1]')
+ image = control_flow_ops.with_dependencies(assert_ops, image)
_AssertAtLeast3DImage(image)
rank = image.get_shape().ndims
@@ -932,24 +939,29 @@
img_w, dynamic_w = _get_dim(image, 2)
img_d = image.get_shape()[3]
+ dynamic_h = dynamic_h or (central_fraction_static is None)
+ dynamic_w = dynamic_w or (central_fraction_static is None)
+
# Compute the bounding boxes for the crop. The type and value of the
# bounding boxes depend on the `image` tensor's rank and whether / not the
# dimensions are statically defined.
if dynamic_h:
img_hd = math_ops.cast(img_h, dtypes.float64)
- bbox_h_start = math_ops.cast((img_hd - img_hd * central_fraction) / 2,
- dtypes.int32)
+ bbox_h_start = math_ops.cast(
+ (img_hd - img_hd * math_ops.cast(central_fraction, dtypes.float64)) /
+ 2, dtypes.int32)
else:
img_hd = float(img_h)
- bbox_h_start = int((img_hd - img_hd * central_fraction) / 2)
+ bbox_h_start = int((img_hd - img_hd * central_fraction_static) / 2)
if dynamic_w:
img_wd = math_ops.cast(img_w, dtypes.float64)
- bbox_w_start = math_ops.cast((img_wd - img_wd * central_fraction) / 2,
- dtypes.int32)
+ bbox_w_start = math_ops.cast(
+ (img_wd - img_wd * math_ops.cast(central_fraction, dtypes.float64)) /
+ 2, dtypes.int32)
else:
img_wd = float(img_w)
- bbox_w_start = int((img_wd - img_wd * central_fraction) / 2)
+ bbox_w_start = int((img_wd - img_wd * central_fraction_static) / 2)
bbox_h_size = img_h - bbox_h_start * 2
bbox_w_size = img_w - bbox_w_start * 2
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 7b477aa..fdcbc68 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -71,7 +71,7 @@
inp = np.random.rand(*shape).astype(nptype)
# Convert to HSV and back, as a batch and individually
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
batch0 = constant_op.constant(inp)
batch1 = image_ops.rgb_to_hsv(batch0)
batch2 = image_ops.hsv_to_rgb(batch1)
@@ -92,7 +92,7 @@
data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
for nptype in [np.float32, np.float64]:
rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
hsv = image_ops.rgb_to_hsv(rgb_np)
rgb = image_ops.hsv_to_rgb(hsv)
rgb_tf = self.evaluate(rgb)
@@ -113,7 +113,7 @@
inp = np.random.rand(*shape).astype(nptype)
# Convert to YIQ and back, as a batch and individually
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
batch0 = constant_op.constant(inp)
batch1 = image_ops.rgb_to_yiq(batch0)
batch2 = image_ops.yiq_to_rgb(batch1)
@@ -145,7 +145,7 @@
inp = np.random.rand(*shape).astype(nptype)
# Convert to YUV and back, as a batch and individually
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
batch0 = constant_op.constant(inp)
batch1 = image_ops.rgb_to_yuv(batch0)
batch2 = image_ops.yuv_to_rgb(batch1)
@@ -187,7 +187,7 @@
def _TestRGBToGrayscale(self, x_np):
y_np = self._RGBToGrayscale(x_np)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.rgb_to_grayscale(x_tf)
y_tf = self.evaluate(y)
@@ -209,7 +209,7 @@
y_np = np.array(
[[1, 1, 1], [2, 2, 2]], dtype=np.uint8).reshape([1, 1, 2, 3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.grayscale_to_rgb(x_tf)
y_tf = self.evaluate(y)
@@ -219,7 +219,7 @@
x_np = np.array([[1, 2]], dtype=np.uint8).reshape([1, 2, 1])
y_np = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.uint8).reshape([1, 2, 3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.grayscale_to_rgb(x_tf)
y_tf = self.evaluate(y)
@@ -233,7 +233,7 @@
# tests if an exception is raised if a three dimensional
# input is used, i.e. the images have shape [batch size, height, width]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# 3-D input with batch dimension.
x_np = np.array([[1, 2]], dtype=np.uint8).reshape([1, 1, 2])
@@ -246,7 +246,7 @@
# tests if an exception is raised if a two dimensional
# input is used, i.e. the images have shape [height, width]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# 1-D input without batch dimension.
x_np = np.array([[1, 2]], dtype=np.uint8).reshape([2])
@@ -263,23 +263,23 @@
# Shape inference works and produces expected output where possible
rgb_shape = [7, None, 19, 3]
gray_shape = rgb_shape[:-1] + [1]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
rgb_tf = array_ops.placeholder(dtypes.uint8, shape=rgb_shape)
gray = image_ops.rgb_to_grayscale(rgb_tf)
self.assertEqual(gray_shape, gray.get_shape().as_list())
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
gray_tf = array_ops.placeholder(dtypes.uint8, shape=gray_shape)
rgb = image_ops.grayscale_to_rgb(gray_tf)
self.assertEqual(rgb_shape, rgb.get_shape().as_list())
# Shape inference does not break for unknown shapes
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
rgb_tf_unknown = array_ops.placeholder(dtypes.uint8)
gray_unknown = image_ops.rgb_to_grayscale(rgb_tf_unknown)
self.assertFalse(gray_unknown.get_shape())
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
gray_tf_unknown = array_ops.placeholder(dtypes.uint8)
rgb_unknown = image_ops.grayscale_to_rgb(gray_tf_unknown)
self.assertFalse(rgb_unknown.get_shape())
@@ -424,7 +424,7 @@
y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_hue(x, delta)
y_tf = self.evaluate(y)
@@ -439,7 +439,7 @@
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_hue(x, delta)
y_tf = self.evaluate(y)
@@ -454,7 +454,7 @@
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_hue(x, delta)
y_tf = self.evaluate(y)
@@ -479,7 +479,7 @@
return y_v.reshape(x_np.shape)
def _adjustHueTf(self, x_np, delta_h):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np)
y = image_ops.adjust_hue(x, delta_h)
y_tf = self.evaluate(y)
@@ -910,7 +910,7 @@
y_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_saturation(x, saturation_factor)
y_tf = self.evaluate(y)
@@ -925,7 +925,7 @@
y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_saturation(x, saturation_factor)
y_tf = self.evaluate(y)
@@ -940,7 +940,7 @@
y_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_saturation(x, saturation_factor)
y_tf = self.evaluate(y)
@@ -979,7 +979,7 @@
"gb_same",
"rgb_same",
]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
for x_shape in x_shapes:
for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255.
@@ -1007,7 +1007,7 @@
def testInvolutionLeftRight(self):
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf))
y_tf = self.evaluate(y)
@@ -1017,7 +1017,7 @@
x_np = np.array(
[[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf))
y_tf = self.evaluate(y)
@@ -1027,7 +1027,7 @@
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_left_right(x_tf)
y_tf = self.evaluate(y)
@@ -1041,7 +1041,7 @@
[[[3, 2, 1], [3, 2, 1]], [[3, 2, 1], [3, 2, 1]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_left_right(x_tf)
y_tf = self.evaluate(y)
@@ -1054,7 +1054,7 @@
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
seed = 42
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.random_flip_left_right(x_tf, seed=seed)
self.assertTrue(y.op.name.startswith("random_flip_left_right"))
@@ -1081,7 +1081,7 @@
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
count_flipped = 0
count_unflipped = 0
@@ -1216,7 +1216,7 @@
x_np = np.vstack([x_np_raw for _ in range(batch_size)])
y_np = np.vstack([y_np_raw for _ in range(batch_size)])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
count_flipped = 0
count_unflipped = 0
@@ -1238,7 +1238,7 @@
def testInvolutionUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf))
y_tf = self.evaluate(y)
@@ -1249,7 +1249,7 @@
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf))
y_tf = self.evaluate(y)
@@ -1259,7 +1259,7 @@
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_up_down(x_tf)
y_tf = self.evaluate(y)
@@ -1273,7 +1273,7 @@
[[[4, 5, 6], [1, 2, 3]], [[10, 11, 12], [7, 8, 9]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_up_down(x_tf)
y_tf = self.evaluate(y)
@@ -1286,7 +1286,7 @@
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
seed = 42
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.random_flip_up_down(x_tf, seed=seed)
self.assertTrue(y.op.name.startswith("random_flip_up_down"))
@@ -1312,7 +1312,7 @@
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
count_flipped = 0
count_unflipped = 0
@@ -1344,7 +1344,7 @@
x_np = np.vstack([x_np_raw for _ in range(batch_size)])
y_np = np.vstack([y_np_raw for _ in range(batch_size)])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
count_flipped = 0
count_unflipped = 0
@@ -1366,7 +1366,7 @@
def testInvolutionTranspose(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.transpose(image_ops.transpose(x_tf))
y_tf = self.evaluate(y)
@@ -1377,7 +1377,7 @@
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.transpose(image_ops.transpose(x_tf))
y_tf = self.evaluate(y)
@@ -1387,7 +1387,7 @@
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.uint8).reshape([3, 2, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.transpose(x_tf)
y_tf = self.evaluate(y)
@@ -1402,7 +1402,7 @@
[[[1, 4], [2, 5], [3, 6]], [[7, 10], [8, 11], [9, 12]]],
dtype=np.uint8).reshape([2, 3, 2, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.transpose(x_tf)
y_tf = self.evaluate(y)
@@ -1454,7 +1454,7 @@
def testRot90GroupOrder(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
rotated = image
for _ in xrange(4):
rotated = image_ops.rot90(rotated)
@@ -1462,7 +1462,7 @@
def testRot90GroupOrderWithBatch(self):
image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
rotated = image
for _ in xrange(4):
rotated = image_ops.rot90(rotated)
@@ -1470,7 +1470,7 @@
def testRot90NumpyEquivalence(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
for k in xrange(4):
y_np = np.rot90(image, k=k)
self.assertAllEqual(
@@ -1478,7 +1478,7 @@
def testRot90NumpyEquivalenceWithBatch(self):
image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
for k in xrange(4):
y_np = np.rot90(image, k=k, axes=(1, 2))
self.assertAllEqual(
@@ -1507,7 +1507,7 @@
class AdjustContrastTest(test_util.TensorFlowTestCase):
def _testContrast(self, x_np, y_np, contrast_factor):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.adjust_contrast(x, contrast_factor)
y_tf = self.evaluate(y)
@@ -1562,7 +1562,7 @@
return y_np
def _adjustContrastTf(self, x_np, contrast_factor):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np)
y = image_ops.adjust_contrast(x, contrast_factor)
y_tf = self.evaluate(y)
@@ -1596,7 +1596,7 @@
class AdjustBrightnessTest(test_util.TensorFlowTestCase):
def _testBrightness(self, x_np, y_np, delta, tol=1e-6):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.adjust_brightness(x, delta)
y_tf = self.evaluate(y)
@@ -1668,7 +1668,7 @@
x_np = np.arange(0, np.prod(x_shape), dtype=data_type).reshape(x_shape)
y_np = self._NumpyPerImageWhitening(x_np)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, dtype=data_type, shape=x_shape)
y = image_ops.per_image_standardization(x)
y_tf = self.evaluate(y)
@@ -1678,14 +1678,14 @@
im_np = np.ones([19, 19, 3]).astype(np.float32) * 249
im = constant_op.constant(im_np)
whiten = image_ops.per_image_standardization(im)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
whiten_np = self.evaluate(whiten)
self.assertFalse(np.any(np.isnan(whiten_np)))
def testBatchWhitening(self):
imgs_np = np.random.uniform(0., 255., [4, 24, 24, 3])
whiten_np = [self._NumpyPerImageWhitening(img) for img in imgs_np]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
imgs = constant_op.constant(imgs_np)
whiten = image_ops.per_image_standardization(imgs)
whiten_tf = self.evaluate(whiten)
@@ -1709,7 +1709,7 @@
y = image_ops.crop_to_bounding_box(x_tensor, offset_height, offset_width,
target_height, target_width)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
return self.evaluate(y)
def _assertReturns(self,
@@ -1910,7 +1910,7 @@
dtype=np.int32).reshape(x_shape)
y_np = np.array([[[3, 4, 5, 6], [3, 4, 5, 6]],
[[6, 5, 4, 3], [6, 5, 4, 3]]]).reshape([2, 2, 4, 1])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.central_crop(x, 0.5)
y_tf = self.evaluate(y)
@@ -2003,6 +2003,21 @@
y = image_ops.central_crop(x_np, 1.0)
self.assertTrue(y.op.name.startswith("central_crop"))
+ def testCentralFractionTensor(self):
+ # Test case for GitHub issue 45324.
+ x_shape = [240, 320, 3]
+ y_shape = [80, 106, 3]
+
+ @def_function.function(autograph=False)
+ def f(x, central_fraction):
+ return image_ops.central_crop(x, central_fraction)
+
+ x_np = np.zeros(x_shape, dtype=np.int32)
+ y_np = np.zeros(y_shape, dtype=np.int32)
+ y_tf = self.evaluate(f(x_np, constant_op.constant(0.33)))
+ self.assertAllEqual(y_tf, y_np)
+ self.assertAllEqual(y_tf.shape, y_np.shape)
+
class PadToBoundingBoxTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@@ -2022,7 +2037,7 @@
def pad_bbox(*args):
return image_ops.pad_to_bounding_box(*args)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
return self.evaluate(pad_bbox(x_tensor, offset_height, offset_width,
target_height, target_width))
@@ -2079,7 +2094,7 @@
i = constant_op.constant([1, 0, 4, 3], dtype=dtypes.int64)
y_tf = image_ops.pad_to_bounding_box(x, i[0], i[1], i[2], i[3])
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(y, self.evaluate(y_tf))
def testNoOp(self):
@@ -2259,7 +2274,7 @@
fraction_object_covered = []
num_iter = 1000
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image_tf = constant_op.constant(image, shape=image.shape)
image_size_tf = constant_op.constant(
image_size_np, shape=image_size_np.shape)
@@ -2386,7 +2401,7 @@
def testSampleDistortedBoundingBoxShape(self):
# Shape function requires placeholders and a graph.
with ops.Graph().as_default():
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image_size = constant_op.constant(
[40, 50, 1], shape=[3], dtype=dtypes.int32)
bounding_box = constant_op.constant(
@@ -2424,7 +2439,7 @@
def testDefaultMinObjectCovered(self):
# By default min_object_covered=0.1 if not provided
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image_size = constant_op.constant(
[40, 50, 1], shape=[3], dtype=dtypes.int32)
bounding_box = constant_op.constant(
@@ -2651,7 +2666,7 @@
img_np = np.array(data, dtype=nptype).reshape(img_shape)
for method in self.METHODS:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images_v2(image, [target_height, target_width],
method)
@@ -2662,7 +2677,7 @@
self.assertAllClose(resized, img_np, atol=1e-5)
# Resizing with a single image must leave the shape unchanged also.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img_single = img_np.reshape(single_shape)
image = constant_op.constant(img_single, shape=single_shape)
y = image_ops.resize_images_v2(image, [target_height, target_width],
@@ -2688,7 +2703,7 @@
img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
for method in self.METHODS:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = resize_func(image, [6, 4], method)
yshape = array_ops.shape(y)
@@ -2698,7 +2713,7 @@
self.assertAllClose(resized, img_np, atol=1e-5)
# Resizing with a single image must leave the shape unchanged also.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img_single = img_np.reshape(single_shape)
image = constant_op.constant(img_single, shape=single_shape)
y = resize_func(image, [6, 4], self.METHODS[0])
@@ -2831,7 +2846,7 @@
for method in self.METHODS:
if test.is_gpu_available() and self.shouldRunOnGPU(method, nptype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images_v2(
image, [target_height, target_width], method)
@@ -2888,7 +2903,7 @@
]
for nptype in self.TYPES:
for method in expected_data:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img_np = np.array(data, dtype=nptype).reshape(img_shape)
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images_v2(image, [target_height, target_width],
@@ -2908,7 +2923,7 @@
methods_to_test = ((gen_image_ops.resize_bilinear, "triangle"),
(gen_image_ops.resize_bicubic, "keyscubic"))
for legacy_method, new_method in methods_to_test:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img_np = np.array(data, dtype=np.float32).reshape(img_shape)
image = constant_op.constant(img_np, shape=img_shape)
legacy_result = legacy_method(
@@ -2945,7 +2960,7 @@
73, 33, 23, 39, 73, 33, 23, 39, 14, 16, 19, 21, 14, 16, 19, 21
]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images_v2(image, [target_height, target_width],
image_ops.ResizeMethod.AREA)
@@ -2963,7 +2978,7 @@
for nptype in [np.float32, np.float64]:
img_np = np.arange(
0, np.prod(input_shape), dtype=nptype).reshape(input_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=input_shape)
new_size = constant_op.constant([target_height, target_width])
out_op = image_ops.resize_images_v2(
@@ -3039,7 +3054,7 @@
def testNameScope(self):
# Testing name scope requires placeholders and a graph.
with ops.Graph().as_default():
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
single_image = array_ops.placeholder(dtypes.float32, shape=[50, 60, 3])
y = image_ops.resize_images(single_image, [55, 66])
self.assertTrue(y.op.name.startswith("resize"))
@@ -3060,7 +3075,7 @@
t, ops.convert_to_tensor(target_max),
preserve_aspect_ratio=preserve_aspect_ratio)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
return self.evaluate(resize_func(x_tensor))
def _assertResizeEqual(self,
@@ -3199,7 +3214,7 @@
img_np = np.array(data, dtype=nptype).reshape(img_shape)
for method in self.METHODS:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width],
method)
@@ -3209,7 +3224,7 @@
self.assertAllClose(resized, img_np, atol=1e-5)
# Resizing with a single image must leave the shape unchanged also.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img_single = img_np.reshape(single_shape)
image = constant_op.constant(img_single, shape=single_shape)
y = image_ops.resize_images(image, [target_height, target_width],
@@ -3234,7 +3249,7 @@
img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
for method in self.METHODS:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = resize_func(image, [6, 4], method)
yshape = array_ops.shape(y)
@@ -3243,7 +3258,7 @@
self.assertAllClose(resized, img_np, atol=1e-5)
# Resizing with a single image must leave the shape unchanged also.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img_single = img_np.reshape(single_shape)
image = constant_op.constant(img_single, shape=single_shape)
y = resize_func(image, [6, 4], self.METHODS[0])
@@ -3374,7 +3389,7 @@
for method in self.METHODS:
if test.is_gpu_available() and self.shouldRunOnGPU(method, nptype):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width],
method)
@@ -3411,7 +3426,7 @@
image_ops.ResizeMethodV1.NEAREST_NEIGHBOR,
image_ops.ResizeMethodV1.AREA
]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img_np = np.array(data, dtype=nptype).reshape(img_shape)
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(
@@ -3448,7 +3463,7 @@
image_ops.ResizeMethodV1.NEAREST_NEIGHBOR,
image_ops.ResizeMethodV1.AREA
]:
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img_np = np.array(data, dtype=nptype).reshape(img_shape)
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(
@@ -3476,7 +3491,7 @@
75, 81, 80, 72, 69, 70, 105, 112, 75, 36, 45, 92, 111, 105
]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width],
image_ops.ResizeMethodV1.BICUBIC)
@@ -3499,7 +3514,7 @@
73, 33, 23, 39, 73, 33, 23, 39, 14, 16, 19, 21, 14, 16, 19, 21
]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width],
image_ops.ResizeMethodV1.AREA)
@@ -3518,7 +3533,7 @@
for align_corners in [True, False]:
img_np = np.arange(
0, np.prod(input_shape), dtype=nptype).reshape(input_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(img_np, shape=input_shape)
new_size = constant_op.constant([target_height, target_width])
out_op = image_ops.resize_images(
@@ -3586,7 +3601,7 @@
# Testing name scope requires placeholders and a graph.
with ops.Graph().as_default():
img_shape = [1, 3, 2, 1]
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
single_image = array_ops.placeholder(dtypes.float32, shape=[50, 60, 3])
y = image_ops.resize_images(single_image, [55, 66])
self.assertTrue(y.op.name.startswith("resize"))
@@ -3603,7 +3618,7 @@
y = image_ops.resize_images(
x_tensor, target_max, preserve_aspect_ratio=preserve_aspect_ratio)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
return self.evaluate(y)
def _assertResizeEqual(self, x, x_shape, y, y_shape,
@@ -3687,7 +3702,7 @@
else:
x_tensor = x
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
return self.evaluate(
image_ops.resize_image_with_pad_v1(x_tensor, target_height,
target_width))
@@ -3807,7 +3822,7 @@
else:
x_tensor = x
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
return self.evaluate(
image_ops.resize_image_with_pad_v2(x_tensor, target_height,
target_width))
@@ -3929,7 +3944,7 @@
def resize_crop_or_pad(*args):
return image_ops.resize_image_with_crop_or_pad(*args)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
return self.evaluate(
resize_crop_or_pad(x_tensor, target_height, target_width))
@@ -4176,7 +4191,7 @@
# Read a real jpeg and verify shape
path = ("tensorflow/core/lib/jpeg/testdata/"
"jpeg_merge_test1.jpg")
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
jpeg0 = io_ops.read_file(path)
image0 = image_ops.decode_jpeg(jpeg0)
image1 = image_ops.decode_jpeg(image_ops.encode_jpeg(image0))
@@ -4192,7 +4207,7 @@
cmyk_path = os.path.join(base, "jpeg_merge_test1_cmyk.jpg")
shape = 256, 128, 3
for channels in 3, 0:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
rgb = image_ops.decode_jpeg(
io_ops.read_file(rgb_path), channels=channels)
cmyk = image_ops.decode_jpeg(
@@ -4248,7 +4263,7 @@
self.evaluate(result)
def testSynthetic(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
# Encode it, then decode it, then encode it
image0 = constant_op.constant(simple_color_ramp())
jpeg0 = image_ops.encode_jpeg(image0)
@@ -4269,7 +4284,7 @@
self.assertLessEqual(len(jpeg0), 6000)
def testSyntheticFasterAlgorithm(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
# Encode it, then decode it, then encode it
image0 = constant_op.constant(simple_color_ramp())
jpeg0 = image_ops.encode_jpeg(image0)
@@ -4293,7 +4308,7 @@
self.assertLessEqual(len(jpeg0), 6000)
def testDefaultDCTMethodIsIntegerFast(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
# Compare decoding with both dct_option=INTEGER_FAST and
# default. They should be the same.
image0 = constant_op.constant(simple_color_ramp())
@@ -4308,7 +4323,7 @@
def testShape(self):
# Shape function requires placeholders and a graph.
with ops.Graph().as_default():
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
jpeg = constant_op.constant("nonsense")
for channels in 0, 1, 3:
image = image_ops.decode_jpeg(jpeg, channels=channels)
@@ -4319,7 +4334,7 @@
# Read a real jpeg and verify shape.
path = ("tensorflow/core/lib/jpeg/testdata/"
"jpeg_merge_test1.jpg")
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
jpeg = io_ops.read_file(path)
# Extract shape without decoding.
image_shape = self.evaluate(image_ops.extract_jpeg_shape(jpeg))
@@ -4329,7 +4344,7 @@
# Read a cmyk jpeg image, and verify its shape.
path = ("tensorflow/core/lib/jpeg/testdata/"
"jpeg_merge_test1_cmyk.jpg")
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
jpeg = io_ops.read_file(path)
image_shape = self.evaluate(image_ops.extract_jpeg_shape(jpeg))
# Cmyk jpeg image has 4 channels.
@@ -4346,7 +4361,7 @@
jpeg = io_ops.read_file(path)
image = image_ops.decode_jpeg(jpeg)
random_jpeg_image = image_ops.random_jpeg_quality(image, 40, 100)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
# Test randomization.
random_jpeg_images = [sess.run(random_jpeg_image) for _ in range(5)]
are_images_equal = []
@@ -4398,11 +4413,11 @@
image = image_ops.decode_jpeg(jpeg)
adjust_jpeg_quality_image = image_ops.adjust_jpeg_quality(
image, jpeg_quality)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
sess.run(adjust_jpeg_quality_image)
def testAdjustJpegQualityShape(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(
np.arange(24, dtype=np.uint8).reshape([2, 4, 3]))
adjusted_image = image_ops.adjust_jpeg_quality(image, 80)
@@ -4418,7 +4433,7 @@
(3, "lena_palette.png"), (4, "lena_palette_trns.png"))
for channels_in, filename in inputs:
for channels in 0, 1, 3, 4:
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
png0 = io_ops.read_file(prefix + filename)
image0 = image_ops.decode_png(png0, channels=channels)
png0, image0 = self.evaluate([png0, image0])
@@ -4428,7 +4443,7 @@
self.assertAllEqual(image0, self.evaluate(image1))
def testSynthetic(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
# Encode it, then decode it
image0 = constant_op.constant(simple_color_ramp())
png0 = image_ops.encode_png(image0, compression=7)
@@ -4443,7 +4458,7 @@
self.assertLessEqual(len(png0), 750)
def testSyntheticUint16(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
# Encode it, then decode it
image0 = constant_op.constant(simple_color_ramp(), dtype=dtypes.uint16)
png0 = image_ops.encode_png(image0, compression=7)
@@ -4458,7 +4473,7 @@
self.assertLessEqual(len(png0), 1500)
def testSyntheticTwoChannel(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
# Strip the b channel from an rgb image to get a two-channel image.
gray_alpha = simple_color_ramp()[:, :, 0:2]
image0 = constant_op.constant(gray_alpha)
@@ -4469,7 +4484,7 @@
self.assertAllEqual(image0, image1)
def testSyntheticTwoChannelUint16(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
# Strip the b channel from an rgb image to get a two-channel image.
gray_alpha = simple_color_ramp()[:, :, 0:2]
image0 = constant_op.constant(gray_alpha, dtype=dtypes.uint16)
@@ -4482,7 +4497,7 @@
def testShape(self):
# Shape function requires placeholders and a graph.
with ops.Graph().as_default():
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
png = constant_op.constant("nonsense")
for channels in 0, 1, 3:
image = image_ops.decode_png(png, channels=channels)
@@ -4500,7 +4515,7 @@
STRIDE = 5
shape = (12, HEIGHT, WIDTH, 3)
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
gif0 = io_ops.read_file(prefix + filename)
image0 = image_ops.decode_gif(gif0)
gif0, image0 = self.evaluate([gif0, image0])
@@ -4528,14 +4543,14 @@
def testShape(self):
# Shape function requires placeholders and a graph.
with ops.Graph().as_default():
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
gif = constant_op.constant("nonsense")
image = image_ops.decode_gif(gif)
self.assertEqual(image.get_shape().as_list(), [None, None, None, 3])
def testAnimatedGif(self):
# Test if all frames in the animated GIF file is properly decoded.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
base = "tensorflow/core/lib/gif/testdata"
gif = io_ops.read_file(os.path.join(base, "pendulum_sm.gif"))
gt_frame0 = io_ops.read_file(os.path.join(base, "pendulum_sm_frame0.png"))
@@ -4560,7 +4575,7 @@
x_np = np.array(original, dtype=original_dtype.as_numpy_dtype())
y_np = np.array(expected, dtype=output_dtype.as_numpy_dtype())
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant(x_np)
y = image_ops.convert_image_dtype(image, output_dtype)
self.assertTrue(y.dtype == output_dtype)
@@ -4577,7 +4592,7 @@
# Tests with Tensor.op requires a graph.
with ops.Graph().as_default():
# Make sure converting to the same data type creates only an identity op
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
image = constant_op.constant([1], dtype=dtypes.uint8)
image_ops.convert_image_dtype(image, dtypes.uint8)
y = image_ops.convert_image_dtype(image, dtypes.uint8)
@@ -4586,7 +4601,7 @@
def testConvertBetweenInteger(self):
# Make sure converting to between integer types scales appropriately
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self._convert([0, 255], dtypes.uint8, dtypes.int16, [0, 255 * 128])
self._convert([0, 32767], dtypes.int16, dtypes.uint8, [0, 255])
self._convert([0, 2**32], dtypes.int64, dtypes.int32, [0, 1])
@@ -4594,7 +4609,7 @@
def testConvertBetweenFloat(self):
# Make sure converting to between float types does nothing interesting
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self._convert([-1.0, 0, 1.0, 200000], dtypes.float32, dtypes.float64,
[-1.0, 0, 1.0, 200000])
self._convert([-1.0, 0, 1.0, 200000], dtypes.float64, dtypes.float32,
@@ -4602,14 +4617,14 @@
def testConvertBetweenIntegerAndFloat(self):
# Make sure converting from and to a float type scales appropriately
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self._convert([0, 1, 255], dtypes.uint8, dtypes.float32,
[0, 1.0 / 255.0, 1])
self._convert([0, 1.1 / 255.0, 1], dtypes.float32, dtypes.uint8,
[0, 1, 255])
def testConvertBetweenInt16AndInt8(self):
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# uint8, uint16
self._convert([0, 255 * 256], dtypes.uint16, dtypes.uint8, [0, 255])
self._convert([0, 255], dtypes.uint8, dtypes.uint16, [0, 255 * 256])
@@ -4640,7 +4655,7 @@
"""
# Create a TensorFlow session.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
# Add a constant to the TensorFlow graph that holds the input.
x_tf = constant_op.constant(x_np, shape=x_np.shape)
@@ -5256,7 +5271,7 @@
img = array_ops.placeholder(dtype=dtypes.float32)
img_np = np.array((2, 2))
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img, img)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(checks, {img: img_np})
@@ -5270,7 +5285,7 @@
img1_np = np.array([1, 2, 2, 1])
img2_np = np.array([1, 3, 3, 1])
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
_, _, checks = image_ops_impl._verify_compatible_image_shapes(
img1, img2)
with self.assertRaises(errors.InvalidArgumentError):
@@ -5289,7 +5304,7 @@
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
q20 = self._LoadTestImage(sess, "cat_q20.jpg")
q72 = self._LoadTestImage(sess, "cat_q72.jpg")
q95 = self._LoadTestImage(sess, "cat_q95.jpg")
@@ -5309,7 +5324,7 @@
image2 = self._RandomImage((8, 8, 1), 1)
psnr = self._PSNR_NumPy(image1, image2, 1)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_image1 = constant_op.constant(image1, shape=image1.shape,
dtype=dtypes.float32)
tf_image2 = constant_op.constant(image2, shape=image2.shape,
@@ -5322,7 +5337,7 @@
image2 = self._RandomImage((10, 8, 8, 1), 1)
psnr = self._PSNR_NumPy(image1, image2, 1)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_image1 = constant_op.constant(image1, shape=image1.shape,
dtype=dtypes.float32)
tf_image2 = constant_op.constant(image2, shape=image2.shape,
@@ -5343,7 +5358,7 @@
self.assertNear(35.302, psnr3, 0.001)
# Test TensorFlow implementation.
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_q20 = constant_op.constant(q20, shape=q20.shape, dtype=dtypes.float32)
tf_q72 = constant_op.constant(q72, shape=q72.shape, dtype=dtypes.float32)
tf_q95 = constant_op.constant(q95, shape=q95.shape, dtype=dtypes.float32)
@@ -5357,7 +5372,7 @@
def testInfinity(self):
q20, _, _ = self._LoadTestImages()
psnr = self._PSNR_NumPy(q20, q20, 1)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
tf_q20 = constant_op.constant(q20, shape=q20.shape, dtype=dtypes.float32)
tf_psnr = self.evaluate(image_ops.psnr(tf_q20, tf_q20, 1, "psnr"))
self.assertAllClose(psnr, tf_psnr, atol=0.001)
@@ -5371,7 +5386,7 @@
img1 = image_ops.convert_image_dtype(img1, dtypes.float32)
img2 = image_ops.convert_image_dtype(img2, dtypes.float32)
psnr_float32 = image_ops.psnr(img1, img2, 1.0)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(
self.evaluate(psnr_uint8), self.evaluate(psnr_float32), atol=0.001)
@@ -5396,7 +5411,7 @@
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
return [self._LoadTestImage(sess, f) for f in self._filenames]
def _RandomImage(self, shape, max_val):
@@ -5412,7 +5427,7 @@
return image_ops.ssim(
*x, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
scores = [
self.evaluate(ssim_func(t))
for t in itertools.combinations_with_replacement(img, 2)
@@ -5436,7 +5451,7 @@
filter_sigma=1.5,
k1=0.01,
k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
def testBatchNumpyInputs(self):
@@ -5447,7 +5462,7 @@
img1 = np.concatenate(img1)
img2 = np.concatenate(img2)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
img1 = self.evaluate(constant_op.constant(img1))
img2 = self.evaluate(constant_op.constant(img2))
@@ -5459,7 +5474,7 @@
filter_sigma=1.5,
k1=0.01,
k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
def testBroadcast(self):
@@ -5472,7 +5487,7 @@
ssim = image_ops.ssim(
img1, img2, 1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
def testNegative(self):
@@ -5492,7 +5507,7 @@
filter_sigma=1.5,
k1=0.01,
k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertLess(self.evaluate(ssim), 0)
def testInt(self):
@@ -5506,7 +5521,7 @@
img2 = image_ops.convert_image_dtype(img2, dtypes.float32)
ssim_float32 = image_ops.ssim(
img1, img2, 1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(
self.evaluate(ssim_uint8), self.evaluate(ssim_float32), atol=0.001)
@@ -5531,7 +5546,7 @@
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
return [self._LoadTestImage(sess, f) for f in self._filenames]
def _RandomImage(self, shape, max_val):
@@ -5550,7 +5565,7 @@
return image_ops.ssim_multiscale(
*x, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
scores = [
self.evaluate(ssim_func(t))
for t in itertools.combinations_with_replacement(img, 2)
@@ -5627,7 +5642,7 @@
filter_sigma=1.5,
k1=0.01,
k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(expected, self.evaluate(msssim), 1e-4)
def testBroadcast(self):
@@ -5641,7 +5656,7 @@
score_tensor = image_ops.ssim_multiscale(
img1, img2, 1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(expected, self.evaluate(score_tensor), 1e-4)
def testRange(self):
@@ -5651,7 +5666,7 @@
If any of the value is negative so that the geometric mean is not
well-defined, then treat the MS-SSIM score as zero.
"""
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session() as sess:
img1 = self._LoadTestImage(sess, "checkerboard1.png")
img2 = self._LoadTestImage(sess, "checkerboard3.png")
images = [img1, img2, np.zeros_like(img1),
@@ -5680,7 +5695,7 @@
img2 = image_ops.convert_image_dtype(img2, dtypes.float32)
ssim_float32 = image_ops.ssim_multiscale(
img1, img2, 1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
self.assertAllClose(
self.evaluate(ssim_uint8), self.evaluate(ssim_float32), atol=0.001)
@@ -5688,7 +5703,7 @@
"""Test case for GitHub issue 28241."""
image = np.random.random([512, 512, 1])
score_tensor = image_ops.ssim_multiscale(image, image, max_val=1.0)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
_ = self.evaluate(score_tensor)
@@ -5728,7 +5743,7 @@
batch = constant_op.constant(batch)
assert batch.get_shape().as_list() == [2, 2, 3, 2]
dy, dx = image_ops.image_gradients(batch)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
actual_dy = self.evaluate(dy)
actual_dx = self.evaluate(dx)
self.assertAllClose(expected_dy, actual_dy)
@@ -5749,7 +5764,7 @@
expected = np.reshape([[[0, 0], [0, 12], [0, 0]],
[[0, 0], [0, 12], [0, 0]]], [1, 2, 3, 1, 2])
sobel = image_ops.sobel_edges(img)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
actual_sobel = self.evaluate(sobel)
self.assertAllClose(expected, actual_sobel)
@@ -5771,7 +5786,7 @@
expected_batch = np.concatenate([expected_two_channel] * batch_size, axis=0)
sobel = image_ops.sobel_edges(img)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
actual_sobel = self.evaluate(sobel)
self.assertAllClose(expected_batch, actual_sobel)
@@ -5842,7 +5857,7 @@
def testJpegUint16(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
@@ -5854,7 +5869,7 @@
def testPngUint16(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
@@ -5873,7 +5888,7 @@
def testGifUint16(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
@@ -5885,7 +5900,7 @@
def testBmpUint16(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
@@ -5897,7 +5912,7 @@
def testJpegFloat32(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
@@ -5909,7 +5924,7 @@
def testPngFloat32(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
@@ -5921,7 +5936,7 @@
def testGifFloat32(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
@@ -5933,7 +5948,7 @@
def testBmpFloat32(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
@@ -5945,7 +5960,7 @@
def testExpandAnimations(self):
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
- with self.cached_session(use_gpu=True) as sess:
+ with self.cached_session():
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 63773ee..c5b9eea 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -647,7 +647,8 @@
value_index,
vocab_size=None,
delimiter="\t",
- name=None):
+ name=None,
+ value_index_offset=0):
"""Constructs a table initializer object to populate from a text file.
It generates one key-value pair per line. The type of table key and
@@ -675,6 +676,13 @@
vocab_size: The number of elements in the file, if known.
delimiter: The delimiter to separate fields in a line.
name: A name for the operation (optional).
+ value_index_offset: A number to add to all indices extracted from the file
+ This is useful for cases where a user would like to reserve one or more
+ low index values for control characters. For instance, if you would
+ like to ensure that no vocabulary item is mapped to index 0 (so you can
+ reserve 0 for a masking value), you can set value_index_offset to 1;
+ this will mean that the first vocabulary element is mapped to 1
+ instead of 0.
Raises:
ValueError: when the filename is empty, or when the table key and value
@@ -718,6 +726,7 @@
self._name = name
self._filename = self._track_trackable(
trackable.Asset(filename), "_filename")
+ self._offset = value_index_offset
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
@@ -740,7 +749,8 @@
self._filename, dtypes.string, name="asset_filepath")
init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
table.resource_handle, filename, self._key_index, self._value_index,
- -1 if self._vocab_size is None else self._vocab_size, self._delimiter)
+ -1 if self._vocab_size is None else self._vocab_size, self._delimiter,
+ self._offset)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
# If the filename tensor is anything other than a string constant (e.g.,
# if it is a placeholder) then it does not make sense to track it as an
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index bbd30ef..773084c 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -46,7 +46,7 @@
l = np.random.randn(*left_shape)
r = np.random.randn(*right_shape)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
left_tensor = constant_op.constant(l, shape=left_shape)
right_tensor = constant_op.constant(r, shape=right_shape)
output = math_ops.squared_difference(left_tensor, right_tensor)
@@ -83,7 +83,7 @@
self._biasedRandN(
shape, bias=bias), dtype=dtype)
- with self.cached_session(use_gpu=True):
+ with self.cached_session():
output = math_ops.abs(value)
error = gradient_checker.compute_gradient_error(
value, shape, output, output.get_shape().as_list())
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 1cdc901..131dab9 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -4791,8 +4791,9 @@
evaluated using Horner's method, i.e.
- `p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1]
- + x * coeffs[0]))`
+ ```python
+ p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1] + x * coeffs[0]))
+ ```
Usage Example:
diff --git a/tensorflow/python/ops/nccl_ops_test.py b/tensorflow/python/ops/nccl_ops_test.py
index 5b3e3e6..239ef1a 100644
--- a/tensorflow/python/ops/nccl_ops_test.py
+++ b/tensorflow/python/ops/nccl_ops_test.py
@@ -76,7 +76,7 @@
for dtype in [np.float16, np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
- with self.test_session(use_gpu=True) as sess:
+ with self.test_session():
for devices in device_sets:
shape = (3, 4)
diff --git a/tensorflow/python/ops/numpy_ops/np_interop_test.py b/tensorflow/python/ops/numpy_ops/np_interop_test.py
index d265b5e..73c0dab 100644
--- a/tensorflow/python/ops/numpy_ops/np_interop_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_interop_test.py
@@ -22,6 +22,7 @@
import tensorflow.compat.v2 as tf
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import numpy_ops as np
from tensorflow.python.ops.numpy_ops import np_math_ops
@@ -229,6 +230,7 @@
# self.assertIsInstance(reduced, np.ndarray)
self.assertAllClose(reduced, 15)
+ @test_util.disable_tfrt('b/180469928')
def testPyFuncInterop(self):
def py_func_fn(a, b):
return a + b
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index bc1bb54..cbda91f 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -987,6 +987,7 @@
self._test_loop_fn(loop_fn, 2)
+ @test_util.disable_tfrt("b/180206304")
def test_create_inside_and_read(self):
def loop_fn(i):
diff --git a/tensorflow/python/ops/ragged/ragged_concat_ops.py b/tensorflow/python/ops/ragged/ragged_concat_ops.py
index cd710f4..190a02c 100644
--- a/tensorflow/python/ops/ragged/ragged_concat_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_concat_ops.py
@@ -23,7 +23,6 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
@@ -107,7 +106,9 @@
name: A name prefix for the returned tensor (optional).
Returns:
- A `RaggedTensor` with rank `R+1`.
+ A `RaggedTensor` with rank `R+1` (if `R>0`).
+ If `R==0`, then the result will be returned as a 1D `Tensor`, since
+ `RaggedTensor` can only be used when `rank>1`.
`result.ragged_rank=1+max(axis, max(rt.ragged_rank for rt in values]))`.
Raises:
@@ -148,11 +149,8 @@
rt_inputs = list(rt_inputs)
# Special case: if there's only one input, then return it as-is.
- if len(rt_inputs) == 1:
- if stack_values:
- return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis)
- else:
- return rt_inputs[0]
+ if len(rt_inputs) == 1 and not stack_values:
+ return rt_inputs[0]
# Check the rank (number of dimensions) of the input tensors.
ndims = None
diff --git a/tensorflow/python/ops/ragged/ragged_stack_op_test.py b/tensorflow/python/ops/ragged/ragged_stack_op_test.py
index 6e1db50..2866c7f 100644
--- a/tensorflow/python/ops/ragged/ragged_stack_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_stack_op_test.py
@@ -319,6 +319,26 @@
rt_inputs=([['a00', 'a01'], [], ['a20', 'a21']],),
axis=0,
expected=[[[b'a00', b'a01'], [], [b'a20', b'a21']]]),
+ dict(
+ descr='One input (uniform 0D)',
+ rt_inputs=(1,),
+ ragged_ranks=[0],
+ axis=0,
+ expected=[1]),
+ dict(
+ descr='One input (uniform 1D)',
+ rt_inputs=([1, 2],),
+ ragged_ranks=[0],
+ axis=0,
+ expected=[[1, 2]],
+ expected_ragged_rank=1),
+ dict(
+ descr='One input (uniform 2D)',
+ rt_inputs=([[1, 2], [3, 4], [5, 6]],),
+ ragged_ranks=[0],
+ axis=0,
+ expected=[[[1, 2], [3, 4], [5, 6]]],
+ expected_ragged_rank=2),
) # pyformat: disable
def testRaggedStack(self,
descr,
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index ba184b2..6caeb6b 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -48,7 +48,7 @@
# Should evaluate to 1 and 1/2.
x_one = [1, 1.]
x_one_half = [2, 1.]
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllClose(
1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one))))
self.assertAllClose(
@@ -60,7 +60,7 @@
# Should evaluate to 1 and 1/2.
x_one = [1, 1.]
x_one_half = [2, 1.]
- with self.session(use_gpu=True):
+ with self.session():
ph = array_ops.placeholder(dtypes.float32)
beta_ph = math_ops.exp(special_math_ops.lbeta(ph))
self.assertAllClose(1, beta_ph.eval(feed_dict={ph: x_one}))
@@ -76,7 +76,7 @@
# = Gamma(1) * Gamma(1) * Gamma(1) * Gamma(1) / Gamma(1 + 1 + 1 + 1)
# = 1 / 6
expected_beta_x = 1 / 6 * np.ones((3, 2, 3))
- with self.session(use_gpu=True):
+ with self.session():
x_ph = array_ops.placeholder(dtypes.float32, [3, 2, 3, None])
beta_ph = math_ops.exp(special_math_ops.lbeta(x_ph))
self.assertAllClose(expected_beta_x,
@@ -86,7 +86,7 @@
def test_two_dimensional_arg(self):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllClose(
[0.5, 0.5],
self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
@@ -96,7 +96,7 @@
def test_two_dimensional_arg_dynamic(self):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
- with self.session(use_gpu=True):
+ with self.session():
ph = array_ops.placeholder(dtypes.float32)
beta_ph = math_ops.exp(special_math_ops.lbeta(ph))
self.assertAllClose([0.5, 0.5],
@@ -106,7 +106,7 @@
def test_two_dimensional_proper_shape(self):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllClose(
[0.5, 0.5],
self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
@@ -119,7 +119,7 @@
@test_util.run_in_graph_and_eager_modes
def test_complicated_shape(self):
- with self.session(use_gpu=True):
+ with self.session():
x = ops.convert_to_tensor(np.random.rand(3, 2, 2))
self.assertAllEqual(
(3, 2), self.evaluate(array_ops.shape(special_math_ops.lbeta(x))))
@@ -133,7 +133,7 @@
# as the answer, always.
x_a = [5.5]
x_b = [0.1]
- with self.session(use_gpu=True):
+ with self.session():
self.assertAllClose(
1,
self.evaluate(math_ops.exp(special_math_ops.lbeta(x_a))),
@@ -144,7 +144,7 @@
@test_util.run_in_graph_and_eager_modes
def test_empty_rank1_returns_negative_infinity(self):
- with self.session(use_gpu=True):
+ with self.session():
x = constant_op.constant([], shape=[0])
lbeta_x = special_math_ops.lbeta(x)
expected_result = constant_op.constant(-np.inf, shape=())
@@ -155,7 +155,7 @@
@test_util.run_in_graph_and_eager_modes
def test_empty_rank2_with_zero_last_dim_returns_negative_infinity(self):
- with self.session(use_gpu=True):
+ with self.session():
event_size = 0
for batch_size in [0, 1, 2]:
x = constant_op.constant([], shape=[batch_size, event_size])
@@ -168,7 +168,7 @@
@test_util.run_in_graph_and_eager_modes
def test_empty_rank2_with_zero_batch_dim_returns_empty(self):
- with self.session(use_gpu=True):
+ with self.session():
batch_size = 0
for event_size in [0, 1, 2]:
x = constant_op.constant([], shape=[batch_size, event_size])
diff --git a/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py b/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py
index 7ecad0a..607af47 100644
--- a/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py
+++ b/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py
@@ -65,7 +65,7 @@
@test_util.run_deprecated_v1
def testAddSimpleGPU(self):
np.random.seed(2) # Fix seed to avoid flakiness
- with self.session(use_gpu=True):
+ with self.session():
# a test case for Add operation
size = (2, 3)
x1 = constant_op.constant(2.0, shape=size, name="x1")
@@ -225,7 +225,7 @@
s = label_data.sum(axis=1)
label_data /= s[:, None]
- with self.session(use_gpu=True):
+ with self.session():
# We treat the inputs as "parameters" here
inp = constant_op.constant(
inp_data.tolist(),
diff --git a/tensorflow/python/profiler/internal/python_hooks.cc b/tensorflow/python/profiler/internal/python_hooks.cc
index 5a7aa11..00af408 100644
--- a/tensorflow/python/profiler/internal/python_hooks.cc
+++ b/tensorflow/python/profiler/internal/python_hooks.cc
@@ -88,6 +88,8 @@
} // namespace
+/*static*/ PythonHookContext* PythonHooks::e2e_context_ = nullptr;
+
std::string PythonTraceEntry::Name() const {
std::string event_name;
if (code_object) {
@@ -103,7 +105,7 @@
return singleton;
}
-void PythonHooks::Start(const PythonHooksOptions& options) {
+void PythonHookContext::Start(const PythonHooksOptions& options) {
if (!Py_IsInitialized()) return;
#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 7)
@@ -135,21 +137,24 @@
auto atexit = py::module::import("atexit");
atexit.attr("register")(py::cpp_function([]() {
PythonHooks* singleton = PythonHooks::GetSingleton();
- singleton->Stop();
- singleton->CollectData(&(singleton->end_to_end_xplane_.emplace()));
+ auto e2e_context = singleton->Stop();
+ // Serialize into internal storage before the tracked PyCodeObjects
+ // went out of scope.
+ if (e2e_context) {
+ e2e_context->CollectData(nullptr);
+ PythonHooks::set_e2e_context(e2e_context.release());
+ }
}));
} catch (const py::error_already_set& e) {
LOG(ERROR) << "Can't install atexit handler for e2e mode." << e.what();
}
}
PyGILState_Release(gil_state);
- active_session_ = true;
}
}
-void PythonHooks::Stop() {
+void PythonHookContext::Stop() {
if (!Py_IsInitialized()) return;
- if (!active_session_) return; // Makes sure Stop() can be reentrant.
if (options_.enable_python_traceme || options_.enable_trace_python_function) {
PyGILState_STATE gil_state = PyGILState_Ensure();
if (options_.enable_trace_python_function) {
@@ -159,12 +164,14 @@
EnableTraceMe(false);
}
PyGILState_Release(gil_state);
- active_session_ = false;
}
}
-void PythonHooks::CollectData(XPlane* raw_plane) {
- DCHECK(raw_plane);
+void PythonHookContext::CollectData(XPlane* raw_plane) {
+ if (raw_plane == nullptr) {
+ end_to_end_xplane_.emplace();
+ raw_plane = &*end_to_end_xplane_;
+ }
XPlaneBuilder plane(raw_plane);
for (auto& it : entries_) {
uint64 thread_id = it.first;
@@ -189,7 +196,7 @@
entries_.clear();
}
-void PythonHooks::Finalize(XSpace* space) {
+void PythonHookContext::Finalize(XSpace* space) {
if (space && options_.enable_trace_python_function) {
XPlane* plane =
FindOrAddMutablePlaneWithName(space, kPythonTracerPlaneName);
@@ -237,7 +244,8 @@
ProfileFast(reinterpret_cast<PyFrameObject*>(frame.ptr()), what, arg.ptr());
}
-void PythonHooks::ProfileFast(PyFrameObject* frame, int what, PyObject* arg) {
+void PythonHookContext::ProfileFast(PyFrameObject* frame, int what,
+ PyObject* arg) {
const int64 thread_id = Env::Default()->GetCurrentThreadId();
uint64 now = GetCurrentTimeNanos();
auto& thread_traces = entries_[thread_id];
@@ -293,7 +301,7 @@
}
}
-void PythonHooks::SetProfilerInAllThreads() {
+void PythonHookContext::SetProfilerInAllThreads() {
// We also want any new threads started to use our profiler.
// NOTE: threading does not provide a C API equivalent to
// `threading.setprofile` so we are forced to go via Python to setup the
@@ -301,10 +309,11 @@
// thread we unregister the Python profile function and use
// `PyEval_SetProfile` to register a C profiler which has significantly less
// overhead (>2x faster).
+ PythonHooks* singleton = PythonHooks::GetSingleton();
py::cpp_function callback =
- py::cpp_function([this](const py::object& frame, const string& event,
- const py::object& arg) {
- ProfileSlow(frame, event, arg);
+ py::cpp_function([singleton](const py::object& frame, const string& event,
+ const py::object& arg) {
+ singleton->ProfileSlow(frame, event, arg);
SysSetProfileNone();
PyEval_SetProfile(ProfileFunction<PythonHooks>, nullptr);
});
@@ -324,7 +333,7 @@
PyThreadState_Swap(curr_thread);
}
-void PythonHooks::ClearProfilerInAllThreads() {
+/*static*/ void PythonHookContext::ClearProfilerInAllThreads() {
PyThreadState* curr_thread = PyThreadState_Get();
PyThreadState* next_thread = curr_thread;
while (next_thread != nullptr) {
@@ -339,7 +348,7 @@
ThreadingSetProfile(py::none());
}
-void PythonHooks::EnableTraceMe(bool enable) {
+/*static*/ void PythonHookContext::EnableTraceMe(bool enable) {
const char* kModuleName =
"tensorflow.python.profiler.trace";
try {
diff --git a/tensorflow/python/profiler/internal/python_hooks.h b/tensorflow/python/profiler/internal/python_hooks.h
index b30fcc3..22145d3 100644
--- a/tensorflow/python/profiler/internal/python_hooks.h
+++ b/tensorflow/python/profiler/internal/python_hooks.h
@@ -20,6 +20,7 @@
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/memory/memory.h"
#include "pybind11/cast.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
@@ -83,33 +84,76 @@
std::stack<PythonTraceEntry> active;
};
+class PythonHookContext {
+ public:
+ void Start(const PythonHooksOptions& option);
+ void Stop();
+ void Finalize(XSpace* space);
+ void ProfileFast(PyFrameObject* frame, int what, PyObject* arg);
+
+ private:
+ void CollectData(XPlane* raw_plane);
+ static void EnableTraceMe(bool enable);
+
+ void SetProfilerInAllThreads();
+ static void ClearProfilerInAllThreads();
+
+ void operator=(const PythonHookContext&) = delete;
+ void operator=(PythonHookContext&&) = delete;
+
+ absl::flat_hash_map<int64, PerThreadEvents> entries_;
+ uint64 start_timestamp_ns_;
+ PythonHooksOptions options_;
+ // In end to end mode, Python get uninitialized before Stop()/Finalize(), we
+ // need to buffer the result.
+ absl::optional<XPlane> end_to_end_xplane_;
+};
+
// Singleton for tracing python function calls.
class PythonHooks {
public:
static PythonHooks* GetSingleton();
- void Start(const PythonHooksOptions& option);
- void Stop();
- void Finalize(XSpace* space);
+ void Start(const PythonHooksOptions& option) {
+ if (active_context_) return;
+ active_context_ = std::make_unique<PythonHookContext>();
+ active_context_->Start(option);
+ }
+
+ std::unique_ptr<PythonHookContext> Stop() {
+ if (e2e_context_) {
+ auto* e2e_context = e2e_context_;
+ e2e_context_ = nullptr;
+ return absl::WrapUnique(e2e_context);
+ }
+
+ if (!active_context_) return nullptr;
+ active_context_->Stop();
+ std::unique_ptr<PythonHookContext> output = std::move(active_context_);
+ active_context_.reset();
+ return output;
+ }
+
void ProfileSlow(const py::object& frame, const string& event,
const py::object& arg);
- void ProfileFast(PyFrameObject* frame, int what, PyObject* arg);
+
+ void ProfileFast(PyFrameObject* frame, int what, PyObject* arg) {
+ if (TF_PREDICT_TRUE(active_context_)) {
+ active_context_->ProfileFast(frame, what, arg);
+ }
+ }
+
+ static void set_e2e_context(PythonHookContext* e2e_context) {
+ e2e_context_ = e2e_context;
+ }
+
+ static PythonHookContext* e2e_context() { return e2e_context_; }
private:
- void EnableTraceMe(bool enable);
- void CollectData(XPlane* raw_plane);
-
- void SetProfilerInAllThreads();
- void ClearProfilerInAllThreads();
-
- // entries_ are accessed when GIL is held, therefore no race conditions.
- absl::flat_hash_map<int64, PerThreadEvents> entries_;
- uint64 start_timestamp_ns_;
- bool active_session_ = false;
- PythonHooksOptions options_;
- // In end to end mode, Python get uninitialized before Stop()/Finalize(), we
- // need to buffer the result.
- absl::optional<XPlane> end_to_end_xplane_;
+ // active_context_ are accessed when GIL is held, therefore no race
+ // conditions.
+ std::unique_ptr<PythonHookContext> active_context_;
+ static PythonHookContext* e2e_context_;
};
} // namespace profiler
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 1979c92..63990cf 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -432,7 +432,7 @@
if isinstance(secondary_referrer, base.Trackable):
trackable_referrers.append(secondary_referrer)
raise AssertionError(
- ("Tried to export a function which references untracked resource {}."
+ ("Tried to export a function which references untracked resource {}. "
"TensorFlow objects (e.g. tf.Variable) captured by functions must "
"be tracked by assigning them to an attribute of a tracked object "
"or assigned to an attribute of the main object directly.\n\n"
diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py
index 419f37e..206776e 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2.py
@@ -297,9 +297,14 @@
# Thus we must fix a common order to tables and ensure they have unique
# names.
- # Set table order here
- self._table_config = list(
- {feature.table for feature in nest.flatten(feature_config)})
+ # Set table order here to the order of the first occurence of the table in a
+ # feature provided by the user. The order of this struct must be fixed
+ # to provide the user with deterministic behavior over multiple
+ # instantiations.
+ self._table_config = []
+ for feature in nest.flatten(feature_config):
+ if feature.table not in self._table_config:
+ self._table_config.append(feature.table)
# Ensure tables have unique names. Also error check the optimizer as we
# specifically don't do that in the TableConfig class to allow high level
diff --git a/tensorflow/python/tpu/tpu_embedding_v2_test.py b/tensorflow/python/tpu/tpu_embedding_v2_test.py
index f524c8e..6c649ee 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2_test.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2_test.py
@@ -1273,6 +1273,32 @@
# not matter.
mid_level_api.build(self.batch_size)
+ def test_same_config_different_instantiations(self):
+ num_tables = 30
+ table_dim = np.random.randint(1, 128, size=[num_tables])
+ table_vocab_size = np.random.randint(100, 1000, size=[num_tables])
+ table_names = ['table{}'.format(i) for i in range(num_tables)]
+ table_data = list(zip(table_dim, table_vocab_size, table_names))
+ strategy = self._get_strategy()
+
+ def tpu_embedding_config():
+ feature_configs = []
+ for dim, vocab, name in table_data:
+ feature_configs.append(tpu_embedding_v2_utils.FeatureConfig(
+ table=tpu_embedding_v2_utils.TableConfig(
+ vocabulary_size=int(vocab), dim=int(dim),
+ initializer=init_ops_v2.Zeros(), name=name)))
+ optimizer = tpu_embedding_v2_utils.Adagrad(
+ learning_rate=0.1)
+ with strategy.scope():
+ mid_level_api = tpu_embedding_v2.TPUEmbedding(
+ feature_config=feature_configs,
+ optimizer=optimizer)
+ mid_level_api._batch_size = 128
+ return mid_level_api._create_config_proto()
+
+ self.assertProtoEquals(tpu_embedding_config(), tpu_embedding_config())
+
def _unpack(strategy, per_replica_output):
per_replica_output = strategy.experimental_local_results(per_replica_output)
diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD
index dfe44e2..e0eb8d0 100644
--- a/tensorflow/python/util/BUILD
+++ b/tensorflow/python/util/BUILD
@@ -90,6 +90,34 @@
)
tf_python_pybind_extension(
+ name = "_pywrap_nest",
+ srcs = ["nest_wrapper.cc"],
+ hdrs = ["nest.h"],
+ module_name = "_pywrap_nest",
+ deps = [
+ "//tensorflow/python:pybind11_lib",
+ "//third_party/python_runtime:headers",
+ "@pybind11",
+ ],
+)
+
+cc_library(
+ name = "cpp_nest",
+ srcs = ["nest.cc"],
+ hdrs = ["nest.h"],
+ deps = [
+ ":cpp_python_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/platform:logging",
+ "//tensorflow/core/platform:stringpiece",
+ "//tensorflow/python/lib/core:safe_pyobject_ptr",
+ "//third_party/python_runtime:headers",
+ ],
+ alwayslink = 1,
+)
+
+tf_python_pybind_extension(
name = "_pywrap_kernel_registry",
srcs = ["kernel_registry_wrapper.cc"],
hdrs = ["kernel_registry.h"],
diff --git a/tensorflow/python/util/nest.cc b/tensorflow/python/util/nest.cc
new file mode 100644
index 0000000..63d6ab2
--- /dev/null
+++ b/tensorflow/python/util/nest.cc
@@ -0,0 +1,146 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/util/nest.h"
+
+#include <utility>
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
+#include "tensorflow/python/util/util.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Gets a string representation of the input object.
+//
+// Args:
+// o: a python object.
+// length: If set to negative, the whole string is returned. Otherwise, the
+// string gets clipped to 'length' in size.
+//
+// Returns:
+// A string representation.
+std::string PyObject_ToString(PyObject* o, int length = -1) {
+ auto str_o = make_safe(PyObject_Str(o));
+ std::string str = PyUnicode_AsUTF8(str_o.get());
+ if (length < 0 || str.size() <= length) {
+ return str;
+ }
+ tensorflow::StringPiece str_piece(str);
+ return tensorflow::strings::StrCat(str_piece.substr(length), "...");
+}
+
+// Gets a list of keys from a dict or mapping type object.
+//
+// Args:
+// o: a dictionary or mapping type object.
+//
+// Returns:
+// A new reference to a list.
+//
+// Raises:
+// TypeError: if `o` is not a dict or mapping type object.
+PyObject* GetKeysFromDictOrMapping(PyObject* o) {
+ if (PyDict_Check(o)) {
+ return PyDict_Keys(o);
+ } else if (PyMapping_Check(o)) {
+ return PyMapping_Keys(o);
+ } else {
+ auto* o_type = Py_TYPE(o);
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting a type compatible with dict or mapping, got '",
+ o_type->tp_name, "'")
+ .c_str());
+ return nullptr;
+ }
+}
+
+} // namespace
+
+PyObject* FlattenDictItems(PyObject* dict) {
+ if (!PyDict_Check(dict) && !swig::IsMapping(dict)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "FlattenDictItems: 'dict' must be a dictionary or ",
+ "collection.Mapping type object, instead of '",
+ Py_TYPE(dict)->tp_name, "'.")
+ .c_str());
+ return nullptr;
+ }
+ PyObject* flat_dictionary = PyDict_New();
+ auto keys = make_safe(GetKeysFromDictOrMapping(dict));
+ for (size_t i = 0; i < PyList_Size(keys.get()); ++i) {
+ auto* key = PyList_GetItem(keys.get(), i);
+ // We use a general approach in case 'dict' is a PyMapping type,
+ // but not a PyDict type.
+ auto* value = PyObject_GetItem(dict, key);
+ if (swig::IsSequence(key)) {
+ // The dict might contain list - list pairs.
+ auto flat_keys = make_safe(swig::Flatten(key, false));
+ auto flat_values = make_safe(swig::Flatten(value, false));
+ size_t flat_keys_sz = PyList_Size(flat_keys.get());
+ size_t flat_values_sz = PyList_Size(flat_values.get());
+ if (flat_keys_sz != flat_values_sz) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "Could not flatten dictionary. Key had ", flat_keys_sz,
+ " elements, but value had ", flat_values_sz,
+ " elements. Key: ", PyObject_ToString(flat_keys.get()),
+ ", value: ", PyObject_ToString(flat_values.get()), ".")
+ .c_str());
+ Py_DecRef(flat_dictionary);
+ return nullptr;
+ }
+ for (size_t i = 0; i < flat_keys_sz; ++i) {
+ auto* flat_key = PyList_GetItem(flat_keys.get(), i);
+ auto* flat_value = PyList_GetItem(flat_values.get(), i);
+ if (PyDict_GetItem(flat_dictionary, flat_key) != nullptr) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "Cannot flatten dict because this key is not unique: ",
+ PyObject_ToString(flat_key))
+ .c_str());
+ Py_DecRef(flat_dictionary);
+ return nullptr;
+ }
+ PyDict_SetItem(flat_dictionary, flat_key, flat_value);
+ }
+ } else {
+ if (PyDict_GetItem(flat_dictionary, key) != nullptr) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "Cannot flatten dict because this key is not unique: ",
+ PyObject_ToString(key))
+ .c_str());
+ Py_DecRef(flat_dictionary);
+ return nullptr;
+ }
+ PyDict_SetItem(flat_dictionary, key, value);
+ }
+ // Manually decrease because PyObject_GetItem() returns a new reference.
+ Py_DECREF(value);
+ }
+ return flat_dictionary;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/util/nest.h b/tensorflow/python/util/nest.h
new file mode 100644
index 0000000..43829f4
--- /dev/null
+++ b/tensorflow/python/util/nest.h
@@ -0,0 +1,37 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_COMPAT_NEST_H_
+#define TENSORFLOW_PYTHON_COMPAT_NEST_H_
+
+#include <Python.h>
+
+namespace tensorflow {
+// Returns a dictionary with flattened keys and values.
+//
+// Args:
+// dict: the dictionary to zip
+//
+// Returns:
+// An new reference to the zipped dictionary.
+//
+// Raises:
+// TypeError: If the input is not a dictionary.
+// ValueError: If any key and value do not have the same structure layout, or
+// if keys are not unique.
+PyObject* FlattenDictItems(PyObject* dict);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_COMPAT_NEST_H_
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 910da98..a3b1530 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -1,4 +1,4 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -48,6 +48,7 @@
import wrapt as _wrapt
from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import _pywrap_nest
from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util.compat import collections_abc as _collections_abc
from tensorflow.python.util.tf_export import tf_export
@@ -562,30 +563,7 @@
ValueError: If any key and value do not have the same structure layout, or
if keys are not unique.
"""
- if not isinstance(dictionary, (dict, _collections_abc.Mapping)):
- raise TypeError("input must be a dictionary")
- flat_dictionary = {}
- for i, v in _six.iteritems(dictionary):
- if not is_sequence(i):
- if i in flat_dictionary:
- raise ValueError(
- "Could not flatten dictionary: key %s is not unique." % i)
- flat_dictionary[i] = v
- else:
- flat_i = flatten(i)
- flat_v = flatten(v)
- if len(flat_i) != len(flat_v):
- raise ValueError(
- "Could not flatten dictionary. Key had %d elements, but value had "
- "%d elements. Key: %s, value: %s."
- % (len(flat_i), len(flat_v), flat_i, flat_v))
- for new_i, new_v in zip(flat_i, flat_v):
- if new_i in flat_dictionary:
- raise ValueError(
- "Could not flatten dictionary: key %s is not unique."
- % (new_i))
- flat_dictionary[new_i] = new_v
- return flat_dictionary
+ return _pywrap_nest.FlattenDictItems(dictionary)
def _packed_nest_with_indices(structure, flat, index, is_seq, sequence_fn=None):
diff --git a/tensorflow/python/util/nest_wrapper.cc b/tensorflow/python/util/nest_wrapper.cc
new file mode 100644
index 0000000..6b87caa
--- /dev/null
+++ b/tensorflow/python/util/nest_wrapper.cc
@@ -0,0 +1,35 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "pybind11/pybind11.h"
+#include "tensorflow/python/lib/core/pybind11_lib.h"
+#include "tensorflow/python/util/nest.h"
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(_pywrap_nest, m) {
+ m.doc() = R"pbdoc(
+ _pywrap_nest
+ -----
+ )pbdoc";
+ m.def(
+ "FlattenDictItems",
+ [](const py::handle& dict) {
+ return tensorflow::PyoOrThrow(tensorflow::FlattenDictItems(dict.ptr()));
+ },
+ R"pbdoc(
+ Returns a dictionary with flattened keys and values.
+ )pbdoc");
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt
index 3cf5dd9..de99a16 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt
@@ -1,9 +1,6 @@
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding"
tf_class {
- is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding_v1.CategoryEncoding\'>"
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding.CategoryEncoding\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer_v1.CombinerPreprocessingLayer\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
@@ -141,11 +138,11 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'max_tokens\', \'output_mode\', \'sparse\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'binary\', \'False\'], "
+ argspec: "args=[\'self\', \'num_tokens\', \'output_mode\', \'sparse\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'binary\', \'False\'], "
}
member_method {
name: "adapt"
- argspec: "args=[\'self\', \'data\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
}
member_method {
name: "add_loss"
@@ -260,14 +257,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "set_num_elements"
- argspec: "args=[\'self\', \'num_elements\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_tfidf_data"
- argspec: "args=[\'self\', \'tfidf_data\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lookup.-text-file-initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lookup.-text-file-initializer.pbtxt
index ff9a0ce..7c69b37 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.lookup.-text-file-initializer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.lookup.-text-file-initializer.pbtxt
@@ -14,7 +14,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
+ argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], "
}
member_method {
name: "initialize"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 1cdb121..a9bdab2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1950,11 +1950,11 @@
}
member_method {
name: "InitializeTableFromTextFile"
- argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
+ argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
}
member_method {
name: "InitializeTableFromTextFileV2"
- argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
+ argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
}
member_method {
name: "InitializeTableV2"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt
index 68d9bec..de99a16 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt
@@ -1,7 +1,6 @@
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding.CategoryEncoding\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
@@ -139,11 +138,11 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'max_tokens\', \'output_mode\', \'sparse\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'binary\', \'False\'], "
+ argspec: "args=[\'self\', \'num_tokens\', \'output_mode\', \'sparse\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'binary\', \'False\'], "
}
member_method {
name: "adapt"
- argspec: "args=[\'self\', \'data\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
}
member_method {
name: "add_loss"
@@ -258,14 +257,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "set_num_elements"
- argspec: "args=[\'self\', \'num_elements\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_tfidf_data"
- argspec: "args=[\'self\', \'tfidf_data\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lookup.-text-file-initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lookup.-text-file-initializer.pbtxt
index ff9a0ce..7c69b37 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.lookup.-text-file-initializer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.lookup.-text-file-initializer.pbtxt
@@ -14,7 +14,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
+ argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], "
}
member_method {
name: "initialize"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 1cdb121..a9bdab2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1950,11 +1950,11 @@
}
member_method {
name: "InitializeTableFromTextFile"
- argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
+ argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
}
member_method {
name: "InitializeTableFromTextFileV2"
- argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
+ argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
}
member_method {
name: "InitializeTableV2"
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index 0dadcbd..9afc778 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -19,7 +19,6 @@
# Get the latest version of pip so it recognize manylinux2010
wget https://bootstrap.pypa.io/get-pip.py
python3.6 get-pip.py
-python get-pip.py
rm -f get-pip.py
# Install pip packages from whl files to avoid the time-consuming process of
@@ -27,41 +26,31 @@
# Pin wheel==0.31.1 to work around issue
# https://github.com/pypa/auditwheel/issues/102
-pip2 install wheel==0.31.1
pip3 install wheel==0.31.1
# Install last working version of setuptools. This must happen before we install
# absl-py, which uses install_requires notation introduced in setuptools 20.5.
-pip2 install --upgrade setuptools==39.1.0
pip3 install --upgrade setuptools==39.1.0
-pip2 install virtualenv
pip3 install virtualenv
# Install six and future.
-pip2 install --upgrade six==1.12.0
pip3 install --upgrade six==1.12.0
-pip2 install "future>=0.17.1"
pip3 install "future>=0.17.1"
# Install absl-py.
-pip2 install --upgrade absl-py
pip3 install --upgrade absl-py
# Install werkzeug.
-pip2 install --upgrade werkzeug==0.11.10
pip3 install --upgrade werkzeug==0.11.10
# Install bleach. html5lib will be picked up as a dependency.
-pip2 install --upgrade bleach==2.0.0
pip3 install --upgrade bleach==2.0.0
# Install markdown.
-pip2 install --upgrade markdown==2.6.8
pip3 install --upgrade markdown==2.6.8
# Install protobuf.
-pip2 install --upgrade protobuf==3.6.1
pip3 install --upgrade protobuf==3.6.1
# Remove obsolete version of six, which can sometimes confuse virtualenv.
@@ -71,27 +60,20 @@
# https://github.com/tensorflow/tensorflow/issues/6968
# This workaround isn't needed for Ubuntu 16.04 or later.
if $(cat /etc/*-release | grep -q 14.04); then
- pip2 install --no-binary=:all: --upgrade numpy==1.14.5
pip3 install --no-binary=:all: --upgrade numpy==1.14.5
else
- pip2 install --upgrade numpy==1.14.5
pip3 install --upgrade numpy==1.14.5
fi
-pip2 install scipy==1.2.2
pip3 install scipy==1.4.1
-pip2 install scikit-learn==0.18.1
pip3 install scikit-learn==0.18.1
# pandas required by `inflow`
-pip2 install pandas==0.19.2
pip3 install pandas==0.19.2
# Benchmark tests require the following:
-pip2 install psutil
pip3 install psutil
-pip2 install py-cpuinfo
pip3 install py-cpuinfo
# pylint==1.6.4 requires python-astroid (>= 1.4.5) requires lazy-object-proxy
@@ -99,57 +81,40 @@
# when using setuptools 39.1.0.
# NOTE: Using the updated version of pylint for python3 as python2 is EOL,
# thus using the updated version of lazy-object-proxy==1.4.3
-pip2 install lazy-object-proxy==1.4.1
pip3 install lazy-object-proxy==1.4.3
# pylint tests require the following version. pylint==1.6.4 hangs erratically,
# thus using the updated version of 2.5.3 only for python3 as python2 is EOL
# and this version is not available.
-pip2 install pylint==1.6.4
pip3 install pylint==2.5.3
# pycodestyle tests require the following:
-pip2 install pycodestyle
pip3 install pycodestyle
-# tf.mock require the following for python2:
-pip2 install mock
-
-pip2 install portpicker
pip3 install portpicker
# TensorFlow Serving integration tests require the following:
-pip2 install grpcio
pip3 install grpcio
# Eager-to-graph execution needs astor, gast and termcolor:
-pip2 install --upgrade astor
pip3 install --upgrade astor
-pip2 install --upgrade gast
pip3 install --upgrade gast
-pip2 install --upgrade termcolor
pip3 install --upgrade termcolor
# Keras
-pip2 install keras_preprocessing==1.1.0 --no-deps
pip3 install keras_preprocessing==1.1.0 --no-deps
-pip2 install --upgrade h5py==2.8.0
pip3 install --upgrade h5py==3.1.0
# Estimator
-pip2 install tf-estimator-nightly --no-deps
pip3 install tf-estimator-nightly --no-deps
# Tensorboard
-pip2 install tb-nightly --no-deps
pip3 install tb-nightly --no-deps
# Argparse
-pip2 install --upgrade argparse
pip3 install --upgrade argparse
# tree
-pip2 install dm-tree
pip3 install dm-tree
# tf.distribute multi worker tests require the following:
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 959bc57..621dd44 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -22,6 +22,9 @@
tensorflow::swig::IsEagerTensorSlow
tensorflow::swig::GetRegisteredPyObject
+[//tensorflow/python/util:cpp_nest] # nest
+tensorflow::FlattenDictItems
+
[//tensorflow/core/util:port] # util_port
tensorflow::IsGoogleCudaEnabled
tensorflow::IsBuiltWithROCm
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 7e0e92a..c2fae97 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -135,6 +135,7 @@
"//tensorflow/python/keras/tests:model_subclassing_test_util",
"//tensorflow/python/keras/tests:model_architectures",
"//tensorflow/python/keras/utils:dataset_creator",
+ "//tensorflow/python/keras/utils:kpl_test_utils",
"//tensorflow/python/keras/benchmarks:keras_benchmark_lib_pip",
"//tensorflow/python/kernel_tests:cudnn_deterministic_base",
"//tensorflow/python/kernel_tests:bias_op_base",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 9b49339..a311809 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -39,7 +39,7 @@
load("//third_party/ruy:workspace.bzl", ruy = "repo")
load("//third_party/sobol_data:workspace.bzl", sobol_data = "repo")
load("//third_party/vulkan_headers:workspace.bzl", vulkan_headers = "repo")
-load("//third_party/toolchains/remote_config:configs.bzl", "initialize_rbe_configs")
+load("@tf_toolchains//toolchains/remote_config:configs.bzl", "initialize_rbe_configs")
def initialize_third_party():
""" Load third party repositories. See above load() statements. """
@@ -201,12 +201,11 @@
tf_http_archive(
name = "eigen_archive",
build_file = clean_dep("//third_party:eigen.BUILD"),
- patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
- sha256 = "768b744d98505db4d73562b7813ee1e102dd185cf79a7ef1d5dbcc6e7e918eaf", # SHARED_EIGEN_SHA
- strip_prefix = "eigen-352f1422d3ceea19a04cab297c6339e0870e1c6c",
+ sha256 = "d76992f1972e4ff270221c7ee8125610a8e02bb46708a7295ee646e99287083b", # SHARED_EIGEN_SHA
+ strip_prefix = "eigen-90ee821c563fa20db4d64d6991ddca256d5c52f2",
urls = [
- "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/352f1422d3ceea19a04cab297c6339e0870e1c6c/eigen-352f1422d3ceea19a04cab297c6339e0870e1c6c.tar.gz",
- "https://gitlab.com/libeigen/eigen/-/archive/352f1422d3ceea19a04cab297c6339e0870e1c6c/eigen-352f1422d3ceea19a04cab297c6339e0870e1c6c.tar.gz",
+ "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/90ee821c563fa20db4d64d6991ddca256d5c52f2/eigen-90ee821c563fa20db4d64d6991ddca256d5c52f2.tar.gz",
+ "https://gitlab.com/libeigen/eigen/-/archive/90ee821c563fa20db4d64d6991ddca256d5c52f2/eigen-90ee821c563fa20db4d64d6991ddca256d5c52f2.tar.gz",
],
)
@@ -366,12 +365,12 @@
tf_http_archive(
name = "org_sqlite",
build_file = clean_dep("//third_party:sqlite.BUILD"),
- sha256 = "8ff0b79fd9118af7a760f1f6a98cac3e69daed325c8f9f0a581ecb62f797fd64",
- strip_prefix = "sqlite-amalgamation-3340000",
+ sha256 = "e0b1c0345fe4338b936e17da8e1bd88366cd210e576834546977f040c12a8f68",
+ strip_prefix = "sqlite-amalgamation-3340100",
system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
urls = [
- "https://storage.googleapis.com/mirror.tensorflow.org/www.sqlite.org/2020/sqlite-amalgamation-3340000.zip",
- "https://www.sqlite.org/2020/sqlite-amalgamation-3340000.zip",
+ "https://storage.googleapis.com/mirror.tensorflow.org/www.sqlite.org/2021/sqlite-amalgamation-3340100.zip",
+ "https://www.sqlite.org/2021/sqlite-amalgamation-3340100.zip",
],
)
@@ -685,8 +684,8 @@
)
# Check out LLVM and MLIR from llvm-project.
- LLVM_COMMIT = "9db6e97a8605f6a447ed171e59d5fa46fdfdc432"
- LLVM_SHA256 = "f30fe9eb9a342187d25babccd85c3af4f09ee7340108a9f3a259af1dc0c76484"
+ LLVM_COMMIT = "892d2822b62ebcaa7aa0b006b5ea4f26593c1618"
+ LLVM_SHA256 = "223c0e99ff272b0eb6245026ec0fefd6254c1f1e794b76171868fcc843a0b6f5"
LLVM_URLS = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
@@ -1137,17 +1136,6 @@
],
)
- tf_http_archive(
- name = "tf_toolchains",
- sha256 = "eb175afa73e5a33d2b5d2aabcfde6c8c3395fd7001eb5ba765a5cd98cce714ba",
- strip_prefix = "toolchains-0.0.2",
- build_file = clean_dep("//third_party:tf_toolchains.BUILD"),
- urls = [
- "http://mirror.tensorflow.org/github.com/tensorflow/toolchains/archive/v0.0.2.tar.gz",
- "https://github.com/tensorflow/toolchains/archive/v0.0.2.tar.gz",
- ],
- )
-
def tf_bind():
"""Bind targets for some external repositories"""
##############################################################################
diff --git a/tensorflow/workspace3.bzl b/tensorflow/workspace3.bzl
index 8ae8799..de0144b 100644
--- a/tensorflow/workspace3.bzl
+++ b/tensorflow/workspace3.bzl
@@ -13,6 +13,16 @@
],
)
+ http_archive(
+ name = "tf_toolchains",
+ sha256 = "d60f9637c64829e92dac3f4477a2c45cdddb9946c5da0dd46db97765eb9de08e",
+ strip_prefix = "toolchains-1.1.5",
+ urls = [
+ "http://mirror.tensorflow.org/github.com/tensorflow/toolchains/archive/v1.1.5.tar.gz",
+ "https://github.com/tensorflow/toolchains/archive/v1.1.5.tar.gz",
+ ],
+ )
+
# Alias so it can be loaded without assigning to a different symbol to prevent
# shadowing previous loads and trigger a buildifier warning.
tf_workspace3 = workspace
diff --git a/third_party/eigen3/gpu_packet_math.patch b/third_party/eigen3/gpu_packet_math.patch
deleted file mode 100644
index c0f466c..0000000
--- a/third_party/eigen3/gpu_packet_math.patch
+++ /dev/null
@@ -1,98 +0,0 @@
-diff -ru a/Eigen/src/Geometry/arch/Geometry_SSE.h b/Eigen/src/Geometry/arch/Geometry_SSE.h
---- a/Eigen/src/Geometry/arch/Geometry_SSE.h
-+++ b/Eigen/src/Geometry/arch/Geometry_SSE.h
-@@ -33,13 +33,14 @@
- Packet4f b = be.template packet<BAlignment,Packet4f>(0);
- Packet4f s1 = pmul(vec4f_swizzle1(a,1,2,0,2),vec4f_swizzle1(b,2,0,1,2));
- Packet4f s2 = pmul(vec4f_swizzle1(a,3,3,3,1),vec4f_swizzle1(b,0,1,2,1));
-- pstoret<float,Packet4f,ResAlignment>(
-- &res.x(),
-- padd(psub(pmul(a,vec4f_swizzle1(b,3,3,3,3)),
-- pmul(vec4f_swizzle1(a,2,0,1,0),
-- vec4f_swizzle1(b,1,2,0,0))),
-- pxor(mask,padd(s1,s2))));
--
-+ pstoret<float, Packet4f, ResAlignment>(
-+ &res.x(),
-+ padd<Packet4f>(
-+ psub<Packet4f>(pmul<Packet4f>(a, vec4f_swizzle1(b, 3, 3, 3, 3)),
-+ pmul<Packet4f>(vec4f_swizzle1(a, 2, 0, 1, 0),
-+ vec4f_swizzle1(b, 1, 2, 0, 0))),
-+ pxor<Packet4f>(mask, padd(s1, s2))));
-+
- return res;
- }
- };
-diff -ru a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
---- a/Eigen/src/Core/GenericPacketMath.h
-+++ b/Eigen/src/Core/GenericPacketMath.h
-@@ -255,49 +255,43 @@
- return std::complex<RealScalar>(b, b);
- }
-
--template <typename Packet, typename Op>
--EIGEN_DEVICE_FUNC inline Packet bitwise_helper(const Packet& a, const Packet& b, Op op) {
-+/** \internal \returns the bitwise and of \a a and \a b */
-+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-+pand(const Packet& a, const Packet& b) {
- const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
- const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
- Packet c;
- unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
- for (size_t i = 0; i < sizeof(Packet); ++i) {
-- *c_ptr++ = op(*a_ptr++, *b_ptr++);
-+ *c_ptr++ = *a_ptr++ & *b_ptr++;
- }
- return c;
- }
-
--/** \internal \returns the bitwise and of \a a and \a b */
--template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
--pand(const Packet& a, const Packet& b) {
--#if defined(EIGEN_HIP_DEVICE_COMPILE)
-- return bitwise_helper(a ,b, std::bit_and<unsigned char>());
--#else
-- EIGEN_USING_STD(bit_and);
-- return bitwise_helper(a ,b, bit_and<unsigned char>());
--#endif
--}
--
- /** \internal \returns the bitwise or of \a a and \a b */
- template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
- por(const Packet& a, const Packet& b) {
--#if defined(EIGEN_HIP_DEVICE_COMPILE)
-- return bitwise_helper(a ,b, std::bit_or<unsigned char>());
--#else
-- EIGEN_USING_STD(bit_or);
-- return bitwise_helper(a ,b, bit_or<unsigned char>());
--#endif
-+ const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
-+ const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
-+ Packet c;
-+ unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
-+ for (size_t i = 0; i < sizeof(Packet); ++i) {
-+ *c_ptr++ = *a_ptr++ | *b_ptr++;
-+ }
-+ return c;
- }
-
- /** \internal \returns the bitwise xor of \a a and \a b */
- template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
- pxor(const Packet& a, const Packet& b) {
--#if defined(EIGEN_HIP_DEVICE_COMPILE)
-- return bitwise_helper(a ,b, std::bit_xor<unsigned char>());
--#else
-- EIGEN_USING_STD(bit_xor);
-- return bitwise_helper(a ,b, bit_xor<unsigned char>());
--#endif
-+ const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
-+ const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
-+ Packet c;
-+ unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
-+ for (size_t i = 0; i < sizeof(Packet); ++i) {
-+ *c_ptr++ = *a_ptr++ ^ *b_ptr++;
-+ }
-+ return c;
- }
-
- /** \internal \returns the bitwise and of \a a and not \a b */
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
index 7e06749..0d1423f 100644
--- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
@@ -153,6 +153,13 @@
]
_, argv = GetOptionValue(argv, '--no-cuda-include-ptx')
+ # nvcc doesn't respect the INCLUDE and LIB env vars from MSVC,
+ # so we explicity specify the system include paths and library search paths.
+ if 'INCLUDE' in os.environ:
+ nvccopts += [('--system-include="%s"' % p) for p in os.environ['INCLUDE'].split(";")]
+ if 'LIB' in os.environ:
+ nvccopts += [('--library-path="%s"' % p) for p in os.environ['LIB'].split(";")]
+
nvccopts += nvcc_compiler_options
nvccopts += undefines
nvccopts += defines
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index a017ab4..8b2c4ce 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -1393,6 +1393,21 @@
else:
_create_local_cuda_repository(repository_ctx)
+# For @bazel_tools//tools/cpp:windows_cc_configure.bzl
+_MSVC_ENVVARS = [
+ "BAZEL_VC",
+ "BAZEL_VC_FULL_VERSION",
+ "BAZEL_VS",
+ "BAZEL_WINSDK_FULL_VERSION",
+ "VS90COMNTOOLS",
+ "VS100COMNTOOLS",
+ "VS110COMNTOOLS",
+ "VS120COMNTOOLS",
+ "VS140COMNTOOLS",
+ "VS150COMNTOOLS",
+ "VS160COMNTOOLS",
+]
+
_ENVIRONS = [
_GCC_HOST_COMPILER_PATH,
_GCC_HOST_COMPILER_PREFIX,
@@ -1410,7 +1425,7 @@
"TMP",
"TMPDIR",
"TF_CUDA_PATHS",
-]
+] + _MSVC_ENVVARS
remote_cuda_configure = repository_rule(
implementation = _create_local_cuda_repository,
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 7ddc874..828035d 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -553,21 +553,6 @@
td_srcs = [":LLVMArmNeonTdFiles"],
)
-cc_library(
- name = "TargetLLVMArmNeonIntr",
- srcs = ["lib/Target/LLVMIR/LLVMArmNeonIntr.cpp"],
- includes = ["include"],
- deps = [
- ":IR",
- ":LLVMArmNeon",
- ":LLVMArmNeonConversionIncGen",
- ":LLVMIRModuleTranslation",
- ":Translation",
- "@llvm-project//llvm:Core",
- "@llvm-project//llvm:Support",
- ],
-)
-
##---------------------------------------------------------------------------##
# ArmSVE dialect.
##---------------------------------------------------------------------------##
@@ -716,21 +701,6 @@
td_srcs = [":LLVMArmSVETdFiles"],
)
-cc_library(
- name = "TargetLLVMArmSVEIntr",
- srcs = ["lib/Target/LLVMIR/LLVMArmSVEIntr.cpp"],
- includes = ["include"],
- deps = [
- ":IR",
- ":LLVMArmSVE",
- ":LLVMArmSVEConversionIncGen",
- ":LLVMIRModuleTranslation",
- ":Translation",
- "@llvm-project//llvm:Core",
- "@llvm-project//llvm:Support",
- ],
-)
-
##---------------------------------------------------------------------------##
# AVX512 dialect.
##---------------------------------------------------------------------------##
@@ -872,21 +842,6 @@
td_srcs = [":LLVMAVX512TdFiles"],
)
-cc_library(
- name = "TargetLLVMAVX512Intr",
- srcs = ["lib/Target/LLVMIR/LLVMAVX512Intr.cpp"],
- includes = ["include"],
- deps = [
- ":IR",
- ":LLVMAVX512",
- ":LLVMAVX512ConversionIncGen",
- ":LLVMIRModuleTranslation",
- ":Translation",
- "@llvm-project//llvm:Core",
- "@llvm-project//llvm:Support",
- ],
-)
-
##---------------------------------------------------------------------------##
# SCF dialect.
##---------------------------------------------------------------------------##
@@ -1200,6 +1155,7 @@
":LinalgToLLVM",
":LinalgToSPIRV",
":LinalgToStandard",
+ ":MathToLLVM",
":OpenMPToLLVM",
":PDLToPDLInterp",
":SCFToGPUPass",
@@ -1925,12 +1881,14 @@
cc_library(
name = "GPUCommonTransforms",
+ srcs = [
+ "lib/Conversion/GPUCommon/GPUOpsLowering.cpp",
+ ],
hdrs = [
+ "lib/Conversion/GPUCommon/GPUOpsLowering.h",
"lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h",
"lib/Conversion/GPUCommon/OpToFuncCallLowering.h",
],
- # TODO(b/155492113): Move back to hdrs once fixed.
- textual_hdrs = ["lib/Conversion/GPUCommon/GPUOpsLowering.h"],
deps = [
":GPUDialect",
":IR",
@@ -1975,6 +1933,7 @@
":GPUToNVVMGen",
":GPUTransforms",
":IR",
+ ":MathDialect",
":NVVMDialect",
":Pass",
":StandardToLLVM",
@@ -2055,6 +2014,7 @@
":GPUDialect",
":GPUToROCDLTGen",
":GPUTransforms",
+ ":MathDialect",
":Pass",
":ROCDLDialect",
":StandardToLLVM",
@@ -2105,10 +2065,10 @@
":GPUDialect",
":IR",
":LLVMDialect",
+ ":NVVMToLLVMIRTranslation",
":Pass",
":StandardToLLVM",
":Support",
- ":TargetNVVMIR",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:NVPTXCodeGen",
"@llvm-project//llvm:Support",
@@ -2783,6 +2743,7 @@
deps = [
":ConversionPassIncGen",
":IR",
+ ":MathDialect",
":Pass",
":SPIRVConversion",
":SPIRVDialect",
@@ -3334,6 +3295,7 @@
":ConversionPassIncGen",
":IR",
":LLVMDialect",
+ ":MathDialect",
":Parser",
":Pass",
":StandardOps",
@@ -3577,6 +3539,8 @@
"lib/Target/LLVMIR/TypeTranslation.cpp",
],
hdrs = [
+ "include/mlir/Target/LLVMIR/Export.h",
+ "include/mlir/Target/LLVMIR/LLVMTranslationInterface.h",
"include/mlir/Target/LLVMIR/ModuleTranslation.h",
"include/mlir/Target/LLVMIR/TypeTranslation.h",
],
@@ -3596,6 +3560,119 @@
)
cc_library(
+ name = "LLVMAVX512ToLLVMIRTranslation",
+ srcs = glob(["lib/Target/LLVMIR/Dialect/LLVMAVX512/*.cpp"]),
+ hdrs = glob(["include/mlir/Target/LLVMIR/Dialect/LLVMAVX512/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":LLVMAVX512",
+ ":LLVMAVX512ConversionIncGen",
+ ":LLVMIRModuleTranslation",
+ ":Support",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "LLVMArmNeonToLLVMIRTranslation",
+ srcs = glob(["lib/Target/LLVMIR/Dialect/LLVMArmNeon/*.cpp"]),
+ hdrs = glob(["include/mlir/Target/LLVMIR/Dialect/LLVMArmNeon/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":LLVMArmNeon",
+ ":LLVMArmNeonConversionIncGen",
+ ":LLVMArmNeonIncGen",
+ ":LLVMIRModuleTranslation",
+ ":Support",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "LLVMArmSVEToLLVMIRTranslation",
+ srcs = glob(["lib/Target/LLVMIR/Dialect/LLVMArmSVE/*.cpp"]),
+ hdrs = glob(["include/mlir/Target/LLVMIR/Dialect/LLVMArmSVE/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":LLVMArmSVE",
+ ":LLVMArmSVEConversionIncGen",
+ ":LLVMIRModuleTranslation",
+ ":Support",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "NVVMToLLVMIRTranslation",
+ srcs = glob(["lib/Target/LLVMIR/Dialect/NVVM/*.cpp"]),
+ hdrs = glob(["include/mlir/Target/LLVMIR/Dialect/NVVM/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":LLVMIRModuleTranslation",
+ ":NVVMConversionIncGen",
+ ":NVVMDialect",
+ ":Support",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "ROCDLToLLVMIRTranslation",
+ srcs = glob(["lib/Target/LLVMIR/Dialect/ROCDL/*.cpp"]),
+ hdrs = glob(["include/mlir/Target/LLVMIR/Dialect/ROCDL/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":LLVMIRModuleTranslation",
+ ":ROCDLConversionIncGen",
+ ":ROCDLDialect",
+ ":Support",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "LLVMToLLVMIRTranslation",
+ srcs = glob(["lib/Target/LLVMIR/Dialect/LLVMIR/*.cpp"]),
+ hdrs = glob(["include/mlir/Target/LLVMIR/Dialect/LLVMIR/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":LLVMConversionIncGen",
+ ":LLVMDialect",
+ ":LLVMIRModuleTranslation",
+ ":Support",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "OpenMPToLLVMIRTranslation",
+ srcs = glob(["lib/Target/LLVMIR/Dialect/OpenMP/*.cpp"]),
+ hdrs = glob(["include/mlir/Target/LLVMIR/Dialect/OpenMP/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":LLVMIRModuleTranslation",
+ ":OpenMPDialect",
+ ":Support",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:FrontendOpenMP",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
name = "TargetLLVMIR",
srcs = [
"lib/Target/LLVMIR/ConvertFromLLVMIR.cpp",
@@ -3605,14 +3682,23 @@
includes = ["include"],
deps = [
":IR",
+ ":LLVMAVX512",
+ ":LLVMAVX512ToLLVMIRTranslation",
+ ":LLVMArmNeon",
+ ":LLVMArmNeonToLLVMIRTranslation",
+ ":LLVMArmSVE",
+ ":LLVMArmSVEToLLVMIRTranslation",
":LLVMConversionIncGen",
":LLVMDialect",
":LLVMIRModuleTranslation",
+ ":LLVMToLLVMIRTranslation",
+ ":NVVMDialect",
+ ":NVVMToLLVMIRTranslation",
":OpenMPDialect",
+ ":OpenMPToLLVMIRTranslation",
+ ":ROCDLDialect",
+ ":ROCDLToLLVMIRTranslation",
":Support",
- ":TargetLLVMAVX512Intr",
- ":TargetLLVMArmNeonIntr",
- ":TargetLLVMArmSVEIntr",
":Translation",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IRReader",
@@ -3621,52 +3707,20 @@
)
cc_library(
- name = "TargetNVVMIR",
- srcs = ["lib/Target/LLVMIR/ConvertToNVVMIR.cpp"],
- hdrs = ["include/mlir/Target/NVVMIR.h"],
- includes = ["include"],
- deps = [
- ":GPUDialect",
- ":IR",
- ":LLVMDialect",
- ":LLVMIRModuleTranslation",
- ":NVVMConversionIncGen",
- ":NVVMDialect",
- ":Support",
- ":Translation",
- "@llvm-project//llvm:Core",
- "@llvm-project//llvm:Support",
- ],
-)
-
-cc_library(
- name = "TargetROCDLIR",
- srcs = ["lib/Target/LLVMIR/ConvertToROCDLIR.cpp"],
- hdrs = ["include/mlir/Target/ROCDLIR.h"],
- includes = ["include"],
- deps = [
- ":GPUDialect",
- ":IR",
- ":LLVMDialect",
- ":LLVMIRModuleTranslation",
- ":ROCDLConversionIncGen",
- ":ROCDLDialect",
- ":Support",
- ":Translation",
- "@llvm-project//llvm:Core",
- "@llvm-project//llvm:Support",
- ],
-)
-
-# TODO(zinenko): Update these so that we can simplify mapping to cmake.
-cc_library(
name = "ExecutionEngine",
- srcs = ["lib/ExecutionEngine/ExecutionEngine.cpp"],
- hdrs = ["include/mlir/ExecutionEngine/ExecutionEngine.h"],
+ srcs = [
+ "include/mlir/ExecutionEngine/CRunnerUtils.h",
+ "lib/ExecutionEngine/ExecutionEngine.cpp",
+ ],
+ hdrs = [
+ "include/mlir/ExecutionEngine/ExecutionEngine.h",
+ "include/mlir/ExecutionEngine/MemRefUtils.h",
+ ],
includes = ["include"],
deps = [
":IR",
":LLVMDialect",
+ ":LLVMIRModuleTranslation",
":Support",
":TargetLLVMIR",
":Translation",
@@ -3731,10 +3785,14 @@
name = "AllTranslations",
hdrs = ["include/mlir/InitAllTranslations.h"],
deps = [
+ ":LLVMAVX512ToLLVMIRTranslation",
+ ":LLVMArmNeonToLLVMIRTranslation",
+ ":LLVMArmSVEToLLVMIRTranslation",
+ ":LLVMToLLVMIRTranslation",
+ ":NVVMToLLVMIRTranslation",
+ ":ROCDLToLLVMIRTranslation",
":SPIRVTranslateRegistration",
":TargetLLVMIR",
- ":TargetNVVMIR",
- ":TargetROCDLIR",
],
)
@@ -3803,6 +3861,9 @@
":LinalgToSPIRV",
":LinalgToStandard",
":LinalgTransforms",
+ ":MathDialect",
+ ":MathToLLVM",
+ ":MathTransforms",
":NVVMDialect",
":OpenACCDialect",
":OpenMPDialect",
@@ -3885,7 +3946,9 @@
cc_library(
name = "MlirJitRunner",
srcs = ["lib/ExecutionEngine/JitRunner.cpp"],
- hdrs = ["include/mlir/ExecutionEngine/JitRunner.h"],
+ hdrs = [
+ "include/mlir/ExecutionEngine/JitRunner.h",
+ ],
includes = ["include"],
deps = [
":AllPassesAndDialectsNoRegistration",
@@ -3893,6 +3956,8 @@
":ExecutionEngineUtils",
":IR",
":LLVMDialect",
+ ":LLVMToLLVMIRTranslation",
+ ":OpenMPToLLVMIRTranslation",
":Parser",
":Pass",
":SCFToStandard",
@@ -3942,9 +4007,13 @@
srcs = ["tools/mlir-cpu-runner/mlir-cpu-runner.cpp"],
linkopts = ["-ldl"],
deps = [
- ":AllPassesAndDialectsNoRegistration",
":ExecutionEngineUtils",
+ ":IR",
+ ":LLVMDialect",
":MlirJitRunner",
+ ":OpenMPDialect",
+ ":OpenMPToLLVMIRTranslation",
+ ":TargetLLVMIR",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:X86AsmParser",
@@ -4012,7 +4081,6 @@
name = "mlir-cuda-runner",
srcs = ["tools/mlir-cuda-runner/mlir-cuda-runner.cpp"],
deps = [
- ":AllPassesAndDialectsNoRegistration",
":Async",
":AsyncToLLVM",
":AsyncTransforms",
@@ -4024,11 +4092,14 @@
":GPUTransforms",
":IR",
":LLVMDialect",
+ ":LLVMIRModuleTranslation",
":MlirJitRunner",
":NVVMDialect",
+ ":NVVMToLLVMIRTranslation",
":Pass",
+ ":StandardOps",
":StandardToLLVM",
- ":TargetNVVMIR",
+ ":TargetLLVMIR",
":Transforms",
"//devtools/build/runtime:get_runfiles_dir",
"//third_party/gpus/cuda:cuda_headers",
@@ -4042,17 +4113,21 @@
name = "mlir-vulkan-runner",
srcs = ["tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp"],
deps = [
- ":AllPassesAndDialectsNoRegistration",
":ExecutionEngineUtils",
+ ":GPUDialect",
":GPUToSPIRV",
":GPUToVulkanTransforms",
":GPUTransforms",
+ ":LLVMDialect",
+ ":LLVMIRModuleTranslation",
":MlirJitRunner",
":Pass",
":SPIRVDialect",
":SPIRVTransforms",
+ ":StandardOps",
":StandardToLLVM",
":StandardToSPIRV",
+ ":TargetLLVMIR",
"@llvm-project//llvm:Support",
],
)
@@ -4061,19 +4136,20 @@
name = "mlir-spirv-cpu-runner",
srcs = ["tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp"],
deps = [
- ":AllPassesAndDialectsNoRegistration",
":ExecutionEngineUtils",
":GPUDialect",
":GPUToSPIRV",
":GPUTransforms",
":IR",
":LLVMDialect",
+ ":LLVMIRModuleTranslation",
":MlirJitRunner",
":Pass",
":SPIRVConversion",
":SPIRVDialect",
":SPIRVToLLVM",
":SPIRVTransforms",
+ ":StandardOps",
":StandardToLLVM",
":TargetLLVMIR",
"@llvm-project//llvm:Core",
@@ -4498,6 +4574,34 @@
],
)
+filegroup(
+ name = "LinalgSparseOpsTdFiles",
+ srcs = [
+ "include/mlir/Dialect/Linalg/IR/LinalgBase.td",
+ "include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td",
+ "include/mlir/Interfaces/ViewLikeInterface.td",
+ ":OpBaseTdFiles",
+ ],
+)
+
+gentbl(
+ name = "LinalgSparseOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ "-gen-op-decls",
+ "include/mlir/Dialect/Linalg/IR/LinalgSparseOps.h.inc",
+ ),
+ (
+ "-gen-op-defs",
+ "include/mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td",
+ td_srcs = [":LinalgSparseOpsTdFiles"],
+)
+
gentbl(
name = "LinalgInterfacesIncGen",
strip_include_prefix = "include",
@@ -4643,6 +4747,7 @@
":LinalgInterfacesIncGen",
":LinalgNamedStructuredOpsIncGen",
":LinalgOpsIncGen",
+ ":LinalgSparseOpsIncGen",
":LinalgStructuredOpsIncGen",
":Parser",
":SideEffectInterfaces",
@@ -4700,7 +4805,9 @@
":LLVMDialect",
":LinalgOps",
":LinalgPassIncGen",
+ ":LinalgSparseOpsIncGen",
":LinalgStructuredOpsIncGen",
+ ":MathDialect",
":Pass",
":SCFDialect",
":SCFToStandard",
@@ -4748,6 +4855,14 @@
"include/mlir/Dialect/Vector/VectorOpsDialect.h.inc",
),
(
+ "-gen-enum-decls",
+ "include/mlir/Dialect/Vector/VectorOpsEnums.h.inc",
+ ),
+ (
+ "-gen-enum-defs",
+ "include/mlir/Dialect/Vector/VectorOpsEnums.cpp.inc",
+ ),
+ (
"-gen-op-doc",
"g3doc/Dialects/Vector/VectorOps.md",
),
@@ -4951,6 +5066,7 @@
":ConversionPassIncGen",
":IR",
":LinalgOps",
+ ":MathDialect",
":Pass",
":StandardOps",
":TosaDialect",
@@ -5074,3 +5190,120 @@
],
visibility = [":friends"],
)
+
+filegroup(
+ name = "MathOpsTdFiles",
+ srcs = [
+ "include/mlir/Dialect/Math/IR/MathBase.td",
+ "include/mlir/Dialect/Math/IR/MathOps.td",
+ ":OpBaseTdFiles",
+ ":SideEffectTdFiles",
+ ":VectorInterfacesTdFiles",
+ ],
+)
+
+gentbl(
+ name = "MathBaseIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ "-gen-dialect-decls -dialect=math",
+ "include/mlir/Dialect/Math/IR/MathOpsDialect.h.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Math/IR/MathBase.td",
+ td_srcs = [
+ ":MathOpsTdFiles",
+ ],
+)
+
+gentbl(
+ name = "MathOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ "-gen-op-decls",
+ "include/mlir/Dialect/Math/IR/MathOps.h.inc",
+ ),
+ (
+ "-gen-op-defs",
+ "include/mlir/Dialect/Math/IR/MathOps.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Math/IR/MathOps.td",
+ td_srcs = [
+ ":MathOpsTdFiles",
+ ],
+)
+
+cc_library(
+ name = "MathDialect",
+ srcs = glob(
+ [
+ "lib/Dialect/Math/IR/*.cpp",
+ "lib/Dialect/Math/IR/*.h",
+ ],
+ ),
+ hdrs = [
+ "include/mlir/Dialect/Math/EDSC/Intrinsics.h",
+ "include/mlir/Dialect/Math/IR/Math.h",
+ "include/mlir/Transforms/InliningUtils.h",
+ ],
+ includes = ["include"],
+ deps = [
+ ":EDSC",
+ ":IR",
+ ":MathBaseIncGen",
+ ":MathOpsIncGen",
+ ":SideEffectInterfaces",
+ ":Support",
+ ":VectorInterfaces",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "MathTransforms",
+ srcs = glob([
+ "lib/Dialect/Math/Transforms/*.cpp",
+ "lib/Dialect/Math/Transforms/*.h",
+ ]),
+ hdrs = glob(["include/mlir/Dialect/Math/Transforms/*.h"]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":MathDialect",
+ ":Pass",
+ ":SCFDialect",
+ ":StandardOps",
+ ":Support",
+ ":Transforms",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "MathToLLVM",
+ srcs = glob([
+ "lib/Conversion/MathToLLVM/*.cpp",
+ "lib/Conversion/MathToLLVM/*.h",
+ ]) + ["lib/Conversion/PassDetail.h"],
+ hdrs = glob([
+ "include/mlir/Conversion/MathToLLVM/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":ConversionPassIncGen",
+ ":IR",
+ ":LLVMDialect",
+ ":MathDialect",
+ ":Pass",
+ ":StandardToLLVM",
+ ":Support",
+ ":Transforms",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ ],
+)
diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD
index 926fed9..0a3c9ed 100644
--- a/third_party/mlir/test.BUILD
+++ b/third_party/mlir/test.BUILD
@@ -262,17 +262,22 @@
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:LLVMIRModuleTranslation",
"@llvm-project//mlir:LLVMTransforms",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathTransforms",
+ "@llvm-project//mlir:NVVMDialect",
+ "@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:ROCDLDialect",
+ "@llvm-project//mlir:ROCDLToLLVMIRTranslation",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SPIRVDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",
- "@llvm-project//mlir:TargetNVVMIR",
- "@llvm-project//mlir:TargetROCDLIR",
+ "@llvm-project//mlir:TargetLLVMIR",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
diff --git a/third_party/toolchains/remote_config/BUILD b/third_party/toolchains/remote_config/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/toolchains/remote_config/BUILD
+++ /dev/null
diff --git a/third_party/toolchains/remote_config/containers.bzl b/third_party/toolchains/remote_config/containers.bzl
deleted file mode 100644
index 18c6caa..0000000
--- a/third_party/toolchains/remote_config/containers.bzl
+++ /dev/null
@@ -1,61 +0,0 @@
-"""Docker images used with remote config and RBE."""
-
-load("//third_party/toolchains/preconfig/generate:containers.bzl", "container_digests")
-
-containers = {
- # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.ubuntu16.04-manylinux2010.
- "ubuntu16.04-manylinux2010": {
- "registry": "gcr.io",
- "repository": "tensorflow-testing/nosla-ubuntu16.04-manylinux2010",
- "digest": container_digests["ubuntu16.04-manylinux2010"],
- },
-
- # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.0-cudnn7-ubuntu16.04-manylinux2010.
- "cuda10.0-cudnn7-ubuntu16.04-manylinux2010": {
- "registry": "gcr.io",
- "repository": "tensorflow-testing/nosla-cuda10.0-cudnn7-ubuntu16.04-manylinux2010",
- "digest": container_digests["cuda10.0-cudnn7-ubuntu16.04-manylinux2010"],
- },
-
- # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010.
- "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": {
- "registry": "gcr.io",
- "repository": "tensorflow-testing/nosla-cuda10.1-cudnn7-ubuntu16.04-manylinux2010",
- "digest": container_digests["cuda10.1-cudnn7-ubuntu16.04-manylinux2010"],
- },
-
- # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython.
- "cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython": {
- "registry": "gcr.io",
- "repository": "tensorflow-testing/nosla-cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython",
- "digest": container_digests["cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython"],
- },
-
- # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu18.04-manylinux2010-multipython.
- "cuda10.1-cudnn7-ubuntu18.04-manylinux2010-multipython": {
- "registry": "gcr.io",
- "repository": "tensorflow-testing/nosla-cuda10.1-cudnn7-ubuntu18.04-manylinux2010-multipython",
- "digest": container_digests["cuda10.1-cudnn7-ubuntu18.04-manylinux2010-multipython"],
- },
-
- # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython.
- "cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython": {
- "registry": "gcr.io",
- "repository": "tensorflow-testing/nosla-cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython",
- "digest": container_digests["cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython"],
- },
-
- # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu18.04-manylinux2010-multipython.
- "rocm-ubuntu18.04-manylinux2010-multipython": {
- "registry": "gcr.io",
- "repository": "tensorflow-testing/nosla-rocm-ubuntu18.04-manylinux2010-multipython",
- "digest": container_digests["rocm-ubuntu18.04-manylinux2010-multipython"],
- },
-
- # Built by gunan@ from a private Dockerfile.
- "windows-1803": {
- "registry": "gcr.io",
- "repository": "tensorflow-testing/tf-win-rbe",
- "digest": container_digests["windows-1803"],
- },
-}
diff --git a/third_party/toolchains/remote_config/rbe_config.bzl b/third_party/toolchains/remote_config/rbe_config.bzl
deleted file mode 100644
index 08c115a..0000000
--- a/third_party/toolchains/remote_config/rbe_config.bzl
+++ /dev/null
@@ -1,165 +0,0 @@
-"""Macro that creates external repositories for remote config."""
-
-load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure")
-load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure")
-load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure")
-load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure")
-load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure")
-load("//third_party/toolchains/remote_config:containers.bzl", "containers")
-load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure")
-
-def _container_image_uri(container_name):
- container = containers[container_name]
- return "docker://%s/%s@%s" % (container["registry"], container["repository"], container["digest"])
-
-def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = None, cuda_version = None, cudnn_version = None, tensorrt_version = None, tensorrt_install_path = None, cudnn_install_path = None, compiler_prefix = None, sysroot = None, python_install_path = "/usr"):
- if cuda_version != None and rocm_version != None:
- fail("Specifying both cuda_version and rocm_version is not supported.")
-
- env = {
- "ABI_VERSION": "gcc",
- "ABI_LIBC_VERSION": "glibc_2.19",
- "BAZEL_COMPILER": compiler,
- "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu",
- "BAZEL_TARGET_LIBC": "glibc_2.19",
- "BAZEL_TARGET_CPU": "k8",
- "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu",
- "CC_TOOLCHAIN_NAME": "linux_gnu_x86",
- "CC": compiler,
- "CLEAR_CACHE": "1",
- "HOST_CXX_COMPILER": compiler,
- "HOST_C_COMPILER": compiler,
- }
-
- if cuda_version != None:
- # The cuda toolchain currently contains its own C++ toolchain definition,
- # so we do not fetch local_config_cc.
- env.update({
- "TF_NEED_CUDA": "1",
- "TF_CUDA_CLANG": "1" if compiler.endswith("clang") else "0",
- "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0",
- "TF_ENABLE_XLA": "1",
- "TF_CUDNN_VERSION": cudnn_version,
- "TF_CUDA_VERSION": cuda_version,
- "CUDNN_INSTALL_PATH": cudnn_install_path if cudnn_install_path != None else "/usr/lib/x86_64-linux-gnu",
- "TF_NEED_TENSORRT": "1",
- "TF_TENSORRT_VERSION": tensorrt_version,
- "TENSORRT_INSTALL_PATH": tensorrt_install_path if tensorrt_install_path != None else "/usr/lib/x86_64-linux-gnu",
- "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "",
- "GCC_HOST_COMPILER_PREFIX": compiler_prefix if compiler_prefix != None else "/usr/bin",
- "CLANG_CUDA_COMPILER_PATH": compiler if compiler.endswith("clang") else "",
- "TF_SYSROOT": sysroot if sysroot else "",
- })
-
- container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os)
- container_image = _container_image_uri(container_name)
- exec_properties = {
- "container-image": container_image,
- "Pool": "default",
- }
-
- remote_cuda_configure(
- name = "%s_config_cuda" % name,
- environ = env,
- exec_properties = exec_properties,
- )
-
- remote_nccl_configure(
- name = "%s_config_nccl" % name,
- environ = env,
- exec_properties = exec_properties,
- )
-
- remote_tensorrt_configure(
- name = "%s_config_tensorrt" % name,
- environ = env,
- exec_properties = exec_properties,
- )
- elif rocm_version != None:
- # The rocm toolchain currently contains its own C++ toolchain definition,
- # so we do not fetch local_config_cc.
- env.update({
- "TF_NEED_ROCM": "1",
- "TF_ENABLE_XLA": "0",
- })
-
- container_name = "rocm-%s" % (os)
- container_image = _container_image_uri(container_name)
- exec_properties = {
- "container-image": container_image,
- "Pool": "default",
- }
-
- remote_rocm_configure(
- name = "%s_config_rocm" % name,
- environ = env,
- exec_properties = exec_properties,
- )
- elif python_versions != None:
- container_image = _container_image_uri(os)
- exec_properties = {
- "container-image": container_image,
- "Pool": "default",
- }
-
- else:
- fail("Neither cuda_version, rocm_version nor python_version specified.")
-
- remote_platform_configure(
- name = "%s_config_platform" % name,
- platform = "linux",
- platform_exec_properties = exec_properties,
- )
- for python_version in python_versions:
- env.update({
- "PYTHON_BIN_PATH": "%s/bin/python%s" % (python_install_path, python_version),
- })
-
- # For backwards compatibility do not add the python version to the name
- # if we only create a single python configuration.
- version = python_version if len(python_versions) > 1 else ""
- remote_python_configure(
- name = "%s_config_python%s" % (name, version),
- environ = env,
- exec_properties = exec_properties,
- platform_constraint = "@%s_config_platform//:platform_constraint" % name,
- )
-
-def _tensorflow_rbe_win_config(name, python_bin_path, container_name = "windows-1803"):
- container_image = _container_image_uri(container_name)
- exec_properties = {
- "container-image": container_image,
- "OSFamily": "Windows",
- }
-
- env = {
- "PYTHON_BIN_PATH": python_bin_path,
- }
-
- remote_platform_configure(
- name = "%s_config_platform" % name,
- platform = "windows",
- platform_exec_properties = exec_properties,
- )
-
- remote_python_configure(
- name = "%s_config_python" % name,
- environ = env,
- exec_properties = exec_properties,
- platform_constraint = "@%s_config_platform//:platform_constraint" % name,
- )
-
-def _tensorflow_local_config(name):
- remote_platform_configure(
- name = "%s_config_platform" % name,
- platform = "local",
- platform_exec_properties = {},
- )
- local_python_configure(
- name = "%s_config_python" % name,
- platform_constraint = "@%s_config_platform//:platform_constraint" % name,
- )
-
-tensorflow_rbe_config = _tensorflow_rbe_config
-tensorflow_rbe_win_config = _tensorflow_rbe_win_config
-tensorflow_local_config = _tensorflow_local_config