blob: 08794a982f92a3eb14e7b858861006af5cb7db09 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
static ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) {
auto* handle_data = c->input_handle_shapes_and_types(input);
if (handle_data != nullptr && !handle_data->empty() &&
(*handle_data)[0].dtype != DT_INVALID) {
return (*handle_data)[0].shape;
}
return c->input(input);
}
// Handle the gradient and, if <sparse>, indices inputs.
// <s> is an input+output parameter, containing the current known input shape to
// the gradient.
static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
int grad_idx, ShapeHandle* s) {
ShapeHandle grad = ShapeOrHandleShape(c, grad_idx);
if (!sparse) {
TF_RETURN_IF_ERROR(c->Merge(*s, grad, s));
return Status::OK();
}
// Indices is a vector where indices.dim[0].rank == grad[0].rank.
ShapeHandle indices;
TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
// Trailing part of grad matches trailing part of *s.
ShapeHandle grad_unknown_first;
TF_RETURN_IF_ERROR(
c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first));
TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s));
return Status::OK();
}
static Status ApplyGradientDescentShapeFn(InferenceContext* c) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // delta
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyGradientDescent")
.Input("var: Ref(T)")
.Input("alpha: T")
.Input("delta: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn(ApplyGradientDescentShapeFn);
REGISTER_OP("ResourceApplyGradientDescent")
.Input("var: resource")
.Input("alpha: T")
.Input("delta: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn(ApplyGradientDescentShapeFn);
static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c,
bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // l1
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l2
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 4 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyProximalGradientDescent")
.Input("var: Ref(T)")
.Input("alpha: T")
.Input("l1: T")
.Input("l2: T")
.Input("delta: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyProximalGradientDescentShapeFn(c, false /* sparse */);
});
REGISTER_OP("SparseApplyProximalGradientDescent")
.Input("var: Ref(T)")
.Input("alpha: T")
.Input("l1: T")
.Input("l2: T")
.Input("grad: T")
.Input("indices: Tindices")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyProximalGradientDescentShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceApplyProximalGradientDescent")
.Input("var: resource")
.Input("alpha: T")
.Input("l1: T")
.Input("l2: T")
.Input("delta: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyProximalGradientDescentShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceSparseApplyProximalGradientDescent")
.Input("var: resource")
.Input("alpha: T")
.Input("l1: T")
.Input("l2: T")
.Input("grad: T")
.Input("indices: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyProximalGradientDescentShapeFn(c, true /* sparse */);
});
static Status ApplyAdadeltaShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
TF_RETURN_IF_ERROR(
c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // accum update
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // epsilon
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyAdadelta")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("accum_update: Ref(T)")
.Input("lr: T")
.Input("rho: T")
.Input("epsilon: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdadeltaShapeFn(c, false /* sparse */);
});
REGISTER_OP("SparseApplyAdadelta")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("accum_update: Ref(T)")
.Input("lr: T")
.Input("rho: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdadeltaShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceApplyAdadelta")
.Input("var: resource")
.Input("accum: resource")
.Input("accum_update: resource")
.Input("lr: T")
.Input("rho: T")
.Input("epsilon: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdadeltaShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceSparseApplyAdadelta")
.Input("var: resource")
.Input("accum: resource")
.Input("accum_update: resource")
.Input("lr: T")
.Input("rho: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdadeltaShapeFn(c, true /* sparse */);
});
static Status ApplyAdagradShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
static Status ApplyAdagradV2ShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // epsilon
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 4 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyAdagrad")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceApplyAdagrad")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, false /* sparse */);
});
REGISTER_OP("ApplyAdagradV2")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("epsilon: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradV2ShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceApplyAdagradV2")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("epsilon: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradV2ShapeFn(c, false /* sparse */);
});
static Status ApplyProximalAdagradShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l1
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // l2
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 5 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyProximalAdagrad")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyProximalAdagradShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceApplyProximalAdagrad")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyProximalAdagradShapeFn(c, false /* sparse */);
});
REGISTER_OP("SparseApplyAdagrad")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("grad: T")
.Input("indices: Tindices")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceSparseApplyAdagrad")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("grad: T")
.Input("indices: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, true /* sparse */);
});
REGISTER_OP("SparseApplyAdagradV2")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradV2ShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceSparseApplyAdagradV2")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradV2ShapeFn(c, true /* sparse */);
});
static Status ApplyAdagradDAShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(
c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // grad_accumulator
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2),
&s)); // gradient_squared_accumulator
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
int idx = sparse ? 5 : 4;
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // global step
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyAdagradDA")
.Input("var: Ref(T)")
.Input("gradient_accumulator: Ref(T)")
.Input("gradient_squared_accumulator: Ref(T)")
.Input("grad: T")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("global_step: int64")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradDAShapeFn(c, false /* sparse */);
});
REGISTER_OP("SparseApplyAdagradDA")
.Input("var: Ref(T)")
.Input("gradient_accumulator: Ref(T)")
.Input("gradient_squared_accumulator: Ref(T)")
.Input("grad: T")
.Input("indices: Tindices")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("global_step: int64")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradDAShapeFn(c, true /* sparse */);
});
REGISTER_OP("SparseApplyProximalAdagrad")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("grad: T")
.Input("indices: Tindices")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyProximalAdagradShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceApplyAdagradDA")
.Input("var: resource")
.Input("gradient_accumulator: resource")
.Input("gradient_squared_accumulator: resource")
.Input("grad: T")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("global_step: int64")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradDAShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceSparseApplyAdagradDA")
.Input("var: resource")
.Input("gradient_accumulator: resource")
.Input("gradient_squared_accumulator: resource")
.Input("grad: T")
.Input("indices: Tindices")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("global_step: int64")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradDAShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceSparseApplyProximalAdagrad")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("grad: T")
.Input("indices: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyProximalAdagradShapeFn(c, true /* sparse */);
});
static Status ApplyFtrlShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // linear
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
int idx = sparse ? 5 : 4;
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr_power
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyFtrl")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("linear: Ref(T)")
.Input("grad: T")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("lr_power: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyFtrlShapeFn(c, false /* sparse */);
});
REGISTER_OP("SparseApplyFtrl")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("linear: Ref(T)")
.Input("grad: T")
.Input("indices: Tindices")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("lr_power: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyFtrlShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceApplyFtrl")
.Input("var: resource")
.Input("accum: resource")
.Input("linear: resource")
.Input("grad: T")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("lr_power: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyFtrlShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceSparseApplyFtrl")
.Input("var: resource")
.Input("accum: resource")
.Input("linear: resource")
.Input("grad: T")
.Input("indices: Tindices")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("lr_power: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyFtrlShapeFn(c, true /* sparse */);
});
REGISTER_OP("ApplyFtrlV2")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("linear: Ref(T)")
.Input("grad: T")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("l2_shrinkage: T")
.Input("lr_power: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyFtrlShapeFn(c, false /* sparse */);
});
REGISTER_OP("SparseApplyFtrlV2")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("linear: Ref(T)")
.Input("grad: T")
.Input("indices: Tindices")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("l2_shrinkage: T")
.Input("lr_power: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyFtrlShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceApplyFtrlV2")
.Input("var: resource")
.Input("accum: resource")
.Input("linear: resource")
.Input("grad: T")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("l2_shrinkage: T")
.Input("lr_power: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyFtrlShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceSparseApplyFtrlV2")
.Input("var: resource")
.Input("accum: resource")
.Input("linear: resource")
.Input("grad: T")
.Input("indices: Tindices")
.Input("lr: T")
.Input("l1: T")
.Input("l2: T")
.Input("l2_shrinkage: T")
.Input("lr_power: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyFtrlShapeFn(c, true /* sparse */);
});
static Status ApplyMomentumShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
int idx = sparse ? 5 : 4;
TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // momentum
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyMomentum")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("grad: T")
.Input("momentum: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, false /* sparse */);
});
REGISTER_OP("SparseApplyMomentum")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("grad: T")
.Input("indices: Tindices")
.Input("momentum: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceApplyMomentum")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("grad: T")
.Input("momentum: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceSparseApplyMomentum")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("grad: T")
.Input("indices: Tindices")
.Input("momentum: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceApplyKerasMomentum")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("grad: T")
.Input("momentum: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceSparseApplyKerasMomentum")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("grad: T")
.Input("indices: Tindices")
.Input("momentum: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, true /* sparse */);
});
static Status ApplyAdamShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta2_power
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta1
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta2
TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // epsilon
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 9 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyAdam")
.Input("var: Ref(T)")
.Input("m: Ref(T)")
.Input("v: Ref(T)")
.Input("beta1_power: T")
.Input("beta2_power: T")
.Input("lr: T")
.Input("beta1: T")
.Input("beta2: T")
.Input("epsilon: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdamShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceApplyAdam")
.Input("var: resource")
.Input("m: resource")
.Input("v: resource")
.Input("beta1_power: T")
.Input("beta2_power: T")
.Input("lr: T")
.Input("beta1: T")
.Input("beta2: T")
.Input("epsilon: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdamShapeFn(c, false /* sparse */);
});
static Status ApplyAdamWithAmsgradShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 3), &s)); // vhat
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta1_power
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta2_power
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta1
TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // beta2
TF_RETURN_IF_ERROR(c->WithRank(c->input(9), 0, &unused)); // epsilon
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 10 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ResourceApplyAdamWithAmsgrad")
.Input("var: resource")
.Input("m: resource")
.Input("v: resource")
.Input("vhat: resource")
.Input("beta1_power: T")
.Input("beta2_power: T")
.Input("lr: T")
.Input("beta1: T")
.Input("beta2: T")
.Input("epsilon: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdamWithAmsgradShapeFn(c, false /* sparse */);
});
static Status ApplyAdaMaxShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta1
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta2
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 8 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyAdaMax")
.Input("var: Ref(T)")
.Input("m: Ref(T)")
.Input("v: Ref(T)")
.Input("beta1_power: T")
.Input("lr: T")
.Input("beta1: T")
.Input("beta2: T")
.Input("epsilon: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdaMaxShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceApplyAdaMax")
.Input("var: resource")
.Input("m: resource")
.Input("v: resource")
.Input("beta1_power: T")
.Input("lr: T")
.Input("beta1: T")
.Input("beta2: T")
.Input("epsilon: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdaMaxShapeFn(c, false /* sparse */);
});
static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // ms
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // mom
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // momentum
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // epsilon
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 7 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
static Status ApplyCenteredRMSPropShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // ms
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // mg
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 3), &s)); // mom
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // rho
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // momentum
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 8 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyRMSProp")
.Input("var: Ref(T)")
.Input("ms: Ref(T)")
.Input("mom: Ref(T)")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyRMSPropShapeFn(c, false /* sparse */);
});
REGISTER_OP("ApplyCenteredRMSProp")
.Input("var: Ref(T)")
.Input("mg: Ref(T)")
.Input("ms: Ref(T)")
.Input("mom: Ref(T)")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyCenteredRMSPropShapeFn(c, false /* sparse */);
});
REGISTER_OP("SparseApplyRMSProp")
.Input("var: Ref(T)")
.Input("ms: Ref(T)")
.Input("mom: Ref(T)")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyRMSPropShapeFn(c, true /* sparse */);
});
REGISTER_OP("SparseApplyCenteredRMSProp")
.Input("var: Ref(T)")
.Input("mg: Ref(T)")
.Input("ms: Ref(T)")
.Input("mom: Ref(T)")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyCenteredRMSPropShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceApplyRMSProp")
.Input("var: resource")
.Input("ms: resource")
.Input("mom: resource")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyRMSPropShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceApplyCenteredRMSProp")
.Input("var: resource")
.Input("mg: resource")
.Input("ms: resource")
.Input("mom: resource")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyCenteredRMSPropShapeFn(c, false /* sparse */);
});
REGISTER_OP("ResourceSparseApplyRMSProp")
.Input("var: resource")
.Input("ms: resource")
.Input("mom: resource")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyRMSPropShapeFn(c, true /* sparse */);
});
REGISTER_OP("ResourceSparseApplyCenteredRMSProp")
.Input("var: resource")
.Input("mg: resource")
.Input("ms: resource")
.Input("mom: resource")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyCenteredRMSPropShapeFn(c, true /* sparse */);
});
static Status ApplyAddSignShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // alpha
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // sign_decay
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyAddSign")
.Input("var: Ref(T)")
.Input("m: Ref(T)")
.Input("lr: T")
.Input("alpha: T")
.Input("sign_decay: T")
.Input("beta: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAddSignShapeFn(c, /*sparse=*/false);
});
REGISTER_OP("ResourceApplyAddSign")
.Input("var: resource")
.Input("m: resource")
.Input("lr: T")
.Input("alpha: T")
.Input("sign_decay: T")
.Input("beta: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAddSignShapeFn(c, /*sparse=*/false);
});
static Status ApplyPowerSignShapeFn(InferenceContext* c, bool sparse) {
ShapeHandle unused;
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // logbase
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // sign_delay
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta
TF_RETURN_IF_ERROR(
HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
if (c->num_outputs() > 0) {
c->set_output(0, s);
}
return Status::OK();
}
REGISTER_OP("ApplyPowerSign")
.Input("var: Ref(T)")
.Input("m: Ref(T)")
.Input("lr: T")
.Input("logbase: T")
.Input("sign_decay: T")
.Input("beta: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyPowerSignShapeFn(c, /*sparse=*/false);
});
REGISTER_OP("ResourceApplyPowerSign")
.Input("var: resource")
.Input("m: resource")
.Input("lr: T")
.Input("logbase: T")
.Input("sign_decay: T")
.Input("beta: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyPowerSignShapeFn(c, /*sparse=*/false);
});
} // namespace tensorflow