Only used fused avg mean and variance update in eager mode until we fix a few modules that depend on legacy behavior.

PiperOrigin-RevId: 300392015
Change-Id: I3c5ed0da92a58fc6eca594e47768049ce837b9fa
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 673ea98..f291da0 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -20,6 +20,7 @@
 
 from tensorflow.python.compat import compat
 from tensorflow.python.distribute import distribution_strategy_context
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -540,7 +541,12 @@
     else:
       inputs_size = None
 
-    if compat.forward_compatible(2020, 3, 6):
+    # TODO(rmlarsen): Support using fused avg updates for non-eager execution
+    # after fixing graph pattern matching and enabling fused_batch_norm to
+    # take exponential_avg_factor as a tensor input.
+    use_fused_avg_updates = (compat.forward_compatible(2020, 3, 6) and
+                             context.executing_eagerly())
+    if use_fused_avg_updates:
       exponential_avg_factor = 1.0 - self.momentum
     else:
       exponential_avg_factor = None
@@ -591,7 +597,7 @@
           data_format=self._data_format)
 
     train_op = _fused_batch_norm_training
-    if compat.forward_compatible(2020, 3, 6) and inputs_size is not None:
+    if use_fused_avg_updates and inputs_size is not None:
       train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
                                              _fused_batch_norm_training,
                                              _fused_batch_norm_training_empty)
@@ -602,7 +608,7 @@
 
     training_value = tf_utils.constant_value(training)
     if training_value or training_value is None:
-      if not compat.forward_compatible(2020, 3, 6):
+      if not use_fused_avg_updates:
         if training_value is None:
           momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                          lambda: 1.0)
@@ -611,7 +617,7 @@
 
       def mean_update():
         """Update self.moving_mean with the most recent data point."""
-        if compat.forward_compatible(2020, 3, 6):
+        if use_fused_avg_updates:
           return self._assign_new_value(self.moving_mean, mean)
         else:
           return self._assign_moving_average(self.moving_mean, mean, momentum,
@@ -619,7 +625,7 @@
 
       def variance_update():
         """Update self.moving_variance with the most recent data point."""
-        if compat.forward_compatible(2020, 3, 6):
+        if use_fused_avg_updates:
           return self._assign_new_value(self.moving_variance, variance)
         else:
           return self._assign_moving_average(self.moving_variance, variance,