Only used fused avg mean and variance update in eager mode until we fix a few modules that depend on legacy behavior.
PiperOrigin-RevId: 300392015
Change-Id: I3c5ed0da92a58fc6eca594e47768049ce837b9fa
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 673ea98..f291da0 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -20,6 +20,7 @@
from tensorflow.python.compat import compat
from tensorflow.python.distribute import distribution_strategy_context
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -540,7 +541,12 @@
else:
inputs_size = None
- if compat.forward_compatible(2020, 3, 6):
+ # TODO(rmlarsen): Support using fused avg updates for non-eager execution
+ # after fixing graph pattern matching and enabling fused_batch_norm to
+ # take exponential_avg_factor as a tensor input.
+ use_fused_avg_updates = (compat.forward_compatible(2020, 3, 6) and
+ context.executing_eagerly())
+ if use_fused_avg_updates:
exponential_avg_factor = 1.0 - self.momentum
else:
exponential_avg_factor = None
@@ -591,7 +597,7 @@
data_format=self._data_format)
train_op = _fused_batch_norm_training
- if compat.forward_compatible(2020, 3, 6) and inputs_size is not None:
+ if use_fused_avg_updates and inputs_size is not None:
train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
_fused_batch_norm_training,
_fused_batch_norm_training_empty)
@@ -602,7 +608,7 @@
training_value = tf_utils.constant_value(training)
if training_value or training_value is None:
- if not compat.forward_compatible(2020, 3, 6):
+ if not use_fused_avg_updates:
if training_value is None:
momentum = tf_utils.smart_cond(training, lambda: self.momentum,
lambda: 1.0)
@@ -611,7 +617,7 @@
def mean_update():
"""Update self.moving_mean with the most recent data point."""
- if compat.forward_compatible(2020, 3, 6):
+ if use_fused_avg_updates:
return self._assign_new_value(self.moving_mean, mean)
else:
return self._assign_moving_average(self.moving_mean, mean, momentum,
@@ -619,7 +625,7 @@
def variance_update():
"""Update self.moving_variance with the most recent data point."""
- if compat.forward_compatible(2020, 3, 6):
+ if use_fused_avg_updates:
return self._assign_new_value(self.moving_variance, variance)
else:
return self._assign_moving_average(self.moving_variance, variance,