| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Generate tensorflow graphs for testing tfcompile.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import argparse |
| import os |
| import sys |
| |
| from tensorflow.core.protobuf import saver_pb2 |
| from tensorflow.python.client import session |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import function |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import app |
| from tensorflow.python.training import saver as saver_lib |
| |
| FLAGS = None |
| |
| |
| def tfadd(_): |
| x = constant_op.constant([1], name='x_const') |
| y = constant_op.constant([2], name='y_const') |
| math_ops.add(x, y, name='x_y_sum') |
| |
| |
| def tfadd_with_ckpt(out_dir): |
| x = array_ops.placeholder(dtypes.int32, name='x_hold') |
| y = variables.Variable(constant_op.constant([0]), name='y_saved') |
| math_ops.add(x, y, name='x_y_sum') |
| |
| init_op = variables.initialize_all_variables() |
| saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) |
| with session.Session() as sess: |
| sess.run(init_op) |
| sess.run(y.assign(y + 42)) |
| # Without the checkpoint, the variable won't be set to 42. |
| ckpt = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt.ckpt') |
| saver.save(sess, ckpt) |
| |
| |
| def tfadd_with_ckpt_saver(out_dir): |
| x = array_ops.placeholder(dtypes.int32, name='x_hold') |
| y = variables.Variable(constant_op.constant([0]), name='y_saved') |
| math_ops.add(x, y, name='x_y_sum') |
| |
| init_op = variables.initialize_all_variables() |
| saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1) |
| with session.Session() as sess: |
| sess.run(init_op) |
| sess.run(y.assign(y + 42)) |
| # Without the checkpoint, the variable won't be set to 42. |
| ckpt_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.ckpt') |
| saver.save(sess, ckpt_file) |
| # Without the SaverDef, the restore op won't be named correctly. |
| saver_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.saver') |
| with open(saver_file, 'wb') as f: |
| f.write(saver.as_saver_def().SerializeToString()) |
| |
| |
| def tfassert_eq(_): |
| x = array_ops.placeholder(dtypes.int32, name='x_hold') |
| y = array_ops.placeholder(dtypes.int32, name='y_hold') |
| control_flow_ops.Assert( |
| math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') |
| math_ops.add(x, math_ops.negative(y), name='x_y_diff') |
| |
| |
| def tfcond(_): |
| p = array_ops.placeholder(dtypes.bool, name='p_hold') |
| x = array_ops.placeholder(dtypes.int32, name='x_hold') |
| y = array_ops.placeholder(dtypes.int32, name='y_hold') |
| z = control_flow_ops.cond(p, lambda: x, lambda: y) |
| array_ops.identity(z, name='result') |
| |
| |
| def tfgather(_): |
| params = array_ops.placeholder(dtypes.float32, name='params') |
| indices = array_ops.placeholder(dtypes.int32, name='indices') |
| array_ops.gather(params, indices, name='gather_output') |
| |
| |
| def tfmatmul(_): |
| x = array_ops.placeholder(dtypes.float32, name='x_hold') |
| y = array_ops.placeholder(dtypes.float32, name='y_hold') |
| math_ops.matmul(x, y, name='x_y_prod') |
| |
| |
| def tfmatmulandadd(_): |
| # This tests multiple outputs. |
| x = array_ops.placeholder(dtypes.float32, name='x_hold') |
| y = array_ops.placeholder(dtypes.float32, name='y_hold') |
| math_ops.matmul(x, y, name='x_y_prod') |
| math_ops.add(x, y, name='x_y_sum') |
| |
| |
| def tffunction(_): |
| |
| @function.Defun(dtypes.int32, dtypes.int32) |
| def test_func(a, b): |
| return a + b |
| |
| x = constant_op.constant([1], name='x_const') |
| y = constant_op.constant([2], name='y_const') |
| test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg |
| |
| |
| def tfsplits(_): |
| """A more complex graph, including splits.""" |
| x = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='x') |
| y = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='y') |
| for _ in range(3): |
| x0, x1 = array_ops.split(x, 2, 0) |
| y0, y1 = array_ops.split(y, 2, 0) |
| x0 += 1 |
| y0 += 1 |
| z = math_ops.matmul(x, y, name='x_y_prod') |
| a = array_ops.concat([x0, y1], axis=0, name='concat_x0_y1') |
| b = array_ops.concat([y0, x1], axis=0, name='concat_y0_x1') |
| x = math_ops.matmul(a, b, name='a_b') |
| y = math_ops.add(x, z) |
| array_ops.identity(y, name='result') |
| |
| |
| def write_graph(build_graph, out_dir): |
| """Build a graph using build_graph and write it out.""" |
| g = ops.Graph() |
| with g.as_default(): |
| build_graph(out_dir) |
| filename = os.path.join(out_dir, 'test_graph_%s.pb' % build_graph.__name__) |
| with open(filename, 'wb') as f: |
| f.write(g.as_graph_def().SerializeToString()) |
| |
| |
| def main(_): |
| write_graph(tfadd, FLAGS.out_dir) |
| write_graph(tfadd_with_ckpt, FLAGS.out_dir) |
| write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) |
| write_graph(tfassert_eq, FLAGS.out_dir) |
| write_graph(tfcond, FLAGS.out_dir) |
| write_graph(tffunction, FLAGS.out_dir) |
| write_graph(tfgather, FLAGS.out_dir) |
| write_graph(tfmatmul, FLAGS.out_dir) |
| write_graph(tfmatmulandadd, FLAGS.out_dir) |
| write_graph(tfsplits, FLAGS.out_dir) |
| |
| |
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.register('type', 'bool', lambda v: v.lower() == 'true') |
| parser.add_argument( |
| '--out_dir', |
| type=str, |
| default='', |
| help='Output directory for graphs, checkpoints and savers.') |
| FLAGS, unparsed = parser.parse_known_args() |
| app.run(main=main, argv=[sys.argv[0]] + unparsed) |