Move keras related summary_ops_test to keras/tests.

PiperOrigin-RevId: 306560299
Change-Id: I100c3e23973276bc4f395c46b08705c0f277fb3b
diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD
index 15a6f6b..87589af 100644
--- a/tensorflow/python/keras/layers/BUILD
+++ b/tensorflow/python/keras/layers/BUILD
@@ -10,7 +10,6 @@
         "//tensorflow/python/distribute:__pkg__",
         "//tensorflow/python/feature_column:__pkg__",
         "//tensorflow/python/keras:__subpackages__",
-        "//tensorflow/python/kernel_tests:__pkg__",
         "//tensorflow/python/training/tracking:__pkg__",
         "//tensorflow/tools/pip_package:__pkg__",
     ],
diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD
index 79d82ae..16d3d84 100644
--- a/tensorflow/python/keras/tests/BUILD
+++ b/tensorflow/python/keras/tests/BUILD
@@ -245,6 +245,22 @@
     ],
 )
 
+cuda_py_test(
+    name = "summary_ops_test",
+    size = "small",
+    srcs = ["summary_ops_test.py"],
+    deps = [
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:lib",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:summary_ops_v2",
+        "//tensorflow/python/keras/engine",
+        "//tensorflow/python/keras/layers:core",
+    ],
+)
+
 tf_py_test(
     name = "temporal_sample_weights_correctness_test",
     srcs = ["temporal_sample_weights_correctness_test.py"],
diff --git a/tensorflow/python/keras/tests/summary_ops_test.py b/tensorflow/python/keras/tests/summary_ops_test.py
new file mode 100644
index 0000000..a62abdc
--- /dev/null
+++ b/tensorflow/python/keras/tests/summary_ops_test.py
@@ -0,0 +1,147 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for V2 summary ops from summary_ops_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.core.util import event_pb2
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine.sequential import Sequential
+from tensorflow.python.keras.engine.training import Model
+from tensorflow.python.keras.layers.core import Activation
+from tensorflow.python.keras.layers.core import Dense
+from tensorflow.python.lib.io import tf_record
+from tensorflow.python.ops import summary_ops_v2 as summary_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
+
+
+class SummaryOpsTest(test_util.TensorFlowTestCase):
+
+  def tearDown(self):
+    super(SummaryOpsTest, self).tearDown()
+    summary_ops.trace_off()
+
+  def keras_model(self, *args, **kwargs):
+    logdir = self.get_temp_dir()
+    writer = summary_ops.create_file_writer(logdir)
+    with writer.as_default():
+      summary_ops.keras_model(*args, **kwargs)
+    writer.close()
+    events = events_from_logdir(logdir)
+    # The first event contains no summary values. The written content goes to
+    # the second event.
+    return events[1]
+
+  @test_util.run_v2_only
+  def testKerasModel(self):
+    model = Sequential(
+        [Dense(10, input_shape=(100,)),
+         Activation('relu', name='my_relu')])
+    event = self.keras_model(name='my_name', data=model, step=1)
+    first_val = event.summary.value[0]
+    self.assertEqual(model.to_json(), first_val.tensor.string_val[0].decode())
+
+  @test_util.run_v2_only
+  def testKerasModel_usesDefaultStep(self):
+    model = Sequential(
+        [Dense(10, input_shape=(100,)),
+         Activation('relu', name='my_relu')])
+    try:
+      summary_ops.set_step(42)
+      event = self.keras_model(name='my_name', data=model)
+      self.assertEqual(42, event.step)
+    finally:
+      # Reset to default state for other tests.
+      summary_ops.set_step(None)
+
+  @test_util.run_v2_only
+  def testKerasModel_subclass(self):
+
+    class SimpleSubclass(Model):
+
+      def __init__(self):
+        super(SimpleSubclass, self).__init__(name='subclass')
+        self.dense = Dense(10, input_shape=(100,))
+        self.activation = Activation('relu', name='my_relu')
+
+      def call(self, inputs):
+        x = self.dense(inputs)
+        return self.activation(x)
+
+    model = SimpleSubclass()
+    with test.mock.patch.object(logging, 'warn') as mock_log:
+      self.assertFalse(
+          summary_ops.keras_model(name='my_name', data=model, step=1))
+      self.assertRegexpMatches(
+          str(mock_log.call_args), 'Model failed to serialize as JSON.')
+
+  @test_util.run_v2_only
+  def testKerasModel_otherExceptions(self):
+    model = Sequential()
+
+    with test.mock.patch.object(model, 'to_json') as mock_to_json:
+      with test.mock.patch.object(logging, 'warn') as mock_log:
+        mock_to_json.side_effect = Exception('oops')
+        self.assertFalse(
+            summary_ops.keras_model(name='my_name', data=model, step=1))
+        self.assertRegexpMatches(
+            str(mock_log.call_args),
+            'Model failed to serialize as JSON. Ignoring... oops')
+
+
+def events_from_file(filepath):
+  """Returns all events in a single event file.
+
+  Args:
+    filepath: Path to the event file.
+
+  Returns:
+    A list of all tf.Event protos in the event file.
+  """
+  records = list(tf_record.tf_record_iterator(filepath))
+  result = []
+  for r in records:
+    event = event_pb2.Event()
+    event.ParseFromString(r)
+    result.append(event)
+  return result
+
+
+def events_from_logdir(logdir):
+  """Returns all events in the single eventfile in logdir.
+
+  Args:
+    logdir: The directory in which the single event file is sought.
+
+  Returns:
+    A list of all tf.Event protos from the single event file.
+
+  Raises:
+    AssertionError: If logdir does not contain exactly one file.
+  """
+  assert gfile.Exists(logdir)
+  files = gfile.ListDirectory(logdir)
+  assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
+  return events_from_file(os.path.join(logdir, files[0]))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 7efb0b8..7d32085 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1215,8 +1215,6 @@
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:function",
-        "//tensorflow/python/keras:engine",
-        "//tensorflow/python/keras/layers",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index 8144c007..387083c 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -38,10 +38,6 @@
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_util
-from tensorflow.python.keras.engine.sequential import Sequential
-from tensorflow.python.keras.engine.training import Model
-from tensorflow.python.keras.layers.core import Activation
-from tensorflow.python.keras.layers.core import Dense
 from tensorflow.python.lib.io import tf_record
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import summary_ops_v2 as summary_ops
@@ -915,17 +911,6 @@
         ],
         step_stats=step_stats)
 
