Added beta parameter from FTRL paper to main optimizer class.
PiperOrigin-RevId: 325290409
Change-Id: I0aa85a26b188b9ab3e1faa7462dca4c5d81f8712
diff --git a/RELEASE.md b/RELEASE.md
index 62bdc11..0eb673b 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -95,6 +95,7 @@
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
+ * Added `beta` parameter to FTRL optimizer to match paper.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
index c7b3867..6c8a6ce 100644
--- a/tensorflow/python/training/ftrl.py
+++ b/tensorflow/python/training/ftrl.py
@@ -49,7 +49,8 @@
name="Ftrl",
accum_name=None,
linear_name=None,
- l2_shrinkage_regularization_strength=0.0):
+ l2_shrinkage_regularization_strength=0.0,
+ beta=None):
r"""Construct a new FTRL optimizer.
Args:
@@ -79,10 +80,11 @@
function w.r.t. the weights w.
Specifically, in the absence of L1 regularization, it is equivalent to
the following update rule:
- w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t -
- 2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t
+ w_{t+1} = w_t - lr_t / (beta + 2*L2*lr_t) * g_t -
+ 2*L2_shrinkage*lr_t / (beta + 2*L2*lr_t) * w_t
where lr_t is the learning rate at t.
When input is sparse shrinkage will only happen on the active weights.
+ beta: A float value; corresponds to the beta parameter in the paper.
Raises:
ValueError: If one of the arguments is invalid.
@@ -119,12 +121,13 @@
self._initial_accumulator_value = initial_accumulator_value
self._l1_regularization_strength = l1_regularization_strength
self._l2_regularization_strength = l2_regularization_strength
+ self._beta = (0.0 if beta is None else beta)
self._l2_shrinkage_regularization_strength = (
l2_shrinkage_regularization_strength)
self._learning_rate_tensor = None
self._learning_rate_power_tensor = None
self._l1_regularization_strength_tensor = None
- self._l2_regularization_strength_tensor = None
+ self._adjusted_l2_regularization_strength_tensor = None
self._l2_shrinkage_regularization_strength_tensor = None
self._accum_name = accum_name
self._linear_name = linear_name
@@ -142,8 +145,14 @@
self._learning_rate, name="learning_rate")
self._l1_regularization_strength_tensor = ops.convert_to_tensor(
self._l1_regularization_strength, name="l1_regularization_strength")
- self._l2_regularization_strength_tensor = ops.convert_to_tensor(
- self._l2_regularization_strength, name="l2_regularization_strength")
+ # L2 regularization strength with beta added in so that the underlying
+ # TensorFlow ops do not need to include that parameter.
+ self._adjusted_l2_regularization_strength_tensor = ops.convert_to_tensor(
+ self._l2_regularization_strength + self._beta /
+ (2. * self._learning_rate),
+ name="adjusted_l2_regularization_strength")
+ assert self._adjusted_l2_regularization_strength_tensor is not None
+ self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta")
self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor(
self._l2_shrinkage_regularization_strength,
name="l2_shrinkage_regularization_strength")
@@ -162,7 +171,7 @@
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype),
- math_ops.cast(self._l2_regularization_strength_tensor,
+ math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype),
math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
use_locking=self._use_locking)
@@ -175,7 +184,7 @@
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype),
- math_ops.cast(self._l2_regularization_strength_tensor,
+ math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype),
math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
var.dtype.base_dtype),
@@ -194,7 +203,7 @@
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype),
- math_ops.cast(self._l2_regularization_strength_tensor,
+ math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype),
math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
use_locking=self._use_locking)
@@ -207,7 +216,7 @@
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype),
- math_ops.cast(self._l2_regularization_strength_tensor,
+ math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype),
math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
var.dtype.base_dtype),
@@ -227,7 +236,7 @@
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype),
- math_ops.cast(self._l2_regularization_strength_tensor,
+ math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype),
math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
use_locking=self._use_locking)
@@ -241,7 +250,7 @@
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype),
- math_ops.cast(self._l2_regularization_strength_tensor,
+ math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype),
math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
grad.dtype.base_dtype),
@@ -260,7 +269,8 @@
indices,
math_ops.cast(self._learning_rate_tensor, grad.dtype),
math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
- math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype),
+ math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
+ grad.dtype),
math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
use_locking=self._use_locking)
else:
@@ -272,7 +282,8 @@
indices,
math_ops.cast(self._learning_rate_tensor, grad.dtype),
math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
- math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype),
+ math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
+ grad.dtype),
math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
grad.dtype),
math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index f0cbe13..ff1bf17 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -161,6 +161,65 @@
self.assertAllCloseAccordingToType(
np.array([-0.93460727, -1.86147261]), v1_val)
+ def testFtrlWithBeta(self):
+ # The v1 optimizers do not support eager execution
+ with ops.Graph().as_default():
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = ftrl.FtrlOptimizer(3.0, initial_accumulator_value=0.1, beta=0.1)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-6.096838, -9.162214]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([-0.717741, -1.425132]), v1_val)
+
+ def testFtrlWithL2_Beta(self):
+ # The v1 optimizers do not support eager execution
+ with ops.Graph().as_default():
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.1,
+ beta=0.1)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-2.735487, -4.704625]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([-0.294335, -0.586556]), v1_val)
+
def testFtrlWithL1_L2(self):
# The v1 optimizers do not support eager execution
with ops.Graph().as_default():
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt
index 1d1aceb..9e12ae9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt
@@ -18,7 +18,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\'], "
+ argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\', \'None\'], "
}
member_method {
name: "apply_gradients"