| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for tensorflow.python.client.session.Session.""" |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import random |
| import os |
| import sys |
| import threading |
| import time |
| import warnings |
| |
| import numpy as np |
| import six |
| from six.moves import xrange # pylint: disable=redefined-builtin |
| |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.core.lib.core import error_codes_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python.client import session |
| from tensorflow.python.framework import common_shapes |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import device as framework_device_lib |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import function |
| from tensorflow.python.framework import importer |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.framework import versions |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import data_flow_ops |
| from tensorflow.python.ops import gen_control_flow_ops |
| # Import gradients to resolve circular imports |
| from tensorflow.python.ops import gradients # pylint: disable=unused-import |
| from tensorflow.python.ops import gradients_impl |
| from tensorflow.python.ops import math_ops |
| # Import resource_variable_ops for the variables-to-tensor implicit conversion. |
| from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import googletest |
| from tensorflow.python.training import server_lib |
| from tensorflow.python.util import compat |
| |
| try: |
| import attr # pylint:disable=g-import-not-at-top |
| except ImportError: |
| attr = None |
| |
| |
| # NOTE(mrry): Dummy shape registration for ops used in the tests, since they |
| # don't have C++ op registrations on which to attach C++ shape fns. |
| ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) |
| |
| |
| class SessionTest(test_util.TensorFlowTestCase): |
| |
| def setUp(self): |
| super(SessionTest, self).setUp() |
| warnings.simplefilter('always') |
| |
| def testUseExistingGraph(self): |
| with ops.Graph().as_default() as g, ops.device('/cpu:0'): |
| a = constant_op.constant(6.0, shape=[1, 1]) |
| b = constant_op.constant(7.0, shape=[1, 1]) |
| c = math_ops.matmul(a, b, name='matmul') |
| with session.Session(graph=g): |
| result = c.eval() |
| self.assertAllEqual(result, [[42.0]]) |
| |
| def testUseDefaultGraph(self): |
| with ops.Graph().as_default(), ops.device('/cpu:0'): |
| a = constant_op.constant(6.0, shape=[1, 1]) |
| b = constant_op.constant(7.0, shape=[1, 1]) |
| c = math_ops.matmul(a, b, name='matmul') |
| with session.Session(): |
| result = c.eval() |
| self.assertAllEqual(result, [[42.0]]) |
| |
| def testCreate(self): |
| with session.Session(): |
| inp = constant_op.constant(10.0, shape=[2, 3], name='W1') |
| copy = array_ops.identity(inp) |
| # Test with feed. |
| # TODO(mrry): Investigate why order='F' didn't work. |
| arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C') |
| copy_val = copy.eval({'W1:0': arr}) |
| self.assertAllEqual(arr, copy_val) |
| # Test without feed. |
| copy_val = copy.eval() |
| self.assertAllEqual( |
| np.asarray( |
| [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32), |
| copy_val) |
| |
| def testManyCPUs(self): |
| with session.Session( |
| config=config_pb2.ConfigProto(device_count={ |
| 'CPU': 2, 'GPU': 0 |
| })) as sess: |
| inp = constant_op.constant(10.0, name='W1') |
| self.assertAllEqual(inp.eval(), 10.0) |
| |
| devices = sess.list_devices() |
| self.assertEqual(2, len(devices)) |
| for device in devices: |
| self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string( |
| device.name).device_type) |
| |
| def testPerSessionThreads(self): |
| with session.Session( |
| config=config_pb2.ConfigProto(use_per_session_threads=True)): |
| inp = constant_op.constant(10.0, name='W1') |
| self.assertAllEqual(inp.eval(), 10.0) |
| |
| def testSessionInterOpThreadPool(self): |
| config = config_pb2.ConfigProto() |
| pool = config.session_inter_op_thread_pool.add() |
| with session.Session(config=config) as s: |
| inp = constant_op.constant(10.0, name='W1') |
| results = s.run([inp]) |
| self.assertAllEqual([10.0], results) |
| |
| pool = config.session_inter_op_thread_pool.add() |
| pool.num_threads = 1 |
| with session.Session(config=config) as s: |
| inp = constant_op.constant(20.0, name='W2') |
| results = s.run([inp]) |
| self.assertAllEqual([20.0], results) |
| |
| pool = config.session_inter_op_thread_pool.add() |
| pool.num_threads = 1 |
| pool.global_name = 't1' |
| run_options = config_pb2.RunOptions() |
| run_options.inter_op_thread_pool = ( |
| len(config.session_inter_op_thread_pool) - 1) |
| with session.Session(config=config) as s: |
| inp = constant_op.constant(30.0, name='W2') |
| results = s.run([inp], options=run_options) |
| self.assertAllEqual([30.0], results) |
| |
| def testErrorsReported(self): |
| with session.Session() as s: |
| constant_op.constant(10.0, name='W1') |
| with self.assertRaises(ValueError): |
| s.run('foo:0') |
| |
| def testErrorPayload(self): |
| with session.Session(): |
| a = array_ops.placeholder(dtypes.float32) |
| with self.assertRaisesOpError(lambda e: e.op == a.op): |
| a.eval() |
| |
| def testErrorCodeWithNoNodeDef(self): |
| with session.Session() as s: |
| a = array_ops.placeholder(dtypes.float32, shape=[]) |
| b = array_ops.placeholder(dtypes.float32, shape=[]) |
| r1 = math_ops.add(a, b) |
| |
| def exc_predicate(e): |
| return (e.op is None and e.node_def is None and |
| e.error_code == error_codes_pb2.INVALID_ARGUMENT) |
| |
| with self.assertRaisesOpError(exc_predicate): |
| # Run with a bogus handle. |
| s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) |
| |
| def testErrorBasedOn(self): |
| with session.Session() as sess: |
| a = constant_op.constant(0.0, shape=[2, 3]) |
| # NOTE(mrry): The original_op is nonsense, but used here to test that the |
| # errors are reported correctly. |
| with sess.graph._original_op(a.op): |
| b = array_ops.identity(a, name='id') |
| with sess.graph._original_op(b.op): |
| c = array_ops.placeholder(dtypes.float32) |
| |
| def exc_predicate(e): |
| return (e.op == c.op and e.op._original_op == b.op and |
| e.op._original_op._original_op == a.op) |
| |
| with self.assertRaisesOpError(exc_predicate): |
| c.eval() |
| |
| def testFetchNone(self): |
| with session.Session() as s: |
| a = constant_op.constant(1.0) |
| with self.assertRaises(TypeError): |
| s.run(None) |
| with self.assertRaises(TypeError): |
| s.run([None]) |
| with self.assertRaises(TypeError): |
| s.run({'b': None}) |
| with self.assertRaises(TypeError): |
| s.run({'a': a, 'b': None}) |
| |
| def testFetchSingleton(self): |
| with session.Session() as sess: |
| a = constant_op.constant(42.0) |
| res = sess.run(a) |
| self.assertEqual(42.0, res) |
| res = sess.run(a.op) # An op, not a tensor. |
| self.assertEqual(None, res) |
| tensor_runner = sess.make_callable(a) |
| res = tensor_runner() |
| self.assertEqual(42.0, res) |
| op_runner = sess.make_callable(a.op) |
| res = op_runner() |
| self.assertEqual(None, res) |
| |
| def testFetchSingletonByName(self): |
| with session.Session() as sess: |
| a = constant_op.constant(42.0) |
| res = sess.run(a.name) |
| self.assertEqual(42.0, res) |
| res = sess.run(a.op) # An op, not a tensor. |
| self.assertEqual(None, res) |
| |
| def testFetchList(self): |
| with session.Session() as sess: |
| a = constant_op.constant(42.0) |
| b = control_flow_ops.no_op() # An op, not a tensor. |
| c = constant_op.constant(44.0) |
| v = variables.Variable([54.0]) |
| assign = v.assign([63.0]) |
| res = sess.run([a, b, c, a.name, assign.op]) |
| self.assertTrue(isinstance(res, list)) |
| self.assertEqual([42.0, None, 44.0, 42.0, None], res) |
| list_runner = sess.make_callable([a, b, c, a.name, assign.op]) |
| res = list_runner() |
| self.assertTrue(isinstance(res, list)) |
| self.assertEqual([42.0, None, 44.0, 42.0, None], res) |
| |
| def testFetchTuple(self): |
| with session.Session() as sess: |
| a = constant_op.constant(42.0) |
| b = control_flow_ops.no_op() # An op, not a tensor. |
| c = constant_op.constant(44.0) |
| res = sess.run((a, b, c, a.name)) |
| self.assertTrue(isinstance(res, tuple)) |
| self.assertEqual((42.0, None, 44.0, 42.0), res) |
| tuple_runner = sess.make_callable((a, b, c, a.name)) |
| res = tuple_runner() |
| self.assertTrue(isinstance(res, tuple)) |
| self.assertEqual((42.0, None, 44.0, 42.0), res) |
| |
| def testFetchNamedTuple(self): |
| # pylint: disable=invalid-name |
| ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) |
| # pylint: enable=invalid-name |
| with session.Session() as sess: |
| a = constant_op.constant(42.0) |
| b = control_flow_ops.no_op() # An op, not a tensor. |
| c = constant_op.constant(44.0) |
| res = sess.run(ABC(a, b, c)) |
| self.assertTrue(isinstance(res, ABC)) |
| self.assertEqual(42.0, res.a) |
| self.assertEqual(None, res.b) |
| self.assertEqual(44.0, res.c) |
| namedtuple_runner = sess.make_callable(ABC(a, b, c)) |
| res = namedtuple_runner() |
| self.assertTrue(isinstance(res, ABC)) |
| self.assertEqual(42.0, res.a) |
| self.assertEqual(None, res.b) |
| self.assertEqual(44.0, res.c) |
| |
| def testFetchDict(self): |
| with session.Session() as sess: |
| a = constant_op.constant(42.0) |
| b = control_flow_ops.no_op() # An op, not a tensor. |
| c = constant_op.constant(44.0) |
| res = sess.run({'a': a, 'b': b, 'c': c}) |
| self.assertTrue(isinstance(res, dict)) |
| self.assertEqual(42.0, res['a']) |
| self.assertEqual(None, res['b']) |
| self.assertEqual(44.0, res['c']) |
| |
| def testFetchOrderedDict(self): |
| with session.Session() as sess: |
| a = constant_op.constant(42.0) |
| b = control_flow_ops.no_op() # An op, not a tensor. |
| c = constant_op.constant(44.0) |
| res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)])) |
| self.assertTrue(isinstance(res, collections.OrderedDict)) |
| self.assertEqual([3, 2, 1], list(res.keys())) |
| self.assertEqual(42.0, res[3]) |
| self.assertEqual(None, res[2]) |
| self.assertEqual(44.0, res[1]) |
| |
| def testFetchAttrs(self): |
| if attr is None: |
| self.skipTest('attr module is unavailable.') |
| |
| @attr.s |
| class SampleAttr(object): |
| field1 = attr.ib() |
| field2 = attr.ib() |
| |
| val1 = np.array([1.2, 3.4, 5.6]) |
| val2 = np.array([[1, 2], [4, 3]]) |
| val3 = np.array([10, 20, 30]) |
| |
| t1 = constant_op.constant(val1) |
| t2 = constant_op.constant(val2) |
| |
| sample = SampleAttr(t1, t2) |
| with session.Session() as sess: |
| result = sess.run(sample) |
| self.assertIsInstance(result, SampleAttr) |
| self.assertAllEqual(val1, result.field1) |
| self.assertAllEqual(val2, result.field2) |
| |
| result = sess.run(sample, feed_dict={sample.field1: val3}) |
| self.assertIsInstance(result, SampleAttr) |
| self.assertAllEqual(val3, result.field1) |
| self.assertAllEqual(val2, result.field2) |
| |
| def testFetchNestedAttrs(self): |
| if attr is None: |
| self.skipTest('attr module is unavailable.') |
| |
| @attr.s |
| class SampleAttr(object): |
| field0 = attr.ib() |
| field1 = attr.ib() |
| |
| v1 = 10 |
| v2 = 20 |
| v3 = np.float32(1.2) |
| v4 = np.float32(3.4) |
| v5 = np.float64(100.001) |
| v6 = np.float64(-23.451) |
| arr1 = np.array([1.2, 6.7, 3.4]) |
| arr2 = np.array([7, 11, 3]) |
| sample = SampleAttr( |
| SampleAttr( |
| SampleAttr(constant_op.constant(v1), constant_op.constant(v2)), |
| SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))), |
| {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)), |
| 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]}) |
| |
| with session.Session() as sess: |
| result = sess.run(sample) |
| self.assertIsInstance(result, SampleAttr) |
| self.assertIsInstance(result.field0, SampleAttr) |
| self.assertIsInstance(result.field0.field0, SampleAttr) |
| self.assertIsInstance(result.field0.field1, SampleAttr) |
| self.assertIsInstance(result.field0.field1.field0, np.ndarray) |
| self.assertAllEqual(arr1, result.field0.field1.field0) |
| self.assertIsInstance(result.field0.field1.field1, np.ndarray) |
| self.assertAllEqual(arr2, result.field0.field1.field1) |
| self.assertIsInstance(result.field1, dict) |
| self.assertIn('A', result.field1) |
| self.assertIn('B', result.field1) |
| self.assertIsInstance(result.field1['A'], SampleAttr) |
| self.assertAllEqual( |
| [v3, v4], |
| [result.field1['A'].field0, result.field1['A'].field1]) |
| self.assertIsInstance(result.field1['B'], list) |
| self.assertEqual(1, len(result.field1['B'])) |
| self.assertIsInstance(result.field1['B'][0], SampleAttr) |
| self.assertAllEqual( |
| [v5, v6], |
| [result.field1['B'][0].field0, result.field1['B'][0].field1]) |
| |
| def testFetchNestingEmptyOneLevel(self): |
| with session.Session() as sess: |
| a_val = 11.0 |
| a = constant_op.constant(a_val) |
| |
| res = sess.run([[], tuple(), {}]) |
| self.assertTrue(isinstance(res, list)) |
| self.assertEquals(3, len(res)) |
| self.assertTrue(isinstance(res[0], list)) |
| self.assertEqual(0, len(res[0])) |
| self.assertTrue(isinstance(res[1], tuple)) |
| self.assertEqual(0, len(res[1])) |
| self.assertTrue(isinstance(res[2], dict)) |
| self.assertEqual(0, len(res[2])) |
| |
| res = sess.run([[], tuple(), {}, a]) |
| self.assertTrue(isinstance(res, list)) |
| self.assertEquals(4, len(res)) |
| self.assertTrue(isinstance(res[0], list)) |
| self.assertEqual(0, len(res[0])) |
| self.assertTrue(isinstance(res[1], tuple)) |
| self.assertEqual(0, len(res[1])) |
| self.assertTrue(isinstance(res[2], dict)) |
| self.assertEqual(0, len(res[2])) |
| self.assertEqual(a_val, res[3]) |
| |
| def testFetchNestingOneLevel(self): |
| with session.Session() as sess: |
| # pylint: disable=invalid-name |
| ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) |
| DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g']) |
| # pylint: enable=invalid-name |
| a_val = 42.0 |
| b_val = None |
| c_val = 44.0 |
| a = constant_op.constant(a_val) |
| b = control_flow_ops.no_op() # An op, not a tensor. |
| c = constant_op.constant(c_val) |
| # List of lists, tuples, namedtuple, and dict |
| res = sess.run([[a, b, c], (a, b, c), |
| ABC(a=a, b=b, c=c), { |
| 'a': a.name, |
| 'c': c, |
| 'b': b |
| }]) |
| self.assertTrue(isinstance(res, list)) |
| self.assertEqual(4, len(res)) |
| self.assertTrue(isinstance(res[0], list)) |
| self.assertEqual(3, len(res[0])) |
| self.assertEqual(a_val, res[0][0]) |
| self.assertEqual(b_val, res[0][1]) |
| self.assertEqual(c_val, res[0][2]) |
| self.assertTrue(isinstance(res[1], tuple)) |
| self.assertEqual(3, len(res[1])) |
| self.assertEqual(a_val, res[1][0]) |
| self.assertEqual(b_val, res[1][1]) |
| self.assertEqual(c_val, res[1][2]) |
| self.assertTrue(isinstance(res[2], ABC)) |
| self.assertEqual(a_val, res[2].a) |
| self.assertEqual(b_val, res[2].b) |
| self.assertEqual(c_val, res[2].c) |
| self.assertTrue(isinstance(res[3], dict)) |
| self.assertEqual(3, len(res[3])) |
| self.assertEqual(a_val, res[3]['a']) |
| self.assertEqual(b_val, res[3]['b']) |
| self.assertEqual(c_val, res[3]['c']) |
| # Tuple of lists, tuples, namedtuple, and dict |
| res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), { |
| 'a': a, |
| 'c': c, |
| 'b': b |
| })) |
| self.assertTrue(isinstance(res, tuple)) |
| self.assertEqual(4, len(res)) |
| self.assertTrue(isinstance(res[0], list)) |
| self.assertEqual(3, len(res[0])) |
| self.assertEqual(a_val, res[0][0]) |
| self.assertEqual(b_val, res[0][1]) |
| self.assertEqual(c_val, res[0][2]) |
| self.assertTrue(isinstance(res[1], tuple)) |
| self.assertEqual(3, len(res[1])) |
| self.assertEqual(a_val, res[1][0]) |
| self.assertEqual(b_val, res[1][1]) |
| self.assertEqual(c_val, res[1][2]) |
| self.assertTrue(isinstance(res[2], ABC)) |
| self.assertEqual(a_val, res[2].a) |
| self.assertEqual(b_val, res[2].b) |
| self.assertEqual(c_val, res[2].c) |
| self.assertTrue(isinstance(res[3], dict)) |
| self.assertEqual(3, len(res[3])) |
| self.assertEqual(a_val, res[3]['a']) |
| self.assertEqual(b_val, res[3]['b']) |
| self.assertEqual(c_val, res[3]['c']) |
| # Namedtuple of lists, tuples, namedtuples, and dict |
| res = sess.run( |
| DEFG( |
| d=[a, b, c], |
| e=(a, b, c), |
| f=ABC(a=a.name, b=b, c=c), |
| g={ |
| 'a': a, |
| 'c': c, |
| 'b': b |
| })) |
| self.assertTrue(isinstance(res, DEFG)) |
| self.assertTrue(isinstance(res.d, list)) |
| self.assertEqual(3, len(res.d)) |
| self.assertEqual(a_val, res.d[0]) |
| self.assertEqual(b_val, res.d[1]) |
| self.assertEqual(c_val, res.d[2]) |
| self.assertTrue(isinstance(res.e, tuple)) |
| self.assertEqual(3, len(res.e)) |
| self.assertEqual(a_val, res.e[0]) |
| self.assertEqual(b_val, res.e[1]) |
| self.assertEqual(c_val, res.e[2]) |
| self.assertTrue(isinstance(res.f, ABC)) |
| self.assertEqual(a_val, res.f.a) |
| self.assertEqual(b_val, res.f.b) |
| self.assertEqual(c_val, res.f.c) |
| self.assertTrue(isinstance(res.g, dict)) |
| self.assertEqual(3, len(res.g)) |
| self.assertEqual(a_val, res.g['a']) |
| self.assertEqual(b_val, res.g['b']) |
| self.assertEqual(c_val, res.g['c']) |
| # Dict of lists, tuples, namedtuples, and dict |
| res = sess.run({ |
| 'd': [a, b, c], |
| 'e': (a, b, c), |
| 'f': ABC(a=a, b=b, c=c), |
| 'g': { |
| 'a': a.name, |
| 'c': c, |
| 'b': b |
| } |
| }) |
| self.assertTrue(isinstance(res, dict)) |
| self.assertEqual(4, len(res)) |
| self.assertTrue(isinstance(res['d'], list)) |
| self.assertEqual(3, len(res['d'])) |
| self.assertEqual(a_val, res['d'][0]) |
| self.assertEqual(b_val, res['d'][1]) |
| self.assertEqual(c_val, res['d'][2]) |
| self.assertTrue(isinstance(res['e'], tuple)) |
| self.assertEqual(3, len(res['e'])) |
| self.assertEqual(a_val, res['e'][0]) |
| self.assertEqual(b_val, res['e'][1]) |
| self.assertEqual(c_val, res['e'][2]) |
| self.assertTrue(isinstance(res['f'], ABC)) |
| self.assertEqual(a_val, res['f'].a) |
| self.assertEqual(b_val, res['f'].b) |
| self.assertEqual(c_val, res['f'].c) |
| self.assertTrue(isinstance(res['g'], dict)) |
| self.assertEqual(3, len(res['g'])) |
| self.assertEqual(a_val, res['g']['a']) |
| self.assertEqual(b_val, res['g']['b']) |
| self.assertEqual(c_val, res['g']['c']) |
| |
| def testFetchTensorObject(self): |
| with session.Session() as s: |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[2, 3]) |
| c = math_ops.matmul(a, b) |
| results_with_list = s.run([c]) |
| self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0]) |
| results_with_single = s.run(c) |
| self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single) |
| results_with_get = c.eval() |
| self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get) |
| a_val, b_val = s.run([a, b]) # Test multiple fetches. |
| self.assertAllEqual([[1.0, 1.0]], a_val) |
| self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val) |
| results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]}) |
| self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0]) |
| self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], |
| results_with_dict['b']) |
| self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0]) |
| self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1]) |
| |
| # Test nested structures |
| results_with_nested_list = s.run([[[a, b], b], a, [a, b]]) |
| self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0]) |
| self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], |
| results_with_nested_list[0][0][1]) |
| self.assertAllEqual(results_with_nested_list[0][0][0], |
| results_with_nested_list[1]) |
| self.assertAllEqual(results_with_nested_list[1], |
| results_with_nested_list[2][0]) |
| self.assertAllEqual(results_with_nested_list[0][0][1], |
| results_with_nested_list[0][1]) |
| self.assertAllEqual(results_with_nested_list[0][1], |
| results_with_nested_list[2][1]) |
| |
| def testFetchScalar(self): |
| with session.Session() as s: |
| for scalar in np.int32, np.int64, np.float16, np.float32, np.float64: |
| x = scalar(7) |
| y = scalar(8) |
| tf_x = constant_op.constant(x, shape=[]) |
| tf_y = constant_op.constant(y) |
| tf_xy = math_ops.add(tf_x, tf_y) |
| # Single fetch |
| xy = s.run(tf_xy) |
| self.assertEqual(scalar, type(xy)) |
| self.assertEqual(x + y, xy) |
| # List fetch |
| xy, = s.run([tf_xy]) |
| self.assertEqual(scalar, type(xy)) |
| self.assertEqual(x + y, xy) |
| # Dict fetch |
| xy = s.run({'xy': tf_xy})['xy'] |
| self.assertEqual(scalar, type(xy)) |
| self.assertEqual(x + y, xy) |
| # Nested list fetch |
| xy = s.run([[[tf_xy]], tf_xy, [tf_xy]]) |
| self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]]) |
| self.assertEqual(scalar, type(xy[0][0][0])) |
| self.assertEqual(scalar, type(xy[1])) |
| self.assertEqual(scalar, type(xy[2][0])) |
| |
| def testFetchOperationObject(self): |
| with session.Session() as s: |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| v = variables.Variable(a, name='testFetchOperationObject_v') |
| s.run(v.initializer) |
| v_val = s.run(v) |
| self.assertAllEqual([[1.0, 1.0]], v_val) |
| |
| def testFetchSparseTensor(self): |
| with session.Session() as s: |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| shape = np.array([7, 9, 2]).astype(np.int64) |
| sp = sparse_tensor.SparseTensor( |
| constant_op.constant(indices), constant_op.constant(values), |
| constant_op.constant(shape)) |
| # Single fetch, use as tuple |
| sp_out = s.run(sp) |
| indices_out, values_out, shape_out = sp_out |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Single fetch, use as SparseTensorValue |
| sp_out = s.run(sp) |
| self.assertAllEqual(sp_out.indices, indices) |
| self.assertAllEqual(sp_out.values, values) |
| self.assertAllEqual(sp_out.dense_shape, shape) |
| # Tuple fetch, use as tuple |
| indices_out, values_out, shape_out = s.run(sp) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # List fetch, use as tuple |
| (indices_out, values_out, shape_out), = s.run([sp]) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # List fetch, use as SparseTensorValue |
| sp_out, = s.run([sp]) |
| self.assertAllEqual(sp_out.indices, indices) |
| self.assertAllEqual(sp_out.values, values) |
| self.assertAllEqual(sp_out.dense_shape, shape) |
| # Dict fetch (single value), use as tuple |
| indices_out, values_out, shape_out = s.run({'sp': sp})['sp'] |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Dict fetch (list value), use as tuple |
| (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp'] |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Dict fetch, use as SparseTensorValue |
| sp_out = s.run({'sp': sp})['sp'] |
| self.assertAllEqual(sp_out.indices, indices) |
| self.assertAllEqual(sp_out.values, values) |
| self.assertAllEqual(sp_out.dense_shape, shape) |
| # Nested list fetch use as tuple |
| sp_out = s.run([[[sp]], sp]) |
| indices_out, values_out, shape_out = sp_out[0][0][0] |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| indices_out, values_out, shape_out = sp_out[1] |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Nested list fetch, use as SparseTensorValue |
| sp_out = s.run([[[sp]], sp]) |
| self.assertAllEqual(sp_out[0][0][0].indices, indices) |
| self.assertAllEqual(sp_out[0][0][0].values, values) |
| self.assertAllEqual(sp_out[0][0][0].dense_shape, shape) |
| self.assertAllEqual(sp_out[1].indices, indices) |
| self.assertAllEqual(sp_out[1].values, values) |
| self.assertAllEqual(sp_out[1].dense_shape, shape) |
| |
| def testFeedSparseTensor(self): |
| with session.Session() as s: |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| shape = np.array([7, 9, 2]).astype(np.int64) |
| sp = sparse_tensor.SparseTensor( |
| array_ops.placeholder(dtype=np.int64, shape=(2, 3)), |
| array_ops.placeholder(dtype=np.float32, shape=(2,)), |
| array_ops.placeholder(dtype=np.int64, shape=(3,)), |
| ) |
| sp_indices = array_ops.identity(sp.indices) |
| sp_values = array_ops.identity(sp.values) |
| sp_shape = array_ops.identity(sp.dense_shape) |
| sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) |
| # Feed with tuple |
| indices_out, values_out, shape_out = s.run( |
| [sp_indices, sp_values, sp_shape], { |
| sp: (indices, values, shape) |
| }) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Feed with tuple, fetch sp directly |
| sp_out = s.run(sp, {sp: (indices, values, shape)}) |
| self.assertAllEqual(sp_out.indices, indices) |
| self.assertAllEqual(sp_out.values, values) |
| self.assertAllEqual(sp_out.dense_shape, shape) |
| # Feed with SparseTensorValue |
| indices_out, values_out, shape_out = s.run( |
| [sp_indices, sp_values, sp_shape], { |
| sp: sparse_tensor.SparseTensorValue(indices, values, shape) |
| }) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Feed with SparseTensorValue, fetch SparseTensorValue |
| sp2_out = s.run(sp2, { |
| sp: sparse_tensor.SparseTensorValue(indices, values, shape) |
| }) |
| self.assertAllEqual(sp2_out.indices, indices) |
| self.assertAllEqual(sp2_out.values, values) |
| self.assertAllEqual(sp2_out.dense_shape, shape) |
| # Feed SparseTensorValue and fetch sp directly. |
| sp_out = s.run(sp, { |
| sp: sparse_tensor.SparseTensorValue(indices, values, shape) |
| }) |
| self.assertAllEqual(sp_out.indices, indices) |
| self.assertAllEqual(sp_out.values, values) |
| self.assertAllEqual(sp_out.dense_shape, shape) |
| |
| def testFeedSparsePlaceholder(self): |
| with session.Session() as s: |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| shape = np.array([7, 9, 2]).astype(np.int64) |
| sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1') |
| sp_indices = array_ops.identity(sp.indices) |
| sp_values = array_ops.identity(sp.values) |
| sp_shape = array_ops.identity(sp.dense_shape) |
| sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) |
| # Feed with tuple |
| indices_out, values_out, shape_out = s.run( |
| [sp_indices, sp_values, sp_shape], { |
| sp: (indices, values, shape) |
| }) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Feed with SparseTensorValue |
| indices_out, values_out, shape_out = s.run( |
| [sp_indices, sp_values, sp_shape], { |
| sp: sparse_tensor.SparseTensorValue(indices, values, shape) |
| }) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Feed with SparseTensorValue, fetch SparseTensorValue |
| sp2_out = s.run(sp2, { |
| sp: sparse_tensor.SparseTensorValue(indices, values, shape) |
| }) |
| self.assertAllEqual(sp2_out.indices, indices) |
| self.assertAllEqual(sp2_out.values, values) |
| self.assertAllEqual(sp2_out.dense_shape, shape) |
| |
| def testFeedSparsePlaceholderPartialShape(self): |
| with session.Session() as s: |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| shape = np.array([7, 9, 2]).astype(np.int64) |
| sp = array_ops.sparse_placeholder( |
| shape=[None, 9, 2], dtype=np.float32, name='placeholder1') |
| sp_indices = array_ops.identity(sp.indices) |
| sp_values = array_ops.identity(sp.values) |
| sp_shape = array_ops.identity(sp.dense_shape) |
| sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) |
| # Feed with tuple |
| indices_out, values_out, shape_out = s.run( |
| [sp_indices, sp_values, sp_shape], { |
| sp: (indices, values, shape) |
| }) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Feed with SparseTensorValue |
| indices_out, values_out, shape_out = s.run( |
| [sp_indices, sp_values, sp_shape], { |
| sp: sparse_tensor.SparseTensorValue(indices, values, shape) |
| }) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| # Feed with SparseTensorValue, fetch SparseTensorValue |
| sp2_out = s.run(sp2, { |
| sp: sparse_tensor.SparseTensorValue(indices, values, shape) |
| }) |
| self.assertAllEqual(sp2_out.indices, indices) |
| self.assertAllEqual(sp2_out.values, values) |
| self.assertAllEqual(sp2_out.dense_shape, shape) |
| |
| def testFeedSparsePlaceholderConstantShape(self): |
| with session.Session() as s: |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| shape = np.array([7, 9, 2]).astype(np.int64) |
| sp = array_ops.sparse_placeholder( |
| dtype=np.float32, shape=shape, name='placeholder1') |
| self.assertAllEqual(sp.dense_shape.eval(session=s), shape) |
| self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape) |
| sp_indices = array_ops.identity(sp.indices) |
| sp_values = array_ops.identity(sp.values) |
| sp_shape = array_ops.identity(sp.dense_shape) |
| # Feed with tuple |
| indices_out, values_out, shape_out = s.run( |
| [sp_indices, sp_values, sp_shape], { |
| sp: (indices, values) |
| }) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(shape_out, shape) |
| |
| def testFetchIndexedSlices(self): |
| with session.Session() as s: |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| dense_shape = np.array([7, 9, 2]).astype(np.int64) |
| ind = ops.IndexedSlices( |
| constant_op.constant(values), constant_op.constant(indices), |
| constant_op.constant(dense_shape)) |
| # Single fetch, use as tuple |
| ind_out = s.run(ind) |
| values_out, indices_out, dense_shape_out = ind_out |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(dense_shape_out, dense_shape) |
| # Single fetch, use as IndexedSlicesValue |
| ind_out = s.run(ind) |
| self.assertAllEqual(ind_out.values, values) |
| self.assertAllEqual(ind_out.indices, indices) |
| self.assertAllEqual(ind_out.dense_shape, dense_shape) |
| # Tuple fetch, use as tuple |
| values_out, indices_out, dense_shape_out = s.run(ind) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(dense_shape_out, dense_shape) |
| # List fetch, use as tuple |
| (values_out, indices_out, dense_shape_out), = s.run([ind]) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(dense_shape_out, dense_shape) |
| # List fetch, use as IndexedSlicesValue |
| ind_out, = s.run([ind]) |
| self.assertAllEqual(ind_out.values, values) |
| self.assertAllEqual(ind_out.indices, indices) |
| self.assertAllEqual(ind_out.dense_shape, dense_shape) |
| |
| def testFeedIndexedSlices(self): |
| with session.Session() as s: |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| dense_shape = np.array([7, 9, 2]).astype(np.int64) |
| ind = ops.IndexedSlices( |
| array_ops.placeholder(dtype=np.float32, shape=(2,)), |
| array_ops.placeholder(dtype=np.int64, shape=(2, 3)), |
| array_ops.placeholder(dtype=np.int64, shape=(3,)), |
| ) |
| ind_values = array_ops.identity(ind.values) |
| ind_indices = array_ops.identity(ind.indices) |
| ind_dense_shape = array_ops.identity(ind.dense_shape) |
| ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape) |
| # Feed with tuple |
| values_out, indices_out, dense_shape_out = s.run( |
| [ind_values, ind_indices, ind_dense_shape], { |
| ind: (values, indices, dense_shape) |
| }) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(dense_shape_out, dense_shape) |
| # Feed with IndexedSlicesValue |
| values_out, indices_out, dense_shape_out = s.run( |
| [ind_values, ind_indices, ind_dense_shape], { |
| ind: ops.IndexedSlicesValue(values, indices, dense_shape) |
| }) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(dense_shape_out, dense_shape) |
| # Feed with IndexedSlicesValue, fetch IndexedSlicesValue |
| ind2_out = s.run(ind2, { |
| ind: ops.IndexedSlicesValue(values, indices, dense_shape) |
| }) |
| self.assertAllEqual(ind2_out.values, values) |
| self.assertAllEqual(ind2_out.indices, indices) |
| self.assertAllEqual(ind2_out.dense_shape, dense_shape) |
| |
| def testFetchIndexedSlicesWithoutDenseShape(self): |
| with session.Session() as s: |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| dense_shape = None |
| ind = ops.IndexedSlices( |
| constant_op.constant(values), constant_op.constant(indices), None) |
| # Single fetch, use as tuple |
| ind_out = s.run(ind) |
| values_out, indices_out, dense_shape_out = ind_out |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(dense_shape_out, dense_shape) |
| # Single fetch, use as IndexedSlicesValue |
| ind_out = s.run(ind) |
| self.assertAllEqual(ind_out.values, values) |
| self.assertAllEqual(ind_out.indices, indices) |
| self.assertAllEqual(ind_out.dense_shape, dense_shape) |
| # Tuple fetch, use as tuple |
| values_out, indices_out, dense_shape_out = s.run(ind) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(dense_shape_out, dense_shape) |
| # List fetch, use as tuple |
| (values_out, indices_out, dense_shape_out), = s.run([ind]) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| self.assertAllEqual(dense_shape_out, dense_shape) |
| # List fetch, use as IndexedSlicesValue |
| ind_out, = s.run([ind]) |
| self.assertAllEqual(ind_out.values, values) |
| self.assertAllEqual(ind_out.indices, indices) |
| self.assertAllEqual(ind_out.dense_shape, dense_shape) |
| |
| def testFeedIndexedSlicesWithoutDenseShape(self): |
| with session.Session() as s: |
| values = np.array([1.0, 2.0]).astype(np.float32) |
| indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) |
| dense_shape = None |
| ind = ops.IndexedSlices( |
| array_ops.placeholder(dtype=np.float32, shape=(2,)), |
| array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None) |
| ind_values = array_ops.identity(ind.values) |
| ind_indices = array_ops.identity(ind.indices) |
| ind2 = ops.IndexedSlices(ind_values, ind_indices) |
| # Feed with tuple |
| values_out, indices_out = s.run([ind_values, ind_indices], { |
| ind: (values, indices) |
| }) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| # Feed with IndexedSlicesValue |
| values_out, indices_out = s.run([ind_values, ind_indices], { |
| ind: ops.IndexedSlicesValue(values, indices, dense_shape) |
| }) |
| self.assertAllEqual(values_out, values) |
| self.assertAllEqual(indices_out, indices) |
| # Feed with IndexedSlicesValue, fetch IndexedSlicesValue |
| ind2_out = s.run(ind2, { |
| ind: ops.IndexedSlicesValue(values, indices, dense_shape) |
| }) |
| self.assertAllEqual(ind2_out.values, values) |
| self.assertAllEqual(ind2_out.indices, indices) |
| self.assertAllEqual(ind2_out.dense_shape, dense_shape) |
| |
| def testExtendWithStatelessOperations(self): |
| with session.Session() as s: |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[2, 3]) |
| c = math_ops.matmul(a, b) |
| c_val = s.run(c) |
| self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) |
| d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) |
| e = math_ops.matmul(c, d) |
| # Extend will happen here. |
| e_val = s.run(e) |
| self.assertAllEqual([[24.0]], e_val) |
| |
| def testExtendWithStatefulOperations(self): |
| with session.Session() as s: |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[2, 3]) |
| c = math_ops.matmul(a, b) |
| v = variables.Variable(c, name='testExtendWithStatefulOperations_v') |
| v.initializer.run() |
| v_val = v.eval() |
| self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) |
| d = constant_op.constant(3.0, shape=[2, 3]) |
| e = math_ops.matmul(a, d) |
| assign_e_to_v = state_ops.assign(v, e) |
| # Extend will happen here. |
| e_val = e.eval() |
| self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) |
| v_val = v.eval() |
| self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) |
| s.run(assign_e_to_v) |
| v_val = v.eval() |
| self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) |
| |
| def testExtendWithGroupBy(self): |
| with session.Session() as s: |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| p = variables.Variable(a, name='testExtendWithGroupBy_p') |
| a_val = a.eval() # Force an Extend after this op. |
| self.assertAllEqual([[1.0, 1.0]], a_val) |
| |
| b = constant_op.constant(2.0, shape=[1, 2]) |
| q = variables.Variable(b, name='testExtendWithGroupBy_q') |
| # Extend will happen here. |
| init = control_flow_ops.group(p.initializer, q.initializer) |
| s.run(init) |
| p_val, q_val = s.run([p, q]) |
| |
| self.assertAllEqual([[1.0, 1.0]], p_val) |
| self.assertAllEqual([[2.0, 2.0]], q_val) |
| |
| def testTensorGetMethod(self): |
| with session.Session(): |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[2, 3]) |
| c = math_ops.matmul(a, b) |
| |
| c_val = c.eval() |
| self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) |
| |
| fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]}) |
| self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val) |
| |
| def testOperationRunMethod(self): |
| with session.Session(): |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[1, 2], name='b') |
| v = variables.Variable(a, a.dtype) |
| assign_a_to_v = state_ops.assign(v, a) |
| |
| assign_a_to_v.eval() |
| |
| v_val = v.eval() |
| self.assertAllEqual([[1.0, 1.0]], v_val) |
| |
| assign_b_to_v = state_ops.assign(v, b) |
| |
| assign_b_to_v.eval() |
| v_val = v.eval() |
| self.assertAllEqual([[2.0, 2.0]], v_val) |
| |
| assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]}) |
| v_val = v.eval() |
| self.assertAllEqual([[3.0, 3.0]], v_val) |
| |
| def testDefaultGraph(self): |
| with session.Session() as s: |
| self.assertEqual(ops.get_default_graph(), s.graph) |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[2, 3]) |
| self.assertEqual(ops.get_default_graph(), a.graph) |
| self.assertEqual(ops.get_default_graph(), b.graph) |
| c = math_ops.matmul(a, b) |
| v = variables.Variable(c, name='testDefaultGraph_v') |
| v.initializer.run() |
| v_val = v.eval() |
| self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) |
| d = constant_op.constant(3.0, shape=[2, 3]) |
| e = math_ops.matmul(a, d) |
| assign_e_to_v = state_ops.assign(v, e) |
| e_val = e.eval() |
| self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) |
| v_val = v.eval() |
| self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) |
| s.run(assign_e_to_v) |
| v_val = v.eval() |
| self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) |
| self.assertEqual(ops.get_default_graph(), s.graph) |
| |
| def _testDefaultGraphInThread(self, constructed_event, continue_event, i): |
| with session.Session() as s: |
| self.assertEqual(ops.get_default_graph(), s.graph) |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[2, 3]) |
| c = math_ops.matmul(a, b) |
| v = variables.Variable(c, name='var_%d' % i) |
| |
| # Block here until all threads have constructed their graph. |
| constructed_event.set() |
| continue_event.wait() |
| |
| assign_c_to_v = state_ops.assign(v, c) |
| v.initializer.run() |
| assign_c_to_v.eval() |
| v_val = v.eval() |
| self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) |
| d = constant_op.constant(3.0, shape=[2, 3]) |
| e = math_ops.matmul(a, d) |
| assign_e_to_v = state_ops.assign(v, e) |
| e_val = e.eval() |
| self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) |
| v_val = v.eval() |
| self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) |
| s.run(assign_e_to_v) |
| v_val = v.eval() |
| self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) |
| self.assertEqual(ops.get_default_graph(), s.graph) |
| |
| def testDefaultGraphWithThreads(self): |
| # Fork ten threads that use their thread-local default graph. |
| threads = [] |
| constructed_events = [threading.Event() for _ in range(10)] |
| continue_event = threading.Event() |
| for i, constructed_event in enumerate(constructed_events): |
| t = self.checkedThread( |
| target=self._testDefaultGraphInThread, |
| args=(constructed_event, continue_event, i)) |
| threads.append(t) |
| for t in threads: |
| t.start() |
| for constructed_event in constructed_events: |
| constructed_event.wait() |
| continue_event.set() |
| for t in threads: |
| t.join() |
| |
| def testParallelRun(self): |
| with session.Session() as sess: |
| c = constant_op.constant(5.0) |
| ev = threading.Event() |
| |
| def run_step(): |
| ev.wait() |
| val = c.eval(session=sess) |
| self.assertEqual(val, 5.0) |
| |
| threads = [self.checkedThread(target=run_step) for _ in range(100)] |
| for t in threads: |
| t.start() |
| ev.set() |
| for t in threads: |
| t.join() |
| |
| @staticmethod |
| def _build_graph(): |
| time.sleep(random.random() * 0.1) |
| # Do some graph construction. Try to exercise non-trivial paths. |
| graph = ops.get_default_graph() |
| gdef = None |
| for _ in range(10): |
| x = array_ops.placeholder(dtype=dtypes.float32) |
| with ops.colocate_with(x): |
| y = array_ops.placeholder(dtype=dtypes.float32) |
| with ops.device('/cpu:0'): |
| z = control_flow_ops.while_loop( |
| lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) |
| with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): |
| gradients_impl.gradients(z, [x, y]) |
| if gdef is None: |
| gdef = graph.as_graph_def() |
| else: |
| importer.import_graph_def(gdef, name='import') |
| |
| def testParallelRunAndSingleBuild(self): |
| with session.Session() as sess: |
| c = constant_op.constant(5.0) |
| stop = threading.Event() |
| |
| def run_loop(): |
| while not stop.is_set(): |
| time.sleep(random.random() * 0.1) |
| self.assertEqual(sess.run(c), 5.0) |
| |
| threads = [self.checkedThread(target=run_loop) for _ in range(10)] |
| for t in threads: |
| t.start() |
| |
| SessionTest._build_graph() |
| |
| stop.set() |
| for t in threads: |
| t.join() |
| |
| def testParallelRunAndParallelBuild(self): |
| with session.Session() as sess: |
| c = constant_op.constant(5.0) |
| stop = threading.Event() |
| |
| def run_loop(): |
| while not stop.is_set(): |
| time.sleep(random.random() * 0.1) |
| self.assertEqual(sess.run(c), 5.0) |
| |
| run_threads = [self.checkedThread(target=run_loop) for _ in range(10)] |
| for t in run_threads: |
| t.start() |
| |
| build_threads = [self.checkedThread(target=SessionTest._build_graph) |
| for _ in range(10)] |
| for t in build_threads: |
| t.start() |
| for t in build_threads: |
| t.join() |
| |
| # Let the run_threads run until the build threads are finished. |
| stop.set() |
| for t in run_threads: |
| t.join() |
| |
| def testRunFeedDict(self): |
| with session.Session() as s: |
| x = array_ops.zeros([2]) |
| |
| y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)}) |
| self.assertAllEqual(y, 2 * np.ones(2)) |
| |
| y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)}) |
| self.assertAllEqual(y, 2 * np.ones(2)) |
| |
| y = s.run(2 * x, feed_dict={x: [1, 1]}) |
| assert (y == 2 * np.ones(2)).all() |
| |
| # Test nested tuple keys |
| z = (((array_ops.zeros([2]),),), array_ops.zeros([2]), |
| (array_ops.zeros([2]),)) |
| result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2] |
| values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),)) |
| result_value = s.run(result, feed_dict={z: values}) |
| self.assertAllEqual(result_value[0], 2 * np.ones(2)) |
| self.assertAllEqual(result_value[1], 2 * np.array([2, 2])) |
| self.assertAllEqual(result_value[2], 2 * np.array([3, 3])) |
| |
| def testGraphDef(self): |
| with session.Session() as sess: |
| self.assertProtoEquals('versions { producer: %d min_consumer: %d }' % |
| (versions.GRAPH_DEF_VERSION, |
| versions.GRAPH_DEF_VERSION_MIN_CONSUMER), |
| sess.graph_def) |
| c = constant_op.constant(5.0, name='c') |
| self.assertEquals(len(sess.graph_def.node), 1) |
| d = constant_op.constant(6.0, name='d') |
| self.assertEquals(len(sess.graph_def.node), 2) |
| self.assertAllEqual(c.eval(), 5.0) |
| self.assertAllEqual(d.eval(), 6.0) |
| e = constant_op.constant(7.0, name='e') |
| self.assertEquals(len(sess.graph_def.node), 3) |
| self.assertAllEqual(e.eval(), 7.0) |
| |
| def testUseAfterClose(self): |
| with session.Session() as sess: |
| c = constant_op.constant(5.0) |
| self.assertAllEqual(sess.run(c), 5.0) |
| with self.assertRaisesWithPredicateMatch( |
| RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)): |
| sess.run(c) |
| |
| def testUseAfterCloseConcurrent(self): |
| with session.Session() as sess: |
| c = constant_op.constant(5.0) |
| self.assertAllEqual(sess.run(c), 5.0) |
| |
| def update_thread(): |
| with self.assertRaisesWithPredicateMatch( |
| RuntimeError, |
| lambda e: 'Attempted to use a closed Session.' in str(e)): |
| while True: |
| sess.run(c) |
| |
| t = threading.Thread(target=update_thread) |
| t.start() |
| time.sleep(0.1) |
| sess.close() |
| t.join() |
| |
| def testUseEmptyGraph(self): |
| with session.Session() as sess: |
| with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): |
| sess.run([]) |
| with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): |
| sess.run(()) |
| with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): |
| sess.run({}) |
| |
| def testNotEntered(self): |
| # pylint: disable=protected-access |
| self.assertEqual(ops._default_session_stack.get_default(), None) |
| # pylint: enable=protected-access |
| with ops.device('/cpu:0'): |
| sess = session.Session() |
| c_1 = constant_op.constant(5.0) |
| with sess.graph.as_default(): |
| c_2 = constant_op.constant(5.0) |
| self.assertEqual(c_1.graph, c_2.graph) |
| self.assertEqual(sess.run(c_2), 5.0) |
| with self.assertRaisesWithPredicateMatch( |
| ValueError, lambda e: 'No default session is registered.' in str(e)): |
| c_2.eval() |
| |
| def testInteractive(self): |
| with ops.device('/cpu:0'): |
| sess = session.InteractiveSession() |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[2, 3]) |
| c = math_ops.matmul(a, b) |
| self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval()) |
| d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) |
| e = math_ops.matmul(c, d) |
| self.assertAllEqual([[24.0]], e.eval()) |
| sess.close() |
| |
| def testMultipleInteractiveSessionsWarning(self): |
| # Reinitialize the global state to ensure that the expected warnings will |
| # be emitted. |
| session.InteractiveSession._active_session_count = 0 # pylint: disable=protected-access |
| |
| sess = session.InteractiveSession() |
| sess.run(constant_op.constant(4.0)) # Run so that the session is "opened". |
| sess.close() |
| # Opening and closing interactive sessions serially should not warn. |
| with warnings.catch_warnings(record=True) as w: |
| sess = session.InteractiveSession() |
| sess.close() |
| self.assertEqual(0, len(w)) |
| |
| with warnings.catch_warnings(record=True) as w: |
| sess = session.InteractiveSession() |
| self.assertEqual(0, len(w)) |
| with warnings.catch_warnings(record=True) as w: |
| sess2 = session.InteractiveSession() |
| self.assertEqual(1, len(w)) |
| self.assertTrue('An interactive session is already active. This can cause ' |
| 'out-of-memory errors in some cases. You must explicitly ' |
| 'call `InteractiveSession.close()` to release resources ' |
| 'held by the other session(s).' in str(w[0].message)) |
| sess2.close() |
| sess.close() |
| |
| def testInteractivePlacePrunedGraph(self): |
| sess = session.InteractiveSession() |
| |
| # Build a graph that has a bad op in it (no kernel). |
| # |
| # This test currently does not link in any GPU kernels, |
| # which is why placing this is invalid. If at some point |
| # GPU kernels are added to this test, some other different |
| # op / device combo should be chosen. |
| with ops.device('/device:GPU:0'): |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| |
| b = constant_op.constant(1.0, shape=[1, 2]) |
| |
| # Only run the valid op, this should work. |
| b.eval() |
| |
| with self.assertRaises(errors.InvalidArgumentError): |
| a.eval() |
| sess.close() |
| |
| def testDefaultSessionPlacePrunedGraph(self): |
| sess = session.Session() |
| |
| # Build a graph that has a bad op in it (no kernel). |
| # |
| # This test currently does not link in any GPU kernels, |
| # which is why placing this is invalid. If at some point |
| # GPU kernels are added to this test, some other different |
| # op / device combo should be chosen. |
| with ops.device('/device:GPU:0'): |
| _ = constant_op.constant(1.0, shape=[1, 2]) |
| |
| b = constant_op.constant(1.0, shape=[1, 2]) |
| |
| with self.assertRaises(errors.InvalidArgumentError): |
| # Even though we don't run the bad op, we place the entire |
| # graph, which should fail with a non-interactive session. |
| sess.run(b) |
| |
| sess.close() |
| |
| def testSharedGraph(self): |
| with ops.Graph().as_default() as g, ops.device('/cpu:0'): |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[2, 3]) |
| c = math_ops.matmul(a, b) |
| |
| with session.Session(graph=g) as sess1: |
| with session.Session(graph=g) as sess2: |
| self.assertAllEqual(sess1.run(c), sess2.run(c)) |
| |
| def testDuplicatedInputs(self): |
| with session.Session() as sess: |
| a = constant_op.constant(1.0, shape=[1, 2]) |
| b = constant_op.constant(2.0, shape=[1, 3]) |
| a_val, b_val, a2_val = sess.run([a, b, a]) |
| self.assertAllEqual(a_val, [[1.0, 1.0]]) |
| self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]]) |
| self.assertAllEqual(a2_val, [[1.0, 1.0]]) |
| |
| def testFeedAndFetch(self): |
| with session.Session() as sess: |
| for dtype in [ |
| dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, |
| dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool, |
| dtypes.complex64, dtypes.complex128 |
| ]: |
| for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: |
| np_dtype = dtype.as_numpy_dtype |
| |
| feed_t = array_ops.placeholder(dtype=dtype, shape=shape) |
| out_t = array_ops.identity(feed_t) |
| |
| np_array = np.random.randint(-10, 10, shape) |
| |
| if dtype == dtypes.bool: |
| np_array = np_array > 0 |
| elif dtype == dtypes.complex64: |
| np_array = np.sqrt(np_array.astype(np_dtype)) |
| elif dtype == dtypes.complex64: |
| np_array = np.sqrt(np_array.astype(np_dtype)) |
| else: |
| np_array = np_array.astype(np_dtype) |
| |
| self.assertAllEqual(np_array, |
| sess.run(out_t, feed_dict={ |
| feed_t: np_array |
| })) |
| # Check that we can also get the feed back. |
| self.assertAllEqual(np_array, |
| sess.run(feed_t, feed_dict={ |
| feed_t: np_array |
| })) |
| # Also check that we can get both back. |
| out_v, feed_v = sess.run( |
| [out_t, feed_t], feed_dict={ |
| feed_t: np_array |
| }) |
| self.assertAllEqual(np_array, out_v) |
| self.assertAllEqual(np_array, feed_v) |
| |
| feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t]) |
| out_v, feed_v = feed_fetch_runner(np_array) |
| self.assertAllEqual(np_array, out_v) |
| self.assertAllEqual(np_array, feed_v) |
| |
| def testMakeCallableOnTensorWithRunOptions(self): |
| with session.Session() as sess: |
| a = constant_op.constant(42.0) |
| tensor_runner = sess.make_callable(a, accept_options=True) |
| run_options = config_pb2.RunOptions( |
| trace_level=config_pb2.RunOptions.FULL_TRACE) |
| run_metadata = config_pb2.RunMetadata() |
| self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) |
| res = tensor_runner(options=run_options, run_metadata=run_metadata) |
| self.assertEqual(42.0, res) |
| self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) |
| |
| def testMakeCallableOnOperationWithRunOptions(self): |
| with session.Session() as sess: |
| a = variables.Variable(42.0) |
| b = state_ops.assign_add(a, 1.0) |
| sess.run(a.initializer) |
| tensor_runner = sess.make_callable(b.op, accept_options=True) |
| run_options = config_pb2.RunOptions( |
| trace_level=config_pb2.RunOptions.FULL_TRACE) |
| run_metadata = config_pb2.RunMetadata() |
| self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) |
| tensor_runner(options=run_options, run_metadata=run_metadata) |
| self.assertEqual(43.0, sess.run(a)) |
| self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) |
| |
| def testMakeCallableWithFeedListAndRunOptions(self): |
| with session.Session() as sess: |
| ph = array_ops.placeholder(dtypes.float32) |
| a = math_ops.add(ph, 1.0) |
| tensor_runner = sess.make_callable( |
| a, feed_list=[ph.name], accept_options=True) |
| run_options = config_pb2.RunOptions( |
| trace_level=config_pb2.RunOptions.FULL_TRACE) |
| run_metadata = config_pb2.RunMetadata() |
| self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) |
| self.assertAllClose(42.0, |
| tensor_runner( |
| 41.0, |
| options=run_options, |
| run_metadata=run_metadata)) |
| self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) |
| |
| def testOptimizedMakeCallable(self): |
| with session.Session() as sess: |
| ph = array_ops.placeholder(dtypes.float32) |
| a = math_ops.add(ph, 1.0) |
| callable_opts = config_pb2.CallableOptions() |
| callable_opts.feed.append(ph.name) |
| callable_opts.fetch.append(a.name) |
| for _ in range(3): |
| callable_fn = sess._make_callable_from_options(callable_opts) |
| for _ in range(5): |
| self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32))) |
| |
| def testOptimizedMakeCallableWithRunMetadata(self): |
| with session.Session() as sess: |
| ph = array_ops.placeholder(dtypes.float32) |
| a = math_ops.add(ph, 1.0) |
| callable_opts = config_pb2.CallableOptions() |
| callable_opts.feed.append(ph.name) |
| callable_opts.fetch.append(a.name) |
| callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE |
| callable_fn = sess._make_callable_from_options(callable_opts) |
| run_metadata = config_pb2.RunMetadata() |
| self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32), |
| run_metadata=run_metadata)) |
| self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) |
| |
| def testFeedError(self): |
| with session.Session() as sess: |
| feed_t = array_ops.placeholder(dtype=dtypes.float32) |
| out_t = array_ops.identity(feed_t) |
| feed_val = constant_op.constant(5.0) |
| with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): |
| sess.run(out_t, feed_dict={feed_t: feed_val}) |
| with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): |
| out_t.eval(feed_dict={feed_t: feed_val}) |
| with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): |
| out_t.op.run(feed_dict={feed_t: feed_val}) |
| |
| def testFeedPrecisionLossError(self): |
| with session.Session() as sess: |
| largest_int64 = np.iinfo(np.int64).max |
| |
| feed_int_implicit_int32 = constant_op.constant(1) |
| feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32) |
| |
| out_t = constant_op.constant(1.0) |
| |
| with self.assertRaisesRegexp(TypeError, |
| 'is not compatible with Tensor type'): |
| sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64}) |
| with self.assertRaisesRegexp(TypeError, |
| 'is not compatible with Tensor type'): |
| sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64}) |
| |
| def testStringFetch(self): |
| with session.Session(): |
| for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: |
| size = 1 |
| for s in shape: |
| size *= s |
| c_list = np.array( |
| [compat.as_bytes(str(i)) for i in xrange(size)], |
| dtype=np.object).reshape(shape) if size > 0 else [] |
| c = constant_op.constant(c_list) |
| self.assertAllEqual(c.eval(), c_list) |
| |
| def testStringFeed(self): |
| with session.Session() as sess: |
| for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: |
| size = 1 |
| for s in shape: |
| size *= s |
| c_list = np.array( |
| [compat.as_bytes(str(i)) for i in xrange(size)], |
| dtype=np.object).reshape(shape) |
| feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape) |
| c = array_ops.identity(feed_t) |
| self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list) |
| self.assertAllEqual( |
| sess.run(feed_t, feed_dict={ |
| feed_t: c_list |
| }), c_list) |
| c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list}) |
| self.assertAllEqual(c_v, c_list) |
| self.assertAllEqual(feed_v, c_list) |
| |
| def testStringFeedWithNullCharacters(self): |
| with session.Session(): |
| c_list = [b'\n\x01\x00', b'\n\x00\x01'] |
| feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2]) |
| c = array_ops.identity(feed_t) |
| out = c.eval(feed_dict={feed_t: c_list}) |
| self.assertEqual(c_list[0], out[0]) |
| self.assertEqual(c_list[1], out[1]) |
| |
| def testStringFeedWithUnicode(self): |
| with session.Session(): |
| c_list = [ |
| u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode', |
| u'\U0001f60e deal with it' |
| ] |
| feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)]) |
| c = array_ops.identity(feed_t) |
| |
| out = c.eval(feed_dict={feed_t: c_list}) |
| for i in range(len(c_list)): |
| self.assertEqual(c_list[i], out[i].decode('utf-8')) |
| |
| out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)}) |
| for i in range(len(c_list)): |
| self.assertEqual(c_list[i], out[i].decode('utf-8')) |
| |
| def testInvalidTargetFails(self): |
| with self.assertRaisesRegexp( |
| errors.NotFoundError, |
| 'No session factory registered for the given session options'): |
| session.Session('INVALID_TARGET') |
| |
| def testFetchByNameDifferentStringTypes(self): |
| with session.Session() as sess: |
| c = constant_op.constant(42.0, name='c') |
| d = constant_op.constant(43.0, name=u'd') |
| e = constant_op.constant(44.0, name=b'e') |
| f = constant_op.constant(45.0, name=r'f') |
| |
| self.assertTrue(isinstance(c.name, six.text_type)) |
| self.assertTrue(isinstance(d.name, six.text_type)) |
| self.assertTrue(isinstance(e.name, six.text_type)) |
| self.assertTrue(isinstance(f.name, six.text_type)) |
| |
| self.assertEqual(42.0, sess.run('c:0')) |
| self.assertEqual(42.0, sess.run(u'c:0')) |
| self.assertEqual(42.0, sess.run(b'c:0')) |
| self.assertEqual(42.0, sess.run(r'c:0')) |
| |
| self.assertEqual(43.0, sess.run('d:0')) |
| self.assertEqual(43.0, sess.run(u'd:0')) |
| self.assertEqual(43.0, sess.run(b'd:0')) |
| self.assertEqual(43.0, sess.run(r'd:0')) |
| |
| self.assertEqual(44.0, sess.run('e:0')) |
| self.assertEqual(44.0, sess.run(u'e:0')) |
| self.assertEqual(44.0, sess.run(b'e:0')) |
| self.assertEqual(44.0, sess.run(r'e:0')) |
| |
| self.assertEqual(45.0, sess.run('f:0')) |
| self.assertEqual(45.0, sess.run(u'f:0')) |
| self.assertEqual(45.0, sess.run(b'f:0')) |
| self.assertEqual(45.0, sess.run(r'f:0')) |
| |
| def testIncorrectGraph(self): |
| with ops.Graph().as_default() as g_1: |
| c_1 = constant_op.constant(1.0, name='c') |
| |
| with ops.Graph().as_default() as g_2: |
| c_2 = constant_op.constant(2.0, name='c') |
| |
| self.assertEqual('c', c_1.op.name) |
| self.assertEqual('c', c_2.op.name) |
| |
| with session.Session(graph=g_1) as sess_1: |
| self.assertEqual(1.0, sess_1.run(c_1)) |
| with self.assertRaises(ValueError): |
| sess_1.run(c_2) |
| with self.assertRaises(ValueError): |
| sess_1.run(c_2.op) |
| |
| with session.Session(graph=g_2) as sess_2: |
| with self.assertRaises(ValueError): |
| sess_2.run(c_1) |
| with self.assertRaises(ValueError): |
| sess_2.run(c_1.op) |
| self.assertEqual(2.0, sess_2.run(c_2)) |
| |
| def testFeedDictKeyException(self): |
| with session.Session() as sess: |
| a = constant_op.constant(1.0, dtypes.float32, name='a') |
| with self.assertRaisesRegexp(TypeError, 'Cannot interpret feed_dict'): |
| sess.run(a, feed_dict={'a': [2.0]}) |
| |
| def testPerStepTrace(self): |
| run_options = config_pb2.RunOptions( |
| trace_level=config_pb2.RunOptions.FULL_TRACE) |
| run_metadata = config_pb2.RunMetadata() |
| |
| with ops.device('/cpu:0'): |
| with session.Session() as sess: |
| sess.run(constant_op.constant(1.0)) |
| self.assertTrue(not run_metadata.HasField('step_stats')) |
| |
| sess.run(constant_op.constant(1.0), run_metadata=run_metadata) |
| self.assertTrue(not run_metadata.HasField('step_stats')) |
| |
| sess.run( |
| constant_op.constant(1.0), |
| options=run_options, |
| run_metadata=run_metadata) |
| |
| self.assertTrue(run_metadata.HasField('step_stats')) |
| self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) |
| |
| def testRunOptionsRunMetadata(self): |
| run_options = config_pb2.RunOptions( |
| trace_level=config_pb2.RunOptions.FULL_TRACE) |
| run_metadata = config_pb2.RunMetadata() |
| |
| with ops.device('/cpu:0'): |
| with session.Session() as sess: |
| # all combinations are valid |
| sess.run(constant_op.constant(1.0), options=None, run_metadata=None) |
| sess.run( |
| constant_op.constant(1.0), options=None, run_metadata=run_metadata) |
| self.assertTrue(not run_metadata.HasField('step_stats')) |
| |
| sess.run( |
| constant_op.constant(1.0), options=run_options, run_metadata=None) |
| self.assertTrue(not run_metadata.HasField('step_stats')) |
| |
| sess.run( |
| constant_op.constant(1.0), |
| options=run_options, |
| run_metadata=run_metadata) |
| |
| self.assertTrue(run_metadata.HasField('step_stats')) |
| self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) |
| |
| def testFeedShapeCompatibility(self): |
| with session.Session() as sess: |
| some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) |
| new_shape = constant_op.constant([2, 2]) |
| reshaped_tensor = array_ops.reshape(some_tensor, new_shape) |
| |
| with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'): |
| sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) |
| |
| with self.assertRaisesRegexp( |
| errors.InvalidArgumentError, |
| 'Input to reshape is a tensor with 4 values, ' |
| 'but the requested shape has 21'): |
| sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) |
| |
| def testInferShapesFalse(self): |
| with ops.Graph().as_default(), ops.device('/cpu:0'): |
| a = constant_op.constant([[1, 2]]) |
| sess = session.Session() |
| self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr) |
| # Avoid lint error regarding 'unused' var a. |
| self.assertTrue(a == a) |
| |
| def testInferShapesTrue(self): |
| config = config_pb2.ConfigProto( |
| graph_options=config_pb2.GraphOptions(infer_shapes=True)) |
| with ops.Graph().as_default(), ops.device('/cpu:0'): |
| a = constant_op.constant([[1, 2]]) |
| sess = session.Session(config=config) |
| self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr) |
| # Avoid lint error regarding 'unused' var a. |
| self.assertTrue(a == a) |
| |
| def testBuildCostModel(self): |
| run_options = config_pb2.RunOptions() |
| config = config_pb2.ConfigProto( |
| allow_soft_placement=True, |
| graph_options=config_pb2.GraphOptions(build_cost_model=100)) |
| with session.Session(config=config) as sess: |
| with ops.device('/device:GPU:0'): |
| a = array_ops.placeholder(dtypes.float32, shape=[]) |
| b = math_ops.add(a, a) |
| c = array_ops.identity(b) |
| d = math_ops.multiply(c, c) |
| for step in xrange(120): |
| run_metadata = config_pb2.RunMetadata() |
| sess.run( |
| d, |
| feed_dict={a: 1.0}, |
| options=run_options, |
| run_metadata=run_metadata) |
| if step == 99: |
| self.assertTrue(run_metadata.HasField('cost_graph')) |
| else: |
| self.assertFalse(run_metadata.HasField('cost_graph')) |
| |
| def runTestOutputPartitionGraphs(self, sess): |
| run_options = config_pb2.RunOptions(output_partition_graphs=True) |
| a = constant_op.constant(1) |
| run_metadata = config_pb2.RunMetadata() |
| sess.run(a, options=run_options, run_metadata=run_metadata) |
| self.assertGreater(len(run_metadata.partition_graphs), 0) |
| sess.run(a, run_metadata=run_metadata) |
| self.assertEqual(len(run_metadata.partition_graphs), 0) |
| |
| def testOutputPartitionGraphsDirect(self): |
| self.runTestOutputPartitionGraphs(session.Session()) |
| |
| def testOutputPartitionGraphsDistributed(self): |
| server = server_lib.Server.create_local_server() |
| self.runTestOutputPartitionGraphs(session.Session(server.target)) |
| |
| def testNonInteractiveSessionNesting(self): |
| sess1 = session.Session() |
| sess1_controller = sess1.as_default() |
| sess1_controller.__enter__() |
| |
| sess2 = session.Session() |
| sess2_controller = sess2.as_default() |
| sess2_controller.__enter__() |
| |
| with self.assertRaisesRegexp(AssertionError, 'Nesting violated'): |
| sess1_controller.__exit__(None, None, None) |
| |
| ops._default_session_stack.reset() |
| |
| def testInteractiveSessionNesting(self): |
| sess1 = session.InteractiveSession() |
| sess2 = session.InteractiveSession() |
| del sess1 |
| del sess2 |
| |
| def testAsDefault(self): |
| c = constant_op.constant(37) |
| sess = session.Session() |
| with sess.as_default(): |
| self.assertEqual(37, c.eval()) |
| |
| # Ensure that the session remains valid even when it is not captured. |
| with session.Session().as_default(): |
| self.assertEqual(37, c.eval()) |
| |
| def testReentry(self): |
| sess = session.Session() |
| with self.assertRaisesRegexp(RuntimeError, 'not re-entrant'): |
| with sess: |
| with sess: |
| pass |
| |
| def testInvalidArgument(self): |
| with self.assertRaisesRegexp(TypeError, 'target must be a string'): |
| session.Session(37) |
| with self.assertRaisesRegexp(TypeError, 'config must be a tf.ConfigProto'): |
| session.Session(config=37) |
| with self.assertRaisesRegexp(TypeError, 'graph must be a tf.Graph'): |
| session.Session(graph=37) |
| |
| def testTimeoutWithShortOperations(self): |
| num_epochs = 5 |
| q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()]) |
| enqueue_op = q.enqueue_many(constant_op.constant([1, 2])) |
| |
| # Use a 10-second timeout, which should be longer than any |
| # non-blocking enqueue_many op. |
| config = config_pb2.ConfigProto(operation_timeout_in_ms=10000) |
| with session.Session(config=config) as sess: |
| for _ in range(num_epochs): |
| sess.run(enqueue_op) |
| self.assertEqual(sess.run(q.size()), num_epochs * 2) |
| |
| def testRegisterFetchAndFeedConversionFunctions(self): |
| |
| class SquaredTensor(object): |
| |
| def __init__(self, tensor): |
| self.sq = math_ops.square(tensor) |
| |
| fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0]) |
| feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)] |
| feed_fn2 = lambda feed: [feed.sq] |
| |
| session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, |
| feed_fn1, feed_fn2) |
| with self.assertRaises(ValueError): |
| session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, |
| feed_fn1, feed_fn2) |
| with self.cached_session() as sess: |
| np1 = np.array([1.0, 1.5, 2.0, 2.5]) |
| np2 = np.array([3.0, 3.5, 4.0, 4.5]) |
| squared_tensor = SquaredTensor(np2) |
| squared_eval = sess.run(squared_tensor) |
| self.assertAllClose(np2 * np2, squared_eval) |
| squared_eval = sess.run( |
| squared_tensor, feed_dict={ |
| squared_tensor: np1 * np1 |
| }) |
| self.assertAllClose(np1 * np1, squared_eval) |
| partial_run = sess.partial_run_setup([squared_tensor], []) |
| squared_eval = sess.partial_run(partial_run, squared_tensor) |
| self.assertAllClose(np2 * np2, squared_eval) |
| |
| def testDefaultLogDevicePlacement(self): |
| |
| class CaptureStderr(str): |
| """Class to capture stderr from C++ shared library.""" |
| |
| def __enter__(self): |
| self._esc = compat.as_str('\b') |
| self._output = compat.as_str('') |
| self._stderr = sys.stderr |
| self._fd = self._stderr.fileno() |
| self._out_pipe, in_pipe = os.pipe() |
| # Save the original io stream. |
| self._dup_fd = os.dup(self._fd) |
| # Replace the original io stream with in pipe. |
| os.dup2(in_pipe, self._fd) |
| return self |
| |
| def __exit__(self, *args): |
| self._stderr.write(self._esc) |
| self._stderr.flush() |
| self.read() |
| os.close(self._out_pipe) |
| # Restore the original io stream. |
| os.dup2(self._dup_fd, self._fd) |
| |
| def read(self): |
| while True: |
| data = os.read(self._out_pipe, 1) |
| if not data or compat.as_str(data) == self._esc: |
| break |
| self._output += compat.as_str(data) |
| |
| def __str__(self): |
| return self._output |
| |
| # Passing the config to the server, but not the session should still result |
| # in logging device placement. |
| config = config_pb2.ConfigProto(log_device_placement=True) |
| server = server_lib.Server.create_local_server(config=config) |
| a = constant_op.constant(1) |
| b = constant_op.constant(2) |
| c = a + b |
| with session.Session(server.target) as sess: |
| with CaptureStderr() as log: |
| sess.run(c) |
| # Ensure that we did log device placement. |
| self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in str(log), |
| str(log)) |
| |
| def testLocalMasterSessionTimeout(self): |
| # Test that the timeout passed in a config to the session works correctly. |
| config = config_pb2.ConfigProto(operation_timeout_in_ms=1000) |
| server = server_lib.Server.create_local_server() |
| q = data_flow_ops.FIFOQueue(1, dtypes.float32) |
| dequeued_t = q.dequeue() |
| |
| with session.Session(server.target, config=config) as sess: |
| # Intentionally do not run any enqueue_ops so that dequeue will block |
| # until operation_timeout_in_ms. |
| with self.assertRaises(errors.DeadlineExceededError): |
| sess.run(dequeued_t) |
| |
| def testDefaultServerTimeout(self): |
| # Test that the default server config timeout gets used when no Session |
| # config is provided. |
| config = config_pb2.ConfigProto(operation_timeout_in_ms=1000) |
| server = server_lib.Server.create_local_server(config=config) |
| q = data_flow_ops.FIFOQueue(1, dtypes.float32) |
| dequeued_t = q.dequeue() |
| |
| with session.Session(server.target) as sess: |
| # Intentionally do not run any enqueue_ops so that dequeue will block |
| # until operation_timeout_in_ms. |
| with self.assertRaises(errors.DeadlineExceededError): |
| sess.run(dequeued_t) |
| |
| def runTestBuildGraphError(self, sess): |
| # Ensure that errors from building the graph get propagated. |
| data = array_ops.placeholder(dtypes.float32, shape=[]) |
| # pylint: disable=protected-access |
| enter_1 = gen_control_flow_ops.enter(data, 'foo_1', False) |
| enter_2 = gen_control_flow_ops.enter(data, 'foo_2', False) |
| # pylint: enable=protected-access |
| res = math_ops.add(enter_1, enter_2) |
| with self.assertRaisesOpError('has inputs from different frames'): |
| sess.run(res, feed_dict={data: 1.0}) |
| |
| def testBuildGraphErrorDirect(self): |
| self.runTestBuildGraphError(session.Session()) |
| |
| def testBuildGraphErrorDist(self): |
| server = server_lib.Server.create_local_server() |
| self.runTestBuildGraphError(session.Session(server.target)) |
| |
| def testDeviceAttributes(self): |
| attrs = session._DeviceAttributes( |
| '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000) |
| self.assertEqual(1337, attrs.memory_limit_bytes) |
| self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name) |
| self.assertEqual('TYPE', attrs.device_type) |
| self.assertEqual(1000000, attrs.incarnation) |
| str_repr = '%s' % attrs |
| self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) |
| |
| def testDeviceAttributesCanonicalization(self): |
| attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1', |
| 'TYPE', 1337, 1000000) |
| self.assertEqual(1337, attrs.memory_limit_bytes) |
| self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name) |
| self.assertEqual('TYPE', attrs.device_type) |
| self.assertEqual(1000000, attrs.incarnation) |
| str_repr = '%s' % attrs |
| self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) |
| |
| def runTestAddFunctionToSession(self, target=''): |
| """Add a function to a session after the graph has already been run.""" |
| |
| @function.Defun(dtypes.float32) |
| def foo(x): |
| return x + 1 |
| |
| x = constant_op.constant(1.0) |
| with session.Session(target=target) as sess: |
| sess.run(x) |
| f = foo(x) |
| result = sess.run(f) |
| self.assertEqual(result, 2.0) |
| |
| def testAddFunctionToSession(self): |
| self.runTestAddFunctionToSession() |
| |
| def testAddFunctionToGrpcSession(self): |
| server = server_lib.Server.create_local_server() |
| self.runTestAddFunctionToSession(server.target) |
| |
| def testOpenAndCloseGrpcSession(self): |
| server = server_lib.Server.create_local_server() |
| with session.Session(server.target): |
| pass |
| |
| def testOpenAndCloseSession(self): |
| with session.Session(): |
| pass |
| |
| def testAutoConvertAndCheckData(self): |
| with self.cached_session() as sess: |
| a = array_ops.placeholder(dtype=dtypes.string) |
| with self.assertRaisesRegexp( |
| TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'): |
| sess.run(a, feed_dict={a: 1}) |
| |
| |
| if __name__ == '__main__': |
| googletest.main() |