Added beta parameter from FTRL paper to main optimizer class.

PiperOrigin-RevId: 325290409
Change-Id: I0aa85a26b188b9ab3e1faa7462dca4c5d81f8712
diff --git a/RELEASE.md b/RELEASE.md
index 62bdc11..0eb673b 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -95,6 +95,7 @@
       * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
     * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
       as an alternative to accepting a `callable` loss.
+    * Added `beta` parameter to FTRL optimizer to match paper.
 * `tf.function` / AutoGraph:
   * Added `experimental_follow_type_hints` argument for `tf.function`. When
     True, the function may use type annotations to optimize the tracing
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
index c7b3867..6c8a6ce 100644
--- a/tensorflow/python/training/ftrl.py
+++ b/tensorflow/python/training/ftrl.py
@@ -49,7 +49,8 @@
                name="Ftrl",
                accum_name=None,
                linear_name=None,
-               l2_shrinkage_regularization_strength=0.0):
+               l2_shrinkage_regularization_strength=0.0,
+               beta=None):
     r"""Construct a new FTRL optimizer.
 
     Args:
@@ -79,10 +80,11 @@
         function w.r.t. the weights w.
         Specifically, in the absence of L1 regularization, it is equivalent to
         the following update rule:
-        w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t -
-                  2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t
+        w_{t+1} = w_t - lr_t / (beta + 2*L2*lr_t) * g_t -
+                  2*L2_shrinkage*lr_t / (beta + 2*L2*lr_t) * w_t
         where lr_t is the learning rate at t.
         When input is sparse shrinkage will only happen on the active weights.
+      beta: A float value; corresponds to the beta parameter in the paper.
 
     Raises:
       ValueError: If one of the arguments is invalid.
@@ -119,12 +121,13 @@
     self._initial_accumulator_value = initial_accumulator_value
     self._l1_regularization_strength = l1_regularization_strength
     self._l2_regularization_strength = l2_regularization_strength
+    self._beta = (0.0 if beta is None else beta)
     self._l2_shrinkage_regularization_strength = (
         l2_shrinkage_regularization_strength)
     self._learning_rate_tensor = None
     self._learning_rate_power_tensor = None
     self._l1_regularization_strength_tensor = None
-    self._l2_regularization_strength_tensor = None
+    self._adjusted_l2_regularization_strength_tensor = None
     self._l2_shrinkage_regularization_strength_tensor = None
     self._accum_name = accum_name
     self._linear_name = linear_name
@@ -142,8 +145,14 @@
         self._learning_rate, name="learning_rate")
     self._l1_regularization_strength_tensor = ops.convert_to_tensor(
         self._l1_regularization_strength, name="l1_regularization_strength")
-    self._l2_regularization_strength_tensor = ops.convert_to_tensor(
-        self._l2_regularization_strength, name="l2_regularization_strength")
+    # L2 regularization strength with beta added in so that the underlying
+    # TensorFlow ops do not need to include that parameter.
+    self._adjusted_l2_regularization_strength_tensor = ops.convert_to_tensor(
+        self._l2_regularization_strength + self._beta /
+        (2. * self._learning_rate),
+        name="adjusted_l2_regularization_strength")
+    assert self._adjusted_l2_regularization_strength_tensor is not None
+    self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta")
     self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor(
         self._l2_shrinkage_regularization_strength,
         name="l2_shrinkage_regularization_strength")
@@ -162,7 +171,7 @@
           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
           math_ops.cast(self._l1_regularization_strength_tensor,
                         var.dtype.base_dtype),
-          math_ops.cast(self._l2_regularization_strength_tensor,
+          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
                         var.dtype.base_dtype),
           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
           use_locking=self._use_locking)
@@ -175,7 +184,7 @@
           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
           math_ops.cast(self._l1_regularization_strength_tensor,
                         var.dtype.base_dtype),
-          math_ops.cast(self._l2_regularization_strength_tensor,
+          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
                         var.dtype.base_dtype),
           math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
                         var.dtype.base_dtype),
@@ -194,7 +203,7 @@
           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
           math_ops.cast(self._l1_regularization_strength_tensor,
                         var.dtype.base_dtype),
-          math_ops.cast(self._l2_regularization_strength_tensor,
+          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
                         var.dtype.base_dtype),
           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
           use_locking=self._use_locking)
@@ -207,7 +216,7 @@
           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
           math_ops.cast(self._l1_regularization_strength_tensor,
                         var.dtype.base_dtype),
-          math_ops.cast(self._l2_regularization_strength_tensor,
+          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
                         var.dtype.base_dtype),
           math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
                         var.dtype.base_dtype),
@@ -227,7 +236,7 @@
           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
           math_ops.cast(self._l1_regularization_strength_tensor,
                         var.dtype.base_dtype),
-          math_ops.cast(self._l2_regularization_strength_tensor,
+          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
                         var.dtype.base_dtype),
           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
           use_locking=self._use_locking)
@@ -241,7 +250,7 @@
           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
           math_ops.cast(self._l1_regularization_strength_tensor,
                         var.dtype.base_dtype),
-          math_ops.cast(self._l2_regularization_strength_tensor,
+          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
                         var.dtype.base_dtype),
           math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
                         grad.dtype.base_dtype),
@@ -260,7 +269,8 @@
           indices,
           math_ops.cast(self._learning_rate_tensor, grad.dtype),
           math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
-          math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype),
+          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
+                        grad.dtype),
           math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
           use_locking=self._use_locking)
     else:
@@ -272,7 +282,8 @@
           indices,
           math_ops.cast(self._learning_rate_tensor, grad.dtype),
           math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
-          math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype),
+          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
+                        grad.dtype),
           math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
                         grad.dtype),
           math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index f0cbe13..ff1bf17 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -161,6 +161,65 @@
           self.assertAllCloseAccordingToType(
               np.array([-0.93460727, -1.86147261]), v1_val)
 
+  def testFtrlWithBeta(self):
+    # The v1 optimizers do not support eager execution
+    with ops.Graph().as_default():
+      for dtype in [dtypes.half, dtypes.float32]:
+        with self.cached_session():
+          var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+          var1 = variables.Variable([4.0, 3.0], dtype=dtype)
+          grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+          grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+          opt = ftrl.FtrlOptimizer(3.0, initial_accumulator_value=0.1, beta=0.1)
+          update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+          self.evaluate(variables.global_variables_initializer())
+
+          v0_val, v1_val = self.evaluate([var0, var1])
+          self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+          self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
+
+          # Run 10 steps FTRL
+          for _ in range(10):
+            update.run()
+          v0_val, v1_val = self.evaluate([var0, var1])
+          self.assertAllCloseAccordingToType(
+              np.array([-6.096838, -9.162214]), v0_val)
+          self.assertAllCloseAccordingToType(
+              np.array([-0.717741, -1.425132]), v1_val)
+
+  def testFtrlWithL2_Beta(self):
+    # The v1 optimizers do not support eager execution
+    with ops.Graph().as_default():
+      for dtype in [dtypes.half, dtypes.float32]:
+        with self.cached_session():
+          var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+          var1 = variables.Variable([4.0, 3.0], dtype=dtype)
+          grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+          grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+          opt = ftrl.FtrlOptimizer(
+              3.0,
+              initial_accumulator_value=0.1,
+              l1_regularization_strength=0.0,
+              l2_regularization_strength=0.1,
+              beta=0.1)
+          update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+          self.evaluate(variables.global_variables_initializer())
+
+          v0_val, v1_val = self.evaluate([var0, var1])
+          self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+          self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
+
+          # Run 10 steps FTRL
+          for _ in range(10):
+            update.run()
+          v0_val, v1_val = self.evaluate([var0, var1])
+          self.assertAllCloseAccordingToType(
+              np.array([-2.735487, -4.704625]), v0_val)
+          self.assertAllCloseAccordingToType(
+              np.array([-0.294335, -0.586556]), v1_val)
+
   def testFtrlWithL1_L2(self):
     # The v1 optimizers do not support eager execution
     with ops.Graph().as_default():
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt
index 1d1aceb..9e12ae9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt
@@ -18,7 +18,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\'], "
+    argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\', \'None\'], "
   }
   member_method {
     name: "apply_gradients"