blob: e1dd82b21e54ef9421aee2303da5faa0209bef45 [file] [log] [blame]
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Our general strategy for preventing conflicts between concurrent
// reads and writes of resource variables is to:
// * For read operations, we:
// - acquire the variable's mutex (in "shared" mode);
// - make a (shallow) copy of the Tensor object, which increments
// the reference count on the variable's TensorBuffer;
// - release the variable's mutex;
// - use the copy of the Tensor object to do the read.
// * For write operations, we:
// - acquire the variable's mutex (in "exclusive" mode);
// - check the reference count of variable's TensorBuffer and
// if it is >1, make a deep copy of the variable's Tensor;
// - mutate the variable's Tensor;
// - and release the variable's mutex.
// This allows several read operations to all use the same
// TensorBuffer without needing to copy. When it comes time to write
// it will only make a copy if there is an outstanding read using the
// buffer. Write operations are serialized by the variable's mutex.
//
// For sparse operations (scatter, gather, sparse optimizer updates),
// we need to avoid copies, since there may not be enough memory for
// to copies of the whole tensor. To support this, we make two
// modifications to the above strategy:
// * For sparse reads (gather), we hold the variable's mutex (still in
// "shared" mode) for the duration of the whole read. This means
// that as long as you only do sparse read operations no write will
// see the reference count >1.
// * For sparse write operations where the user explicitly specifies
// that they want to perform the write without locks held
// (use_locking=false), we never copy even if the variable's
// reference count is >1.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif
#include <memory>
#include <vector>
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
#include "tensorflow/core/kernels/gather_nd_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/resource_variable_util.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/determinism.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
ResourceHandlesOp<Var>);
ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
}
namespace {
Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
Tensor* output;
Notification n;
Status status;
AllocatorAttributes attr;
if (t->dtype() == DT_VARIANT) {
attr.set_on_host(true);
}
TF_RETURN_IF_ERROR(
ctx->allocate_output(output_idx, t->shape(), &output, attr));
if (t->dtype() == DT_VARIANT) {
output->flat<Variant>() = t->flat<Variant>();
} else if (ctx->op_device_context() != nullptr) {
// TODO(apassos): remove the down_cast by just returning Device* from
// OpKernelContext
Device* device = down_cast<Device*>(ctx->device());
ctx->op_device_context()->CopyTensorInSameDevice(
t, device, output, [&n, &status](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
} else {
switch (t->dtype()) {
#define HANDLER(type) \
case DataTypeToEnum<type>::value: \
output->flat<type>() = t->flat<type>(); \
break;
TF_CALL_ALL_TYPES(HANDLER);
#undef HANDLER
default:
return errors::Internal("Unsupported dtype", t->dtype());
}
}
return Status::OK();
}
} // namespace
void ReadVariableOp::Compute(OpKernelContext* ctx) {
core::RefCountPtr<Var> variable;
const ResourceHandle& handle = HandleFromInput(ctx, 0);
const auto status = LookupResource(ctx, handle, &variable);
OP_REQUIRES(ctx, status.ok(),
errors::FailedPrecondition(
"Could not find variable ", handle.name(), ". ",
"This could mean that the variable has been deleted. ",
"In TF1, it can also mean the variable is uninitialized. ",
"Debug info: container=", handle.container(),
", status error message=", status.error_message()));
tf_shared_lock ml(*variable->mu());
// We're acquiring a reference to the underlying buffer while
// holding a shared lock to guarantee ordering of reads and
// writes when in copy-on-write mode.
const Tensor* t = variable->tensor();
if (!variable->copy_on_read_mode.load()) {
OP_REQUIRES(
ctx, dtype_ == t->dtype(),
errors::InvalidArgument(
"Trying to read variable with wrong dtype. Expected ",
DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
ctx->set_output(0, *t);
} else {
OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
}
}
ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
int n;
OP_REQUIRES_OK(c, c->GetAttr("N", &n));
OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
OP_REQUIRES(c, n == dtypes_.size(),
errors::InvalidArgument(
"Mismatched number of arguments to ReadVariablesOp (", n,
" vs. ", dtypes_.size(), ")"));
}
void ReadVariablesOp::Compute(OpKernelContext* ctx) {
std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
std::vector<const ResourceHandle*> handles(dtypes_.size());
for (size_t i = 0; i < dtypes_.size(); ++i) {
handles[i] = &HandleFromInput(ctx, i);
}
OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
std::vector<string> uninitialized_vars;
for (int64_t i = 0; i < variables.size(); i++) {
if (variables[i] == nullptr) {
uninitialized_vars.push_back(handles[i]->name());
}
}
OP_REQUIRES(ctx, uninitialized_vars.empty(),
errors::FailedPrecondition(
"In ReadVariablesOp the following variables were "
"found uninitialized: ",
absl::StrJoin(uninitialized_vars, ", ")));
for (size_t i = 0; i < dtypes_.size(); ++i) {
// We're acquiring a reference to the underlying buffer while
// holding a shared lock to guarantee ordering of reads and
// writes.
tf_shared_lock ml(*variables[i]->mu());
OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
errors::InvalidArgument(
"Trying to read variable ", handles[i]->name(),
" from Container: ", handles[i]->container(),
" with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
" got ", DataTypeString(variables[i]->tensor()->dtype())));
if (variables[i]->copy_on_read_mode.load()) {
OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
} else {
const Tensor& t = *variables[i]->tensor();
ctx->set_output(i, t);
}
}
}
REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
ReadVariableOp);
REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
ReadVariablesOp);
REGISTER_KERNEL_BUILDER(
Name("ReadVariableOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
ReadVariableOp);
REGISTER_KERNEL_BUILDER(
Name("_ReadVariablesOp").Device(DEVICE_DEFAULT).HostMemory("resources"),
ReadVariablesOp);
VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
// Use const_tensor_ if the variable is non-anonymous.
if (!is_anonymous_) {
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}),
&const_tensor_, attr));
const_tensor_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
context, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
}
}
void VarHandleOp::Compute(OpKernelContext* ctx) {
if (is_anonymous_) {
Var* resource = new Var(dtype_and_shape_.dtype);
ResourceMgr* mgr = ctx->resource_manager();
ResourceHandle handle = ResourceHandle::MakeRefCountingHandle<Var>(
resource, ctx->device()->name(),
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
// TODO(b/203901837): See if we can abolish all code paths that lookup
// anonymous variables and then stop publishing them to the manager.
OP_REQUIRES_OK(ctx, mgr->CreateUnowned<Var>(handle.container(),
handle.name(), resource));
AllocatorAttributes attr;
attr.set_on_host(true);
Tensor tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr));
tensor.scalar<ResourceHandle>()() = std::move(handle);
ctx->set_output(0, tensor);
} else {
ctx->set_output(0, const_tensor_);
}
}
REGISTER_KERNEL_BUILDER(Name("VarHandleOp").Device(DEVICE_CPU), VarHandleOp);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
template <> \
void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
const GPUDevice& d, typename TTypes<type>::Flat lhs, \
typename TTypes<type>::ConstFlat rhs); \
extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
}
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_variant(REGISTER_GPU_KERNELS);
TF_CALL_uint32(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_DEFAULT_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \
.Device(DEVICE_DEFAULT) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
VarHandleOp)
TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS);
TF_CALL_bfloat16(REGISTER_DEFAULT_KERNELS);
TF_CALL_int64(REGISTER_DEFAULT_KERNELS);
TF_CALL_variant(REGISTER_DEFAULT_KERNELS);
TF_CALL_uint32(REGISTER_DEFAULT_KERNELS);
#undef REGISTER_DEFAULT_KERNELS
REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
.Device(DEVICE_DEFAULT)
.HostMemory("resources")
.TypeConstraint("dtypes",
{DT_INT64, DT_COMPLEX64,
DT_COMPLEX128, DT_HALF, DT_FLOAT,
DT_DOUBLE, DT_BOOL, DT_VARIANT}),
ResourceHandlesOp<Var>);
REGISTER_KERNEL_BUILDER(
Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
VariableShapeOp<int32>);
REGISTER_KERNEL_BUILDER(Name("VariableShape")
.Device(DEVICE_CPU)
.TypeConstraint<int64_t>("out_type"),
VariableShapeOp<int64_t>);
REGISTER_KERNEL_BUILDER(Name("VariableShape")
.Device(DEVICE_DEFAULT)
.TypeConstraint<int32>("out_type")
.HostMemory("output")
.HostMemory("input"),
VariableShapeOp<int32>);
REGISTER_KERNEL_BUILDER(Name("VariableShape")
.Device(DEVICE_DEFAULT)
.TypeConstraint<int64_t>("out_type")
.HostMemory("output")
.HostMemory("input"),
VariableShapeOp<int64_t>);
DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx,
ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
}
void DestroyResourceOp::Compute(OpKernelContext* ctx) {
const ResourceHandle& p = HandleFromInput(ctx, 0);
Status status = DeleteResource(ctx, p);
if (ignore_lookup_error_ && errors::IsNotFound(status)) {
return;
}
OP_REQUIRES_OK(ctx, status);
}
REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
DestroyResourceOp);
REGISTER_KERNEL_BUILDER(
Name("DestroyResourceOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
DestroyResourceOp);
template <typename Device, typename T>
class AssignVariableOp : public OpKernel {
public:
explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
if (!c->GetAttr("_grappler_relax_allocator_constraints",
&relax_constraints_)
.ok()) {
relax_constraints_ = false;
}
if (c->HasAttr("validate_shape")) {
OP_REQUIRES_OK(c, c->GetAttr("validate_shape", &validate_shape_));
}
}
void Compute(OpKernelContext* context) override {
OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
errors::InvalidArgument(
"Variable and value dtypes don't match; respectively, ",
DataTypeString(dtype_), " and ",
DataTypeString(context->input(1).dtype())));
core::RefCountPtr<Var> variable;
const Tensor& value = context->input(1);
// Note: every resource-variable-manipulating op assumes copy-on-write
// semantics, and creates a copy of the variable's Tensor if its refcount is
// bigger than 1 when we try to modify it. This means we never need to copy
// the original tensor for AssignVariableOp; even if there are other live
// users of it we know none can modify it so this is always safe (even in
// esoteric cases where the same tensor is used to initialize multiple
// variables or the tensor is a constant this is safe, as future writes will
// trigger copies).
OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
context, HandleFromInput(context, 0), &variable,
[this, &value](Var** ptr) {
*ptr = new Var(dtype_);
*(*ptr)->tensor() = value;
(*ptr)->is_initialized = true;
return Status::OK();
}));
mutex_lock ml(*variable->mu());
// (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized)
// check below is to allow an XLA specific situation wherein update can
// happen first by the AssignVariableOp,
// in which case the variable is still uninitialized.
// When using TF-XLA, this scenario is possible when the execution uses the
// 'fallback' path (which essentially invokes Tensorflow ops via
// partitioned_call).
OP_REQUIRES(context,
(variable->tensor()->dtype() == DT_INVALID &&
!variable->is_initialized) ||
variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
if (validate_shape_) {
OP_REQUIRES(
context,
(!variable->is_initialized ||
variable->tensor()->shape().IsSameSize(value.shape())),
errors::InvalidArgument(
"Trying to assign to variable with tensor with wrong shape."
" Expected ",
variable->tensor()->shape().DebugString(), " got ",
value.shape().DebugString()));
}
if (variable->copy_on_read_mode.load()) {
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
attr.set_nic_compatible(true);
OP_REQUIRES_OK(context,
context->allocate_temp(value.dtype(), value.shape(),
variable->tensor(), attr));
functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
copy_functor(context->eigen_device<Device>(),
variable->tensor()->flat<T>(), value.flat<T>());
} else {
*variable->tensor() = value;
}
variable->is_initialized = true;
}
private:
DataType dtype_;
bool relax_constraints_;
bool validate_shape_ = false;
};
template <typename Device>
class AssignVariableOp<Device, Variant> : public OpKernel {
public:
explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
OP_REQUIRES(c, dtype_ == DT_VARIANT,
errors::Internal("Variant kernel called with dtype: ",
DataTypeString(dtype_)));
}
void Compute(OpKernelContext* context) override {
const Tensor& value = context->input(1);
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
context, HandleFromInput(context, 0), &variable,
[](Var** ptr) {
// Created on host.
*ptr = new Var(DT_VARIANT);
return Status::OK();
}));
// For purposes of forwarding DT_VARIANT, we want the least
// restrictive attr; we already know the input is on host.
AllocatorAttributes attr;
// Copying is unnecessary if we are the last user of the value
// tensor, we can just adopt the input tensor's buffer instead.
// Note that Variant objects themselves always reside on host.
//
// We nevertheless want to signal to the runtime that the tensor
// should reside in memory of the associated device, as Variant
// tensors may be marked as sitting on either CPU or GPU. This
// helps to elide one or more copies.
std::unique_ptr<Tensor> input_alias = context->forward_input(
1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
value.shape(),
DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
attr);
mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(DT_VARIANT)));
variable->is_initialized = true;
*variable->tensor() = Tensor(DT_VARIANT, value.shape());
if (input_alias) {
*variable->tensor() = *input_alias;
return;
}
// Need to copy, but maybe we can re-use variable's buffer?
if (!variable->tensor()->RefCountIsOne() ||
!variable->tensor()->shape().IsSameSize(value.shape())) {
// Allocation of DT_VARIANT is always on host.
attr.set_on_host(true);
OP_REQUIRES_OK(context, context->allocate_temp(DT_VARIANT, value.shape(),
variable->tensor(), attr));
}
const auto elements_in = value.flat<Variant>();
auto elements_out = variable->tensor()->flat<Variant>();
for (int64_t i = 0; i < elements_in.size(); ++i) {
elements_out(i) = elements_in(i);
}
}
private:
DataType dtype_;
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
AssignVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("dtype") \
.HostMemory("resource"), \
AssignVariableOp<GPUDevice, type>);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_uint32(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")
.Device(DEVICE_DEFAULT)
.TypeConstraint<Variant>("dtype")
.HostMemory("resource"),
AssignVariableOp<CPUDevice, Variant>);
template <typename Device, typename T, DenseUpdateType Op>
class AssignUpdateVariableOp : public OpKernel {
public:
explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&variable));
const Tensor& value = context->input(1);
// TODO(apassos): We could possibly avoid the copy done by
// PrepareToUpdateVariable() for commutative operations like Op ==
// ADD if value's refcount was 1.
mutex_lock ml(*variable->mu());
Tensor* var_tensor = variable->tensor();
OP_REQUIRES_OK(context, ValidateAssignUpdateVariableOpShapes(
var_tensor->shape(), value.shape()));
OP_REQUIRES_OK(
context, PrepareToUpdateVariable<Device, T>(
context, var_tensor, variable->copy_on_read_mode.load()));
functor::DenseUpdate<Device, T, Op> update_functor;
update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
value.flat<T>());
}
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("AssignAddVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
REGISTER_KERNEL_BUILDER( \
Name("AssignSubVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
.Device(DEVICE_GPU) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
AssignUpdateVariableOp<GPUDevice, type, ADD>); \
REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp") \
.Device(DEVICE_GPU) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
AssignUpdateVariableOp<GPUDevice, type, SUB>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
class VarIsInitializedOp : public OpKernel {
public:
explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* context) override {
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}), &output));
auto output_tensor = output->tensor<bool, 0>();
core::RefCountPtr<Var> variable;
Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
if (!s.ok()) {
output_tensor() = false;
return;
}
mutex_lock ml(*variable->mu());
output_tensor() = variable->is_initialized;
}
};
REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
VarIsInitializedOp);
REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
.Device(DEVICE_DEFAULT)
.HostMemory("resource")
.HostMemory("is_initialized"),
VarIsInitializedOp);
template <typename Device, typename T, typename Index>
class ResourceGatherOp : public OpKernel {
public:
explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
}
void Compute(OpKernelContext* c) override {
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
// NOTE: We hold the lock for the whole gather operation instead
// of increasing the reference count of v->tensor() to avoid a
// situation where a write to the same variable will see a
// reference count greater than one and make a copy of the
// (potentially very large) tensor buffer.
tf_shared_lock ml(*v->mu());
const Tensor& params = *v->tensor();
const Tensor& indices = c->input(1);
OP_REQUIRES(
c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
errors::InvalidArgument("params must be at least 1 dimensional"));
OP_REQUIRES(
c, params.shape().dims() >= batch_dims_,
errors::InvalidArgument("params must have at least ", batch_dims_,
" (batch_dims) dimensions but it has shape ",
params.shape().DebugString()));
// Check that we have enough index space
const int64_t N = indices.NumElements();
OP_REQUIRES(
c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("params.shape[0] too large for ",
DataTypeString(DataTypeToEnum<Index>::v()),
" indexing: ", params.dim_size(0), " > ",
std::numeric_limits<Index>::max()));
// The result shape is params.shape[:batch_dims] +
// indices.shape[batch_dims:] + params.shape[batch_dims+1:].
TensorShape result_shape;
for (int i = 0; i < batch_dims_; ++i) {
result_shape.AddDim(params.dim_size(i));
}
for (int i = batch_dims_; i < indices.dims(); ++i) {
result_shape.AddDim(indices.dim_size(i));
}
for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
result_shape.AddDim(params.dim_size(i));
}
Tensor* out = nullptr;
Tensor tmp;
if (params.dtype() == DT_VARIANT) {
tmp = Tensor(DT_VARIANT, result_shape);
c->set_output(0, tmp);
out = &tmp;
} else {
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
}
if (N > 0) {
Tensor tmp_indices;
// Points to the original or updated (if batch_dims is set) indices.
const Tensor* op_indices = &indices;
if (batch_dims_ > 0) {
OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
&tmp_indices));
functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
indices.flat<Index>());
AddBatchOffsets(c, &tmp_indices, params);
if (!c->status().ok()) return;
op_indices = &tmp_indices;
}
int64_t gather_dim_size = 1;
for (int idx = 0; idx <= batch_dims_; ++idx) {
gather_dim_size *= params.dim_size(idx);
}
int64_t inner_size = 1;
for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
inner_size *= params.dim_size(i);
}
auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
const auto indices_flat = op_indices->flat<Index>();
auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
functor::GatherFunctor<Device, T, Index> functor;
int64_t bad_i = functor(c, params_flat, indices_flat, out_flat);
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
}
}
private:
// Add the batch offset derived from params to each batch of indices.
// Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
// If indexing into a params dimension of size 4, then the indices will become
// [0, 1, 2, 4, 5, 6]
void AddBatchOffsets(OpKernelContext* ctx, Tensor* indices,
const Tensor& params) {
int64_t batch_size = 1; // The size of all batch dimensions.
for (int idx = 0; idx < batch_dims_; ++idx) {
batch_size *= params.dim_size(idx);
}
OP_REQUIRES(
ctx, batch_size != 0,
errors::InvalidArgument(
"Inner size of indices would result in batch_size of 0 and a ",
"division by 0 in the implementation. This is illegal"));
auto indices_flat = indices->flat<Index>();
int64_t const index_inner_size = indices->NumElements() / batch_size;
int64_t const batch_offset = params.dim_size(batch_dims_);
for (int64_t batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
++batch_idx) {
for (int64_t idx = 0; idx < index_inner_size; ++idx) {
indices_flat(dest_idx++) += batch_offset * batch_idx;
}
}
}
int32 batch_dims_ = 0;
};
#define REGISTER_GATHER_FULL(dev, type, index_type) \
REGISTER_KERNEL_BUILDER(Name("ResourceGather") \
.Device(DEVICE_##dev) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype") \
.TypeConstraint<index_type>("Tindices"), \
ResourceGatherOp<dev##Device, type, index_type>)
#define REGISTER_GATHER_ALL_INDICES(dev, type) \
REGISTER_GATHER_FULL(dev, type, int32); \
REGISTER_GATHER_FULL(dev, type, int64_t)
#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
// Registration of the CPU implementations.
TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
TF_CALL_int64(REGISTER_GATHER_GPU);
TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
// Variant objects themselves sit on CPU, even if they contain data
// pointing to a device.
REGISTER_KERNEL_BUILDER(Name("ResourceGather")
.Device(DEVICE_DEFAULT)
.HostMemory("resource")
.HostMemory("indices")
.TypeConstraint<Variant>("dtype")
.TypeConstraint<int32>("Tindices"),
ResourceGatherOp<CPUDevice, Variant, int32>)
REGISTER_KERNEL_BUILDER(Name("ResourceGather")
.Device(DEVICE_DEFAULT)
.HostMemory("resource")
.HostMemory("indices")
.TypeConstraint<Variant>("dtype")
.TypeConstraint<int64_t>("Tindices"),
ResourceGatherOp<CPUDevice, Variant, int64>)
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_GATHER_CPU
#undef REGISTER_GATHER_GPU
#undef REGISTER_GATHER_ALL_INDICES
#undef REGISTER_GATHER_FULL
template <typename Device, typename T, typename Index>
class ResourceGatherNdOp : public OpKernel {
public:
explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* c) override {
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
// NOTE: We hold the lock for the whole gather operation instead
// of increasing the reference count of v->tensor() to avoid a
// situation where a write to the same variable will see a
// reference count greater than one and make a copy of the
// (potentially very large) tensor buffer.
tf_shared_lock ml(*v->mu());
const Tensor& params = *v->tensor();
const Tensor& indices = c->input(1);
Tensor out;
OP_REQUIRES_OK(
c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
c->set_output(0, out);
}
};
#define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd") \
.Device(DEVICE_##dev) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype") \
.TypeConstraint<index_type>("Tindices"), \
ResourceGatherNdOp<dev##Device, type, index_type>)
#define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
REGISTER_GATHER_ND_FULL(dev, type, int32); \
REGISTER_GATHER_ND_FULL(dev, type, int64_t)
#define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
// Registration of the CPU implementations.
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_GATHER_ND_CPU
#undef REGISTER_GATHER_ND_GPU
#undef REGISTER_GATHER_ND_ALL_INDICES
#undef REGISTER_GATHER_ND_FULL
namespace {
template <typename Device>
bool isCPUDevice() {
return false;
}
template <>
bool isCPUDevice<CPUDevice>() {
return true;
}
template <typename T>
bool ValidateInput(const Tensor& updates) {
const auto updates_flat = updates.flat<T>();
const T zero(0);
for (int i = 0; i < updates.NumElements(); i++) {
if (updates_flat(i) == zero) return false;
}
return true;
}
template <>
bool ValidateInput<Variant>(const Tensor& updates) {
return true;
}
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices,
const Tensor& updates, Index num_indices);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
Status CopyTensorToHost(OpKernelContext* c, const Tensor& device_tensor,
Tensor* host_tensor) {
AllocatorAttributes alloc_attr;
alloc_attr.set_on_host(true);
alloc_attr.set_gpu_compatible(true);
auto stream = c->op_device_context()->stream();
TF_RETURN_IF_ERROR(c->allocate_temp(
device_tensor.dtype(), device_tensor.shape(), host_tensor, alloc_attr));
se::DeviceMemoryBase device_ptr(
const_cast<Tensor&>(device_tensor).flat<T>().data(),
device_tensor.flat<T>().size() * sizeof(T));
stream->ThenMemcpy(host_tensor->flat<T>().data(), device_ptr,
device_tensor.NumElements() * sizeof(T));
if (!stream) {
return errors::Internal("Failed to copy indices to host");
}
return Status::OK();
}
// Copies inputs to the CPU, runs DoScatter on the CPU, then copies output
// back to GPU. This is useful because the CPU implementation is deterministic
// and the GPU implementation is not. Tensor inputs to this function must be on
// the GPU.
template <typename T, typename Index, scatter_op::UpdateOp Op>
Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices,
const Tensor& updates, Index num_indices) {
auto stream = c->op_device_context()->stream();
Tensor host_indices;
TF_RETURN_IF_ERROR(CopyTensorToHost<Index>(c, indices, &host_indices));
Tensor host_updates;
TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, updates, &host_updates));
Tensor host_params;
TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, *params, &host_params));
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(DoScatter<CPUDevice, T, Index, Op>(
c, &host_params, host_indices, host_updates, num_indices));
// Copy 'host_params' to device.
se::DeviceMemoryBase params_ptr(params->flat<T>().data(),
params->flat<T>().size() * sizeof(T));
stream->ThenMemcpy(&params_ptr, host_params.flat<T>().data(),
host_params.NumElements() * sizeof(T));
if (!stream) {
return errors::Internal("Failed to copy params to device");
}
// Deallocate host_params' buffer once the host-to-device copy is complete.
// host_params is captured by value in the lambda so that its buffer is only
// destructed once the lambda is destructed.
c->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
stream, [host_params] {});
return Status::OK();
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices,
const Tensor& updates, Index num_indices) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value &&
tensorflow::OpDeterminismRequired()) {
if (!DataTypeCanUseMemcpy(params->dtype())) {
return errors::Unimplemented(
"GPU Scatter ops for dtype ", DataTypeString(params->dtype()),
" do not yet have a deterministic implementation");
}
return DoScatterOnCpu<T, Index, op>(c, params, indices, updates,
num_indices);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
auto indices_flat = indices.flat<Index>();
auto params_flat = params->flat_outer_dims<T>();
int64_t num_updates = updates.NumElements();
auto updates_flat =
updates.shaped<T, 2>({num_indices, num_updates / num_indices});
functor::ScatterFunctor<Device, T, Index, op> functor;
const Index bad_i = functor(c, c->template eigen_device<Device>(),
params_flat, updates_flat, indices_flat);
if (bad_i >= 0) {
return errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", params->dim_size(0), ")");
}
return Status::OK();
}
} // namespace
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
class ResourceScatterUpdateOp : public OpKernel {
public:
explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
// We use the same kernel for many operations.
// Each operation has a different set of attributes defined in its nodes.
Status s = c->GetAttr("use_locking", &use_exclusive_lock_);
if (!s.ok()) {
use_exclusive_lock_ = false;
}
}
void Compute(OpKernelContext* c) override {
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE ||
c->input_dtype(0) == DT_STRING ||
c->input_dtype(0) == DT_VARIANT;
if (is_non_pod_dtype || use_exclusive_lock_) {
mutex_lock ml(*v->mu());
DoCompute(c);
} else {
// For POD dtypes, we can safely run the update without the mutex.
tf_shared_lock ml(*v->mu());
DoCompute(c);
}
}
private:
bool use_exclusive_lock_;
void DoCompute(OpKernelContext* c) {
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
Tensor* params = v->tensor();
const Tensor& indices = c->input(1);
const Tensor& updates = c->input(2);
// Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
OP_REQUIRES(c,
updates.dims() == 0 ||
updates.dims() == indices.dims() + params->dims() - 1,
errors::InvalidArgument(
"Must have updates.shape = indices.shape + "
"params.shape[1:] or updates.shape = [], got ",
"updates.shape ", updates.shape().DebugString(),
", indices.shape ", indices.shape().DebugString(),
", params.shape ", params->shape().DebugString()));
// Check that we have enough index space
const int64_t N_big = indices.NumElements();
OP_REQUIRES(
c, N_big <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("indices has too many elements for ",
DataTypeString(DataTypeToEnum<Index>::v()),
" indexing: ", N_big, " > ",
std::numeric_limits<Index>::max()));
const Index N = static_cast<Index>(N_big);
OP_REQUIRES(
c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("params.shape[0] too large for ",
DataTypeString(DataTypeToEnum<Index>::v()),
" indexing: ", params->dim_size(0), " > ",
std::numeric_limits<Index>::max()));
// Prevent division by 0
if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) {
OP_REQUIRES(c, ValidateInput<T>(updates),
errors::InvalidArgument("updates must not contain 0"));
}
if (N > 0) {
auto indices_flat = indices.flat<Index>();
auto params_flat = params->flat_outer_dims<T>();
if (TensorShapeUtils::IsScalar(updates.shape())) {
const auto update = updates.scalar<T>();
functor::ScatterScalarFunctor<Device, T, Index, op> functor;
const Index bad_i = functor(c, c->template eigen_device<Device>(),
params_flat, update, indices_flat);
OP_REQUIRES(c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i),
" = ", indices_flat(bad_i), " is not in [0, ",
params->dim_size(0), ")"));
} else {
OP_REQUIRES(
c, TensorShapeUtils::StartsWith(updates.shape(), indices.shape()),
errors::InvalidArgument(
"The shape of indices (", indices.shape().DebugString(),
") must be a prefix of the shape of updates (",
updates.shape().DebugString(), ")"));
OP_REQUIRES_OK(
c, DoScatter<Device, T, Index, op>(c, params, indices, updates, N));
}
}
}
};
#define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
REGISTER_KERNEL_BUILDER( \
Name(name) \
.Device(DEVICE_##dev) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype") \
.TypeConstraint<index_type>("Tindices"), \
ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
#define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64_t, dev, name, op);
#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
scatter_op::UpdateOp::ADD); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
scatter_op::UpdateOp::SUB); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
scatter_op::UpdateOp::MUL); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
scatter_op::UpdateOp::DIV); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
scatter_op::UpdateOp::ASSIGN);
#define REGISTER_SCATTER_MINMAX(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
scatter_op::UpdateOp::MIN); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
scatter_op::UpdateOp::MAX);
// Registers CPU kernels.
#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
REGISTER_SCATTER_ARITHMETIC(type, CPU);
#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
REGISTER_SCATTER_KERNEL(tstring, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
// Registers GPU kernels.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
REGISTER_SCATTER_ARITHMETIC(type, GPU);
#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
.Device(DEVICE_DEFAULT)
.HostMemory("resource")
.HostMemory("indices")
.TypeConstraint<Variant>("dtype")
.TypeConstraint<int32>("Tindices"),
ResourceScatterUpdateOp<CPUDevice, Variant, int32,
scatter_op::UpdateOp::ASSIGN>)
REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
.Device(DEVICE_GPU)
.HostMemory("resource")
.TypeConstraint<bool>("dtype")
.TypeConstraint<int32>("Tindices"),
ResourceScatterUpdateOp<GPUDevice, bool, int32,
scatter_op::UpdateOp::ASSIGN>)
REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
.Device(DEVICE_DEFAULT)
.HostMemory("resource")
.HostMemory("indices")
.TypeConstraint<Variant>("dtype")
.TypeConstraint<int64_t>("Tindices"),
ResourceScatterUpdateOp<CPUDevice, Variant, int64,
scatter_op::UpdateOp::ASSIGN>)
REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
.Device(DEVICE_GPU)
.HostMemory("resource")
.TypeConstraint<int64_t>("dtype")
.TypeConstraint<int64_t>("Tindices"),
ResourceScatterUpdateOp<GPUDevice, int64, int64,
scatter_op::UpdateOp::ASSIGN>)
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_SCATTER_ARITHMETIC
#undef REGISTER_SCATTER_ARITHMETIC_CPU
#undef REGISTER_SCATTER_MINMAX
#undef REGISTER_SCATTER_MINMAX_CPU
#undef REGISTER_SCATTER_KERNEL
#undef REGISTER_SCATTER_KERNEL_INDEX
} // namespace tensorflow