Lower tf.RandomUniform to HLO.
PiperOrigin-RevId: 273866236
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 1f66203..d39d7c8 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -650,4 +650,22 @@
// TODO(hinsu): Implement custom printer and parser.
}
+//===----------------------------------------------------------------------===//
+// XLA RngUniform Operator.
+//===----------------------------------------------------------------------===//
+def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
+ let arguments = (ins
+ HLO_Tensor:$a,
+ HLO_Tensor:$b,
+ I64Tensor:$shape
+ );
+
+ let results = (outs HLO_Tensor);
+
+ // TODO(bgogul): Disable conversion operator for `RngUniform` as
+ // the default constructor for `xla::RngUniform` takes `Shape` as
+ // the last argument and the default converter does not deal with it.
+ let hasCustomHLOConverter = 1;
+}
+
#endif // HLO_OPS
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
index 30b7489..bc3c733 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
@@ -586,4 +586,15 @@
}];
}
+class BASE_HLO_RngUniformOp {
+ string summary = "RNG with uniform distribution.";
+
+ string description = [{
+ Constructs an output of a given shape with random numbers generated following
+ the uniform distribution over the interval `[a,b)`.
+
+ See https://www.tensorflow.org/xla/operation_semantics#rnguniform.
+ }];
+}
+
#endif // HLO_OPS_BASE
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 51619f5..8534389 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1195,3 +1195,15 @@
%0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
+
+// CHECK-LABEL: func @rng_uniform
+func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x12x64xf32> {
+ // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
+ // CHECK: %[[CONV:.*]] = "xla_hlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64>
+ // CHECK: %[[F32:.*]] = "xla_hlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x12x64xf32>
+ %0 = "tf.RandomUniform"(%arg0) {T = "tfdtype$DT_INT32", dtype = "tfdtype$DT_FLOAT", seed = 0 : i64, seed2 = 0 : i64} : (tensor<3xi32>) -> tensor<12x12x64xf32>
+ // CHECK: return %[[F32]] : tensor<12x12x64xf32>
+ return %0 : tensor<12x12x64xf32>
+}
+
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 794b5ac..204d9ea 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -94,6 +94,17 @@
return axis < 0 ? axis + rank : axis;
}
+/// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
+/// the shape of the input value.
+static xla_hlo::ConvertOp CastElementsToI64(Location loc, Value *value,
+ PatternRewriter *rewriter) {
+ auto type = value->getType().cast<RankedTensorType>();
+ assert(type && "CastElementsToI64 requires a shaped tensor as input.");
+ ArrayRef<int64_t> shape = type.getShape();
+ auto i64_type = rewriter->getTensorType(shape, rewriter->getIntegerType(64));
+ return rewriter->create<xla_hlo::ConvertOp>(loc, i64_type, value);
+}
+
// Returns minimum value for the given int or float element type.
static xla_hlo::ConstOp GetMinValueForType(Type ty, Location loc,
PatternRewriter *rewriter) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index d0f3c917..683fd2f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -138,6 +138,9 @@
def HasRankedFirstOperand
: Constraint<CPred<"(*$0.begin())->getType().isa<RankedTensorType>()">>;
+def IsShapedTensor
+ : Constraint<CPred<"$0->getType().isa<RankedTensorType>()">>;
+
// This pattern converts TensorFlow axis format to HLO axis format which
// doesn't wrap around like TensorFlow and is always positive. For this
// conversion, use the first input to get inputs rank. Other inputs need not be
@@ -236,3 +239,19 @@
def : Pat<(TfOp:$res AnyStaticShapeTensor:$arg, $ignored),
(HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;
}
+
+//===----------------------------------------------------------------------===//
+// RngUniform.
+//===----------------------------------------------------------------------===//
+def CastElementsToI64: NativeCodeCall<
+ "CastElementsToI64($0->getLoc(), $1, &$_builder)">;
+
+// TODO(misard,phawkins): handle random number generator seeds/states correctly.
+def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2),
+ (HLO_RngUniformOp
+ (HLO_ConstOp
+ (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)),
+ (HLO_ConstOp
+ (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)),
+ (CastElementsToI64 $old, $shape)),
+ [(IsShapedTensor $shape)]>;