blob: a97ca34298de30a8b2622fa47ac8107996b8cb07 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.tools.docs.generate2."""
import os
import pathlib
import shutil
import types
from unittest import mock
import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow.python.platform import googletest
from tensorflow.tools.docs import generate2
# Make a mock tensorflow package that won't take too long to test.
fake_tf = types.ModuleType('FakeTensorFlow')
fake_tf.estimator = tf_estimator
fake_tf.keras = tf.keras
fake_tf.nn = tf.nn
fake_tf.summary = tf.summary
fake_tf.raw_ops = types.ModuleType('raw_ops')
fake_tf.raw_ops.Add = tf.raw_ops.Add
fake_tf.Module = tf.Module
fake_tf.__version__ = tf.__version__
for name in sorted(dir(tf.raw_ops))[:5]:
setattr(fake_tf.raw_ops, name, getattr(tf.raw_ops, name))
class Generate2Test(googletest.TestCase):
@mock.patch.object(generate2, 'tf', fake_tf)
def test_end_to_end(self):
generate2.MIN_NUM_FILES_EXPECTED = 1
output_dir = pathlib.Path(googletest.GetTempDir())/'output'
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
generate2.build_docs(
output_dir=output_dir,
code_url_prefix='',
search_hints=True,
)
raw_ops_page = (output_dir/'tf/raw_ops.md').read_text()
self.assertIn('/tf/raw_ops/Add.md', raw_ops_page)
if __name__ == '__main__':
googletest.main()