Refine constraints for TF dialect ops
- added some unmodelled ops to dialect
- relaxed over-restrictive side effects
- moved ops that can now be auto-generated to tf_generated_ops.td
- updated extra op traits
PiperOrigin-RevId: 368536646
Change-Id: I28cccd7a1bd1709af393f3d131534a6c9fa8ee81
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 51a7030..96b6326 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -851,77 +851,6 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
-def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments, NoSideEffect]> {
- let summary = [{
-Batches all the inputs tensors to the computation done by the function.
- }];
-
- let description = [{
-So, for example, in the following code
-
- ```python
-
- # This input will be captured.
- y = tf.placeholder_with_default(1.0, shape=[])
-
- @tf.Defun(tf.float32)
- def computation(a):
- return tf.matmul(a, a) + y
-
- b = gen_batch_ops.batch_function(
- f=computation
- in_tensors=[a],
- captured_tensors=computation.captured_inputs,
- Tout=[o.type for o in computation.definition.signature.output_arg],
- num_batch_threads=1,
- max_batch_size=10,
- batch_timeout_micros=100000, # 100ms
- allowed_batch_sizes=[3, 10],
- batching_queue="")
- ```
-
-If more than one session.run call is simultaneously trying to compute `b`
-the values of `a` will be gathered, non-deterministically concatenated
-along the first axis, and only one thread will run the computation.
-
-Assumes that all arguments of the function are Tensors which will be batched
-along their first dimension.
-
-Arguments that are captured, are not batched. The session.run call which does
-the concatenation, will use the values of the captured tensors available to it.
-Therefore, typical uses of captured tensors should involve values which remain
-unchanged across session.run calls. Inference is a good example of this.
-
-SparseTensor is not supported. The return value of the decorated function
-must be a Tensor or a list/tuple of Tensors.
- }];
-
- let arguments = (ins
- Arg<Variadic<TF_Tensor>, [{The tensors to be batched.}]>:$in_tensors,
- Arg<Variadic<TF_Tensor>, [{The tensors which are captured in the function, and don't need
-to be batched.}]>:$captured_tensors,
-
- SymbolRefAttr:$f,
- I64Attr:$num_batch_threads,
- I64Attr:$max_batch_size,
- I64Attr:$batch_timeout_micros,
- DefaultValuedAttr<I64Attr, "10">:$max_enqueued_batches,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$allowed_batch_sizes,
- StrAttr:$container,
- StrAttr:$shared_name,
- StrAttr:$batching_queue,
- DefaultValuedAttr<BoolAttr, "false">:$enable_large_batch_splitting
- );
-
- let results = (outs
- Res<Variadic<TF_Tensor>, [{The output tensors.}]>:$out_tensors
- );
-
- TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
- TF_DerivedOperandTypeListAttr Tcaptured = TF_DerivedOperandTypeListAttr<1>;
- TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
-}
-
def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
let summary = "Multiplies slices of two tensors in batches.";
@@ -5105,32 +5034,6 @@
}];
}
-def TF_GeneratorDatasetOp : TF_Op<"GeneratorDataset", [AttrSizedOperandSegments, NoSideEffect, TF_NoConstantFold]> {
- let summary = [{
-Creates a dataset that invokes a function to generate elements.
- }];
-
- let arguments = (ins
- Variadic<TF_Tensor>:$init_func_other_args,
- Variadic<TF_Tensor>:$next_func_other_args,
- Variadic<TF_Tensor>:$finalize_func_other_args,
-
- SymbolRefAttr:$init_func,
- SymbolRefAttr:$next_func,
- SymbolRefAttr:$finalize_func,
- Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
- Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
- );
-
- let results = (outs
- TF_VariantTensor:$handle
- );
-
- TF_DerivedOperandTypeListAttr Tnext_func_args = TF_DerivedOperandTypeListAttr<1>;
- TF_DerivedOperandTypeListAttr Tfinalize_func_args = TF_DerivedOperandTypeListAttr<2>;
- TF_DerivedOperandTypeListAttr Tinit_func_args = TF_DerivedOperandTypeListAttr<0>;
-}
-
def TF_GreaterOp : TF_Op<"Greater", [NoSideEffect, ResultsBroadcastableShape]>,
WithBroadcastableCmpOpBuilder {
let summary = "Returns the truth value of (x > y) element-wise.";
@@ -9365,20 +9268,6 @@
);
}
-def TF_OptionalFromValueOp : TF_Op<"OptionalFromValue", [NoSideEffect]> {
- let summary = "Constructs an Optional variant from a tuple of tensors.";
-
- let arguments = (ins
- Variadic<TF_Tensor>:$components
- );
-
- let results = (outs
- TF_VariantTensor:$optional
- );
-
- TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
-}
-
def TF_OptionalGetValueOp : TF_Op<"OptionalGetValue", [NoSideEffect]> {
let summary = [{
Returns the value stored in an Optional variant or raises an error if none exists.
@@ -9410,16 +9299,6 @@
);
}
-def TF_OptionalNoneOp : TF_Op<"OptionalNone", [NoSideEffect]> {
- let summary = "Creates an Optional variant with no value.";
-
- let arguments = (ins);
-
- let results = (outs
- TF_VariantTensor:$optional
- );
-}
-
def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> {
let summary = "Enqueue multiple Tensor values on the computation outfeed.";
@@ -9769,28 +9648,6 @@
let hasFolder = 1;
}
-def TF_PrefetchDatasetOp : TF_Op<"PrefetchDataset", [NoSideEffect]> {
- let summary = [{
-Creates a dataset that asynchronously prefetches elements from `input_dataset`.
- }];
-
- let arguments = (ins
- TF_VariantTensor:$input_dataset,
- Arg<TF_Int64Tensor, [{The maximum number of elements to buffer in an iterator over
-this dataset.}]>:$buffer_size,
-
- Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
- Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
- DefaultValuedAttr<I64Attr, "0">:$slack_period,
- DefaultValuedAttr<BoolAttr, "true">:$legacy_autotune,
- DefaultValuedAttr<I64Attr, "0">:$buffer_size_min
- );
-
- let results = (outs
- TF_VariantTensor:$handle
- );
-}
-
def TF_PreventGradientOp : TF_Op<"PreventGradient", [NoSideEffect, TF_SameOperandsAndResultTypeResolveRef]> {
let summary = [{
An identity op that triggers an error if a gradient is requested.
@@ -10577,7 +10434,7 @@
];
}
-def TF_RangeDatasetOp : TF_Op<"RangeDataset", [NoSideEffect, TF_NoConstantFold]> {
+def TF_RangeDatasetOp : TF_Op<"RangeDataset", []> {
let summary = [{
Creates a dataset with a range of values. Corresponds to python's xrange.
}];
@@ -13593,32 +13450,6 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
-def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
- let summary = "Performs gradient updates of embedding tables.";
-
- let arguments = (ins
- Arg<Variadic<TF_Float32Tensor>, [{A TensorList of gradients with which to update embedding tables.
-This argument has the same length and shapes as the return value of
-RecvTPUEmbeddingActivations, but contains gradients of the model's loss
-with respect to the embedding activations. The embedding tables are updated
-from these gradients via the optimizer specified in the TPU embedding
-configuration given to tpu.initialize_system.}]>:$inputs,
- Arg<Variadic<TF_Float32Tensor>, [{A TensorList of float32 scalars, one for each dynamic learning
-rate tag: see the comments in
-//third_party/tensorflow/core/protobuf/tpu/optimization_parameters.proto.
-Multiple tables can share the same dynamic learning rate tag as specified
-in the configuration. If the learning rates for all tables are constant,
-this list should be empty.}]>:$learning_rates,
-
- StrAttr:$config
- );
-
- let results = (outs);
-
- TF_DerivedOperandSizeAttr NN = TF_DerivedOperandSizeAttr<1>;
- TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
-}
-
def TF_SerializeIteratorOp : TF_Op<"SerializeIterator", []> {
let summary = [{
Converts the given `resource_handle` representing an iterator to a variant tensor.
@@ -17027,24 +16858,6 @@
];
}
-def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", [NoSideEffect, TF_NoConstantFold]> {
- let summary = [{
-Creates a dataset that emits each dim-0 slice of `components` once.
- }];
-
- let arguments = (ins
- Variadic<TF_Tensor>:$components,
-
- Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
- );
-
- let results = (outs
- TF_VariantTensor:$handle
- );
-
- TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
-}
-
def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> {
let summary = "Assign `value` to the sliced l-value reference of `input`.";
@@ -18819,39 +18632,6 @@
TF_DerivedResultSizeAttr num_tables = TF_DerivedResultSizeAttr<0>;
}
-def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
- let summary = "An op that performs gradient updates of embedding tables.";
-
- let description = [{
-The gradients argument is a TensorList having the same length and shapes as the
-return value of _RecvTPUEmbeddingActivations, but contains gradients of the
-model's loss with respect to the embedding activations. The embedding tables are
-updated from these gradients via the optimizer specified in the
-TPUEmbeddingConfiguration proto given to tpu.initialize_system.
- }];
-
- let arguments = (ins
- Arg<Variadic<TF_Float32Tensor>, [{A TensorList of gradients with which to update embedding tables.}]>:$gradients,
- Arg<Variadic<TF_Float32Tensor>, [{A TensorList of learning rates used for updating the embedding
-tables via the optimizer. The length of the TensorList must be equal to the
-number of dynamic learning rate tags specified in the
-TPUEmbeddingConfiguration proto.}]>:$learning_rates,
- Arg<TF_VariantTensor, [{A Tensor with type=DT_VARIANT containing the deduplication
-data. The tensor is an XLA nested tuple containing N elements (where N is
-the ratio of the number of embedding to tensor cores per TPU chip). Each
-element of the nested tuple is a tuple of rank 1 tensors. Each tensor either
-contains indices (DT_UINT32) for embedding lookup on the TensorCore or
-weights (DT_FLOAT) to apply to the output of the embedding lookup operation.}]>:$deduplication_data,
-
- StrAttr:$config
- );
-
- let results = (outs);
-
- TF_DerivedOperandSizeAttr NumTables = TF_DerivedOperandSizeAttr<0>;
- TF_DerivedOperandSizeAttr NumLearningRateTags = TF_DerivedOperandSizeAttr<1>;
-}
-
def TF__TPUCompileMlirOp : TF_Op<"_TPUCompileMlir", []> {
let summary = [{
Compiles a computations for execution on one or more TPU devices.
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 97d1b95..ac1d80e 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -915,6 +915,73 @@
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
+// Multiple variadic operands with different sizes are not supported by the
+// dialect generator, so we manually added the op.
+def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
+ let summary = "Performs gradient updates of embedding tables.";
+
+ let description = [{
+inputs: A TensorList of gradients with which to update embedding tables.
+ This argument has the same length and shapes as the return value of
+ RecvTPUEmbeddingActivations, but contains gradients of the model's loss
+ with respect to the embedding activations. The embedding tables are updated
+ from these gradients via the optimizer specified in the TPU embedding
+ configuration given to tpu.initialize_system.
+learning_rates: A TensorList of float32 scalars, one for each dynamic learning
+ rate tag: see the comments in
+ //third_party/tensorflow/core/protobuf/tpu/optimization_parameters.proto.
+ Multiple tables can share the same dynamic learning rate tag as specified
+ in the configuration. If the learning rates for all tables are constant,
+ this list should be empty.
+config: Serialized TPUEmbeddingConfiguration proto.
+ }];
+
+ let arguments = (ins
+ Variadic<TF_Tensor>:$inputs,
+ Variadic<TF_Tensor>:$learning_rates,
+ StrAttr:$config
+ );
+
+ TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
+ TF_DerivedOperandSizeAttr NN = TF_DerivedOperandSizeAttr<1>;
+}
+
+// Multiple variadic operands with different sizes are not supported by the
+// dialect generator, so we manually added the op.
+def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
+ let summary = "Performs gradient updates of embedding tables.";
+
+ let description = [{
+The gradients argument is a TensorList having the same length and shapes as the
+return value of _RecvTPUEmbeddingActivations, but contains gradients of the
+model's loss with respect to the embedding activations. The embedding tables are
+updated from these gradients via the optimizer specified in the
+TPUEmbeddingConfiguration proto given to tpu.initialize_system.
+
+gradients: A TensorList of gradients with which to update embedding tables.
+learning_rates: A TensorList of learning rates used for updating the embedding
+ tables via the optimizer. The length of the TensorList must be equal to the
+ number of dynamic learning rate tags specified in the
+ TPUEmbeddingConfiguration proto.
+deduplication_data: A Tensor with type=DT_VARIANT containing the deduplication
+ data. The tensor is an XLA nested tuple containing N elements. Each
+ element of the nested tuple is a tuple of rank 1 tensors. Each tensor either
+ contains indices (DT_INT32) for embedding lookup or weights (DT_FLOAT) to
+ apply to the output of the embedding lookup operation.
+config: Serialized TPUEmbeddingConfiguration proto.
+ }];
+
+ let arguments = (ins
+ Variadic<TF_Tensor>:$gradients,
+ Variadic<TF_Tensor>:$learning_rates,
+ TF_VariantTensor:$deduplication_data,
+ StrAttr:$config
+ );
+
+ TF_DerivedOperandSizeAttr NumTables = TF_DerivedOperandSizeAttr<0>;
+ TF_DerivedOperandSizeAttr NumLearningRateTags = TF_DerivedOperandSizeAttr<1>;
+}
+
// Updated the op description text from the auto-generated op definition.
def TF__RecvTPUEmbeddingDeduplicationDataOp : TF_Op<"_RecvTPUEmbeddingDeduplicationData", []> {
let summary = [{
@@ -1121,6 +1188,23 @@
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
}
+def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
+ let summary = [{
+ Creates a dataset that emits each dim-0 slice of `components` once.
+ }];
+
+ let arguments = (ins
+ Variadic<TF_Tensor>:$components,
+ Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
+ );
+
+ let results = (outs
+ TF_VariantTensor:$handle
+ );
+
+ TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
+}
+
def TF_ReduceDatasetOp : TF_Op<"ReduceDataset", [SameVariadicOperandSize]> {
let summary = [{
Reduces the input dataset to a singleton using a reduce function.
@@ -1258,6 +1342,77 @@
let verifier = [{ return VerifyPartitionedCall(*this); }];
}
+def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> {
+ let summary = [{
+Batches all the inputs tensors to the computation done by the function.
+ }];
+
+ let description = [{
+So, for example, in the following code
+
+ ```python
+
+ # This input will be captured.
+ y = tf.placeholder_with_default(1.0, shape=[])
+
+ @tf.Defun(tf.float32)
+ def computation(a):
+ return tf.matmul(a, a) + y
+
+ b = gen_batch_ops.batch_function(
+ f=computation
+ in_tensors=[a],
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg],
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000, # 100ms
+ allowed_batch_sizes=[3, 10],
+ batching_queue="")
+ ```
+
+If more than one session.run call is simultaneously trying to compute `b`
+the values of `a` will be gathered, non-deterministically concatenated
+along the first axis, and only one thread will run the computation.
+
+Assumes that all arguments of the function are Tensors which will be batched
+along their first dimension.
+
+Arguments that are captured, are not batched. The session.run call which does
+the concatenation, will use the values of the captured tensors available to it.
+Therefore, typical uses of captured tensors should involve values which remain
+unchanged across session.run calls. Inference is a good example of this.
+
+SparseTensor is not supported. The return value of the decorated function
+must be a Tensor or a list/tuple of Tensors.
+ }];
+
+ let arguments = (ins
+ Variadic<TF_Tensor>:$in_tensors,
+ Variadic<TF_Tensor>:$captured_tensors,
+
+ SymbolRefAttr:$f,
+ I64Attr:$num_batch_threads,
+ I64Attr:$max_batch_size,
+ I64Attr:$batch_timeout_micros,
+ DefaultValuedAttr<I64Attr, "10">:$max_enqueued_batches,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$allowed_batch_sizes,
+ StrAttr:$container,
+ StrAttr:$shared_name,
+ StrAttr:$batching_queue,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_large_batch_splitting,
+ I32ElementsAttr:$operand_segment_sizes
+ );
+
+ let results = (outs
+ Variadic<TF_Tensor>:$out_tensors
+ );
+
+ TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
+ TF_DerivedOperandTypeListAttr Tcaptured = TF_DerivedOperandTypeListAttr<1>;
+ TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
+}
+
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";