blob: b37d812986c277aa3c045f64cf5493171aec404d [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
#include <string>
#include <utility>
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
namespace tpu {
std::string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
switch (alg) {
case OptimizationAlgorithm::kAdagrad:
return "Adagrad";
case OptimizationAlgorithm::kAdagradMomentum:
return "AdagradMomentum";
case OptimizationAlgorithm::kBoundedAdagrad:
return "BoundedAdagrad";
case OptimizationAlgorithm::kStochasticGradientDescent:
return "StochasticGradientDescent";
case OptimizationAlgorithm::kFtrl:
return "FTRL";
case OptimizationAlgorithm::kAdam:
return "ADAM";
case OptimizationAlgorithm::kMomentum:
return "Momentum";
case OptimizationAlgorithm::kRmsProp:
return "RMSProp";
case OptimizationAlgorithm::kCenteredRmsProp:
return "CenteredRMSProp";
case OptimizationAlgorithm::kMdlAdagradLight:
return "MDLAdagradLight";
case OptimizationAlgorithm::kAdadelta:
return "Adadelta";
case OptimizationAlgorithm::kProximalAdagrad:
return "ProximalAdagrad";
case OptimizationAlgorithm::kOnlineYogi:
return "OnlineYogi";
case OptimizationAlgorithm::kProximalYogi:
return "ProximalYogi";
case OptimizationAlgorithm::kFrequencyEstimator:
return "FrequencyEstimator";
case OptimizationAlgorithm::kUserDefinedProgram:
return "UserDefinedProgram";
case OptimizationAlgorithm::kAssign:
return "Assign";
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return "*** Not set ***";
}
return "*** Not set ***";
}
std::string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
switch (alg) {
case OptimizationAlgorithm::kAdagrad:
return "Adagrad";
case OptimizationAlgorithm::kAdagradMomentum:
return "Adagrad with Momentum";
case OptimizationAlgorithm::kBoundedAdagrad:
return "Bounded Adagrad";
case OptimizationAlgorithm::kStochasticGradientDescent:
return "stochastic gradient descent";
case OptimizationAlgorithm::kFtrl:
return "FTRL";
case OptimizationAlgorithm::kAdam:
return "ADAM";
case OptimizationAlgorithm::kMomentum:
return "Momentum";
case OptimizationAlgorithm::kRmsProp:
return "RMSProp";
case OptimizationAlgorithm::kCenteredRmsProp:
return "centered RMSProp";
case OptimizationAlgorithm::kMdlAdagradLight:
return "MDL Adagrad Light";
case OptimizationAlgorithm::kAdadelta:
return "Adadelta";
case OptimizationAlgorithm::kProximalAdagrad:
return "proximal Adagrad";
case OptimizationAlgorithm::kOnlineYogi:
return "online Yogi";
case OptimizationAlgorithm::kProximalYogi:
return "proximal Yogi";
case OptimizationAlgorithm::kFrequencyEstimator:
return "frequency estimator";
case OptimizationAlgorithm::kUserDefinedProgram:
return "UserDefinedProgram";
case OptimizationAlgorithm::kAssign:
return "Assign";
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return "unknown (not specified)";
}
return "unknown (not specified)";
}
// Returns the number of optimization parameter vectors used by the optimization
// algorithm, excluding the weights themselves and assuming no gradient
// accumulation.
Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
int* count) {
switch (params.parameters_case()) {
case OptimizationAlgorithm::kAdagrad:
*count = 1;
return Status::OK();
case OptimizationAlgorithm::kAdagradMomentum:
*count = 2;
return Status::OK();
case OptimizationAlgorithm::kBoundedAdagrad:
*count = 1;
return Status::OK();
case OptimizationAlgorithm::kStochasticGradientDescent:
*count = 0;
return Status::OK();
case OptimizationAlgorithm::kFtrl:
*count = 2;
return Status::OK();
case OptimizationAlgorithm::kAdam:
*count = 2;
return Status::OK();
case OptimizationAlgorithm::kMomentum:
*count = 1;
return Status::OK();
case OptimizationAlgorithm::kRmsProp:
*count = 2;
return Status::OK();
case OptimizationAlgorithm::kCenteredRmsProp:
*count = 3;
return Status::OK();
case OptimizationAlgorithm::kMdlAdagradLight:
*count = 3;
return Status::OK();
case OptimizationAlgorithm::kAdadelta:
*count = 2;
return Status::OK();
case OptimizationAlgorithm::kProximalAdagrad:
*count = 1;
return Status::OK();
case OptimizationAlgorithm::kOnlineYogi:
*count = 2;
return Status::OK();
case OptimizationAlgorithm::kProximalYogi:
*count = 2;
return Status::OK();
case OptimizationAlgorithm::kFrequencyEstimator:
*count = 1;
return Status::OK();
case OptimizationAlgorithm::kUserDefinedProgram: {
const xla::ProgramShapeProto& program_shape =
params.user_defined_program().program().host_program_shape();
const int num_inputs = program_shape.parameters_size();
const int num_outputs = program_shape.result().tuple_shapes_size();
if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
(num_inputs != num_outputs + 2))) {
return errors::InvalidArgument(
"User-defined TPU embedding optimizer program must have at least "
"two inputs and the number of outputs must be 1 or 2 less than the "
"number of inputs. Received ",
num_inputs, " input(s) and ", num_outputs, "output(s).");
}
*count = num_outputs - 1;
return Status::OK();
}
case OptimizationAlgorithm::kAssign:
*count = 0;
return Status::OK();
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return errors::InvalidArgument("No optimization algorithm specified");
}
return errors::InvalidArgument("No optimization algorithm specified");
}
Status GetGradientAccumulationSupport(const OptimizationParameters& params,
GradientAccumulationSupport* support) {
int auxiliary_parameter_count;
TF_RETURN_IF_ERROR(
GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
? GradientAccumulationSupport::kSupported
: GradientAccumulationSupport::kNotSupported;
return Status::OK();
}
Status UseGradientAccumulation(const OptimizationParameters& params,
bool* use_gradient_accumulation) {
GradientAccumulationSupport support;
TF_RETURN_IF_ERROR(GetGradientAccumulationSupport(params, &support));
bool raw_gradient_accumulation_status = false;
switch (params.gradient_accumulation_status()) {
case GradientAccumulationStatus::UNSPECIFIED: {
// Default is now to turn gradient accumulation on by default.
raw_gradient_accumulation_status = true;
break;
}
case GradientAccumulationStatus::DISABLED: {
raw_gradient_accumulation_status = false;
break;
}
case GradientAccumulationStatus::ENABLED: {
raw_gradient_accumulation_status = true;
break;
}
default:
return errors::Internal(
absl::StrCat("Unsupported gradient accumulation status ",
GradientAccumulationStatus_Status_Name(
params.gradient_accumulation_status())));
}
switch (support) {
case GradientAccumulationSupport::kSupported: {
*use_gradient_accumulation = raw_gradient_accumulation_status;
break;
}
case GradientAccumulationSupport::kNotSupported: {
if (raw_gradient_accumulation_status) {
return errors::InvalidArgument(strings::Printf(
"Optimization algorithm %s does not support gradient accumulation "
"but parameters specify it.",
GetOptimizationAlgorithmName(params.parameters_case()).c_str()));
}
*use_gradient_accumulation = false;
break;
}
}
return Status::OK();
}
Status GetOptimizationAlgorithmStateVariables(
const OptimizationParameters& params,
std::vector<StateVariableSpecification>* state_variables) {
// The parameter set for the weights themselves is required to be named
// "parameters". The rest should stay stable for compatibility. There is an
// internal function, GetOptimizationAlgorithmStateVariableInternalIndices,
// that needs to be updated along with this one.
bool use_gradient_accumulation;
TF_RETURN_IF_ERROR(
UseGradientAccumulation(params, &use_gradient_accumulation));
auto add_state_variable = [&](const std::string& name) {
StateVariableSpecification spec;
spec.set_name(name);
(void)spec.mutable_user_defined();
state_variables->push_back(spec);
};
switch (params.parameters_case()) {
case OptimizationAlgorithm::kAdagrad: {
add_state_variable("parameters");
add_state_variable("accumulators");
break;
}
case OptimizationAlgorithm::kAdagradMomentum: {
add_state_variable("parameters");
add_state_variable("accumulators");
add_state_variable("momenta");
break;
}
case OptimizationAlgorithm::kBoundedAdagrad: {
add_state_variable("parameters");
add_state_variable("accumulators");
break;
}
case OptimizationAlgorithm::kStochasticGradientDescent: {
add_state_variable("parameters");
break;
}
case OptimizationAlgorithm::kFtrl: {
add_state_variable("parameters");
add_state_variable("accumulators");
add_state_variable("linears");
break;
}
case OptimizationAlgorithm::kAdam: {
add_state_variable("parameters");
add_state_variable("momenta");
add_state_variable("velocities");
break;
}
case OptimizationAlgorithm::kMomentum: {
add_state_variable("parameters");
add_state_variable("momenta");
break;
}
case OptimizationAlgorithm::kRmsProp: {
add_state_variable("parameters");
add_state_variable("ms");
add_state_variable("mom");
break;
}
case OptimizationAlgorithm::kCenteredRmsProp: {
add_state_variable("parameters");
add_state_variable("ms");
add_state_variable("mom");
add_state_variable("mg");
break;
}
case OptimizationAlgorithm::kMdlAdagradLight: {
add_state_variable("parameters");
add_state_variable("accumulators");
add_state_variable("weights");
add_state_variable("benefits");
break;
}
case OptimizationAlgorithm::kAdadelta: {
add_state_variable("parameters");
add_state_variable("accumulators");
add_state_variable("updates");
break;
}
case OptimizationAlgorithm::kProximalAdagrad: {
add_state_variable("parameters");
add_state_variable("accumulators");
break;
}
case OptimizationAlgorithm::kOnlineYogi: {
add_state_variable("parameters");
add_state_variable("vs");
add_state_variable("linears");
break;
}
case OptimizationAlgorithm::kProximalYogi: {
add_state_variable("parameters");
add_state_variable("v");
add_state_variable("m");
break;
}
case OptimizationAlgorithm::kFrequencyEstimator: {
add_state_variable("parameters");
add_state_variable("last_hit_step");
break;
}
case OptimizationAlgorithm::kUserDefinedProgram: {
add_state_variable("parameters");
int num_slots = -1;
TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
for (int i = 0; i < num_slots; ++i) {
add_state_variable(absl::StrCat("Slot_", i));
}
break;
}
case OptimizationAlgorithm::kAssign: {
add_state_variable("parameters");
break;
}
case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
return errors::InvalidArgument("No optimization algorithm specified");
}
}
// This needs to be last for compatibility.
if (use_gradient_accumulation) {
StateVariableSpecification gradient_acc;
gradient_acc.set_name("gradient_accumulators");
gradient_acc.mutable_fill_with_constant()->set_initial_value(
GradientAccumulatorInitialValue());
state_variables->push_back(std::move(gradient_acc));
}
if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
return errors::InvalidArgument(
"Optimization algorithm",
GetOptimizationAlgorithmName(params.parameters_case()),
"does not support gradient accumulation because it "
"already has too many other accumulators");
}
return Status::OK();
}
std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
return {
OptimizationAlgorithm::kAdagrad,
OptimizationAlgorithm::kAdagradMomentum,
OptimizationAlgorithm::kBoundedAdagrad,
OptimizationAlgorithm::kStochasticGradientDescent,
OptimizationAlgorithm::kFtrl,
OptimizationAlgorithm::kAdam,
OptimizationAlgorithm::kMomentum,
OptimizationAlgorithm::kRmsProp,
OptimizationAlgorithm::kCenteredRmsProp,
OptimizationAlgorithm::kMdlAdagradLight,
OptimizationAlgorithm::kAdadelta,
OptimizationAlgorithm::kProximalAdagrad,
OptimizationAlgorithm::kOnlineYogi,
OptimizationAlgorithm::kProximalYogi,
OptimizationAlgorithm::kFrequencyEstimator,
OptimizationAlgorithm::kUserDefinedProgram,
OptimizationAlgorithm::kAssign,
};
}
Status LoadOpShapeFunction::operator()(
shape_inference::InferenceContext* c) const {
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
// Exactly one must be non-default.
if ((table_id >= 0) == (!table_name.empty())) {
return errors::InvalidArgument(
"exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
int shard_id;
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
// Verify shapes have rank 2 and are compatible when they are
// required to be valid.
shape_inference::ShapeHandle parameter_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &parameter_shape));
for (int j = 1; j < c->num_inputs(); ++j) {
shape_inference::ShapeHandle accumulator_j_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
shape_inference::ShapeHandle merged;
TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
}
return Status::OK();
}
Status RetrieveOpShapeFunction::operator()(
shape_inference::InferenceContext* c) const {
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
// Exactly one must be non-default.
if ((table_id >= 0) == (!table_name.empty())) {
return errors::InvalidArgument(
"exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
int shard_id;
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
for (int j = 0; j < c->num_outputs(); ++j) {
c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
2, c->UnknownDim())));
}
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow