Lower tf.RandomUniform to HLO.

PiperOrigin-RevId: 273866236
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 1f66203..d39d7c8 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -650,4 +650,22 @@
   // TODO(hinsu): Implement custom printer and parser.
 }
 
+//===----------------------------------------------------------------------===//
+// XLA RngUniform Operator.
+//===----------------------------------------------------------------------===//
+def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
+  let arguments = (ins
+    HLO_Tensor:$a,
+    HLO_Tensor:$b,
+    I64Tensor:$shape
+  );
+
+  let results = (outs HLO_Tensor);
+
+  // TODO(bgogul): Disable conversion operator for `RngUniform` as
+  // the default constructor for `xla::RngUniform` takes `Shape` as
+  // the last argument and the default converter does not deal with it.
+  let hasCustomHLOConverter = 1;
+}
+
 #endif // HLO_OPS
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
index 30b7489..bc3c733 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
@@ -586,4 +586,15 @@
   }];
 }
 
+class BASE_HLO_RngUniformOp {
+  string summary = "RNG with uniform distribution.";
+
+  string description = [{
+    Constructs an output of a given shape with random numbers generated following
+    the uniform distribution over the interval `[a,b)`.
+
+    See https://www.tensorflow.org/xla/operation_semantics#rnguniform.
+  }];
+}
+
 #endif // HLO_OPS_BASE
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 51619f5..8534389 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1195,3 +1195,15 @@
   %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<3xi32>
   return %0 : tensor<3xi32>
 }
+
+// CHECK-LABEL: func @rng_uniform
+func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x12x64xf32> {
+  // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+  // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
+  // CHECK: %[[CONV:.*]] = "xla_hlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64>
+  // CHECK: %[[F32:.*]] = "xla_hlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x12x64xf32>
+  %0 = "tf.RandomUniform"(%arg0) {T = "tfdtype$DT_INT32", dtype = "tfdtype$DT_FLOAT", seed = 0 : i64, seed2 = 0 : i64} : (tensor<3xi32>) -> tensor<12x12x64xf32>
+  // CHECK: return %[[F32]] : tensor<12x12x64xf32>
+  return %0 : tensor<12x12x64xf32>
+}
+
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 794b5ac..204d9ea 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -94,6 +94,17 @@
   return axis < 0 ? axis + rank : axis;
 }
 
+/// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
+/// the shape of the input value.
+static xla_hlo::ConvertOp CastElementsToI64(Location loc, Value *value,
+                                            PatternRewriter *rewriter) {
+  auto type = value->getType().cast<RankedTensorType>();
+  assert(type && "CastElementsToI64 requires a shaped tensor as input.");
+  ArrayRef<int64_t> shape = type.getShape();
+  auto i64_type = rewriter->getTensorType(shape, rewriter->getIntegerType(64));
+  return rewriter->create<xla_hlo::ConvertOp>(loc, i64_type, value);
+}
+
 // Returns minimum value for the given int or float element type.
 static xla_hlo::ConstOp GetMinValueForType(Type ty, Location loc,
                                            PatternRewriter *rewriter) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index d0f3c917..683fd2f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -138,6 +138,9 @@
 def HasRankedFirstOperand
   : Constraint<CPred<"(*$0.begin())->getType().isa<RankedTensorType>()">>;
 
+def IsShapedTensor
+  : Constraint<CPred<"$0->getType().isa<RankedTensorType>()">>;
+
 // This pattern converts TensorFlow axis format to HLO axis format which
 // doesn't wrap around like TensorFlow and is always positive. For this
 // conversion, use the first input to get inputs rank. Other inputs need not be
@@ -236,3 +239,19 @@
   def : Pat<(TfOp:$res AnyStaticShapeTensor:$arg, $ignored),
             (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;
 }
+
+//===----------------------------------------------------------------------===//
+// RngUniform.
+//===----------------------------------------------------------------------===//
+def CastElementsToI64: NativeCodeCall<
+  "CastElementsToI64($0->getLoc(), $1, &$_builder)">;
+
+// TODO(misard,phawkins): handle random number generator seeds/states correctly.
+def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2),
+          (HLO_RngUniformOp
+            (HLO_ConstOp
+              (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)),
+            (HLO_ConstOp
+              (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)),
+            (CastElementsToI64 $old, $shape)),
+            [(IsShapedTensor $shape)]>;