blob: 546b205ddf9b7cd744d1b8f10db2d7bbc9573ad5 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// XLA implementations of Random ops
// TODO(misard,phawkins): handle random number generator seeds/states correctly.
// TODO(misard,phawkins): add tests.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
class RandomUniformOp : public XlaOpKernel {
public:
explicit RandomUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(
0, &shape, xla::ValueInferenceMode::kUpperBound));
const DataType dtype = output_type(0);
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
xla::XlaBuilder* b = ctx->builder();
LOG_FIRST_N(WARNING, 1)
<< "Warning: Using tf.random.uniform with XLA compilation will ignore "
"seeds; consider using tf.random.stateless_uniform instead if "
"reproducible behavior is desired. "
<< name();
xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype),
XlaHelpers::One(b, dtype), xla_shape);
std::vector<bool> dynamic_dims;
OP_REQUIRES_OK(ctx,
ctx->ResolveInputDynamismIntoPredVector(0, &dynamic_dims));
for (int64_t i = 0; i < xla_shape.rank(); ++i) {
// If a dimension is dynamic, call set-dimension-size on the output.
if (dynamic_dims[i]) {
auto dynamic_dim_size = xla::Slice(ctx->Input(0), {i}, {i + 1}, {1});
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
result = xla::SetDimensionSize(result, dynamic_dim_size, i);
}
}
ctx->SetOutput(0, result);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp);
};
REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstantInput("shape"),
RandomUniformOp);
REGISTER_XLA_OP(Name("RandomShuffle"), MlirXlaOpKernel);
class RandomUniformIntOp : public XlaOpKernel {
public:
explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx,
TensorShapeToXLAShape(input_type(1), shape, &xla_shape));
const TensorShape minval_shape = ctx->InputShape(1);
const TensorShape maxval_shape = ctx->InputShape(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
errors::InvalidArgument("minval must be 0-D, got shape ",
minval_shape.DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval_shape.DebugString()));
auto minval = ctx->Input(1);
auto maxval = ctx->Input(2);
LOG_FIRST_N(WARNING, 1)
<< "Warning: Using tf.random.uniform with XLA compilation will ignore "
"seeds; consider using tf.random.stateless_uniform instead if "
"reproducible behavior is desired. "
<< name();
ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp);
};
REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstantInput("shape"),
RandomUniformIntOp);
class RandomStandardNormalOp : public XlaOpKernel {
public:
explicit RandomStandardNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const DataType dtype = output_type(0);
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
xla::XlaBuilder* b = ctx->builder();
// Normal distribution with a mean of 0 and a standard deviation of 1:
xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype),
XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp);
};
REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstantInput("shape"),
RandomStandardNormalOp);
class TruncatedNormalOp : public XlaOpKernel {
public:
explicit TruncatedNormalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const DataType dtype = output_type(0);
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp one = xla::One(b, xla_shape.element_type());
xla::XlaOp min_positive =
xla::MinPositiveNormalValue(b, xla_shape.element_type());
LOG_FIRST_N(WARNING, 1)
<< "Warning: Using tf.random.truncated_normal with XLA "
"compilation will ignore seeds; consider using "
"tf.random.stateless_truncated_normal instead if "
"reproducible behavior is desired. "
<< name();
auto uniform = xla::RngUniform(min_positive, one, xla_shape);
ctx->SetOutput(0, TruncatedNormal(uniform));
}
};
REGISTER_XLA_OP(Name("TruncatedNormal")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
TruncatedNormalOp);
class ParameterizedTruncatedNormalOp : public XlaOpKernel {
public:
explicit ParameterizedTruncatedNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const DataType dtype = output_type(0);
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp one = xla::One(b, xla_shape.element_type());
xla::XlaOp min_positive =
xla::MinPositiveNormalValue(b, xla_shape.element_type());
LOG_FIRST_N(WARNING, 1)
<< "Warning: Using tf.random.truncated_normal with XLA "
"compilation will ignore seeds; consider using "
"tf.random.stateless_truncated_normal instead if "
"reproducible behavior is desired. "
<< name();
xla::XlaOp uniform = xla::RngUniform(min_positive, one, xla_shape);
auto result = b->ReportErrorOrReturn([&]() -> StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::XlaOp means,
BroadcastTo(ctx->Input(1), shape.dim_sizes()));
TF_ASSIGN_OR_RETURN(xla::XlaOp stddevs,
BroadcastTo(ctx->Input(2), shape.dim_sizes()));
TF_ASSIGN_OR_RETURN(xla::XlaOp minvals,
BroadcastTo(ctx->Input(3), shape.dim_sizes()));
TF_ASSIGN_OR_RETURN(xla::XlaOp maxvals,
BroadcastTo(ctx->Input(4), shape.dim_sizes()));
return ParameterizedTruncatedNormal(uniform, means, stddevs, minvals,
maxvals);
});
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(Name("ParameterizedTruncatedNormal")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
ParameterizedTruncatedNormalOp);
} // namespace
} // namespace tensorflow