| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for training_coordinator.py.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import glob |
| import os.path |
| import shutil |
| import time |
| |
| from tensorflow.core.framework import graph_pb2 |
| from tensorflow.core.framework import summary_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import meta_graph_pb2 |
| from tensorflow.core.util import event_pb2 |
| from tensorflow.core.util.event_pb2 import SessionLog |
| from tensorflow.python.client import session |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import meta_graph |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import summary_ops_v2 |
| from tensorflow.python.platform import gfile |
| from tensorflow.python.platform import test |
| from tensorflow.python.summary import plugin_asset |
| from tensorflow.python.summary import summary_iterator |
| from tensorflow.python.summary.writer import writer |
| from tensorflow.python.summary.writer import writer_cache |
| from tensorflow.python.util import compat |
| |
| |
| class FileWriterTestCase(test.TestCase): |
| |
| def _FileWriter(self, *args, **kwargs): |
| return writer.FileWriter(*args, **kwargs) |
| |
| def _TestDir(self, test_name): |
| test_dir = os.path.join(self.get_temp_dir(), test_name) |
| return test_dir |
| |
| def _CleanTestDir(self, test_name): |
| test_dir = self._TestDir(test_name) |
| if os.path.exists(test_dir): |
| shutil.rmtree(test_dir) |
| return test_dir |
| |
| def _EventsReader(self, test_dir): |
| event_paths = glob.glob(os.path.join(test_dir, "event*")) |
| # If the tests runs multiple times in the same directory we can have |
| # more than one matching event file. We only want to read the last one. |
| self.assertTrue(event_paths) |
| return summary_iterator.summary_iterator(event_paths[-1]) |
| |
| def _assertRecent(self, t): |
| self.assertTrue(abs(t - time.time()) < 5) |
| |
| def _assertEventsWithGraph(self, test_dir, g, has_shapes): |
| meta_graph_def = meta_graph.create_meta_graph_def( |
| graph_def=g.as_graph_def(add_shapes=has_shapes)) |
| |
| rr = self._EventsReader(test_dir) |
| |
| # The first event should list the file_version. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals("brain.Event:2", ev.file_version) |
| |
| # The next event should have the graph. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(0, ev.step) |
| ev_graph = graph_pb2.GraphDef() |
| ev_graph.ParseFromString(ev.graph_def) |
| self.assertProtoEquals(g.as_graph_def(add_shapes=has_shapes), ev_graph) |
| |
| # The next event should have the metagraph. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(0, ev.step) |
| ev_meta_graph = meta_graph_pb2.MetaGraphDef() |
| ev_meta_graph.ParseFromString(ev.meta_graph_def) |
| self.assertProtoEquals(meta_graph_def, ev_meta_graph) |
| |
| # We should be done. |
| self.assertRaises(StopIteration, lambda: next(rr)) |
| |
| def testAddingSummaryGraphAndRunMetadata(self): |
| test_dir = self._CleanTestDir("basics") |
| sw = self._FileWriter(test_dir) |
| |
| sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1) |
| sw.add_summary( |
| summary_pb2.Summary( |
| value=[summary_pb2.Summary.Value( |
| tag="mee", simple_value=10.0)]), |
| 10) |
| sw.add_summary( |
| summary_pb2.Summary( |
| value=[summary_pb2.Summary.Value( |
| tag="boo", simple_value=20.0)]), |
| 20) |
| with ops.Graph().as_default() as g: |
| constant_op.constant([0], name="zero") |
| sw.add_graph(g, global_step=30) |
| |
| run_metadata = config_pb2.RunMetadata() |
| device_stats = run_metadata.step_stats.dev_stats.add() |
| device_stats.device = "test" |
| sw.add_run_metadata(run_metadata, "test run", global_step=40) |
| sw.close() |
| rr = self._EventsReader(test_dir) |
| |
| # The first event should list the file_version. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals("brain.Event:2", ev.file_version) |
| |
| # The next event should be the START message. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(1, ev.step) |
| self.assertEquals(SessionLog.START, ev.session_log.status) |
| |
| # The next event should have the value 'mee=10.0'. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(10, ev.step) |
| self.assertProtoEquals(""" |
| value { tag: 'mee' simple_value: 10.0 } |
| """, ev.summary) |
| |
| # The next event should have the value 'boo=20.0'. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(20, ev.step) |
| self.assertProtoEquals(""" |
| value { tag: 'boo' simple_value: 20.0 } |
| """, ev.summary) |
| |
| # The next event should have the graph_def. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(30, ev.step) |
| ev_graph = graph_pb2.GraphDef() |
| ev_graph.ParseFromString(ev.graph_def) |
| self.assertProtoEquals(g.as_graph_def(add_shapes=True), ev_graph) |
| |
| # The next event should have metadata for the run. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(40, ev.step) |
| self.assertEquals("test run", ev.tagged_run_metadata.tag) |
| parsed_run_metadata = config_pb2.RunMetadata() |
| parsed_run_metadata.ParseFromString(ev.tagged_run_metadata.run_metadata) |
| self.assertProtoEquals(run_metadata, parsed_run_metadata) |
| |
| # We should be done. |
| self.assertRaises(StopIteration, lambda: next(rr)) |
| |
| def testGraphAsNamed(self): |
| test_dir = self._CleanTestDir("basics_named_graph") |
| with ops.Graph().as_default() as g: |
| constant_op.constant([12], name="douze") |
| sw = self._FileWriter(test_dir, graph=g) |
| sw.close() |
| self._assertEventsWithGraph(test_dir, g, True) |
| |
| def testGraphAsPositional(self): |
| test_dir = self._CleanTestDir("basics_positional_graph") |
| with ops.Graph().as_default() as g: |
| constant_op.constant([12], name="douze") |
| sw = self._FileWriter(test_dir, g) |
| sw.close() |
| self._assertEventsWithGraph(test_dir, g, True) |
| |
| def testGraphDefAsNamed(self): |
| test_dir = self._CleanTestDir("basics_named_graph_def") |
| with ops.Graph().as_default() as g: |
| constant_op.constant([12], name="douze") |
| gd = g.as_graph_def() |
| sw = self._FileWriter(test_dir, graph_def=gd) |
| sw.close() |
| self._assertEventsWithGraph(test_dir, g, False) |
| |
| def testGraphDefAsPositional(self): |
| test_dir = self._CleanTestDir("basics_positional_graph_def") |
| with ops.Graph().as_default() as g: |
| constant_op.constant([12], name="douze") |
| gd = g.as_graph_def() |
| sw = self._FileWriter(test_dir, gd) |
| sw.close() |
| self._assertEventsWithGraph(test_dir, g, False) |
| |
| def testGraphAndGraphDef(self): |
| with self.assertRaises(ValueError): |
| test_dir = self._CleanTestDir("basics_graph_and_graph_def") |
| with ops.Graph().as_default() as g: |
| constant_op.constant([12], name="douze") |
| gd = g.as_graph_def() |
| sw = self._FileWriter(test_dir, graph=g, graph_def=gd) |
| sw.close() |
| |
| def testNeitherGraphNorGraphDef(self): |
| with self.assertRaises(TypeError): |
| test_dir = self._CleanTestDir("basics_string_instead_of_graph") |
| sw = self._FileWriter(test_dir, "string instead of graph object") |
| sw.close() |
| |
| def testCloseAndReopen(self): |
| test_dir = self._CleanTestDir("close_and_reopen") |
| sw = self._FileWriter(test_dir) |
| sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1) |
| sw.close() |
| # Sleep at least one second to make sure we get a new event file name. |
| time.sleep(1.2) |
| sw.reopen() |
| sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 2) |
| sw.close() |
| |
| # We should now have 2 events files. |
| event_paths = sorted(glob.glob(os.path.join(test_dir, "event*"))) |
| self.assertEquals(2, len(event_paths)) |
| |
| # Check the first file contents. |
| rr = summary_iterator.summary_iterator(event_paths[0]) |
| # The first event should list the file_version. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals("brain.Event:2", ev.file_version) |
| # The next event should be the START message. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(1, ev.step) |
| self.assertEquals(SessionLog.START, ev.session_log.status) |
| # We should be done. |
| self.assertRaises(StopIteration, lambda: next(rr)) |
| |
| # Check the second file contents. |
| rr = summary_iterator.summary_iterator(event_paths[1]) |
| # The first event should list the file_version. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals("brain.Event:2", ev.file_version) |
| # The next event should be the START message. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(2, ev.step) |
| self.assertEquals(SessionLog.START, ev.session_log.status) |
| # We should be done. |
| self.assertRaises(StopIteration, lambda: next(rr)) |
| |
| def testNonBlockingClose(self): |
| test_dir = self._CleanTestDir("non_blocking_close") |
| sw = self._FileWriter(test_dir) |
| # Sleep 1.2 seconds to make sure event queue is empty. |
| time.sleep(1.2) |
| time_before_close = time.time() |
| sw.close() |
| self._assertRecent(time_before_close) |
| |
| def testWithStatement(self): |
| test_dir = self._CleanTestDir("with_statement") |
| with self._FileWriter(test_dir) as sw: |
| sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1) |
| event_paths = sorted(glob.glob(os.path.join(test_dir, "event*"))) |
| self.assertEquals(1, len(event_paths)) |
| |
| # Checks that values returned from session Run() calls are added correctly to |
| # summaries. These are numpy types so we need to check they fit in the |
| # protocol buffers correctly. |
| def testAddingSummariesFromSessionRunCalls(self): |
| test_dir = self._CleanTestDir("global_step") |
| sw = self._FileWriter(test_dir) |
| with self.cached_session(): |
| i = constant_op.constant(1, dtype=dtypes.int32, shape=[]) |
| l = constant_op.constant(2, dtype=dtypes.int64, shape=[]) |
| # Test the summary can be passed serialized. |
| summ = summary_pb2.Summary( |
| value=[summary_pb2.Summary.Value( |
| tag="i", simple_value=1.0)]) |
| sw.add_summary(summ.SerializeToString(), i.eval()) |
| sw.add_summary( |
| summary_pb2.Summary( |
| value=[summary_pb2.Summary.Value( |
| tag="l", simple_value=2.0)]), |
| l.eval()) |
| sw.close() |
| |
| rr = self._EventsReader(test_dir) |
| |
| # File_version. |
| ev = next(rr) |
| self.assertTrue(ev) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals("brain.Event:2", ev.file_version) |
| |
| # Summary passed serialized. |
| ev = next(rr) |
| self.assertTrue(ev) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(1, ev.step) |
| self.assertProtoEquals(""" |
| value { tag: 'i' simple_value: 1.0 } |
| """, ev.summary) |
| |
| # Summary passed as SummaryObject. |
| ev = next(rr) |
| self.assertTrue(ev) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(2, ev.step) |
| self.assertProtoEquals(""" |
| value { tag: 'l' simple_value: 2.0 } |
| """, ev.summary) |
| |
| # We should be done. |
| self.assertRaises(StopIteration, lambda: next(rr)) |
| |
| def testPluginMetadataStrippedFromSubsequentEvents(self): |
| test_dir = self._CleanTestDir("basics") |
| sw = self._FileWriter(test_dir) |
| |
| sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1) |
| |
| # We add 2 summaries with the same tags. They both have metadata. The writer |
| # should strip the metadata from the second one. |
| value = summary_pb2.Summary.Value(tag="foo", simple_value=10.0) |
| value.metadata.plugin_data.plugin_name = "bar" |
| value.metadata.plugin_data.content = compat.as_bytes("... content ...") |
| sw.add_summary(summary_pb2.Summary(value=[value]), 10) |
| value = summary_pb2.Summary.Value(tag="foo", simple_value=10.0) |
| value.metadata.plugin_data.plugin_name = "bar" |
| value.metadata.plugin_data.content = compat.as_bytes("... content ...") |
| sw.add_summary(summary_pb2.Summary(value=[value]), 10) |
| |
| sw.close() |
| rr = self._EventsReader(test_dir) |
| |
| # The first event should list the file_version. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals("brain.Event:2", ev.file_version) |
| |
| # The next event should be the START message. |
| ev = next(rr) |
| self._assertRecent(ev.wall_time) |
| self.assertEquals(1, ev.step) |
| self.assertEquals(SessionLog.START, ev.session_log.status) |
| |
| # This is the first event with tag foo. It should contain SummaryMetadata. |
| ev = next(rr) |
| self.assertProtoEquals(""" |
| value { |
| tag: "foo" |
| simple_value: 10.0 |
| metadata { |
| plugin_data { |
| plugin_name: "bar" |
| content: "... content ..." |
| } |
| } |
| } |
| """, ev.summary) |
| |
| # This is the second event with tag foo. It should lack SummaryMetadata |
| # because the file writer should have stripped it. |
| ev = next(rr) |
| self.assertProtoEquals(""" |
| value { |
| tag: "foo" |
| simple_value: 10.0 |
| } |
| """, ev.summary) |
| |
| # We should be done. |
| self.assertRaises(StopIteration, lambda: next(rr)) |
| |
| def testFileWriterWithSuffix(self): |
| test_dir = self._CleanTestDir("test_suffix") |
| sw = self._FileWriter(test_dir, filename_suffix="_test_suffix") |
| for _ in range(10): |
| sw.add_summary( |
| summary_pb2.Summary(value=[ |
| summary_pb2.Summary.Value(tag="float_ten", simple_value=10.0) |
| ]), |
| 10) |
| sw.close() |
| sw.reopen() |
| sw.close() |
| event_filenames = glob.glob(os.path.join(test_dir, "event*")) |
| for filename in event_filenames: |
| self.assertTrue(filename.endswith("_test_suffix")) |
| |
| def testPluginAssetSerialized(self): |
| class ExamplePluginAsset(plugin_asset.PluginAsset): |
| plugin_name = "example" |
| |
| def assets(self): |
| return {"foo.txt": "foo!", "bar.txt": "bar!"} |
| |
| with ops.Graph().as_default() as g: |
| plugin_asset.get_plugin_asset(ExamplePluginAsset) |
| |
| logdir = self.get_temp_dir() |
| fw = self._FileWriter(logdir) |
| fw.add_graph(g) |
| plugin_dir = os.path.join(logdir, writer._PLUGINS_DIR, "example") |
| |
| with gfile.Open(os.path.join(plugin_dir, "foo.txt"), "r") as f: |
| content = f.read() |
| self.assertEqual(content, "foo!") |
| |
| with gfile.Open(os.path.join(plugin_dir, "bar.txt"), "r") as f: |
| content = f.read() |
| self.assertEqual(content, "bar!") |
| |
| |
| class SessionBasedFileWriterTestCase(FileWriterTestCase): |
| """Tests for FileWriter behavior when passed a Session argument.""" |
| |
| def _FileWriter(self, *args, **kwargs): |
| if "session" not in kwargs: |
| # Pass in test_session() as the session. It will be cached during this |
| # test method invocation so that any other use of test_session() with no |
| # graph should result in re-using the same underlying Session. |
| with self.cached_session() as sess: |
| kwargs["session"] = sess |
| return writer.FileWriter(*args, **kwargs) |
| return writer.FileWriter(*args, **kwargs) |
| |
| def _createTaggedSummary(self, tag): |
| summary = summary_pb2.Summary() |
| summary.value.add(tag=tag) |
| return summary |
| |
| def testSharing_withOtherSessionBasedFileWriters(self): |
| logdir = self.get_temp_dir() |
| with session.Session() as sess: |
| # Initial file writer |
| writer1 = writer.FileWriter(session=sess, logdir=logdir) |
| writer1.add_summary(self._createTaggedSummary("one"), 1) |
| writer1.flush() |
| |
| # File writer, should share file with writer1 |
| writer2 = writer.FileWriter(session=sess, logdir=logdir) |
| writer2.add_summary(self._createTaggedSummary("two"), 2) |
| writer2.flush() |
| |
| # File writer with different logdir (shouldn't be in this logdir at all) |
| writer3 = writer.FileWriter(session=sess, logdir=logdir + "-other") |
| writer3.add_summary(self._createTaggedSummary("three"), 3) |
| writer3.flush() |
| |
| # File writer in a different session (should be in separate file) |
| time.sleep(1.1) # Ensure filename has a different timestamp |
| with session.Session() as other_sess: |
| writer4 = writer.FileWriter(session=other_sess, logdir=logdir) |
| writer4.add_summary(self._createTaggedSummary("four"), 4) |
| writer4.flush() |
| |
| # One more file writer, should share file with writer1 |
| writer5 = writer.FileWriter(session=sess, logdir=logdir) |
| writer5.add_summary(self._createTaggedSummary("five"), 5) |
| writer5.flush() |
| |
| event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*")))) |
| |
| # First file should have tags "one", "two", and "five" |
| events = summary_iterator.summary_iterator(next(event_paths)) |
| self.assertEqual("brain.Event:2", next(events).file_version) |
| self.assertEqual("one", next(events).summary.value[0].tag) |
| self.assertEqual("two", next(events).summary.value[0].tag) |
| self.assertEqual("five", next(events).summary.value[0].tag) |
| self.assertRaises(StopIteration, lambda: next(events)) |
| |
| # Second file should have just "four" |
| events = summary_iterator.summary_iterator(next(event_paths)) |
| self.assertEqual("brain.Event:2", next(events).file_version) |
| self.assertEqual("four", next(events).summary.value[0].tag) |
| self.assertRaises(StopIteration, lambda: next(events)) |
| |
| # No more files |
| self.assertRaises(StopIteration, lambda: next(event_paths)) |
| |
| # Just check that the other logdir file exists to be sure we wrote it |
| self.assertTrue(glob.glob(os.path.join(logdir + "-other", "event*"))) |
| |
| def testSharing_withExplicitSummaryFileWriters(self): |
| logdir = self.get_temp_dir() |
| with session.Session() as sess: |
| # Initial file writer via FileWriter(session=?) |
| writer1 = writer.FileWriter(session=sess, logdir=logdir) |
| writer1.add_summary(self._createTaggedSummary("one"), 1) |
| writer1.flush() |
| |
| # Next one via create_file_writer(), should use same file |
| writer2 = summary_ops_v2.create_file_writer(logdir=logdir) |
| with summary_ops_v2.always_record_summaries(), writer2.as_default(): |
| summary2 = summary_ops_v2.scalar("two", 2.0, step=2) |
| sess.run(writer2.init()) |
| sess.run(summary2) |
| sess.run(writer2.flush()) |
| |
| # Next has different shared name, should be in separate file |
| time.sleep(1.1) # Ensure filename has a different timestamp |
| writer3 = summary_ops_v2.create_file_writer(logdir=logdir, name="other") |
| with summary_ops_v2.always_record_summaries(), writer3.as_default(): |
| summary3 = summary_ops_v2.scalar("three", 3.0, step=3) |
| sess.run(writer3.init()) |
| sess.run(summary3) |
| sess.run(writer3.flush()) |
| |
| # Next uses a second session, should be in separate file |
| time.sleep(1.1) # Ensure filename has a different timestamp |
| with session.Session() as other_sess: |
| writer4 = summary_ops_v2.create_file_writer(logdir=logdir) |
| with summary_ops_v2.always_record_summaries(), writer4.as_default(): |
| summary4 = summary_ops_v2.scalar("four", 4.0, step=4) |
| other_sess.run(writer4.init()) |
| other_sess.run(summary4) |
| other_sess.run(writer4.flush()) |
| |
| # Next via FileWriter(session=?) uses same second session, should be in |
| # same separate file. (This checks sharing in the other direction) |
| writer5 = writer.FileWriter(session=other_sess, logdir=logdir) |
| writer5.add_summary(self._createTaggedSummary("five"), 5) |
| writer5.flush() |
| |
| # One more via create_file_writer(), should use same file |
| writer6 = summary_ops_v2.create_file_writer(logdir=logdir) |
| with summary_ops_v2.always_record_summaries(), writer6.as_default(): |
| summary6 = summary_ops_v2.scalar("six", 6.0, step=6) |
| sess.run(writer6.init()) |
| sess.run(summary6) |
| sess.run(writer6.flush()) |
| |
| event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*")))) |
| |
| # First file should have tags "one", "two", and "six" |
| events = summary_iterator.summary_iterator(next(event_paths)) |
| self.assertEqual("brain.Event:2", next(events).file_version) |
| self.assertEqual("one", next(events).summary.value[0].tag) |
| self.assertEqual("two", next(events).summary.value[0].tag) |
| self.assertEqual("six", next(events).summary.value[0].tag) |
| self.assertRaises(StopIteration, lambda: next(events)) |
| |
| # Second file should have just "three" |
| events = summary_iterator.summary_iterator(next(event_paths)) |
| self.assertEqual("brain.Event:2", next(events).file_version) |
| self.assertEqual("three", next(events).summary.value[0].tag) |
| self.assertRaises(StopIteration, lambda: next(events)) |
| |
| # Third file should have "four" and "five" |
| events = summary_iterator.summary_iterator(next(event_paths)) |
| self.assertEqual("brain.Event:2", next(events).file_version) |
| self.assertEqual("four", next(events).summary.value[0].tag) |
| self.assertEqual("five", next(events).summary.value[0].tag) |
| self.assertRaises(StopIteration, lambda: next(events)) |
| |
| # No more files |
| self.assertRaises(StopIteration, lambda: next(event_paths)) |
| |
| |
| class FileWriterCacheTest(test.TestCase): |
| """FileWriterCache tests.""" |
| |
| def _test_dir(self, test_name): |
| """Create an empty dir to use for tests. |
| |
| Args: |
| test_name: Name of the test. |
| |
| Returns: |
| Absolute path to the test directory. |
| """ |
| test_dir = os.path.join(self.get_temp_dir(), test_name) |
| if os.path.isdir(test_dir): |
| for f in glob.glob("%s/*" % test_dir): |
| os.remove(f) |
| else: |
| os.makedirs(test_dir) |
| return test_dir |
| |
| def test_cache(self): |
| with ops.Graph().as_default(): |
| dir1 = self._test_dir("test_cache_1") |
| dir2 = self._test_dir("test_cache_2") |
| sw1 = writer_cache.FileWriterCache.get(dir1) |
| sw2 = writer_cache.FileWriterCache.get(dir2) |
| sw3 = writer_cache.FileWriterCache.get(dir1) |
| self.assertEqual(sw1, sw3) |
| self.assertFalse(sw1 == sw2) |
| sw1.close() |
| sw2.close() |
| events1 = glob.glob(os.path.join(dir1, "event*")) |
| self.assertTrue(events1) |
| events2 = glob.glob(os.path.join(dir2, "event*")) |
| self.assertTrue(events2) |
| events3 = glob.glob(os.path.join("nowriter", "event*")) |
| self.assertFalse(events3) |
| |
| def test_clear(self): |
| with ops.Graph().as_default(): |
| dir1 = self._test_dir("test_clear") |
| sw1 = writer_cache.FileWriterCache.get(dir1) |
| writer_cache.FileWriterCache.clear() |
| sw2 = writer_cache.FileWriterCache.get(dir1) |
| self.assertFalse(sw1 == sw2) |
| |
| |
| if __name__ == "__main__": |
| test.main() |