TPU embedding implementation of the proximal yogi optimizer.
PiperOrigin-RevId: 271680647
diff --git a/tensorflow/core/protobuf/tpu/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
index f52f7bf..778a97e 100644
--- a/tensorflow/core/protobuf/tpu/optimization_parameters.proto
+++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
@@ -202,12 +202,6 @@
// \beta_2 from Algorithm 2 in the paper.
float beta2 = 3;
- // Initial value of V variable in paper.
- float initial_v = 4;
-
- // Initial value of linear variable in FTRL.
- float initial_linear = 5;
-
// x -> copysign(1, x) (i.e., return 1 for an input of +0 rather than 0).
message SignActivation {}
@@ -222,6 +216,45 @@
}
}
+// The online Yogi optimizer does not implement hyper-parameter update; use the
+// dynamic learning rate feature instead, setting the learning rate to:
+// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
+// Here, t is the current timestep.
+//
+// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
+// plus some extensions based on FTRL.
+//
+// Note that the code by default implements the lazy version of proximal Yogi.
+message ProximalYogiParameters {
+ // The L1 regularization parameter.
+ float l1 = 1;
+
+ // The L2 regularization parameter.
+ float l2 = 2;
+
+ // The exponential decay rate for the 1st moment estimates.
+ float beta1 = 3;
+
+ // The exponential decay rate for the 2nd moment estimates.
+ float beta2 = 4;
+
+ // A constant trading off adaptivity and noise.
+ float epsilon = 5;
+
+ // x -> copysign(1, x) (i.e., return 1 for an input of +0 rather than 0).
+ message SignActivation {}
+
+ // x -> tanh(x * 10)
+ message TanhActivation {}
+
+ // Activation to use to replace sign function in v_t update in Algorithm 2 of
+ // paper.
+ oneof activation {
+ SignActivation sign = 8;
+ TanhActivation tanh = 9;
+ }
+}
+
// Status of using gradient accumulation (doing two passes over the input
// gradients: one to accumulate them into a temporary array and another to apply
// them using the actual optimization algorithm). The extra message is to wrap
@@ -293,6 +326,7 @@
AdadeltaParameters adadelta = 12;
ProximalAdagradParameters proximal_adagrad = 14;
OnlineYogiParameters online_yogi = 20;
+ ProximalYogiParameters proximal_yogi = 21;
}
reserved 15; // Old use_gradient_accumulation.
diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
index d2f34549..9ce394f 100644
--- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
+++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
@@ -49,6 +49,8 @@
return "ProximalAdagrad";
case OptimizationAlgorithm::kOnlineYogi:
return "OnlineYogi";
+ case OptimizationAlgorithm::kProximalYogi:
+ return "ProximalYogi";
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return "*** Not set ***";
}
@@ -81,6 +83,8 @@
return "proximal Adagrad";
case OptimizationAlgorithm::kOnlineYogi:
return "online Yogi";
+ case OptimizationAlgorithm::kProximalYogi:
+ return "proximal Yogi";
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return "unknown (not specified)";
}
@@ -128,6 +132,9 @@
case OptimizationAlgorithm::kOnlineYogi:
*count = 2;
return Status::OK();
+ case OptimizationAlgorithm::kProximalYogi:
+ *count = 2;
+ return Status::OK();
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return errors::InvalidArgument("No optimization algorithm specified");
}
@@ -256,6 +263,13 @@
MakeStandardStateVariableSpecification("linears", 0.0));
break;
}
+ case OptimizationAlgorithm::kProximalYogi: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("v", 0.0));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("m", 0.0));
+ break;
+ }
case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
return errors::InvalidArgument("No optimization algorithm specified");
}
@@ -292,6 +306,7 @@
OptimizationAlgorithm::kAdadelta,
OptimizationAlgorithm::kProximalAdagrad,
OptimizationAlgorithm::kOnlineYogi,
+ OptimizationAlgorithm::kProximalYogi,
};
}
@@ -536,7 +551,8 @@
return Status::OK();
}
case OptimizationAlgorithm::kBoundedAdagrad:
- case OptimizationAlgorithm::kOnlineYogi: {
+ case OptimizationAlgorithm::kOnlineYogi:
+ case OptimizationAlgorithm::kProximalYogi: {
*internal = true;
return Status::OK();
}