TPU embedding implementation of the proximal yogi optimizer.

PiperOrigin-RevId: 271680647
diff --git a/tensorflow/core/protobuf/tpu/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
index f52f7bf..778a97e 100644
--- a/tensorflow/core/protobuf/tpu/optimization_parameters.proto
+++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
@@ -202,12 +202,6 @@
   // \beta_2 from Algorithm 2 in the paper.
   float beta2 = 3;
 
-  // Initial value of V variable in paper.
-  float initial_v = 4;
-
-  // Initial value of linear variable in FTRL.
-  float initial_linear = 5;
-
   // x -> copysign(1, x) (i.e., return 1 for an input of +0 rather than 0).
   message SignActivation {}
 
@@ -222,6 +216,45 @@
   }
 }
 
+// The online Yogi optimizer does not implement hyper-parameter update; use the
+// dynamic learning rate feature instead, setting the learning rate to:
+// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
+// Here, t is the current timestep.
+//
+// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
+// plus some extensions based on FTRL.
+//
+// Note that the code by default implements the lazy version of proximal Yogi.
+message ProximalYogiParameters {
+  // The L1 regularization parameter.
+  float l1 = 1;
+
+  // The L2 regularization parameter.
+  float l2 = 2;
+
+  // The exponential decay rate for the 1st moment estimates.
+  float beta1 = 3;
+
+  // The exponential decay rate for the 2nd moment estimates.
+  float beta2 = 4;
+
+  // A constant trading off adaptivity and noise.
+  float epsilon = 5;
+
+  // x -> copysign(1, x) (i.e., return 1 for an input of +0 rather than 0).
+  message SignActivation {}
+
+  // x -> tanh(x * 10)
+  message TanhActivation {}
+
+  // Activation to use to replace sign function in v_t update in Algorithm 2 of
+  // paper.
+  oneof activation {
+    SignActivation sign = 8;
+    TanhActivation tanh = 9;
+  }
+}
+
 // Status of using gradient accumulation (doing two passes over the input
 // gradients: one to accumulate them into a temporary array and another to apply
 // them using the actual optimization algorithm). The extra message is to wrap
@@ -293,6 +326,7 @@
     AdadeltaParameters adadelta = 12;
     ProximalAdagradParameters proximal_adagrad = 14;
     OnlineYogiParameters online_yogi = 20;
+    ProximalYogiParameters proximal_yogi = 21;
   }
 
   reserved 15;  // Old use_gradient_accumulation.
diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
index d2f34549..9ce394f 100644
--- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
+++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
@@ -49,6 +49,8 @@
       return "ProximalAdagrad";
     case OptimizationAlgorithm::kOnlineYogi:
       return "OnlineYogi";
+    case OptimizationAlgorithm::kProximalYogi:
+      return "ProximalYogi";
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return "*** Not set ***";
   }
@@ -81,6 +83,8 @@
       return "proximal Adagrad";
     case OptimizationAlgorithm::kOnlineYogi:
       return "online Yogi";
+    case OptimizationAlgorithm::kProximalYogi:
+      return "proximal Yogi";
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return "unknown (not specified)";
   }
@@ -128,6 +132,9 @@
     case OptimizationAlgorithm::kOnlineYogi:
       *count = 2;
       return Status::OK();
+    case OptimizationAlgorithm::kProximalYogi:
+      *count = 2;
+      return Status::OK();
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return errors::InvalidArgument("No optimization algorithm specified");
   }
@@ -256,6 +263,13 @@
           MakeStandardStateVariableSpecification("linears", 0.0));
       break;
     }
+    case OptimizationAlgorithm::kProximalYogi: {
+      state_variables->push_back(
+          MakeStandardStateVariableSpecification("v", 0.0));
+      state_variables->push_back(
+          MakeStandardStateVariableSpecification("m", 0.0));
+      break;
+    }
     case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
       return errors::InvalidArgument("No optimization algorithm specified");
     }
@@ -292,6 +306,7 @@
       OptimizationAlgorithm::kAdadelta,
       OptimizationAlgorithm::kProximalAdagrad,
       OptimizationAlgorithm::kOnlineYogi,
+      OptimizationAlgorithm::kProximalYogi,
   };
 }
 
@@ -536,7 +551,8 @@
       return Status::OK();
     }
     case OptimizationAlgorithm::kBoundedAdagrad:
-    case OptimizationAlgorithm::kOnlineYogi: {
+    case OptimizationAlgorithm::kOnlineYogi:
+    case OptimizationAlgorithm::kProximalYogi: {
       *internal = true;
       return Status::OK();
     }