[MLIR/XLA] Preliminary version of the L(late/lowered)HLO dialect.
LHLO is a lower version of HLO where the buffer assignment has been performed
and is encoded into a dialect: operations are no longer functional,
and explicitly write to memory.
Dialect, file and class names are still subject to change, as a larger
refactoring is necessary.
Common items from HLO and LHLO are refactored into b(base)xla.td file:
currently, it's all the helper datatypes, as well as summary and description
for each operation.
PiperOrigin-RevId: 264698441
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 1bd9d53..928e7a8 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -34,6 +34,7 @@
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/xla",
+ "//tensorflow/compiler/mlir/xla:lxla",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 546d981..d3df4bd 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -20,7 +20,9 @@
filegroup(
name = "xla_ops_td_files",
srcs = [
+ "ir/lxla_ops.td",
"ir/xla_ops.td",
+ "ir/xla_ops_base.td",
"@local_config_mlir//:OpBaseTdFiles",
],
)
@@ -45,6 +47,44 @@
)
gentbl(
+ name = "xla_ops_base_inc_gen",
+ tbl_outs = [
+ (
+ "-gen-op-decls",
+ "ir/xla_ops_base.h.inc",
+ ),
+ (
+ "-gen-op-defs",
+ "ir/xla_ops_base.cc.inc",
+ ),
+ ],
+ tblgen = "@local_config_mlir//:mlir-tblgen",
+ td_file = "ir/xla_ops_base.td",
+ td_srcs = [
+ ":xla_ops_td_files",
+ ],
+)
+
+gentbl(
+ name = "lxla_ops_inc_gen",
+ tbl_outs = [
+ (
+ "-gen-op-decls",
+ "ir/lxla_ops.h.inc",
+ ),
+ (
+ "-gen-op-defs",
+ "ir/lxla_ops.cc.inc",
+ ),
+ ],
+ tblgen = "@local_config_mlir//:mlir-tblgen",
+ td_file = "ir/lxla_ops.td",
+ td_srcs = [
+ ":xla_ops_td_files",
+ ],
+)
+
+gentbl(
name = "xla_legalize_tf_inc_gen",
tbl_outs = [
(
@@ -148,6 +188,7 @@
copts = ["-std=c++14"],
includes = ["include"],
deps = [
+ ":xla_ops_base_inc_gen",
":xla_ops_inc_gen",
"@llvm//:support",
"@local_config_mlir//:Analysis",
@@ -160,12 +201,39 @@
alwayslink = 1,
)
+cc_library(
+ name = "lxla",
+ srcs = [
+ "ir/lxla_ops.cc",
+ "ir/lxla_ops.cc.inc",
+ "ir/lxla_ops.h.inc",
+ ],
+ hdrs = [
+ "ir/lxla_ops.h",
+ "transforms/passes.h",
+ ],
+ includes = ["include"],
+ deps = [
+ ":lxla_ops_inc_gen",
+ ":xla_ops_base_inc_gen",
+ "@llvm//:support",
+ "@local_config_mlir//:Analysis",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:StandardOps",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:TransformUtils",
+ ],
+ alwayslink = 1,
+)
+
# Library with XLA dialect static initialization.
cc_library(
name = "xla_dialect_registration",
srcs = ["ir/dialect_registration.cc"],
copts = ["-std=c++14"],
deps = [
+ ":lxla",
":xla",
"@local_config_mlir//:IR",
],
@@ -315,12 +383,14 @@
srcs = [
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"//tensorflow/compiler/mlir/xla:ir/xla_ops.td",
+ "//tensorflow/compiler/mlir/xla:ir/xla_ops_base.td",
],
outs = [
"operator_writers.inc",
],
cmd = ("$(location :operator_writer_gen) " +
"-I external/local_config_mlir/include " +
- "$(location //tensorflow/compiler/mlir/xla:ir/xla_ops.td) " + " -o $@"),
+ "$(location //tensorflow/compiler/mlir/xla:ir/xla_ops.td) " +
+ " -o $@"),
tools = [":operator_writer_gen"],
)
diff --git a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc
index 57f8733..8d1e108 100644
--- a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc
+++ b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc
@@ -13,7 +13,9 @@
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/mlir/xla/ir/lxla_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
// Static initialization for XLA dialect registration.
static mlir::DialectRegistration<mlir::XLA::XlaHloDialect> xla_hlo_ops;
+static mlir::DialectRegistration<mlir::LXLA::LXlaHloDialect> lxla_hlo_ops;
diff --git a/tensorflow/compiler/mlir/xla/ir/lxla_ops.cc b/tensorflow/compiler/mlir/xla/ir/lxla_ops.cc
new file mode 100644
index 0000000..e5f767f
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/ir/lxla_ops.cc
@@ -0,0 +1,64 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines the operations used in the XLA dialect.
+
+#include "tensorflow/compiler/mlir/xla/ir/lxla_ops.h"
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/IR/Attributes.h" // TF:local_config_mlir
+#include "mlir/IR/Builders.h" // TF:local_config_mlir
+#include "mlir/IR/Dialect.h" // TF:local_config_mlir
+#include "mlir/IR/Location.h" // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
+#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
+#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
+#include "mlir/IR/Operation.h" // TF:local_config_mlir
+#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
+#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
+#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
+#include "mlir/IR/Types.h" // TF:local_config_mlir
+#include "mlir/IR/Value.h" // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/xla/ir/lxla_ops.h.inc"
+
+namespace mlir {
+namespace LXLA {
+
+LXlaHloDialect::LXlaHloDialect(MLIRContext* context)
+ : Dialect(getDialectNamespace(), context) {
+ addOperations<
+#define GET_OP_LIST
+#include "tensorflow/compiler/mlir/xla/ir/lxla_ops.cc.inc"
+ >();
+}
+
+#define GET_OP_CLASSES
+#include "tensorflow/compiler/mlir/xla/ir/lxla_ops.cc.inc"
+
+// TODO(cheshire): Support folding, reuse code from xla_ops.cc.
+
+} // namespace LXLA
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/ir/lxla_ops.h b/tensorflow/compiler/mlir/xla/ir/lxla_ops.h
new file mode 100644
index 0000000..dac876b
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/ir/lxla_ops.h
@@ -0,0 +1,49 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines the operations used in the LXLA dialect.
+
+#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LXLA_OPS_H_
+#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LXLA_OPS_H_
+
+#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/Attributes.h" // TF:local_config_mlir
+#include "mlir/IR/Dialect.h" // TF:local_config_mlir
+#include "mlir/IR/Location.h" // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
+#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
+#include "mlir/IR/Operation.h" // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
+#include "mlir/IR/Types.h" // TF:local_config_mlir
+#include "mlir/Support/Functional.h" // TF:local_config_mlir
+
+namespace mlir {
+class OpBuilder;
+
+namespace LXLA {
+
+class LXlaHloDialect : public Dialect {
+ public:
+ explicit LXlaHloDialect(MLIRContext *context);
+ static StringRef getDialectNamespace() { return "lxla_hlo"; }
+};
+
+#define GET_OP_CLASSES
+#include "tensorflow/compiler/mlir/xla/ir/lxla_ops.h.inc"
+
+} // namespace LXLA
+} // end namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_LXLA_OPS_H_
diff --git a/tensorflow/compiler/mlir/xla/ir/lxla_ops.td b/tensorflow/compiler/mlir/xla/ir/lxla_ops.td
new file mode 100644
index 0000000..e670876
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/ir/lxla_ops.td
@@ -0,0 +1,323 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is the operation definition file for LXLA.
+
+#ifdef LXLA_OPS
+#else
+#define LXLA_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+#ifdef XLA_OPS_BASE
+#else
+include "tensorflow/compiler/mlir/xla/ir/xla_ops_base.td"
+#endif
+
+def LXLA_Dialect : Dialect {
+ let name = "lxla_hlo";
+ let cppNamespace = "LXLA";
+}
+
+//===----------------------------------------------------------------------===//
+// XLA type definitions.
+//===----------------------------------------------------------------------===//
+
+// Any integer tensor types
+def LXLA_IntBuffer : StaticShapeMemRefOf<[XLA_Int]>;
+
+// Any floating-point tensor types
+def LXLA_FpBuffer : StaticShapeMemRefOf<[AnyFloat]>;
+
+
+def LXLA_PredBuffer : StaticShapeMemRefOf<[XLA_Pred]>;
+
+// Any integer or floating-point tensor types
+def LXLA_IntOrFpBuffer : StaticShapeMemRefOf<[XLA_Int, AnyFloat]>;
+
+def LXLA_Buffer : StaticShapeMemRefOf<[AnyFloat, AnyInteger]>;
+
+def LXLA_TupleBuffer : NestedTupleOf<[LXLA_Buffer]>;
+
+def LXLA_BufferOrTuple : AnyTypeOf<[LXLA_Buffer, LXLA_TupleBuffer]>;
+
+//===----------------------------------------------------------------------===//
+// XLA nullary op definitions.
+//===----------------------------------------------------------------------===//
+
+class LXLA_Op<string mnemonic, list<OpTrait> traits> : Op<LXLA_Dialect,
+ mnemonic, traits>;
+
+def LXLA_ConstOp : BXLA_ConstOp, LXLA_Op<"constant", []> {
+ let arguments = (ins
+ ElementsAttr:$value,
+ LXLA_Buffer:$output
+ );
+}
+
+def LXLA_IotaOp : BXLA_IotaOp, LXLA_Op<"iota", []> {
+ let arguments = (ins I64Attr:$iota_dimension,
+ LXLA_Buffer:$output);
+}
+
+//===----------------------------------------------------------------------===//
+// XLA unary elementwise op definitions.
+//===----------------------------------------------------------------------===//
+// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
+
+class LXLA_UnaryElementwiseOp<string mnemonic> :
+ LXLA_Op<mnemonic, [SameTypeOperands]> {
+ let arguments = (ins LXLA_Buffer:$input,
+ LXLA_Buffer:$output);
+}
+
+def LXLA_AbsOp: LXLA_UnaryElementwiseOp<"abs">, BXLA_AbsOp;
+
+def LXLA_ConvertOp : LXLA_UnaryElementwiseOp<"convert">, BXLA_ConvertOp;
+
+def LXLA_ExpOp: LXLA_UnaryElementwiseOp<"exp">, BXLA_ExpOp;
+
+def LXLA_NegOp: LXLA_UnaryElementwiseOp<"neg">, BXLA_NegOp;
+
+def LXLA_SignOp: LXLA_UnaryElementwiseOp<"sign">, BXLA_SignOp;
+
+def LXLA_TanhOp: LXLA_UnaryElementwiseOp<"tanh">, BXLA_TanhOp;
+
+//===----------------------------------------------------------------------===//
+// XLA binary elementwise op definitions.
+//===----------------------------------------------------------------------===//
+
+class LXLA_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
+ LXLA_Op<mnemonic, traits> {
+ let arguments = (ins
+ LXLA_Buffer:$lhs,
+ LXLA_Buffer:$rhs,
+ LXLA_Buffer:$out,
+ BroadcastDimAttr:$broadcast_dimensions
+ );
+}
+
+def LXLA_AddOp : LXLA_BinaryElementwiseOp<"add", []>, BXLA_AddOp;
+
+def LXLA_DivOp : LXLA_BinaryElementwiseOp<"div", []>, BXLA_DivOp;
+
+def LXLA_MaxOp : LXLA_BinaryElementwiseOp<"max", []>, BXLA_MaxOp;
+
+def LXLA_MinOp : LXLA_BinaryElementwiseOp<"min", []>, BXLA_MinOp;
+
+def LXLA_MulOp : LXLA_BinaryElementwiseOp<"mul", []>, BXLA_MulOp;
+
+def LXLA_SubOp : LXLA_BinaryElementwiseOp<"sub", []>, BXLA_SubOp;
+
+def LXLA_AndOp: LXLA_BinaryElementwiseOp<"and", []>, BXLA_AndOp;
+
+//===----------------------------------------------------------------------===//
+// XLA control flow op definitions.
+//===----------------------------------------------------------------------===//
+
+// TODO(b/139813999): specify required function signature in a type-safe way.
+def LXLA_ReduceOp: LXLA_Op<"reduce", [SameVariadicOperandSize]>, BXLA_ReduceOp {
+ let arguments = (ins
+ Variadic<LXLA_BufferOrTuple>:$operands_and_init,
+ Variadic<LXLA_BufferOrTuple>:$out,
+ SymbolRefAttr:$computation,
+ ElementsAttr:$dimensions
+ );
+}
+//===----------------------------------------------------------------------===//
+// XLA tuple op definitions.
+//===----------------------------------------------------------------------===//
+
+def LXLA_GetTupleElementOp: LXLA_Op<"get_tuple_element", []>, BXLA_GetTupleElementOp {
+ let arguments = (ins
+ LXLA_TupleBuffer:$input,
+ LXLA_BufferOrTuple:$out,
+ I32Attr:$index
+ );
+}
+
+def LXLA_TupleOp : LXLA_Op<"tuple", []>, BXLA_TupleOp {
+ let arguments = (ins
+ Variadic<LXLA_BufferOrTuple>:$val,
+ LXLA_TupleBuffer:$out);
+}
+
+def LXLA_CompareOp: LXLA_Op<"compare", []>, BXLA_CompareOp {
+ let arguments = (ins
+ LXLA_Buffer:$lhs,
+ LXLA_Buffer:$rhs,
+ LXLA_PredBuffer:$out,
+ BroadcastDimAttr:$broadcast_dimensions,
+ XLA_ComparisonDirectionAttr:$comparison_direction
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// XLA Slice definitions.
+//===----------------------------------------------------------------------===//
+
+def LXLA_SliceOp: LXLA_Op<
+ "slice",
+ [AllTypesMatch<["start_indices", "limit_indices"]>]> {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ LXLA_Buffer:$output,
+ ElementsAttr:$start_indices,
+ ElementsAttr:$limit_indices
+ );
+}
+
+def XLA_DynamicUpdateSliceOp: LXLA_Op<"dynamic-update-slice", []> {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ LXLA_Buffer:$update,
+ LXLA_Buffer:$output,
+ Variadic<LXLA_Buffer>:$start_indices
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// XLA Other op definitions.
+//===----------------------------------------------------------------------===//
+
+def XLA_BatchNormInferenceOp : LXLA_Op<"batch_norm_inference", []>,
+ BXLA_BatchNormInferenceOp {
+
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ LXLA_Buffer:$scale,
+ LXLA_Buffer:$offset,
+ LXLA_Buffer:$mean,
+ LXLA_Buffer:$variance,
+ LXLA_Buffer:$output,
+ F32Attr:$epsilon,
+ I64Attr:$feature_index
+ );
+}
+
+def LXLA_BroadcastOp : LXLA_Op<"broadcast",
+ []>, BXLA_BroadcastOp {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ LXLA_Buffer:$output,
+ ElementsAttr:$broadcast_sizes
+ );
+}
+
+def LXLA_BroadcastInDimOp : LXLA_Op<"broadcast_in_dim",
+ []>, BXLA_BroadcastInDimOp {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ LXLA_Buffer:$output,
+ BroadcastDimAttr:$broadcast_dimensions
+ );
+}
+
+def LXLA_ClampOp : LXLA_Op<"clamp", []>, BXLA_ClampOp {
+ let arguments = (ins
+ LXLA_Buffer:$min,
+ LXLA_Buffer:$operand,
+ LXLA_Buffer:$max,
+ LXLA_Buffer:$output
+ );
+}
+
+def LXLA_ConcatenateOp : LXLA_Op<"concatenate", []>, BXLA_ConcatenateOp {
+ let arguments = (ins
+ Variadic<LXLA_Buffer>:$val,
+ LXLA_Buffer:$output,
+ I64Attr: $dimension
+ );
+}
+
+def LXLA_ConvOp : LXLA_Op<"conv", []>, BXLA_ConvOp {
+ let arguments = (ins
+ LXLA_Buffer:$lhs,
+ LXLA_Buffer:$rhs,
+ LXLA_Buffer:$output
+ );
+}
+
+def LXLA_DotOp: LXLA_Op<"dot", []>, BXLA_DotOp {
+ let arguments = (ins
+ LXLA_Buffer:$lhs,
+ LXLA_Buffer:$rhs,
+ XLA_PrecisionConfigAttr:$precision_config,
+ LXLA_Buffer:$output
+ );
+}
+
+def LXLA_GatherOp: LXLA_Op<"gather", []>, BXLA_GatherOp {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ LXLA_IntBuffer:$start_indices,
+ I64Attr: $index_vector_dim,
+ ElementsAttr: $offset_dims,
+ ElementsAttr: $slice_sizes,
+ ElementsAttr: $collapsed_slice_dims,
+ ElementsAttr: $start_index_map,
+ LXLA_Buffer:$output
+ );
+}
+
+def LXLA_ReshapeOp: LXLA_Op<"reshape", []>, BXLA_ReshapeOp {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ LXLA_Buffer:$output
+ );
+}
+
+
+def LXLA_SelectOp: LXLA_Op<"select", []>, BXLA_SelectOp {
+ let arguments = (ins
+ LXLA_PredBuffer:$pred,
+ LXLA_Buffer:$on_true,
+ LXLA_Buffer:$on_false,
+ LXLA_Buffer:$output
+ );
+}
+
+def LXLA_ReverseOp: LXLA_Op<"reverse", []>, BXLA_ReverseOp {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ ElementsAttr:$dimensions,
+ LXLA_Buffer:$output
+ );
+}
+
+def LXLA_PadOp: LXLA_Op<"pad", []>, BXLA_PadOp {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ LXLA_Buffer:$padding_value,
+ ElementsAttr: $edge_padding_low,
+ ElementsAttr: $edge_padding_high,
+ ElementsAttr: $interior_padding,
+ LXLA_Buffer: $output
+ );
+}
+
+def LXLA_TransposeOp: LXLA_Op<"transpose", []>, BXLA_TransposeOp {
+ let arguments = (ins
+ LXLA_Buffer:$operand,
+ ElementsAttr:$permutation,
+ LXLA_Buffer:$output
+ );
+}
+
+
+#endif // LXLA_OPS
diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.cc b/tensorflow/compiler/mlir/xla/ir/xla_ops.cc
index 36a21bd..6cfd3ae 100644
--- a/tensorflow/compiler/mlir/xla/ir/xla_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/xla_ops.cc
@@ -55,7 +55,7 @@
>();
// Support unknown operations because not all XLA operations are registered.
- allowUnknownOperations();
+ // allowUnknownOperations();
}
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.td b/tensorflow/compiler/mlir/xla/ir/xla_ops.td
index e7c8c4f..8b14bbc 100644
--- a/tensorflow/compiler/mlir/xla/ir/xla_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/xla_ops.td
@@ -24,6 +24,11 @@
include "mlir/IR/OpBase.td"
#endif // OP_BASE
+#ifdef XLA_OPS_BASE
+#else
+include "tensorflow/compiler/mlir/xla/ir/xla_ops_base.td"
+#endif
+
def XLA_Dialect : Dialect {
let name = "xla_hlo";
let cppNamespace = "XLA";
@@ -39,16 +44,12 @@
// XLA type definitions.
//===----------------------------------------------------------------------===//
-def XLA_Int : IntOfWidths<[8, 16, 32, 64]>;
-
// Any integer tensor types
def XLA_IntTensor : StaticShapeTensorOf<[XLA_Int]>;
// Any floating-point tensor types
def XLA_FpTensor : StaticShapeTensorOf<[AnyFloat]>;
-def XLA_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
-
def XLA_PredTensor : StaticShapeTensorOf<[XLA_Pred]>;
// Any integer or floating-point tensor types
@@ -64,13 +65,7 @@
// XLA nullary op definitions.
//===----------------------------------------------------------------------===//
-def XLA_ConstOp : XLA_Op<"constant", [NoSideEffect]> {
- let summary = "Constant operator";
-
- let description = [{
- Represents a constant value.
- }];
-
+def XLA_ConstOp : BXLA_ConstOp, XLA_Op<"constant", [NoSideEffect]> {
let arguments = (ins
ElementsAttr:$value
);
@@ -89,13 +84,7 @@
let hasCustomHLOConverter = 1;
}
-def XLA_IotaOp : XLA_Op<"iota", [NoSideEffect]> {
- let summary = "Iota operator";
-
- let description = [{
- Creates a rank 1 array of values starting at zero and incrementing by one.
- }];
-
+def XLA_IotaOp : BXLA_IotaOp, XLA_Op<"iota", [NoSideEffect]> {
let arguments = (ins I64Attr:$iota_dimension);
let results = (outs XLA_Tensor:$output);
@@ -117,28 +106,10 @@
let results = (outs XLA_Tensor);
}
-def XLA_AbsOp: XLA_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]> {
- let summary = "Absolute value operator";
-
- let description = [{
- Returns `abs(operand)` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
- }];
-}
+def XLA_AbsOp: XLA_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]>, BXLA_AbsOp;
def XLA_ConvertOp : XLA_UnaryElementwiseOp<
- "convert", [NoSideEffect, SameOperandsAndResultShape]> {
- let summary = "Convert operator";
-
- let description = [{
- Performs element-wise conversion of values from one type to another, e.g.
- float to int.
-
- See https://www.tensorflow.org/xla/operation_semantics#convertelementtype.
- }];
-
+ "convert", [NoSideEffect, SameOperandsAndResultShape]>, BXLA_ConvertOp {
let hasFolder = 1;
// TODO(b/130357376) Convert has a special constructor. Use a custom
@@ -146,69 +117,19 @@
let hasCustomHLOConverter = 1;
}
-def XLA_ExpOp: XLA_UnaryElementwiseOp<"exp", [NoSideEffect, SameOperandsAndResultType]> {
- let summary = "Exponential operator";
+def XLA_ExpOp: XLA_UnaryElementwiseOp<"exp", [NoSideEffect, SameOperandsAndResultType]>, BXLA_ExpOp;
- let description = [{
- Returns `e^(operand)` element-wise.
+def XLA_NegOp: XLA_UnaryElementwiseOp<"neg", [NoSideEffect, SameOperandsAndResultType]>, BXLA_NegOp;
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
- }];
-}
-
-def XLA_NegOp: XLA_UnaryElementwiseOp<"neg", [NoSideEffect, SameOperandsAndResultType]> {
- let summary = "Negation operator";
-
- let description = [{
- Returns `-operand` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
- }];
-}
-
-def XLA_SignOp: XLA_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultShape]> {
- let summary = "Sign operator";
-
- let description = [{
- Returns `sign(operand)` element-wise, where
-
- ```
- sign(x) = -1 : x < 0
- = -0 : x = -0
- = NaN : x = NaN
- = +0 : x = +0
- = 1 : x > 0
- ```
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
- }];
-}
+def XLA_SignOp: XLA_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultShape]>, BXLA_SignOp;
def XLA_TanhOp: XLA_UnaryElementwiseOp<"tanh",
- [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType]> {
- let summary = "Tanh operator";
-
- let description = [{
- Returns `tanh(operand)` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
- }];
-}
+ [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType]>, BXLA_TanhOp;
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
//===----------------------------------------------------------------------===//
-// The broadcasting dimensions correspond to a tuple that describes how a
-// smaller rank shape is broadcast into a larger rank shape. For example,
-// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
-// matching the matrix to dimensions 1 and 2 of the cuboid.
-def BroadcastDimAttr : OptionalAttr<ElementsAttr>;
-
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class XLA_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
XLA_Op<mnemonic, traits> {
@@ -223,86 +144,32 @@
}
def XLA_AddOp : XLA_BinaryElementwiseOp<"add",
- [Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Addition operator";
-
- let description = [{
- Returns `lhs + rhs` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
- }];
-}
+ [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BXLA_AddOp;
def XLA_DivOp : XLA_BinaryElementwiseOp<"div",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Division operator";
-
- let description = [{
- Returns `lhs / rhs` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
- }];
-}
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_DivOp;
def XLA_MaxOp : XLA_BinaryElementwiseOp<"max",
- [Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Maximum operator";
-
- let description = [{
- Returns `max(lhs, rhs)` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
- }];
-}
+ [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BXLA_MaxOp;
def XLA_MinOp : XLA_BinaryElementwiseOp<"min",
- [Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Minimum operator";
-
- let description = [{
- Returns `min(lhs, rhs)` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
- }];
-}
+ [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BXLA_MinOp;
def XLA_MulOp : XLA_BinaryElementwiseOp<"mul",
- [Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Multiplication operator";
-
- let description = [{
- Returns `lhs * rhs` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
- }];
-}
+ [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BXLA_MulOp;
def XLA_SubOp : XLA_BinaryElementwiseOp<"sub",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Subtraction operator";
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_SubOp;
- let description = [{
- Returns `lhs - rhs` element-wise.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
- }];
-}
-
-def XLA_AndOp: XLA_BinaryElementwiseOp<"and", [Commutative, NoSideEffect]>;
+def XLA_AndOp: XLA_BinaryElementwiseOp<"and", [Commutative, NoSideEffect]>, BXLA_AndOp;
//===----------------------------------------------------------------------===//
// XLA control flow op definitions.
//===----------------------------------------------------------------------===//
def XLA_WhileOp: XLA_Op<"while", [NoSideEffect, SameOperandsAndResultType]> {
- let summary = "While operator";
+ string summary = "While operator";
- let description = [{
+ string description = [{
Returns the result of executing a body function until the cond body returns
true.
@@ -321,15 +188,7 @@
let hasCustomHLOConverter = 1;
}
-def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]> {
- let summary = "Reduce operator";
-
- let description = [{
- Returns the result of executing a reduction function on one or more arrays
- in parallel.
-
- See https://www.tensorflow.org/xla/operation_semantics#reduce.
- }];
+def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]>, BXLA_ReduceOp {
let arguments = (ins
Variadic<XLA_TensorOrTuple>:$operands_and_init,
@@ -346,15 +205,7 @@
//===----------------------------------------------------------------------===//
// XLA tuple op definitions.
//===----------------------------------------------------------------------===//
-def XLA_GetTupleElementOp: XLA_Op<"get_tuple_element", [NoSideEffect]> {
- let summary = "GetTupleElement operator";
-
- let description = [{
- Returns a member of a tuple specified by an index.
-
- See https://www.tensorflow.org/xla/operation_semantics#gettupleelement.
- }];
-
+def XLA_GetTupleElementOp: XLA_Op<"get_tuple_element", [NoSideEffect]>, BXLA_GetTupleElementOp {
let arguments = (ins
XLA_Tuple,
I32Attr:$index
@@ -366,15 +217,7 @@
let hasCustomHLOConverter = 1;
}
-def XLA_TupleOp : XLA_Op<"tuple", [NoSideEffect]> {
- let summary = "XLA's tuple op";
-
- let description = [{
- Groups a set of tensor inputs into a single tuple object.
-
- See https://www.tensorflow.org/xla/operation_semantics#tuple.
- }];
-
+def XLA_TupleOp : XLA_Op<"tuple", [NoSideEffect]>, BXLA_TupleOp {
let arguments = (ins Variadic<XLA_TensorOrTuple>:$val);
let results = (outs XLA_Tuple);
@@ -382,49 +225,8 @@
let hasCustomHLOConverter = 1;
}
-//===----------------------------------------------------------------------===//
-// Precision Config enum definitions.
-//===----------------------------------------------------------------------===//
-
-// These mirror the XLA PrecisionConfig proto enum.
-def XLA_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
-def XLA_PRECISION_HIGH : StrEnumAttrCase<"HIGH">;
-def XLA_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
-
-def XLA_PrecisionAttr : StrEnumAttr<"Precision",
- "XLA precision for an operand. Has backend specific meaning.",
- [XLA_PRECISION_DEFAULT, XLA_PRECISION_HIGH, XLA_PRECISION_HIGHEST]>;
-
-// TODO(b/129153247) See if it's possible to also validate the size.
-def XLA_PrecisionConfigAttr:
- OptionalAttr<
- TypedArrayAttrBase<XLA_PrecisionAttr, "Precision Config attribute">>;
-
-//===----------------------------------------------------------------------===//
-// Comparison op definitions.
-//===----------------------------------------------------------------------===//
-
-// These mirror the XLA ComparisonDirection enum.
-def XLA_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
-def XLA_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
-def XLA_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
-def XLA_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
-def XLA_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
-def XLA_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
-
-def XLA_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
- "Which comparison operation to perform.",
- [
- XLA_COMPARISON_DIRECTION_EQ,
- XLA_COMPARISON_DIRECTION_NE,
- XLA_COMPARISON_DIRECTION_GE,
- XLA_COMPARISON_DIRECTION_GT,
- XLA_COMPARISON_DIRECTION_LE,
- XLA_COMPARISON_DIRECTION_LT
- ]>;
-
def XLA_CompareOp: XLA_Op<"compare",
- [NoSideEffect, SameOperandsAndResultShape]> {
+ [NoSideEffect, SameOperandsAndResultShape]>, BXLA_CompareOp {
let arguments = (ins
XLA_Tensor:$lhs,
XLA_Tensor:$rhs,
@@ -432,14 +234,6 @@
XLA_ComparisonDirectionAttr:$comparison_direction
);
let results = (outs XLA_PredTensor);
- let summary = "Comparison operator";
-
- let description = [{
- Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
-
- See
- https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
- }];
}
//===----------------------------------------------------------------------===//
@@ -458,14 +252,6 @@
let results = (outs XLA_Tensor);
- let summary = "Slice operator";
-
- let description = [{
- Slices a portion of the `operand` into a new configuration.
-
- See https://www.tensorflow.org/xla/operation_semantics#slice.
- }];
-
// TODO(b/129422361) Two of the required arguments comes from the start and
// limit indices which aren't handled by the codegen.
let hasCustomHLOConverter = 1;
@@ -481,15 +267,6 @@
let results = (outs XLA_Tensor:$result);
- let summary = "Dynamic Update Slice operator";
-
- let description = [{
- DynamicUpdateSlice generates a result which is the value of the input array
- operand, with a slice update overwritten at start_indices.
-
- See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice.
- }];
-
// TODO(b/129422361) Requires a custom constructor.
let hasCustomHLOConverter = 1;
}
@@ -499,14 +276,8 @@
// XLA Other op definitions.
//===----------------------------------------------------------------------===//
-def XLA_BatchNormInferenceOp : XLA_Op<"batch_norm_inference", [NoSideEffect]> {
- let summary = "Batch Normalization for Inference";
-
- let description = [{
- Normalizes an array across batch and spatial dimensions.
-
- See https://www.tensorflow.org/xla/operation_semantics#batchnorminference
- }];
+def XLA_BatchNormInferenceOp : XLA_Op<"batch_norm_inference", [NoSideEffect]>,
+ BXLA_BatchNormInferenceOp {
let arguments = (ins
XLA_Tensor:$operand,
@@ -522,21 +293,7 @@
}
def XLA_BroadcastOp : XLA_Op<"broadcast",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Broadcast a tensor to a higher rank by prepending dimensions";
-
- let description = [{
- Broadcasts the operand tensor to a higher rank by prepending
- `broadcast_sizes` to the dimensions. The current values of the operand are
- copied into the other dimensions.
-
- This is a more limited form of broadcasting, that corresponds to the XLA
- client Broadcast method. For a more general form of broadcasting, see the
- BroadcastInDimOp.
-
- See https://www.tensorflow.org/xla/operation_semantics#broadcast.
- }];
-
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_BroadcastOp {
let arguments = (ins
XLA_Tensor:$operand,
ElementsAttr:$broadcast_sizes
@@ -594,26 +351,7 @@
}
def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Broadcast a tensor into the given shape by adding dimensions.";
-
- let description = [{
- Broadcasts the `operand` tensor to a higher rank. This is not the limited
- form of broadcasting exposed as the XLA client broadcast op, but rather the
- more powerful "InDim" broadcasting, which is closer to the HLO broadcast op
- and exposed in the XLA client BroadcastInDim method.
-
- `broadcast_dimensions` maps the operand dimension number to the target shape
- dimension number. It must have the same size as the rank of the operand. The
- mapped dimensions must either be the same size or the dimension being
- broadcast from must be size 1 (degenerate broadcasting).
-
- For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The
- The scalar value will be broadcast to every element in the target shape.
-
- See https://www.tensorflow.org/xla/broadcasting.
- }];
-
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_BroadcastInDimOp {
let arguments = (ins
XLA_Tensor:$operand,
BroadcastDimAttr:$broadcast_dimensions
@@ -693,19 +431,7 @@
}
def XLA_ClampOp : XLA_Op<"clamp",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Clamp operator";
-
- let description = [{
- Clamps an operand to within the range between a minimum and maximum value.
-
- Note: All three arrays must be the same shape. Alternatively, as a
- restricted form of broadcasting, min and/or max can be a scalar (0D
- tensor) of the element type of the tensor operand.
-
- See https://www.tensorflow.org/xla/operation_semantics#clamp.
- }];
-
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_ClampOp {
let arguments = (ins
XLA_Tensor:$min,
XLA_Tensor:$operand,
@@ -746,14 +472,7 @@
}
def XLA_ConcatenateOp : XLA_Op<"concatenate",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "XLA's concantenate op";
-
- let description = [{
- Concatenates a set of tensors along the specified dimension.
-
- See https://www.tensorflow.org/xla/operation_semantics#concatenate.
- }];
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_ConcatenateOp {
let arguments = (
ins Variadic<XLA_Tensor>:$val,
@@ -793,15 +512,7 @@
let hasCustomHLOConverter = 1;
}
-def XLA_ConvOp : XLA_Op<"conv", [NoSideEffect]> {
- let summary = "Convolution operator";
-
- let description = [{
- Computes a convolution of the kind used in neural networks.
-
- See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
- }];
-
+def XLA_ConvOp : XLA_Op<"conv", [NoSideEffect]>, BXLA_ConvOp {
let arguments = (ins
XLA_Tensor:$lhs,
XLA_Tensor:$rhs
@@ -815,9 +526,9 @@
}
def XLA_CopyOp: XLA_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> {
- let summary = "Copy operator";
+ string summary = "Copy operator";
- let description = [{
+ string description = [{
Returns a copy of `operand`.
}];
@@ -829,23 +540,16 @@
let hasCustomHLOConverter = 1;
}
-def XLA_DotOp: XLA_Op<"dot", [NoSideEffect]> {
+def XLA_DotOp: XLA_Op<"dot", [NoSideEffect]>, BXLA_DotOp {
let arguments = (
ins XLA_Tensor:$lhs,
XLA_Tensor:$rhs,
XLA_PrecisionConfigAttr:$precision_config
);
let results = (outs XLA_Tensor);
-
- let description = [{
- Performs dot products between vectors, vector/matrix and matrix/matrix
- multiplication.
-
- See https://www.tensorflow.org/xla/operation_semantics#dot.
- }];
}
-def XLA_GatherOp: XLA_Op<"gather", [NoSideEffect]> {
+def XLA_GatherOp: XLA_Op<"gather", [NoSideEffect]>, BXLA_GatherOp {
let arguments = (
ins XLA_Tensor:$operand,
XLA_IntTensor:$start_indices,
@@ -858,33 +562,16 @@
let results = (outs XLA_Tensor);
- let summary = "Gather operator";
-
- let description = [{
- Stitches together several slices of an input array.
-
- See https://www.tensorflow.org/xla/operation_semantics#gather.
- }];
-
// TODO(b/129422361) Attributes are not by the codegen. The optional argument
// (dimensions) needs to be added as an attribute.
let hasCustomHLOConverter = 1;
}
def XLA_ReshapeOp: XLA_Op<"reshape",
- [NoSideEffect, SameOperandsAndResultElementType]> {
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_ReshapeOp {
let arguments = (ins XLA_Tensor:$operand);
let results = (outs XLA_Tensor);
-
- let summary = "Reshape operator";
-
- let description = [{
- Reshapes the dimensions of `operand` into a new configuration.
-
- See https://www.tensorflow.org/xla/operation_semantics#reshape.
- }];
-
let hasFolder = 1;
// TODO(b/129422361) One of the required arguments comes from the new shape,
@@ -894,22 +581,7 @@
}
-def XLA_SelectOp: XLA_Op<"select", [NoSideEffect]> {
- let summary = "Select operator";
-
- let description = [{
- Constructs an output tensor from the elements of `on_true` and `on_false`
- based on the values of `pred`.
-
- `on_true` and `on_false` must be the same shape. For each element of `pred`,
- `res` has the corresponding element of `on_true` or `on_false` depending on
- the value in `pred`. `pred` must be the same shape as `on_true` and
- `on_false` or a scalar, in which case `res` is equal to either `on_true` or
- `on_false`.
-
- See https://www.tensorflow.org/xla/operation_semantics#select.
- }];
-
+def XLA_SelectOp: XLA_Op<"select", [NoSideEffect]>, BXLA_SelectOp {
let arguments = (ins
XLA_PredTensor:$pred,
XLA_Tensor:$on_true,
@@ -948,16 +620,7 @@
}
def XLA_ReverseOp: XLA_Op<"reverse",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Reverse operator";
-
- let description = [{
- Reverses the specified dimensions of `operand` according to the given
- `dimensions`.
-
- See https://www.tensorflow.org/xla/operation_semantics#rev_reverse.
- }];
-
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_ReverseOp {
let arguments = (ins
XLA_Tensor:$operand,
ElementsAttr:$dimensions
@@ -970,16 +633,7 @@
}
def XLA_PadOp: XLA_Op<"pad",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Pad operator";
-
- let description = [{
- Pads the edges of `operand` with the `padding_value` and according to
- the passed configuration.
-
- See https://www.tensorflow.org/xla/operation_semantics#pad.
- }];
-
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_PadOp {
let arguments = (ins
XLA_Tensor:$operand,
XLA_Tensor:$padding_value,
@@ -1044,17 +698,7 @@
}
def XLA_TransposeOp: XLA_Op<"transpose",
- [NoSideEffect, SameOperandsAndResultElementType]> {
- let summary = "Transpose operator";
-
- let description = [{
- Permutes the dimensions of `operand` according to the given `permutation`.
-
- `res_dimensions[i] = operand_dimensions[permutation[i]]`
-
- See https://www.tensorflow.org/xla/operation_semantics#transpose.
- }];
-
+ [NoSideEffect, SameOperandsAndResultElementType]>, BXLA_TransposeOp {
let arguments = (ins
XLA_Tensor:$operand,
ElementsAttr:$permutation
diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops_base.td b/tensorflow/compiler/mlir/xla/ir/xla_ops_base.td
new file mode 100644
index 0000000..a1aa8fc
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/ir/xla_ops_base.td
@@ -0,0 +1,495 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef XLA_OPS_BASE
+#else
+#define XLA_OPS_BASE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def XLA_Int : IntOfWidths<[8, 16, 32, 64]>;
+def XLA_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
+
+//===----------------------------------------------------------------------===//
+// XLA nullary op definitions.
+//===----------------------------------------------------------------------===//
+
+class BXLA_ConstOp {
+ string summary = "Constant operator";
+
+ string description = [{
+ Represents a constant value.
+ }];
+}
+
+class BXLA_IotaOp {
+ string summary = "Iota operator";
+
+ string description = [{
+ Creates a rank 1 array of values starting at zero and incrementing by one.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA unary elementwise op definitions.
+//===----------------------------------------------------------------------===//
+// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
+
+class BXLA_AbsOp {
+ string summary = "Absolute value operator";
+
+ string description = [{
+ Returns `abs(operand)` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
+class BXLA_ConvertOp {
+ string summary = "Convert operator";
+
+ string description = [{
+ Performs element-wise conversion of values from one type to another, e.g.
+ float to int.
+
+ See https://www.tensorflow.org/xla/operation_semantics#convertelementtype.
+ }];
+}
+
+class BXLA_ExpOp {
+ string summary = "Exponential operator";
+
+ string description = [{
+ Returns `e^(operand)` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
+class BXLA_NegOp {
+ string summary = "Negation operator";
+
+ string description = [{
+ Returns `-operand` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
+class BXLA_SignOp {
+ string summary = "Sign operator";
+
+ string description = [{
+ Returns `sign(operand)` element-wise, where
+
+ ```
+ sign(x) = -1 : x < 0
+ = -0 : x = -0
+ = NaN : x = NaN
+ = +0 : x = +0
+ = 1 : x > 0
+ ```
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
+class BXLA_TanhOp {
+ string summary = "Tanh operator";
+
+ string description = [{
+ Returns `tanh(operand)` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA binary elementwise op definitions.
+//===----------------------------------------------------------------------===//
+
+// The broadcasting dimensions correspond to a tuple that describes how a
+// smaller rank shape is broadcast into a larger rank shape. For example,
+// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
+// matching the matrix to dimensions 1 and 2 of the cuboid.
+def BroadcastDimAttr : OptionalAttr<ElementsAttr>;
+
+class BXLA_AddOp {
+ string summary = "Addition operator";
+
+ string description = [{
+ Returns `lhs + rhs` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+ }];
+}
+
+class BXLA_DivOp {
+ string summary = "Division operator";
+
+ string description = [{
+ Returns `lhs / rhs` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+ }];
+}
+
+class BXLA_MaxOp {
+ string summary = "Maximum operator";
+
+ string description = [{
+ Returns `max(lhs, rhs)` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+ }];
+}
+
+class BXLA_MinOp {
+ string summary = "Minimum operator";
+
+ string description = [{
+ Returns `min(lhs, rhs)` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+ }];
+}
+
+class BXLA_MulOp {
+ string summary = "Multiplication operator";
+
+ string description = [{
+ Returns `lhs * rhs` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+ }];
+}
+
+class BXLA_SubOp {
+ string summary = "Subtraction operator";
+
+ string description = [{
+ Returns `lhs - rhs` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+ }];
+}
+
+class BXLA_AndOp {
+ string summary = "Logical and";
+
+ string description = [{
+ Returns `lhs /\ rhs` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA control flow op definitions.
+//===----------------------------------------------------------------------===//
+
+class BXLA_ReduceOp {
+ string summary = "Reduce operator";
+
+ string description = [{
+ Returns the result of executing a reduction function on one or more arrays
+ in parallel.
+
+ See https://www.tensorflow.org/xla/operation_semantics#reduce.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA tuple op definitions.
+//===----------------------------------------------------------------------===//
+class BXLA_GetTupleElementOp {
+ string summary = "GetTupleElement operator";
+
+ string description = [{
+ Returns a member of a tuple specified by an index.
+
+ See https://www.tensorflow.org/xla/operation_semantics#gettupleelement.
+ }];
+}
+
+class BXLA_TupleOp {
+ string summary = "XLA's tuple op";
+
+ string description = [{
+ Groups a set of tensor inputs into a single tuple object.
+
+ See https://www.tensorflow.org/xla/operation_semantics#tuple.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Precision Config enum definitions.
+//===----------------------------------------------------------------------===//
+
+// These mirror the XLA PrecisionConfig proto enum.
+def XLA_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
+def XLA_PRECISION_HIGH : StrEnumAttrCase<"HIGH">;
+def XLA_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
+
+def XLA_PrecisionAttr : StrEnumAttr<"Precision",
+ "XLA precision for an operand. Has backend specific meaning.",
+ [XLA_PRECISION_DEFAULT, XLA_PRECISION_HIGH, XLA_PRECISION_HIGHEST]>;
+
+// TODO(b/129153247) See if it's possible to also validate the size.
+def XLA_PrecisionConfigAttr:
+ OptionalAttr<
+ TypedArrayAttrBase<XLA_PrecisionAttr, "Precision Config attribute">>;
+
+//===----------------------------------------------------------------------===//
+// Comparison op definitions.
+//===----------------------------------------------------------------------===//
+
+// These mirror the XLA ComparisonDirection enum.
+def XLA_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
+def XLA_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
+def XLA_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
+def XLA_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
+def XLA_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
+def XLA_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
+
+def XLA_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
+ "Which comparison operation to perform.",
+ [
+ XLA_COMPARISON_DIRECTION_EQ,
+ XLA_COMPARISON_DIRECTION_NE,
+ XLA_COMPARISON_DIRECTION_GE,
+ XLA_COMPARISON_DIRECTION_GT,
+ XLA_COMPARISON_DIRECTION_LE,
+ XLA_COMPARISON_DIRECTION_LT
+ ]>;
+
+class BXLA_CompareOp {
+ string summary = "Comparison operator";
+
+ string description = [{
+ Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA Slice definitions.
+//===----------------------------------------------------------------------===//
+
+class BXLA_SliceOp {
+ string summary = "Slice operator";
+
+ string description = [{
+ Slices a portion of the `operand` into a new configuration.
+
+ See https://www.tensorflow.org/xla/operation_semantics#slice.
+ }];
+}
+
+class BXLA_DynamicUpdateSliceOp {
+ string summary = "Dynamic Update Slice operator";
+
+ string description = [{
+ DynamicUpdateSlice generates a result which is the value of the input array
+ operand, with a slice update overwritten at start_indices.
+
+ See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA Other op definitions.
+//===----------------------------------------------------------------------===//
+
+class BXLA_BatchNormInferenceOp {
+ string summary = "Batch Normalization for Inference";
+
+ string description = [{
+ Normalizes an array across batch and spatial dimensions.
+
+ See https://www.tensorflow.org/xla/operation_semantics#batchnorminference
+ }];
+}
+
+class BXLA_BroadcastOp {
+ string summary = "Broadcast a tensor to a higher rank by prepending dimensions";
+
+ string description = [{
+ Broadcasts the operand tensor to a higher rank by prepending
+ `broadcast_sizes` to the dimensions. The current values of the operand are
+ copied into the other dimensions.
+
+ This is a more limited form of broadcasting, that corresponds to the XLA
+ client Broadcast method. For a more general form of broadcasting, see the
+ BroadcastInDimOp.
+
+ See https://www.tensorflow.org/xla/operation_semantics#broadcast.
+ }];
+}
+
+class BXLA_BroadcastInDimOp {
+ string summary = "Broadcast a tensor into the given shape by adding dimensions.";
+
+ string description = [{
+ Broadcasts the `operand` tensor to a higher rank. This is not the limited
+ form of broadcasting exposed as the XLA client broadcast op, but rather the
+ more powerful "InDim" broadcasting, which is closer to the HLO broadcast op
+ and exposed in the XLA client BroadcastInDim method.
+
+ `broadcast_dimensions` maps the operand dimension number to the target shape
+ dimension number. It must have the same size as the rank of the operand. The
+ mapped dimensions must either be the same size or the dimension being
+ broadcast from must be size 1 (degenerate broadcasting).
+
+ For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The
+ The scalar value will be broadcast to every element in the target shape.
+
+ See https://www.tensorflow.org/xla/broadcasting.
+ }];
+}
+
+class BXLA_ClampOp {
+ string summary = "Clamp operator";
+
+ string description = [{
+ Clamps an operand to within the range between a minimum and maximum value.
+
+ Note: All three arrays must be the same shape. Alternatively, as a
+ restricted form of broadcasting, min and/or max can be a scalar (0D
+ tensor) of the element type of the tensor operand.
+
+ See https://www.tensorflow.org/xla/operation_semantics#clamp.
+ }];
+}
+
+class BXLA_ConcatenateOp {
+ string summary = "XLA's concantenate op";
+
+ string description = [{
+ Concatenates a set of tensors along the specified dimension.
+
+ See https://www.tensorflow.org/xla/operation_semantics#concatenate.
+ }];
+}
+
+class BXLA_ConvOp {
+ string summary = "Convolution operator";
+
+ string description = [{
+ Computes a convolution of the kind used in neural networks.
+
+ See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
+ }];
+}
+
+class BXLA_DotOp {
+ string summary = "Dot operator";
+ string description = [{
+ Performs dot products between vectors, vector/matrix and matrix/matrix
+ multiplication.
+
+ See https://www.tensorflow.org/xla/operation_semantics#dot.
+ }];
+}
+
+class BXLA_GatherOp{
+ string summary = "Gather operator";
+
+ string description = [{
+ Stitches together several slices of an input array.
+
+ See https://www.tensorflow.org/xla/operation_semantics#gather.
+ }];
+}
+
+class BXLA_ReshapeOp {
+ string summary = "Reshape operator";
+
+ string description = [{
+ Reshapes the dimensions of `operand` into a new configuration.
+
+ See https://www.tensorflow.org/xla/operation_semantics#reshape.
+ }];
+}
+
+class BXLA_SelectOp {
+ string summary = "Select operator";
+
+ string description = [{
+ Constructs an output tensor from the elements of `on_true` and `on_false`
+ based on the values of `pred`.
+
+ `on_true` and `on_false` must be the same shape. For each element of `pred`,
+ `res` has the corresponding element of `on_true` or `on_false` depending on
+ the value in `pred`. `pred` must be the same shape as `on_true` and
+ `on_false` or a scalar, in which case `res` is equal to either `on_true` or
+ `on_false`.
+
+ See https://www.tensorflow.org/xla/operation_semantics#select.
+ }];
+}
+
+class BXLA_ReverseOp {
+ string summary = "Reverse operator";
+
+ string description = [{
+ Reverses the specified dimensions of `operand` according to the given
+ `dimensions`.
+
+ See https://www.tensorflow.org/xla/operation_semantics#rev_reverse.
+ }];
+}
+
+class BXLA_PadOp {
+ string summary = "Pad operator";
+
+ string description = [{
+ Pads the edges of `operand` with the `padding_value` and according to
+ the passed configuration.
+
+ See https://www.tensorflow.org/xla/operation_semantics#pad.
+ }];
+}
+
+class BXLA_TransposeOp {
+ string summary = "Transpose operator";
+
+ string description = [{
+ Permutes the dimensions of `operand` according to the given `permutation`.
+
+ `res_dimensions[i] = operand_dimensions[permutation[i]]`
+
+ See https://www.tensorflow.org/xla/operation_semantics#transpose.
+ }];
+}
+
+#endif // XLA_OPS_BASE
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
new file mode 100644
index 0000000..94b1ffe
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
@@ -0,0 +1,124 @@
+// RUN: tf-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// -----
+
+func @enforce_static_shapes(%arg0: memref<?xf32>, %arg1: memref<?xf32>) -> () {
+ // expected-error@+1{{op operand #0 must be statically shaped memref of floating-point or integer values}}
+ "lxla_hlo.tanh"(%arg0, %arg1) : (memref<?xf32>, memref<?xf32>) -> ()
+}
+
+// -----
+
+func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
+ // expected-error@+1{{'lxla_hlo.tanh' op requires all operands to have the same type}}
+ "lxla_hlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @add_memrefs
+func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
+ "xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @abs_memref
+func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @convert_memref
+func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @exp_memref
+func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.exp"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @neg_memref
+func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.neg"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @sign_memref
+func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_memref
+func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @add_memref
+func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @div_memref
+func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.div"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @max_memref
+func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.max"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @min_memref
+func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.min"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @mul_memref
+func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.mul"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @sub_memref
+func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.sub"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @and_memref
+func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+ "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+}
+
+// -----
+
+func @reduce_computation(%sum: memref<1xf32>, %element: memref<1xf32>) -> () {
+ "xla_lhlo.add"(%element, %sum, %sum) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
+}
+
+// CHECK-LABEL: func @reduce_memref
+func @reduce_memref(%input: memref<10xf32>, %out: memref<1xf32>) -> () {
+ "xla_lhlo.reduce"(%input, %out) {computation = @reduce_computation} : (memref<10xf32>, memref<1xf32>) -> ()
+}