-  def keras_model(self, *args, **kwargs):
-    logdir = self.get_temp_dir()
-    writer = summary_ops.create_file_writer(logdir)
-    with writer.as_default():
-      summary_ops.keras_model(*args, **kwargs)
-    writer.close()
-    events = events_from_logdir(logdir)
-    # The first event contains no summary values. The written content goes to
-    # the second event.
-    return events[1]
-
   def run_trace(self, f, step=1):
     assert context.executing_eagerly()
     logdir = self.get_temp_dir()
@@ -1054,62 +1039,6 @@
       summary_ops.set_step(None)
 
   @test_util.run_v2_only
-  def testKerasModel(self):
-    model = Sequential(
-        [Dense(10, input_shape=(100,)),
-         Activation('relu', name='my_relu')])
-    event = self.keras_model(name='my_name', data=model, step=1)
-    first_val = event.summary.value[0]
-    self.assertEqual(model.to_json(), first_val.tensor.string_val[0].decode())
-
-  @test_util.run_v2_only
-  def testKerasModel_usesDefaultStep(self):
-    model = Sequential(
-        [Dense(10, input_shape=(100,)),
-         Activation('relu', name='my_relu')])
-    try:
-      summary_ops.set_step(42)
-      event = self.keras_model(name='my_name', data=model)
-      self.assertEqual(42, event.step)
-    finally:
-      # Reset to default state for other tests.
-      summary_ops.set_step(None)
-
-  @test_util.run_v2_only
-  def testKerasModel_subclass(self):
-
-    class SimpleSubclass(Model):
-
-      def __init__(self):
-        super(SimpleSubclass, self).__init__(name='subclass')
-        self.dense = Dense(10, input_shape=(100,))
-        self.activation = Activation('relu', name='my_relu')
-
-      def call(self, inputs):
-        x = self.dense(inputs)
-        return self.activation(x)
-
-    model = SimpleSubclass()
-    with test.mock.patch.object(logging, 'warn') as mock_log:
-      self.assertFalse(
-          summary_ops.keras_model(name='my_name', data=model, step=1))
-      self.assertRegexpMatches(
-          str(mock_log.call_args), 'Model failed to serialize as JSON.')
-
-  @test_util.run_v2_only
-  def testKerasModel_otherExceptions(self):
-    model = Sequential()
-
-    with test.mock.patch.object(model, 'to_json') as mock_to_json:
-      with test.mock.patch.object(logging, 'warn') as mock_log:
-        mock_to_json.side_effect = Exception('oops')
-        self.assertFalse(
-            summary_ops.keras_model(name='my_name', data=model, step=1))
-        self.assertRegexpMatches(
-            str(mock_log.call_args),
-            'Model failed to serialize as JSON. Ignoring... oops')
-
-  @test_util.run_v2_only
   def testTrace(self):
 
     @def_function.function