| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for tensorflow.python.framework.ops.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import gc |
| import os |
| import threading |
| import weakref |
| |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python.client import session |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import function as eager_function |
| from tensorflow.python.framework import common_shapes |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import device as pydev |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import function |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.framework import test_ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.framework import versions |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import resources |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| import tensorflow.python.ops.gradients # pylint: disable=unused-import |
| from tensorflow.python.platform import googletest |
| from tensorflow.python.util import compat |
| |
| ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn) |
| |
| |
| class ResourceTest(test_util.TensorFlowTestCase): |
| |
| def testBuildGraph(self): |
| with self.cached_session(): |
| pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") |
| test_ops.resource_create_op(pt).run() |
| |
| def testInitialize(self): |
| with self.cached_session(): |
| handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") |
| resources.register_resource( |
| handle=handle, |
| create_op=test_ops.resource_create_op(handle), |
| is_initialized_op=test_ops.resource_initialized_op(handle)) |
| self.assertEquals( |
| len( |
| resources.report_uninitialized_resources( |
| resources.shared_resources()).eval()), 1) |
| resources.initialize_resources(resources.shared_resources()).run() |
| self.assertEquals( |
| len( |
| resources.report_uninitialized_resources( |
| resources.shared_resources()).eval()), 0) |
| |
| |
| class TensorAndShapeTest(test_util.TensorFlowTestCase): |
| |
| def testShape(self): |
| op = ops.Operation( |
| ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) |
| t = op.outputs[0] |
| self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) |
| t.set_shape([1, 2, 3]) |
| self.assertEqual([1, 2, 3], t.get_shape()) |
| |
| def testIterable(self): |
| op = ops.Operation( |
| ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) |
| t = op.outputs[0] |
| self.assertTrue(isinstance(t, ops.Tensor)) |
| with self.assertRaisesRegexp(TypeError, "iter"): |
| for _ in t: |
| pass |
| |
| def testAddShape(self): |
| with self.cached_session(): |
| a = array_ops.zeros([2, 3]) |
| b = array_ops.ones([1, 3]) |
| c = a + b |
| self.assertEqual([2, 3], c.shape) |
| |
| def testUnknownDim(self): |
| with self.cached_session(): |
| a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) |
| b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) |
| c = a + b |
| self.assertEqual([2, None, 3], c.shape.as_list()) |
| |
| def testUnknownShape(self): |
| with self.cached_session(): |
| a = array_ops.placeholder(dtype=dtypes.float32, shape=None) |
| b = array_ops.ones([1, 3]) |
| c = a + b |
| self.assertEqual(tensor_shape.unknown_shape(), c.shape) |
| |
| def testScalarShape(self): |
| with self.cached_session(): |
| a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) |
| b = array_ops.ones([]) |
| c = a + b |
| self.assertEqual(tensor_shape.scalar(), c.shape) |
| |
| def testShapeFunctionError(self): |
| with self.cached_session(): |
| a = array_ops.ones([1, 2, 3]) |
| b = array_ops.ones([4, 5, 6]) |
| with self.assertRaisesRegexp( |
| ValueError, |
| r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) " |
| r"with input shapes: \[1,2,3\], \[4,5,6\]."): |
| _ = a + b |
| |
| |
| class IndexedSlicesTest(test_util.TensorFlowTestCase): |
| |
| def testToTensor(self): |
| with self.cached_session(): |
| values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) |
| indices = constant_op.constant([0, 2]) |
| dense_shape = constant_op.constant([3, 2]) |
| x = ops.IndexedSlices(values, indices, dense_shape) |
| tensor = ops.convert_to_tensor(x, name="tensor") |
| self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]]) |
| |
| def testNegation(self): |
| with self.cached_session(): |
| values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) |
| indices = constant_op.constant([0, 2]) |
| x = -ops.IndexedSlices(values, indices) |
| self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]]) |
| self.assertAllEqual(x.indices.eval(), [0, 2]) |
| |
| def testScalarMul(self): |
| with self.cached_session(): |
| values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) |
| indices = constant_op.constant([0, 2]) |
| x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) |
| self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]]) |
| self.assertAllEqual(x.indices.eval(), [0, 2]) |
| |
| |
| class NodeDefConstructorTest(test_util.TensorFlowTestCase): |
| |
| def testNoArgs(self): |
| nodedef = ops._NodeDef("None", "bar") |
| self.assertProtoEquals("op: 'None' name: 'bar'", nodedef) |
| |
| def testArgs(self): |
| nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") |
| self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", |
| nodedef) |
| nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j")) |
| self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) |
| |
| |
| def _apply_op(g, *args, **kwargs): |
| op = g.create_op(*args, **kwargs) |
| if len(op.outputs) == 1: |
| return op.outputs[0] |
| else: |
| return op.outputs |
| |
| |
| class OperationTest(test_util.TensorFlowTestCase): |
| |
| def testNoInputs(self): |
| op = test_ops.float_output_string_output(name="myop").a.op |
| self.assertEqual(2, len(op.values())) |
| self.assertEqual(0, len(op.inputs)) |
| self.assertEqual("myop", op.name) |
| |
| float_t, label_str_t = op.values() |
| self.assertEqual(dtypes.float32, float_t.dtype) |
| self.assertEqual(op, float_t.op) |
| self.assertEqual(0, float_t._value_index) |
| self.assertEqual(0, len(float_t.consumers())) |
| self.assertEqual("myop", float_t._as_node_def_input()) |
| |
| self.assertEqual(dtypes.string, label_str_t.dtype) |
| self.assertEqual(op, label_str_t.op) |
| self.assertEqual(1, label_str_t._value_index) |
| self.assertEqual(0, len(label_str_t.consumers())) |
| self.assertEqual("myop:1", label_str_t._as_node_def_input()) |
| |
| self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'", |
| op.node_def) |
| |
| def testNoOutputs(self): |
| op1 = test_ops.float_output(name="myop1").op |
| float_t, = op1.values() |
| op2 = test_ops.float_input(float_t, name="myop2") |
| self.assertEqual(0, len(op2.values())) |
| self.assertEqual(1, len(op2.inputs)) |
| self.assertIs(float_t, op2.inputs[0]) |
| |
| self.assertEqual(1, len(float_t.consumers())) |
| self.assertEqual(op2, float_t.consumers()[0]) |
| |
| self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def) |
| self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'", |
| op2.node_def) |
| |
| def testInputsAndOutputs(self): |
| op1 = test_ops.float_output(name="myop1").op |
| self.assertEqual(1, len(op1.values())) |
| float1_t, = op1.values() |
| |
| op2 = test_ops.float_output_string_output(name="myop2").a.op |
| self.assertEqual(2, len(op2.values())) |
| float2_t, label2_str_t = op2.values() |
| |
| # Note that we consume label2_str_t twice here. |
| op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op |
| self.assertEqual(2, len(op3.values())) |
| |
| self.assertEqual(1, len(float1_t.consumers())) |
| self.assertEqual(op3, float1_t.consumers()[0]) |
| |
| self.assertEqual(0, len(float2_t.consumers())) |
| |
| self.assertEqual(2, len(label2_str_t.consumers())) |
| self.assertEqual(op3, label2_str_t.consumers()[0]) |
| self.assertEqual(op3, label2_str_t.consumers()[1]) |
| |
| self.assertProtoEquals(""" |
| op:'Foo2' name:'myop3' |
| input:'myop1' input:'myop2:1' input:'myop2:1' |
| """, op3.node_def) |
| |
| def testDeviceObject(self): |
| op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], []) |
| op._set_device("/job:goo/device:GPU:0") |
| self.assertProtoEquals( |
| "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) |
| op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], []) |
| op._set_device( |
| pydev.DeviceSpec( |
| job="muu", device_type="CPU", device_index=0)) |
| self.assertProtoEquals( |
| "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) |
| |
| def testReferenceInput(self): |
| g = ops.Graph() |
| op1 = ops.Operation( |
| ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], |
| [dtypes.float32_ref, dtypes.float32]) |
| self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) |
| self.assertEquals([], list(op1.inputs)) |
| ref_t, nonref_t = op1.values() |
| # NOTE(mrry): Must specify input_types to preserve ref-typed input. |
| op2 = ops.Operation( |
| ops._NodeDef("RefInputFloatInput", "op2"), |
| g, [ref_t, nonref_t], [], |
| input_types=[dtypes.float32_ref, dtypes.float32]) |
| self.assertProtoEquals( |
| "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", |
| op2.node_def) |
| self.assertEquals([ref_t, nonref_t], list(op2.inputs)) |
| op3 = ops.Operation( |
| ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) |
| self.assertProtoEquals( |
| "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", |
| op3.node_def) |
| |
| def testInvalidNames(self): |
| g = ops.Graph() |
| with self.assertRaises(ValueError): |
| ops.Operation(ops._NodeDef("op", ""), g) |
| with self.assertRaises(ValueError): |
| ops.Operation(ops._NodeDef("op", "_invalid"), g) |
| with self.assertRaises(ValueError): |
| ops.Operation(ops._NodeDef("op", "-invalid"), g) |
| with self.assertRaises(ValueError): |
| ops.Operation(ops._NodeDef("op", "/invalid"), g) |
| with self.assertRaises(ValueError): |
| ops.Operation(ops._NodeDef("op", "invalid:0"), g) |
| |
| def testNoShapeFunction(self): |
| op = test_ops.a() |
| self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) |
| |
| def testConvertToTensorNestedArray(self): |
| with self.cached_session(): |
| values = [[2], [3], [5], [7]] |
| tensor = ops.convert_to_tensor(values) |
| self.assertAllEqual((4, 1), tensor.get_shape().as_list()) |
| self.assertAllEqual(values, tensor.eval()) |
| |
| def testShapeTuple(self): |
| with self.cached_session(): |
| c = constant_op.constant(1) |
| self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access |
| |
| def testConvertToTensorEager(self): |
| with context.eager_mode(): |
| t = constant_op.constant(1) |
| self.assertTrue(isinstance(t, ops.EagerTensor)) |
| converted = ops.convert_to_tensor(t) |
| self.assertTrue(isinstance(converted, ops.EagerTensor)) |
| converted = ops.convert_to_tensor(1) |
| self.assertTrue(isinstance(converted, ops.EagerTensor)) |
| |
| def testConvertToTensorNestedTuple(self): |
| with self.cached_session(): |
| values = ((2,), (3,), (5,), (7,)) |
| tensor = ops.convert_to_tensor(values) |
| self.assertAllEqual((4, 1), tensor.get_shape().as_list()) |
| self.assertAllEqual(values, ops.convert_to_tensor(values).eval()) |
| |
| def testConvertToTensorNestedTensors(self): |
| with self.cached_session(): |
| values = ((2,), (3,), (5,), (7,)) |
| tensor = ops.convert_to_tensor( |
| [constant_op.constant(row) for row in values]) |
| self.assertAllEqual((4, 1), tensor.get_shape().as_list()) |
| self.assertAllEqual(values, tensor.eval()) |
| tensor = ops.convert_to_tensor( |
| [[constant_op.constant(v) for v in row] for row in values]) |
| self.assertAllEqual((4, 1), tensor.get_shape().as_list()) |
| self.assertAllEqual(values, tensor.eval()) |
| |
| def testConvertToTensorNestedMix(self): |
| with self.cached_session(): |
| values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) |
| tensor = ops.convert_to_tensor(values) |
| self.assertAllEqual((4, 1), tensor.get_shape().as_list()) |
| self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval()) |
| |
| def testConvertToTensorPreferred(self): |
| with self.cached_session(): |
| values = [2, 3, 5, 7] |
| tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) |
| self.assertEqual(dtypes.float32, tensor.dtype) |
| |
| with self.cached_session(): |
| # Convert empty tensor to anything. |
| values = [] |
| tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) |
| self.assertEqual(dtypes.int64, tensor.dtype) |
| |
| with self.cached_session(): |
| # The preferred dtype is a type error and will convert to |
| # float32 instead. |
| values = [1.23] |
| tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) |
| self.assertEqual(dtypes.float32, tensor.dtype) |
| |
| def testConvertToInvalidTensorType(self): |
| with self.assertRaises(TypeError): |
| # Forcing an invalid dtype should fail with a type error. |
| values = [1.23] |
| _ = ops.convert_to_tensor(values, dtype=dtypes.int64) |
| |
| def testNoConvert(self): |
| # Operation cannot be converted to Tensor. |
| op = control_flow_ops.no_op() |
| with self.assertRaisesRegexp(TypeError, |
| r"Can't convert Operation '.*' to Tensor"): |
| ops.convert_to_tensor(op) |
| |
| def testStr(self): |
| node_def = ops._NodeDef("None", "op1") |
| op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32]) |
| self.assertEqual(str(node_def), str(op)) |
| |
| def testRepr(self): |
| op = ops.Operation( |
| ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]) |
| self.assertEqual("<tf.Operation 'op1' type=None>", repr(op)) |
| |
| def testGetAttr(self): |
| op = test_ops.default_attrs() |
| self.assertEqual(op.get_attr("string_val"), b"abc") |
| self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) |
| self.assertEqual(op.get_attr("int_val"), 123) |
| self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) |
| self.assertEqual(op.get_attr("float_val"), 10.0) |
| self.assertEqual(op.get_attr("float_list_val"), [10.0]) |
| self.assertEqual(op.get_attr("bool_val"), True) |
| self.assertEqual(op.get_attr("bool_list_val"), [True, False]) |
| self.assertEqual(op.get_attr("shape_val"), |
| tensor_shape.as_shape([2, 1]).as_proto()) |
| self.assertEqual(op.get_attr("shape_list_val"), |
| [tensor_shape.as_shape([]).as_proto(), |
| tensor_shape.as_shape([1]).as_proto()]) |
| self.assertEqual(op.get_attr("tensor_val"), |
| tensor_util.make_tensor_proto(1, dtypes.int32)) |
| self.assertEqual(op.get_attr("tensor_list_val"), |
| [tensor_util.make_tensor_proto(1, dtypes.int32)]) |
| |
| type_val = op.get_attr("type_val") |
| # First check that type_val is a DType, because the assertEquals will work |
| # no matter what since DType overrides __eq__ |
| self.assertIsInstance(type_val, dtypes.DType) |
| self.assertEqual(type_val, dtypes.int32) |
| |
| type_list_val = op.get_attr("type_list_val") |
| self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) |
| self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) |
| |
| @function.Defun(dtypes.float32, func_name="MyFunc") |
| def func(x): |
| return x |
| |
| op = test_ops.func_attr(func) |
| self.assertEqual(op.get_attr("f"), |
| attr_value_pb2.NameAttrList(name="MyFunc")) |
| |
| # Try fetching missing attr |
| with self.assertRaisesRegexp( |
| ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."): |
| op.get_attr("FakeAttr") |
| |
| # TODO(b/65162920): remove this test when users who are directly mutating the |
| # node_def have been updated to proper usage. |
| def testSetAttr(self): |
| op = test_ops.int_attr().op |
| op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) |
| # TODO(skyewm): add node_def check |
| self.assertEqual(op.get_attr("foo"), 2) |
| |
| # TODO(nolivia): test all error cases |
| def testAddControlInput(self): |
| with ops.Graph().as_default(): |
| x = constant_op.constant(1).op |
| y = constant_op.constant(2).op |
| z = constant_op.constant(3).op |
| z._add_control_input(x) # pylint: disable=protected-access |
| self.assertEqual(z.control_inputs, [x]) |
| z._add_control_input(x) # pylint: disable=protected-access |
| self.assertEqual(z.control_inputs, [x]) |
| z._add_control_inputs([x, y, y]) # pylint: disable=protected-access |
| self.assertEqual(z.control_inputs, [x, y]) |
| self.assertEqual(x._control_outputs, [z]) |
| |
| def testRemoveAllControlInputs(self): |
| a = constant_op.constant(1) |
| with ops.control_dependencies([a]): |
| b = constant_op.constant(2) |
| c = constant_op.constant(3) |
| d = constant_op.constant(4) |
| e = constant_op.constant(5) |
| with ops.control_dependencies([a, c]): |
| f = d + e |
| |
| self.assertEqual(a.op.control_inputs, []) |
| self.assertEqual(b.op.control_inputs, [a.op]) |
| self.assertEqual(f.op.control_inputs, [a.op, c.op]) |
| |
| a.op._remove_all_control_inputs() # pylint: disable=protected-access |
| self.assertEqual(a.op.control_inputs, []) |
| |
| b.op._remove_all_control_inputs() # pylint: disable=protected-access |
| self.assertEqual(b.op.control_inputs, []) |
| |
| f.op._remove_all_control_inputs() # pylint: disable=protected-access |
| self.assertEqual(f.op.control_inputs, []) |
| self.assertEqual(list(f.op.inputs), [d, e]) |
| |
| def testControlInputCycle(self): |
| graph = ops.Graph() |
| with graph.as_default(): |
| z = constant_op.constant(0) |
| x = constant_op.constant(1) |
| y = constant_op.constant(2) |
| y.op._add_control_input(z.op) # pylint: disable=protected-access |
| y.op._add_control_input(x.op) # pylint: disable=protected-access |
| x.op._add_control_input(y.op) # pylint: disable=protected-access |
| with self.session(graph=graph) as sess: |
| with self.assertRaisesRegexp( |
| errors.InvalidArgumentError, |
| "Graph is invalid, contains a cycle with 2 nodes"): |
| sess.run(x) |
| |
| def testUpdateInput(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = constant_op.constant(1) |
| y = constant_op.constant(2) |
| z = x + y |
| |
| z.op._update_input(0, y) # pylint: disable=protected-access |
| self.assertEquals(list(z.op.inputs), [y, y]) |
| self.assertEquals(x.consumers(), []) |
| self.assertEquals(y.consumers(), [z.op, z.op]) |
| with session.Session(graph=g) as sess: |
| self.assertEquals(sess.run(z), 4) |
| |
| z.op._update_input(0, x) # pylint: disable=protected-access |
| self.assertEquals(list(z.op.inputs), [x, y]) |
| self.assertEquals(x.consumers(), [z.op]) |
| self.assertEquals(y.consumers(), [z.op]) |
| with session.Session(graph=g) as sess: |
| self.assertEquals(sess.run(z), 3) |
| |
| z.op._update_input(1, y) # pylint: disable=protected-access |
| self.assertEquals(list(z.op.inputs), [x, y]) |
| self.assertEquals(x.consumers(), [z.op]) |
| self.assertEquals(y.consumers(), [z.op]) |
| with session.Session(graph=g) as sess: |
| self.assertEquals(sess.run(z), 3) |
| |
| def testUpdateInputGraphError(self): |
| g_0 = ops.Graph() |
| g_1 = ops.Graph() |
| with g_0.as_default(): |
| x = constant_op.constant(1) |
| with g_1.as_default(): |
| y = constant_op.constant(2) |
| z = y * 2 |
| with self.assertRaisesRegexp(ValueError, "must be from the same graph"): |
| z.op._update_input(0, x) # pylint: disable=protected-access |
| |
| def testUpdateInputTypeError(self): |
| g = ops.Graph() |
| with g.as_default(): |
| w = constant_op.constant(0) |
| x = constant_op.constant("") |
| y = constant_op.constant(1) |
| z = y + w |
| z.op._update_input(0, x) # pylint: disable=protected-access |
| with session.Session(graph=g) as sess: |
| with self.assertRaisesRegexp( |
| errors.InvalidArgumentError, |
| "Input 0 of node add was passed string from Const_1:0 incompatible " |
| "with expected int32"): |
| sess.run(z) |
| |
| def testUpdateInputShapeError(self): |
| g = ops.Graph() |
| with g.as_default(): |
| w = constant_op.constant(2, shape=[3, 1]) |
| x = constant_op.constant(0, shape=[3, 1]) |
| y = constant_op.constant(1, shape=[2, 2]) |
| z = w + x |
| with self.assertRaisesRegexp( |
| errors.InvalidArgumentError, |
| r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): |
| z.op._update_input(0, y) # pylint: disable=protected-access |
| |
| def testUpdateInputOutOfRange(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = constant_op.constant(1) |
| with self.assertRaisesRegexp( |
| errors.OutOfRangeError, |
| r"Cannot update edge. Input index \[1\] is greater than the number of " |
| r"total inputs \[0\]." |
| ): |
| x.op._update_input(1, x) # pylint: disable=protected-access |
| |
| def testOpDef(self): |
| x = constant_op.constant(0) |
| y = constant_op.constant(1) |
| z = x + y |
| |
| self.assertEqual(x.op.op_def.name, "Const") |
| self.assertEqual(len(x.op.op_def.input_arg), 0) |
| self.assertEqual(len(x.op.op_def.output_arg), 1) |
| |
| self.assertEqual(z.op.op_def.name, "Add") |
| self.assertEqual(len(z.op.op_def.input_arg), 2) |
| self.assertEqual(len(z.op.op_def.output_arg), 1) |
| |
| def testInputFromDifferentGraphError(self): |
| g_0 = ops.Graph() |
| g_1 = ops.Graph() |
| with g_0.as_default(): |
| x = constant_op.constant(1) |
| with g_1.as_default(): |
| y = constant_op.constant(2) |
| with self.assertRaisesRegexp(ValueError, "must be from the same graph"): |
| y * x # pylint: disable=pointless-statement |
| |
| def testInputsAreImmutable(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = test_ops.int_output() |
| op = test_ops.int_input_int_output(x, name="myop").op |
| with self.assertRaisesRegexp( |
| AttributeError, "'_InputList' object has no attribute 'append'"): |
| op.inputs.append(None) |
| |
| |
| class CreateOpTest(test_util.TensorFlowTestCase): |
| |
| def testNodeDefArgs(self): |
| g = ops.Graph() |
| op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") |
| with g.device("/device:GPU:0"): |
| op2 = g.create_op( |
| "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None, |
| name="myop2") |
| op3 = g.create_op( |
| "Foo3", |
| [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], |
| [dtypes.float32, dtypes.int32], |
| None, |
| name="myop3") |
| self.assertDeviceEqual(None, op1.device) |
| self.assertDeviceEqual("/device:GPU:0", op2.device) |
| self.assertDeviceEqual(None, op3.device) |
| self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def) |
| self.assertProtoEquals( |
| "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'", |
| op2.node_def) |
| self.assertProtoEquals( |
| "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'", |
| op3.node_def) |
| |
| def testReferenceInput(self): |
| g = ops.Graph() |
| op1 = g.create_op( |
| "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], |
| name="op1") |
| self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) |
| ref_t, nonref_t = op1.values() |
| # NOTE(mrry): Must specify input_types to preserve ref-typed input. |
| op2 = g.create_op( |
| "RefInputFloatInput", [ref_t, nonref_t], [], |
| input_types=[dtypes.float32_ref, dtypes.float32], |
| name="op2") |
| self.assertProtoEquals( |
| "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", |
| op2.node_def) |
| op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3") |
| self.assertProtoEquals( |
| "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", |
| op3.node_def) |
| |
| def testFinalized(self): |
| g = ops.Graph() |
| g.finalize() |
| with self.assertRaises(RuntimeError): |
| g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") |
| |
| # Test unfinalize. |
| g._unsafe_unfinalize() |
| g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") |
| |
| |
| # NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation |
| # method. Arguably we should only test the public APIs that depend on this |
| # method. However, this logic is complex and tricky, and it can be difficult to |
| # ascertain if we have adequate coverage (e.g. a graph may run successfully if |
| # the control flow context isn't set properly, but a more complicated use case |
| # that might not be obvious to test will fail). Thus we instead explicitly test |
| # the low-level behavior. |
| class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): |
| |
| def testBasic(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = test_ops.int_output() |
| c_op = ops._create_c_op( |
| g, ops._NodeDef("IntInputIntOutput", "myop"), [x], []) |
| op = g._create_op_from_tf_operation(c_op) |
| |
| self.assertEqual(op.name, "myop") |
| self.assertEqual(op.type, "IntInputIntOutput") |
| self.assertEqual(len(op.outputs), 1) |
| self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape()) |
| self.assertEqual(list(op.inputs), [x]) |
| self.assertEqual(op.control_inputs, []) |
| self.assertEqual(op.graph, g) |
| self.assertEqual(x.consumers(), [op]) |
| self.assertIsNotNone(op.traceback) |
| self.assertEqual(g.get_operation_by_name("myop"), op) |
| self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) |
| |
| @test_util.enable_c_shapes |
| def testShape(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) |
| c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], []) |
| op = g._create_op_from_tf_operation(c_op) |
| |
| self.assertEqual(op.name, "myop") |
| self.assertEqual(op.type, "Identity") |
| self.assertEqual(len(op.outputs), 1) |
| self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3)) |
| |
| def testUniqueName(self): |
| g = ops.Graph() |
| with g.as_default(): |
| c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) |
| c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) |
| op = g._create_op_from_tf_operation(c_op) |
| op2 = g._create_op_from_tf_operation(c_op2) |
| |
| # Create ops with same names as op1 and op2. We expect the new names to be |
| # uniquified. |
| op3 = test_ops.int_output(name="myop").op |
| op4 = test_ops.int_output(name="myop_1").op |
| |
| self.assertEqual(op.name, "myop") |
| self.assertEqual(op2.name, "myop_1") |
| self.assertEqual(op3.name, "myop_2") |
| self.assertEqual(op4.name, "myop_1_1") |
| |
| def testCond(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = test_ops.int_output() |
| |
| def true_fn(): |
| ops._create_c_op(ops.get_default_graph(), |
| ops._NodeDef("IntInput", "cond/myop"), [x], []) |
| new_ops = g._add_new_tf_operations() |
| self.assertEqual(len(new_ops), 1) |
| return x |
| |
| control_flow_ops.cond(x < 10, true_fn, lambda: x) |
| |
| op = g.get_operation_by_name("cond/myop") |
| self.assertIsNotNone(op) |
| self.assertEqual(op.name, "cond/myop") |
| self.assertEqual(op.type, "IntInput") |
| self.assertEqual(op.outputs, []) |
| op_input = op.inputs[0].op |
| self.assertEqual(op_input.type, "Switch") |
| self.assertEqual(op_input.inputs[0], x) |
| self.assertEqual(op.graph, g) |
| # pylint: disable=protected-access |
| self.assertIsNotNone(op._get_control_flow_context()) |
| self.assertEqual(op._get_control_flow_context().name, |
| "cond/cond_text") |
| # pylint: enable=protected-access |
| |
| def testWhileLoop(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = test_ops.int_output() |
| |
| def body(i): |
| ops._create_c_op(ops.get_default_graph(), |
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) |
| new_ops = g._add_new_tf_operations() |
| self.assertEqual(len(new_ops), 1) |
| return i |
| |
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") |
| |
| op = g.get_operation_by_name("myloop/myop") |
| self.assertIsNotNone(op) |
| self.assertEqual(op.name, "myloop/myop") |
| self.assertEqual(op.type, "IntInput") |
| self.assertEqual(op.outputs, []) |
| op_input = op.inputs[0].op |
| self.assertEqual(op_input.type, "Enter") |
| self.assertEqual(list(op_input.inputs), [x]) |
| self.assertEqual(op.graph, g) |
| # pylint: disable=protected-access |
| self.assertIsNotNone(op._get_control_flow_context()) |
| self.assertEqual(op._get_control_flow_context().name, |
| "myloop/while_context") |
| # pylint: enable=protected-access |
| |
| def testWhileLoopWithInternalControlDep(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = test_ops.int_output() |
| |
| def body(i): |
| c = constant_op.constant(1.0, name="c") |
| ops._create_c_op(ops.get_default_graph(), |
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) |
| with ops.control_dependencies([c]): |
| new_ops = g._add_new_tf_operations() |
| self.assertEqual(len(new_ops), 1) |
| return i |
| |
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") |
| |
| op = g.get_operation_by_name("myloop/myop") |
| self.assertIsNotNone(op) |
| c = g.get_operation_by_name("myloop/c") |
| self.assertIsNotNone(c) |
| # Internal control dep is preserved |
| self.assertEqual(op.control_inputs, [c]) |
| |
| def testWhileLoopWithExternalControlDep(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = test_ops.int_output() |
| c = constant_op.constant(1.0) |
| |
| def body(i): |
| ops._create_c_op(ops.get_default_graph(), |
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) |
| with ops.control_dependencies([c]): |
| new_ops = g._add_new_tf_operations() |
| self.assertEqual(len(new_ops), 1) |
| return i |
| |
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") |
| |
| op = g.get_operation_by_name("myloop/myop") |
| self.assertIsNotNone(op) |
| # External control dep is removed and replaced with internal control dep |
| self.assertNotEqual(op.control_inputs[0], c.op) |
| self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) |
| |
| |
| class ApplyOpTest(test_util.TensorFlowTestCase): |
| |
| def testNodeDefArgs(self): |
| g = ops.Graph() |
| t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") |
| with g.device("/device:GPU:0"): |
| t2 = _apply_op( |
| g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2") |
| t3 = _apply_op( |
| g, |
| "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], |
| name="myop3") |
| self.assertTrue(isinstance(t1, ops.Tensor)) |
| self.assertTrue(isinstance(t2, list)) |
| self.assertTrue(isinstance(t3, list)) |
| self.assertTrue(isinstance(t3[0], ops.Tensor)) |
| self.assertEqual("myop1", t1._as_node_def_input()) |
| self.assertEqual("myop2", t2[0]._as_node_def_input()) |
| self.assertEqual("myop2:1", t2[1]._as_node_def_input()) |
| self.assertEqual("myop3", t3[0]._as_node_def_input()) |
| # Validate that we got the right ops as well |
| self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def) |
| self.assertProtoEquals( |
| "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'", |
| t2[0].op.node_def) |
| self.assertProtoEquals( |
| "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'", |
| t3[0].op.node_def) |
| |
| def testReferenceInput(self): |
| g = ops.Graph() |
| ref_t, nonref_t = _apply_op( |
| g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], |
| name="op1") |
| self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", |
| ref_t.op.node_def) |
| # NOTE(mrry): Must specify input_types to preserve ref-typed input. |
| out_2 = _apply_op( |
| g, |
| "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32], |
| input_types=[dtypes.float32_ref, dtypes.float32], |
| name="op2") |
| self.assertProtoEquals( |
| "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'", |
| out_2.op.node_def) |
| out_3 = _apply_op( |
| g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32], |
| name="op3") |
| self.assertProtoEquals( |
| "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'", |
| out_3.op.node_def) |
| |
| |
| class NameStackTest(test_util.TensorFlowTestCase): |
| |
| def testBasics(self): |
| g = ops.Graph() |
| self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("foo", g.unique_name("foo")) |
| self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("foo_1", g.unique_name("foo")) |
| self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("foo_2", g.unique_name("foo")) |
| self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False)) |
| self.assertEqual("foo_1_1", g.unique_name("foo_1")) |
| self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False)) |
| self.assertEqual("foo_1_2", g.unique_name("foo_1")) |
| self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False)) |
| self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) |
| with g.name_scope("bar"): |
| self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("bar/foo", g.unique_name("foo")) |
| self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("bar/foo_1", g.unique_name("foo")) |
| with g.name_scope(None): |
| self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("foo_3", g.unique_name("foo")) |
| with g.name_scope("baz"): |
| self.assertEqual( |
| "bar/baz/foo", g.unique_name( |
| "foo", mark_as_used=False)) |
| self.assertEqual("bar/baz/foo", g.unique_name("foo")) |
| self.assertEqual( |
| "bar/baz/foo_1", g.unique_name( |
| "foo", mark_as_used=False)) |
| self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) |
| with g.name_scope("baz"): |
| self.assertEqual( |
| "bar/baz_1/foo", g.unique_name( |
| "foo", mark_as_used=False)) |
| self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) |
| self.assertEqual( |
| "bar/baz_1/foo_1", g.unique_name( |
| "foo", mark_as_used=False)) |
| self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) |
| with g.name_scope("quux"): |
| self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("quux/foo", g.unique_name("foo")) |
| with g.name_scope("bar"): |
| with g.name_scope("baz"): |
| self.assertEqual( |
| "bar_1/baz/foo", g.unique_name( |
| "foo", mark_as_used=False)) |
| self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) |
| self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False)) |
| self.assertEqual("foo_4", g.unique_name("foo")) |
| self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False)) |
| self.assertEqual("bar_2", g.unique_name("bar")) |
| |
| def testNameAndVariableScope(self): |
| with self.cached_session() as sess: |
| with sess.graph.name_scope("l0"): |
| with variable_scope.variable_scope("l1"): |
| with sess.graph.name_scope("l1") as scope: |
| self.assertEqual("l0/l1/l1/", scope) |
| self.assertEqual( |
| "l0/l1/l1/foo", |
| sess.graph.unique_name( |
| "foo", mark_as_used=False)) |
| self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo")) |
| with sess.graph.name_scope("l2") as scope: |
| self.assertEqual("l0/l1/l2/", scope) |
| self.assertEqual( |
| "l0/l1/l2/foo", |
| sess.graph.unique_name( |
| "foo", mark_as_used=False)) |
| self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo")) |
| |
| def testOutOfOrderUniqueName(self): |
| g = ops.Graph() |
| self.assertEqual("foo_2", g.unique_name("foo_2")) |
| self.assertEqual("foo", g.unique_name("foo")) |
| self.assertEqual("foo_1", g.unique_name("foo")) |
| self.assertEqual("foo_3", g.unique_name("foo")) |
| |
| def testUniqueNameCaseInsensitivity(self): |
| g = ops.Graph() |
| self.assertEqual("foo", g.unique_name("foo")) |
| self.assertEqual("Foo_1", g.unique_name("Foo")) |
| with g.name_scope("bar"): |
| self.assertEqual("bar/foo", g.unique_name("foo")) |
| with g.name_scope("Bar"): |
| self.assertEqual("Bar_1/foo", g.unique_name("foo")) |
| |
| def testInvalidNameRaisesError(self): |
| g = ops.Graph() |
| with g.name_scope(""): # Should not raise |
| pass |
| with g.name_scope("foo/"): # Should not raise |
| with g.name_scope("_bar"): # Should not raise |
| pass |
| with self.assertRaises(ValueError): |
| with g.name_scope("foo:0"): |
| pass |
| with self.assertRaises(ValueError): |
| with g.name_scope("_bar"): |
| pass |
| |
| |
| class NameTest(test_util.TensorFlowTestCase): |
| |
| def testGenerateName(self): |
| g = ops.Graph() |
| op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) |
| self.assertEqual("TwoFloatOutputs", op0.name) |
| self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name) |
| self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name) |
| |
| op1 = g.create_op("FloatOutput", [], [dtypes.float32]) |
| self.assertEqual("FloatOutput", op1.name) |
| self.assertEqual("FloatOutput:0", op1.outputs[0].name) |
| |
| op2 = g.create_op("FloatOutput", [], [dtypes.float32]) |
| self.assertEqual("FloatOutput_1", op2.name) |
| self.assertEqual("FloatOutput_1:0", op2.outputs[0].name) |
| |
| op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op") |
| self.assertEqual("my_op", op3.name) |
| self.assertEqual("my_op:0", op3.outputs[0].name) |
| |
| def testNameScope(self): |
| g = ops.Graph() |
| |
| with g.name_scope("foo") as foo: |
| self.assertEqual("foo/", foo) |
| with g.name_scope("foo2") as foo2: |
| self.assertEqual("foo/foo2/", foo2) |
| with g.name_scope(None) as empty1: |
| self.assertEqual("", empty1) |
| with g.name_scope("foo3") as foo3: |
| self.assertEqual("foo3/", foo3) |
| with g.name_scope("") as empty2: |
| self.assertEqual("", empty2) |
| |
| self.assertEqual("FloatOutput", |
| g.create_op("FloatOutput", [], [dtypes.float32]).name) |
| with g.name_scope("bar") as scope: |
| self.assertEqual("bar/FloatOutput", |
| g.create_op("FloatOutput", [], [dtypes.float32]).name) |
| self.assertEqual("bar/FloatOutput_1", |
| g.create_op("FloatOutput", [], [dtypes.float32]).name) |
| # If you use the value from "with .. as", that values is used as-is. |
| self.assertEqual( |
| "bar", g.create_op( |
| "FloatOutput", [], [dtypes.float32], name=scope).name) |
| with g.name_scope("baz") as scope: |
| with g.name_scope("quux"): |
| self.assertEqual("baz/quux/FloatOutput", |
| g.create_op("FloatOutput", [], [dtypes.float32]).name) |
| # If you use the value from the enclosing "with .. as", nothing is pushed. |
| with g.name_scope(scope): |
| self.assertEqual("baz/FloatOutput", |
| g.create_op("FloatOutput", [], [dtypes.float32]).name) |
| self.assertEqual( |
| "baz", g.create_op( |
| "FloatOutput", [], [dtypes.float32], name=scope).name) |
| self.assertEqual( |
| "trailing", |
| g.create_op( |
| "FloatOutput", [], [dtypes.float32], name="trailing/").name) |
| with g.name_scope("bar"): |
| self.assertEqual("bar_1/FloatOutput", |
| g.create_op("FloatOutput", [], [dtypes.float32]).name) |
| with g.name_scope("bar/"): |
| self.assertEqual("bar/FloatOutput_2", |
| g.create_op("FloatOutput", [], [dtypes.float32]).name) |
| |
| |
| class DeviceTest(test_util.TensorFlowTestCase): |
| |
| def testNoDevice(self): |
| g = ops.Graph() |
| op = g.create_op("FloatOutput", [], [dtypes.float32]) |
| self.assertDeviceEqual(None, op.device) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" } |
| """, gd) |
| |
| def testDevicePartialString(self): |
| g = ops.Graph() |
| with g.device("/job:worker/replica:2"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/job:worker/replica:2" } |
| """, gd) |
| |
| def testDeviceFull(self): |
| g = ops.Graph() |
| with g.device( |
| pydev.DeviceSpec( |
| job="worker", replica=2, task=0, device_type="CPU", |
| device_index=3)): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/job:worker/replica:2/task:0/device:CPU:3" } |
| """, gd) |
| |
| def testNesting(self): |
| g = ops.Graph() |
| with g.device("/job:worker/replica:2"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/job:worker/replica:3/task:0"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/job:worker/replica:2" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/job:worker/replica:3/task:0" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/replica:2" } |
| """, gd) |
| |
| def testNestingString(self): |
| g = ops.Graph() |
| with g.device("/job:worker/replica:2"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/job:worker/replica:3/task:0"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/job:worker/replica:2" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/job:worker/replica:3/task:0" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/replica:2" } |
| """, gd) |
| |
| def testNestingOverrideGpuCpu(self): |
| g = ops.Graph() |
| with g.device("/job:worker/replica:2/device:CPU:1"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/job:worker/replica:2/device:GPU:2"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/job:worker/replica:2/device:CPU:1" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/job:worker/replica:2/device:GPU:2" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/replica:2/device:CPU:1" } |
| """, gd) |
| |
| def testNestingWithMergeDeviceFunction(self): |
| g = ops.Graph() |
| |
| with g.device(pydev.merge_device("/device:GPU:0")): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(pydev.merge_device("/job:worker")): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(pydev.merge_device("/device:CPU:0")): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(pydev.merge_device("/job:ps")): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(pydev.merge_device(None)): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/device:GPU:0" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/job:worker/device:GPU:0" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/device:CPU:0" } |
| node { name: "FloatOutput_3" op: "FloatOutput" |
| device: "/job:ps/device:CPU:0" } |
| node { name: "FloatOutput_4" op: "FloatOutput" |
| device: "/job:ps/device:CPU:0" } |
| """, gd) |
| |
| def testNestingWithDeviceStrings(self): |
| g = ops.Graph() |
| |
| with g.device("/device:GPU:0"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/job:worker"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/device:CPU:0"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/job:ps"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(""): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/device:GPU:0" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/job:worker/device:GPU:0" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/device:CPU:0" } |
| node { name: "FloatOutput_3" op: "FloatOutput" |
| device: "/job:ps/device:CPU:0" } |
| node { name: "FloatOutput_4" op: "FloatOutput" |
| device: "/job:ps/device:CPU:0" } |
| """, gd) |
| |
| def testNestingWithDeviceStringWildcard(self): |
| g = ops.Graph() |
| |
| with g.device("/device:GPU:7"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/device:GPU:*"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| |
| with g.device("/device:CPU:*"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/device:CPU:5"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/device:GPU:7" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/device:GPU:7" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/device:CPU:*" } |
| node { name: "FloatOutput_3" op: "FloatOutput" |
| device: "/device:CPU:5" } |
| """, gd) |
| |
| def testNoneClearsDefault(self): |
| g = ops.Graph() |
| with g.device("/job:worker/replica:2/device:CPU:1"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(None): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/job:worker/replica:2/device:CPU:1" } |
| node { name: "FloatOutput_1" op: "FloatOutput" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/replica:2/device:CPU:1" } |
| """, gd) |
| |
| def testNoneIgnoresOuterDeviceFunction(self): |
| g = ops.Graph() |
| with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(None): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/job:worker/replica:2/device:CPU:1" } |
| node { name: "FloatOutput_1" op: "FloatOutput" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/replica:2/device:CPU:1" } |
| """, gd) |
| |
| def _overwritingDeviceFunction(self, unused_op): |
| # This device function unconditionally overwrites the device of ops. |
| # |
| # NOTE(mrry): Writing device functions like this is not |
| # recommended. Instead, in most cases you should use |
| # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the |
| # argument to `tf.device()` and the device component will be merged in. |
| return "/job:overwrite" |
| |
| def testOverwritingBehavior(self): |
| g = ops.Graph() |
| with g.device(self._overwritingDeviceFunction): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device("/job:ps"): # Will be overwritten. |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(None): # Disables overwriting device function |
| with g.device("/job:ps"): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| with g.device(None): # Disables overwriting device function |
| with g.device(pydev.merge_device("/job:ps")): |
| g.create_op("FloatOutput", [], [dtypes.float32]) |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput" op: "FloatOutput" |
| device: "/job:overwrite" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/job:overwrite" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:overwrite" } |
| node { name: "FloatOutput_3" op: "FloatOutput" |
| device: "/job:ps" } |
| node { name: "FloatOutput_4" op: "FloatOutput" |
| device: "/job:ps" } |
| """, gd) |
| |
| |
| class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): |
| |
| class TestThread(threading.Thread): |
| |
| def __init__(self, graph, replica_id): |
| super(MultithreadedGraphStateTest.TestThread, self).__init__() |
| self._graph = graph |
| self._replica_id = replica_id |
| # This thread sets this event when it mutated the graph. The caller can |
| # wait for that. |
| self.has_mutated_graph = threading.Event() |
| # This thread waits for when it should continue. The caller can set this |
| # event. |
| self.should_continue = threading.Event() |
| |
| def run(self): |
| # Mutate a graph's stack, then set `has_mutated_graph`, then wait for |
| # `should_continue`, then add an op to the graph affected by the graph's |
| # stack. |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def testDeviceFunctionStack(self): |
| |
| class DeviceSettingThread(self.TestThread): |
| |
| def run(self): |
| with g.device("/job:worker/replica:{}".format(self._replica_id)): |
| self.has_mutated_graph.set() |
| self.should_continue.wait() |
| self.should_continue.clear() |
| g.create_op( |
| "FloatOutput", [], [dtypes.float32], |
| name="FloatOutput_{}".format(self._replica_id)) |
| |
| g = ops.Graph() |
| # If `switch_to_thread` isn't called, then device placement of the ops |
| # below is not deterministic. |
| g.switch_to_thread_local() |
| threads = [DeviceSettingThread(g, i) for i in range(3)] |
| for t in threads: |
| t.start() |
| t.has_mutated_graph.wait() |
| t.has_mutated_graph.clear() |
| for t in threads: |
| t.should_continue.set() |
| t.join() |
| |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "FloatOutput_0" op: "FloatOutput" |
| device: "/job:worker/replica:0" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/job:worker/replica:1" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/replica:2" } |
| """, gd) |
| |
| def testColocateWith(self): |
| |
| class ColocatingThread(self.TestThread): |
| |
| def __init__(self, graph, replica_id, op_to_colocate_with): |
| super(ColocatingThread, self).__init__(graph, replica_id) |
| self._op_to_colocate_with = op_to_colocate_with |
| |
| def run(self): |
| with g.colocate_with(self._op_to_colocate_with): |
| self.has_mutated_graph.set() |
| self.should_continue.wait() |
| self.should_continue.clear() |
| g.create_op( |
| "FloatOutput", [], [dtypes.float32], |
| name="FloatOutput_{}".format(self._replica_id)) |
| |
| g = ops.Graph() |
| ops_to_colocate_with = [] |
| for i in range(3): |
| with g.device("/job:worker/replica:{}".format(i)): |
| ops_to_colocate_with.append( |
| g.create_op( |
| "FloatOutput", [], [dtypes.float32], |
| name="ColocateWithMe_{}".format(i))) |
| |
| # If `switch_to_thread` isn't called, then `device` and `attr` values for |
| # the ops below are not deterministic. |
| g.switch_to_thread_local() |
| threads = [ |
| ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3) |
| ] |
| for t in threads: |
| t.start() |
| t.has_mutated_graph.wait() |
| t.has_mutated_graph.clear() |
| for t in threads: |
| t.should_continue.set() |
| t.join() |
| |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "ColocateWithMe_0" op: "FloatOutput" |
| device: "/job:worker/replica:0" } |
| node { name: "ColocateWithMe_1" op: "FloatOutput" |
| device: "/job:worker/replica:1" } |
| node { name: "ColocateWithMe_2" op: "FloatOutput" |
| device: "/job:worker/replica:2" } |
| node { name: "FloatOutput_0" op: "FloatOutput" |
| device: "/job:worker/replica:0" |
| attr { key: "_class" |
| value { list { |
| s: "loc:@ColocateWithMe_0"}}}} |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| device: "/job:worker/replica:1" |
| attr { key: "_class" |
| value { list { |
| s: "loc:@ColocateWithMe_1"}}}} |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| device: "/job:worker/replica:2" |
| attr { key: "_class" |
| value { list { |
| s: "loc:@ColocateWithMe_2"}}}} |
| """, gd) |
| |
| def testControlDependencies(self): |
| |
| class DependingThread(self.TestThread): |
| |
| def __init__(self, graph, replica_id, dependency_op): |
| super(DependingThread, self).__init__(graph, replica_id) |
| self._dependency_op = dependency_op |
| |
| def run(self): |
| with g.control_dependencies([self._dependency_op]): |
| self.has_mutated_graph.set() |
| self.should_continue.wait() |
| self.should_continue.clear() |
| g.create_op( |
| "FloatOutput", [], [dtypes.float32], |
| name="FloatOutput_{}".format(self._replica_id)) |
| |
| g = ops.Graph() |
| dependency_ops = [] |
| for i in range(3): |
| dependency_ops.append( |
| g.create_op( |
| "FloatOutput", [], [dtypes.float32], |
| name="ColocateWithMe_{}".format(i))) |
| |
| # If `switch_to_thread` isn't called, then `input` values for the ops below |
| # are not deterministic. |
| g.switch_to_thread_local() |
| threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)] |
| for t in threads: |
| t.start() |
| t.has_mutated_graph.wait() |
| t.has_mutated_graph.clear() |
| for t in threads: |
| t.should_continue.set() |
| t.join() |
| |
| gd = g.as_graph_def() |
| self.assertProtoEqualsVersion(""" |
| node { name: "ColocateWithMe_0" op: "FloatOutput" } |
| node { name: "ColocateWithMe_1" op: "FloatOutput" } |
| node { name: "ColocateWithMe_2" op: "FloatOutput" } |
| node { name: "FloatOutput_0" op: "FloatOutput" |
| input: "^ColocateWithMe_0" } |
| node { name: "FloatOutput_1" op: "FloatOutput" |
| input: "^ColocateWithMe_1" } |
| node { name: "FloatOutput_2" op: "FloatOutput" |
| input: "^ColocateWithMe_2" } |
| """, gd) |
| |
| def testNameStack(self): |
| |
| class NameSettingThread(self.TestThread): |
| |
| def run(self): |
| with g.name_scope("foo"): |
| op1 = g.create_op("FloatOutput", [], [dtypes.float32]) |
| self.has_mutated_graph.set() |
| self.should_continue.wait() |
| self.should_continue.clear() |
| op2 = g.create_op("FloatOutput", [], [dtypes.float32]) |
| self.result = (op1, op2) |
| |
| g = ops.Graph() |
| threads = [NameSettingThread(g, i) for i in range(3)] |
| for t in threads: |
| t.start() |
| t.has_mutated_graph.wait() |
| t.has_mutated_graph.clear() |
| |
| for t in threads: |
| t.should_continue.set() |
| t.join() |
| |
| suffixes = ["", "_1", "_2"] |
| for t, s in zip(threads, suffixes): |
| self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name) |
| self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name) |
| |
| |
| class ObjectWithName(object): |
| |
| def __init__(self, name): |
| self._name = name |
| |
| @property |
| def name(self): |
| return self._name |
| |
| |
| class CollectionTest(test_util.TensorFlowTestCase): |
| |
| def test_get_collections(self): |
| g = ops.Graph() |
| self.assertSequenceEqual(g.collections, []) |
| g.add_to_collection("key", 12) |
| g.add_to_collection("key", 15) |
| self.assertSequenceEqual(g.collections, ["key"]) |
| g.add_to_collection("other", "foo") |
| self.assertSequenceEqual(sorted(g.collections), ["key", "other"]) |
| |
| def test_add_to_collection(self): |
| g = ops.Graph() |
| g.add_to_collection("key", 12) |
| g.add_to_collection("other", "foo") |
| g.add_to_collection("key", 34) |
| |
| # Note that only blank1 is returned. |
| g.add_to_collection("blah", 27) |
| blank1 = ObjectWithName("prefix/foo") |
| g.add_to_collection("blah", blank1) |
| blank2 = ObjectWithName("junk/foo") |
| g.add_to_collection("blah", blank2) |
| |
| self.assertEqual([12, 34], g.get_collection("key")) |
| self.assertEqual([], g.get_collection("nothing")) |
| self.assertEqual([27, blank1, blank2], g.get_collection("blah")) |
| self.assertEqual([blank1], g.get_collection("blah", "prefix")) |
| self.assertEqual([blank1], g.get_collection("blah", ".*x")) |
| |
| # Make sure that get_collection() returns a first-level |
| # copy of the collection, while get_collection_ref() returns |
| # the original list. |
| other_collection_snapshot = g.get_collection("other") |
| other_collection_ref = g.get_collection_ref("other") |
| self.assertEqual(["foo"], other_collection_snapshot) |
| self.assertEqual(["foo"], other_collection_ref) |
| g.add_to_collection("other", "bar") |
| self.assertEqual(["foo"], other_collection_snapshot) |
| self.assertEqual(["foo", "bar"], other_collection_ref) |
| self.assertEqual(["foo", "bar"], g.get_collection("other")) |
| self.assertTrue(other_collection_ref is g.get_collection_ref("other")) |
| |
| # Verify that getting an empty collection ref returns a modifiable list. |
| empty_coll_ref = g.get_collection_ref("empty") |
| self.assertEqual([], empty_coll_ref) |
| empty_coll = g.get_collection("empty") |
| self.assertEqual([], empty_coll) |
| self.assertFalse(empty_coll is empty_coll_ref) |
| empty_coll_ref2 = g.get_collection_ref("empty") |
| self.assertTrue(empty_coll_ref2 is empty_coll_ref) |
| # Add to the collection. |
| empty_coll_ref.append("something") |
| self.assertEqual(["something"], empty_coll_ref) |
| self.assertEqual(["something"], empty_coll_ref2) |
| self.assertEqual([], empty_coll) |
| self.assertEqual(["something"], g.get_collection("empty")) |
| empty_coll_ref3 = g.get_collection_ref("empty") |
| self.assertTrue(empty_coll_ref3 is empty_coll_ref) |
| |
| def test_add_to_collections_uniquify(self): |
| g = ops.Graph() |
| g.add_to_collections([1, 2, 1], "key") |
| # Make sure "key" is not added twice |
| self.assertEqual(["key"], g.get_collection(1)) |
| |
| def test_add_to_collections_from_list(self): |
| g = ops.Graph() |
| g.add_to_collections(["abc", "123"], "key") |
| self.assertEqual(["key"], g.get_collection("abc")) |
| self.assertEqual(["key"], g.get_collection("123")) |
| |
| def test_add_to_collections_from_tuple(self): |
| g = ops.Graph() |
| g.add_to_collections(("abc", "123"), "key") |
| self.assertEqual(["key"], g.get_collection("abc")) |
| self.assertEqual(["key"], g.get_collection("123")) |
| |
| def test_add_to_collections_from_generator(self): |
| g = ops.Graph() |
| |
| def generator(): |
| yield "abc" |
| yield "123" |
| |
| g.add_to_collections(generator(), "key") |
| self.assertEqual(["key"], g.get_collection("abc")) |
| self.assertEqual(["key"], g.get_collection("123")) |
| |
| def test_add_to_collections_from_set(self): |
| g = ops.Graph() |
| g.add_to_collections(set(["abc", "123"]), "key") |
| self.assertEqual(["key"], g.get_collection("abc")) |
| self.assertEqual(["key"], g.get_collection("123")) |
| |
| def test_add_to_collections_from_string(self): |
| g = ops.Graph() |
| g.add_to_collections("abc", "key") |
| self.assertEqual(["key"], g.get_collection("abc")) |
| |
| def test_default_graph(self): |
| with ops.Graph().as_default(): |
| ops.add_to_collection("key", 90) |
| ops.add_to_collection("key", 100) |
| # Collections are ordered. |
| self.assertEqual([90, 100], ops.get_collection("key")) |
| |
| def test_defun(self): |
| with context.eager_mode(): |
| |
| @eager_function.defun |
| def defun(): |
| ops.add_to_collection("int", 1) |
| ops.add_to_collection("tensor", constant_op.constant(2)) |
| |
| @eager_function.defun |
| def inner_defun(): |
| self.assertEqual(ops.get_collection("int"), [1]) |
| three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] |
| ops.add_to_collection("int", 2) |
| self.assertEqual(ops.get_collection("int"), [1, 2]) |
| ops.add_to_collection("foo", "bar") |
| self.assertEqual(ops.get_collection("foo"), ["bar"]) |
| return three |
| |
| self.assertEqual(ops.get_collection("int"), [1]) |
| three = inner_defun() |
| self.assertEqual(ops.get_collection("int"), [1, 2]) |
| self.assertEqual(ops.get_collection("foo"), ["bar"]) |
| return three |
| |
| three = defun() |
| self.assertEqual(three.numpy(), 3) |
| |
| |
| ops.NotDifferentiable("FloatOutput") |
| |
| |
| @ops.RegisterGradient("CopyOp") |
| def _CopyGrad(op, x_grad): # pylint: disable=invalid-name |
| _ = op |
| return x_grad |
| |
| |
| @ops.RegisterGradient("copy_override") |
| def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name |
| _ = op |
| return x_grad |
| |
| |
| class RegistrationTest(test_util.TensorFlowTestCase): |
| |
| def testRegisterGradients(self): |
| x = test_ops.float_output() |
| y = test_ops.copy_op(x) |
| fn = ops.get_gradient_function(y.op) |
| self.assertEqual(_CopyGrad, fn) |
| |
| def testOverrideGradients(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = test_ops.float_output() |
| with g.gradient_override_map({"CopyOp": "copy_override"}): |
| y = test_ops.copy_op(x) |
| fn = ops.get_gradient_function(y.op) |
| self.assertEqual(_CopyOverrideGrad, fn) |
| |
| def testNonExistentOverride(self): |
| g = ops.Graph() |
| with g.as_default(): |
| x = test_ops.float_output() |
| with g.gradient_override_map({"CopyOp": "unknown_override"}): |
| y = test_ops.copy_op(x) |
| with self.assertRaisesRegexp(LookupError, "unknown_override"): |
| ops.get_gradient_function(y.op) |
| |
| |
| class ComparisonTest(test_util.TensorFlowTestCase): |
| |
| def testMembershipAllowed(self): |
| g = ops.Graph() |
| t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") |
| t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") |
| self.assertTrue(isinstance(t1, ops.Tensor)) |
| self.assertTrue(isinstance(t2, ops.Tensor)) |
| self.assertTrue(t1 in [t1]) |
| self.assertTrue(t1 not in [t2]) |
| |
| |
| class ControlDependenciesTest(test_util.TensorFlowTestCase): |
| |
| def testBasic(self): |
| g = ops.Graph() |
| with g.as_default(): |
| # Creating unregistered ops with _apply_op() doesn't work with the C API |
| # TODO(skyewm): address this more consistently. Possible solutions are |
| # to use registered ops in all tests, create a way to register ops in |
| # Python tests, or conditionally disable the op registration check in |
| # the C API. |
| a = constant_op.constant(1.0) |
| b = constant_op.constant(1.0) |
| with g.control_dependencies([a]): |
| c = constant_op.constant(1.0) |
| d = array_ops.identity(b) |
| e = array_ops.identity(c) |
| |
| self.assertEqual(c.op.control_inputs, [a.op]) |
| self.assertEqual(d.op.control_inputs, [a.op]) |
| # e should be dominated by c. |
| self.assertEqual(e.op.control_inputs, []) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testEager(self): |
| def future(): |
| future.calls += 1 |
| return constant_op.constant(2.0) |
| future.calls = 0 |
| |
| if context.executing_eagerly(): |
| a = constant_op.constant(1.0) |
| b = future |
| with ops.control_dependencies([a, b]): |
| c = constant_op.constant(3.0) |
| self.assertEqual(future.calls, 1) |
| else: |
| g = ops.Graph() |
| with g.as_default(): |
| a = constant_op.constant(1.0) |
| b = future() |
| with g.control_dependencies([a, b]): |
| c = constant_op.constant(3.0) |
| self.assertEqual(c.op.control_inputs, [a.op, b.op]) |
| self.assertEqual(future.calls, 1) |
| |
| def testBasicWithConversion(self): |
| g = ops.Graph() |
| a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| class ConvertibleObj(object): |
| |
| def _as_graph_element(self): |
| return a |
| |
| with g.control_dependencies([ConvertibleObj()]): |
| c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| self.assertEqual(c.op.control_inputs, [a.op]) |
| |
| def testNested(self): |
| g = ops.Graph() |
| a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| with g.control_dependencies([a_1, a_2, a_3, a_4]): |
| b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| with g.control_dependencies([a_1]): |
| with g.control_dependencies([a_2]): |
| with g.control_dependencies([a_3]): |
| with g.control_dependencies([a_4]): |
| b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], |
| b_1.op.control_inputs) |
| self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) |
| |
| def testClear(self): |
| g = ops.Graph() |
| a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| with g.control_dependencies([a_1]): |
| with g.control_dependencies([a_2]): |
| with g.control_dependencies(None): |
| with g.control_dependencies([a_3]): |
| with g.control_dependencies([a_4]): |
| # deps [a_3, a_4] |
| b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| # deps = [a_3] |
| b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| # deps back to None |
| b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| # deps back to [a_1, a_2] |
| b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| # deps back to [a_1] |
| b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| with g.control_dependencies(None): |
| # deps are None again |
| b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) |
| self.assertItemsEqual([a_3.op], b_3.op.control_inputs) |
| self.assertItemsEqual([], b_none.op.control_inputs) |
| self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) |
| self.assertItemsEqual([a_1.op], b_1.op.control_inputs) |
| self.assertItemsEqual([], b_none2.op.control_inputs) |
| |
| def testComplex(self): |
| g = ops.Graph() |
| |
| # Usage pattern: |
| # * Nodes a_i are constants defined at the outermost scope, and are used |
| # as control inputs for the ith nested scope. |
| # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. |
| # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. |
| # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. |
| # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. |
| |
| a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| with g.control_dependencies([a_1]): |
| b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], |
| [dtypes.float32]) |
| c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], |
| [dtypes.float32]) |
| d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], |
| [dtypes.float32]) |
| e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| with g.control_dependencies([a_2]): |
| b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], |
| [dtypes.float32]) |
| c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], |
| [dtypes.float32]) |
| d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], |
| [dtypes.float32]) |
| e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], |
| [dtypes.float32]) |
| with g.control_dependencies([a_3]): |
| b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], |
| [dtypes.float32]) |
| c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], |
| [dtypes.float32]) |
| d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], |
| [dtypes.float32]) |
| e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], |
| [dtypes.float32]) |
| with g.control_dependencies([a_4]): |
| b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], |
| [dtypes.float32]) |
| c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], |
| [dtypes.float32]) |
| d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], |
| [dtypes.float32]) |
| e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], |
| [dtypes.float32]) |
| |
| self.assertItemsEqual([a_1.op], b_1.op.control_inputs) |
| self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) |
| self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) |
| self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) |
| |
| self.assertItemsEqual([], c_1.op.control_inputs) |
| self.assertItemsEqual([a_2.op], c_2.op.control_inputs) |
| self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) |
| self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) |
| |
| self.assertItemsEqual([], d_1.op.control_inputs) |
| self.assertItemsEqual([], d_2.op.control_inputs) |
| self.assertItemsEqual([], d_3.op.control_inputs) |
| self.assertItemsEqual([], d_4.op.control_inputs) |
| |
| self.assertItemsEqual([a_1.op], e_1.op.control_inputs) |
| self.assertItemsEqual([a_2.op], e_2.op.control_inputs) |
| self.assertItemsEqual([a_3.op], e_3.op.control_inputs) |
| self.assertItemsEqual([a_4.op], e_4.op.control_inputs) |
| |
| def testRepeatedDependency(self): |
| g = ops.Graph() |
| a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) |
| a_0, a_1 = a.outputs |
| with g.control_dependencies([a_0]): |
| b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| with g.control_dependencies([a_1]): |
| c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| self.assertEqual(b.op.control_inputs, [a]) |
| self.assertEqual(c.op.control_inputs, [a]) |
| |
| def testNoControlDependencyWithDataDependency(self): |
| g = ops.Graph() |
| a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| with g.control_dependencies([a]): |
| b = _apply_op(g, "Identity", [a], [dtypes.float32]) |
| |
| self.assertEqual(b.op.control_inputs, []) |
| |
| |
| class OpScopeTest(test_util.TensorFlowTestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testNames(self): |
| with ops.name_scope("foo") as foo: |
| self.assertEqual("foo/", foo) |
| with ops.name_scope("foo2") as foo2: |
| self.assertEqual("foo/foo2/", foo2) |
| with ops.name_scope(None) as empty1: |
| self.assertEqual("", empty1) |
| with ops.name_scope("foo3") as foo3: |
| self.assertEqual("foo3/", foo3) |
| with ops.name_scope("") as empty2: |
| self.assertEqual("", empty2) |
| with ops.name_scope("foo/") as outer_foo: |
| self.assertEqual("foo/", outer_foo) |
| with ops.name_scope("") as empty3: |
| self.assertEqual("", empty3) |
| with ops.name_scope("foo4") as foo4: |
| self.assertEqual("foo/foo4/", foo4) |
| with ops.name_scope("foo5//") as foo5: |
| self.assertEqual("foo5//", foo5) |
| with ops.name_scope("foo6") as foo6: |
| self.assertEqual("foo5//foo6/", foo6) |
| with ops.name_scope("/") as foo7: |
| self.assertEqual("/", foo7) |
| with ops.name_scope("//") as foo8: |
| self.assertEqual("//", foo8) |
| with ops.name_scope("a//b/c") as foo9: |
| self.assertEqual("foo/a//b/c/", foo9) |
| with ops.name_scope("a//b/c") as foo10: |
| self.assertEqual("a//b/c/", foo10) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testEagerDefaultScopeName(self): |
| with ops.name_scope(None, "default") as scope: |
| self.assertEqual(scope, "default/") |
| with ops.name_scope(None, "default2") as scope2: |
| self.assertEqual(scope2, "default/default2/") |
| |
| def testNoScopeName(self): |
| g0 = ops.Graph() |
| values = [ |
| g0.create_op("A", [], [dtypes.float32]), |
| g0.create_op("B", [], [dtypes.float32]) |
| ] |
| with self.assertRaises(ValueError): |
| with ops.name_scope(None, values=values): |
| pass |
| with self.assertRaises(ValueError): |
| with ops.name_scope(None, None, values): |
| pass |
| |
| def testEmptyScopeName(self): |
| g0 = ops.Graph() |
| a = g0.create_op("A", [], [dtypes.float32]) |
| b = g0.create_op("B", [], [dtypes.float32]) |
| with ops.name_scope("", values=[a, b]) as scope: |
| self.assertEqual("", scope) |
| self.assertEqual(g0, ops.get_default_graph()) |
| with ops.name_scope("", "my_default_scope", [a, b]) as scope: |
| self.assertEqual("", scope) |
| self.assertEqual(g0, ops.get_default_graph()) |
| |
| def testDefaultScopeName(self): |
| g0 = ops.Graph() |
| a = g0.create_op("A", [], [dtypes.float32]) |
| b = g0.create_op("B", [], [dtypes.float32]) |
| scope_name = "my_scope" |
| default_scope_name = "my_default_scope" |
| with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope: |
| self.assertEqual("%s/" % scope_name, scope) |
| self.assertEqual(g0, ops.get_default_graph()) |
| with ops.name_scope(None, default_scope_name, [a, b]) as scope: |
| self.assertEqual("%s/" % default_scope_name, scope) |
| self.assertEqual(g0, ops.get_default_graph()) |
| |
| def _testGraphElements(self, graph_elements): |
| scope_name = "my_scope" |
| with ops.name_scope(scope_name, values=graph_elements) as scope: |
| self.assertEqual("%s/" % scope_name, scope) |
| self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) |
| g1 = ops.Graph() |
| a = g1.create_op("A", [], [dtypes.float32]) |
| with self.assertRaises(ValueError): |
| with ops.name_scope(scope_name, values=graph_elements + [a]): |
| pass |
| |
| def testTensor(self): |
| g0 = ops.Graph() |
| a = g0.create_op("A", [], [dtypes.float32]) |
| b = g0.create_op("B", [], [dtypes.float32]) |
| self._testGraphElements([a, b]) |
| |
| def testSparseTensor(self): |
| g0 = ops.Graph() |
| a = g0.create_op("A", [], [dtypes.float32]) |
| b = g0.create_op("B", [], [dtypes.float32]) |
| sparse = sparse_tensor.SparseTensor( |
| _apply_op(g0, "Int64Output", [], [dtypes.int64]), |
| _apply_op(g0, "FloatOutput", [], [dtypes.float32]), |
| _apply_op(g0, "Int64Output", [], [dtypes.int64])) |
| self._testGraphElements([a, sparse, b]) |
| |
| def testVariable(self): |
| g0 = ops.Graph() |
| with g0.as_default(): |
| variable = variables.Variable([1.0]) |
| a = g0.create_op("A", [], [dtypes.float32]) |
| b = g0.create_op("B", [], [dtypes.float32]) |
| self._testGraphElements([a, variable, b]) |
| |
| |
| class InitScopeTest(test_util.TensorFlowTestCase): |
| |
| def testClearsControlDependencies(self): |
| g = ops.Graph() |
| a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| with g.as_default(): |
| with g.control_dependencies([a_1]): |
| with g.control_dependencies([a_2]): |
| with ops.init_scope(): |
| with g.control_dependencies([a_3]): |
| with g.control_dependencies([a_4]): |
| # deps [a_3, a_4] |
| b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| # deps = [a_3] |
| b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| # deps back to None |
| b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| # deps back to [a_1, a_2] |
| b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| # deps back to [a_1] |
| b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| with ops.init_scope(): |
| # deps are None again |
| b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| |
| self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) |
| self.assertItemsEqual([a_3.op], b_3.op.control_inputs) |
| self.assertItemsEqual([], b_none.op.control_inputs) |
| self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) |
| self.assertItemsEqual([a_1.op], b_1.op.control_inputs) |
| self.assertItemsEqual([], b_none2.op.control_inputs) |
| |
| def testLiftsOpsFromFunctions(self): |
| g0 = ops.Graph() |
| g1 = ops.Graph() |
| g1._building_function = True # pylint: disable=protected-access |
| g2 = ops.Graph() |
| g2._building_function = True # pylint: disable=protected-access |
| |
| with g0.as_default(): |
| with g1.as_default(): |
| with g2.as_default(): |
| with ops.init_scope(): |
| _ = constant_op.constant(1.0) |
| |
| self.assertEqual(len(g2.get_operations()), 0) |
| self.assertEqual(len(g1.get_operations()), 0) |
| self.assertEqual(len(g0.get_operations()), 1) |
| |
| def testPreservesDevices(self): |
| g0 = ops.Graph() |
| with g0.as_default(), ops.device("CPU:0"): |
| g1 = ops.Graph() |
| g1._building_function = True # pylint: disable=protected-access |
| with g1.as_default(), ops.device("GPU:0"): |
| with ops.init_scope(): |
| # init_scope should preserve device set under `g1`. |
| on_gpu = constant_op.constant(1.0) |
| self.assertEqual(on_gpu.device, "/device:GPU:0") |
| still_on_gpu = constant_op.constant(1.0) |
| self.assertEqual(still_on_gpu.device, "/device:GPU:0") |
| on_cpu = constant_op.constant(1.0) |
| self.assertEqual(on_cpu.device, "/device:CPU:0") |
| |
| def testComposes(self): |
| g0 = ops.Graph() |
| g1 = ops.Graph() |
| g1._building_function = True # pylint: disable=protected-access |
| g2 = ops.Graph() |
| g2._building_function = True # pylint: disable=protected-access |
| g3 = ops.Graph() |
| g3._building_function = False # pylint: disable=protected-access |
| |
| with g0.as_default(): |
| with g1.as_default(): |
| with ops.init_scope(): |
| # This op should be lifted into g0. |
| _ = constant_op.constant(1.0) |
| self.assertIs(g0, ops.get_default_graph()) |
| self.assertEqual(len(g2.get_operations()), 0) |
| self.assertEqual(len(g1.get_operations()), 0) |
| self.assertEqual(len(g0.get_operations()), 1) |
| with g2.as_default(): |
| with ops.init_scope(): |
| # This op should be lifted into g0. |
| _ = constant_op.constant(1.0) |
| self.assertIs(g0, ops.get_default_graph()) |
| with g3.as_default(): |
| with ops.init_scope(): |
| # This op should be lifted into g3, because g3 is not building a |
| # function. |
| _ = constant_op.constant(1.0) |
| self.assertIs(g3, ops.get_default_graph()) |
| |
| self.assertEqual(len(g3.get_operations()), 1) |
| self.assertEqual(len(g2.get_operations()), 0) |
| self.assertEqual(len(g1.get_operations()), 0) |
| self.assertEqual(len(g0.get_operations()), 2) |
| |
| def testEscapesToEagerContext(self): |
| g = ops.Graph() |
| g._building_function = True # pylint: disable=protected-access |
| with context.eager_mode(): |
| with context.graph_mode(): |
| with g.as_default(): |
| with ops.init_scope(): |
| # Because g is building a function, init_scope should |
| # escape out to the eager context. |
| self.assertTrue(context.executing_eagerly()) |
| # g should be reinstated as the default graph, and the |
| # graph context should be re-entered. |
| self.assertIs(g, ops.get_default_graph()) |
| self.assertFalse(context.executing_eagerly()) |
| |
| def testStaysInEagerWhenOnlyEagerContextActive(self): |
| with context.eager_mode(): |
| with ops.init_scope(): |
| self.assertTrue(context.eager_mode()) |
| self.assertTrue(context.eager_mode()) |
| |
| def testEscapesDefunWhenInEagerMode(self): |
| |
| def function_with_variables(): |
| with ops.init_scope(): |
| self.v = resource_variable_ops.ResourceVariable(3) |
| return self.v.assign_add(1) |
| |
| with context.eager_mode(): |
| # Each invocation of function_with_variables recreates a variable. |
| self.assertEqual(4, int(function_with_variables())) |
| self.assertEqual(4, int(function_with_variables())) |
| |
| compiled = eager_function.defun(function_with_variables) |
| # The init_scope in function_with_variables lifts the variable out |
| # of the graph function constructed by defun; hence, |
| # compiled now appears to be stateful. |
| self.assertEqual(4, int(compiled())) |
| self.assertEqual(5, int(compiled())) |
| |
| def testEscapesDefunWhenInGraphMode(self): |
| def function_with_variables(name): |
| with ops.init_scope(): |
| _ = variable_scope.get_variable(name, shape=(1,)) |
| |
| g = ops.Graph() |
| with g.as_default(): |
| with self.cached_session(): |
| # First ensure that graphs that are not building functions are |
| # not escaped. |
| function_with_variables("foo") |
| with self.assertRaisesRegexp(ValueError, |
| r"Variable foo already exists.*"): |
| # This will fail because reuse is not set to True. |
| function_with_variables("foo") |
| |
| compiled = eager_function.defun(function_with_variables) |
| compiled("bar") |
| self.assertEqual( |
| len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) |
| |
| # The second call to `compiled` should not create variables: the |
| # init_scope has lifted the variable creation code out of the defun. |
| compiled("bar") |
| self.assertEqual( |
| len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) |
| |
| def testEscapesNestedDefun(self): |
| |
| def inner_function(): |
| with ops.init_scope(): |
| self.v = resource_variable_ops.ResourceVariable(1) |
| return self.v.assign_add(2) |
| |
| def outer_function(inner=None): |
| with ops.init_scope(): |
| self.v0 = resource_variable_ops.ResourceVariable(0) |
| return self.v0.assign_add(1) + inner() |
| |
| with context.eager_mode(): |
| # Each invocation of outer_function recreates variables. |
| self.assertEqual(4, int(outer_function(inner=inner_function))) |
| self.assertEqual(4, int(outer_function(inner=inner_function))) |
| |
| compiled_inner = eager_function.defun(inner_function) |
| compiled_outer = eager_function.defun(outer_function) |
| # The init_scope lifts variables out of the graph functions |
| # constructed by defun; hence, compiled_outer should now appear to be |
| # stateful. |
| self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) |
| self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) |
| |
| def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self): |
| with context.graph_mode(): |
| ops.reset_default_graph() |
| # This doesn't push anything onto the graph stack, but it does |
| # set the stack's global graph. |
| global_graph = ops.get_default_graph() |
| fn_graph = ops.Graph() |
| |
| # pylint: disable=protected-access |
| fn_graph._building_function = True |
| self.assertEqual(len(ops._default_graph_stack.stack), 0) |
| with fn_graph.as_default(): |
| self.assertEqual(len(ops._default_graph_stack.stack), 1) |
| with ops.init_scope(): |
| self.assertGreater(len(ops._default_graph_stack.stack), 1) |
| dummy = constant_op.constant(1.0) |
| self.assertEqual(len(ops._default_graph_stack.stack), 1) |
| # Note that the global graph is _not_ on the graph stack. |
| self.assertEqual(len(ops._default_graph_stack.stack), 0) |
| # Ensure that `dummy` was added to the global graph. |
| self.assertEqual(global_graph, dummy.graph) |
| # pylint: enable=protected-access |
| |
| def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): |
| with context.graph_mode(): |
| # pylint: disable=protected-access |
| self.assertEqual(len(ops._default_graph_stack.stack), 0) |
| with ops.init_scope(): |
| self.assertGreater(len(ops._default_graph_stack.stack), 0) |
| self.assertEqual(len(ops._default_graph_stack.stack), 0) |
| # pylint: enable=protected-access |
| |
| def testPreservesNameScopeInGraphConstruction(self): |
| with ops.Graph().as_default(): |
| function_graph = ops.Graph() |
| with function_graph.as_default(): |
| with ops.name_scope("inner"), ops.init_scope(): |
| self.assertEqual(ops.get_name_scope(), "inner") |
| self.assertEqual(ops.get_name_scope(), "") |
| |
| def testEnteringGraphFromEagerIsSticky(self): |
| with context.eager_mode(): |
| g = ops.Graph() |
| with g.as_default(): |
| with ops.init_scope(): |
| self.assertFalse(context.executing_eagerly()) |
| self.assertEqual(g, ops.get_default_graph()) |
| |
| def testMixGraphEager(self): |
| with context.eager_mode(): |
| c = constant_op.constant(1.0) |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegexp( |
| RuntimeError, "Attempting to capture an EagerTensor"): |
| math_ops.add(c, c) |
| c2 = constant_op.constant(2.0) |
| with self.assertRaisesRegexp( |
| TypeError, "contains objects other than 'EagerTensor'"): |
| math_ops.add(c2, c2) |
| |
| def testPreservesNameScopeInEagerExecution(self): |
| with context.eager_mode(): |
| def foo(): |
| with ops.name_scope("inner"), ops.init_scope(): |
| if context.executing_eagerly(): |
| # A trailing slash is always appended when eager execution is |
| # enabled. |
| self.assertEqual(context.context().scope_name, "inner/") |
| else: |
| self.assertEqual(ops.get_name_scope(), "inner") |
| |
| foo() |
| self.assertEqual(ops.get_name_scope(), "") |
| foo_compiled = eager_function.defun(foo) |
| foo_compiled() |
| self.assertEqual(ops.get_name_scope(), "") |
| |
| |
| class GraphTest(test_util.TensorFlowTestCase): |
| |
| def setUp(self): |
| ops.reset_default_graph() |
| |
| def _AssertDefault(self, expected): |
| self.assertIs(expected, ops.get_default_graph()) |
| |
| def testResetDefaultGraphNesting(self): |
| g0 = ops.Graph() |
| with self.assertRaises(AssertionError): |
| with g0.as_default(): |
| ops.reset_default_graph() |
| |
| def testGraphContextManagerCancelsEager(self): |
| with context.eager_mode(): |
| with ops.Graph().as_default(): |
| self.assertFalse(context.executing_eagerly()) |
| |
| def testGraphContextManager(self): |
| g0 = ops.Graph() |
| with g0.as_default() as g1: |
| self.assertIs(g0, g1) |
| |
| def testDefaultGraph(self): |
| orig = ops.get_default_graph() |
| self._AssertDefault(orig) |
| g0 = ops.Graph() |
| self._AssertDefault(orig) |
| context_manager_0 = g0.as_default() |
| self._AssertDefault(orig) |
| with context_manager_0 as g0: |
| self._AssertDefault(g0) |
| with ops.Graph().as_default() as g1: |
| self._AssertDefault(g1) |
| self._AssertDefault(g0) |
| self._AssertDefault(orig) |
| |
| def testPreventFeeding(self): |
| g = ops.Graph() |
| a = constant_op.constant(2.0) |
| self.assertTrue(g.is_feedable(a)) |
| g.prevent_feeding(a) |
| self.assertFalse(g.is_feedable(a)) |
| |
| def testPreventFetching(self): |
| g = ops.Graph() |
| a = constant_op.constant(2.0) |
| self.assertTrue(g.is_fetchable(a)) |
| g.prevent_fetching(a.op) |
| self.assertFalse(g.is_fetchable(a)) |
| |
| def testAsGraphElementConversions(self): |
| |
| class ConvertibleObj(object): |
| |
| def _as_graph_element(self): |
| return "FloatOutput:0" |
| |
| class NonConvertibleObj(object): |
| |
| pass |
| |
| g = ops.Graph() |
| a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) |
| self.assertEqual(a, g.as_graph_element(ConvertibleObj())) |
| with self.assertRaises(TypeError): |
| g.as_graph_element(NonConvertibleObj()) |
| |
| # Regression test against creating custom __del__ functions in classes |
| # involved in cyclic references, e.g. Graph and Operation. (Python won't gc |
| # cycles that require calling a __del__ method, because the __del__ method can |
| # theoretically increase the object's refcount to "save" it from gc, and any |
| # already-deleted objects in the cycle would have be to restored.) |
| def testGarbageCollected(self): |
| # Create a graph we can delete and a weak reference to monitor if it's gc'd |
| g = ops.Graph() |
| g_ref = weakref.ref(g) |
| # Create some ops |
| with g.as_default(): |
| a = constant_op.constant(2.0) |
| b = constant_op.constant(3.0) |
| c = math_ops.add(a, b) |
| # Create a session we can delete |
| with session.Session(graph=g) as sess: |
| sess.run(c) |
| # Delete all references and trigger gc |
| del g |
| del a |
| del b |
| del c |
| del sess |
| gc.collect() |
| self.assertIsNone(g_ref()) |
| |
| def testRunnableAfterInvalidShape(self): |
| with ops.Graph().as_default(): |
| with self.assertRaises(ValueError): |
| math_ops.add([1, 2], [1, 2, 3]) |
| a = constant_op.constant(1) |
| with session.Session() as sess: |
| sess.run(a) |
| |
| def testRunnableAfterInvalidShapeWithKernelLabelMap(self): |
| g = ops.Graph() |
| with g.as_default(): |
| with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): |
| with self.assertRaises(ValueError): |
| test_ops.kernel_label_required(1) |
| a = constant_op.constant(1) |
| with session.Session() as sess: |
| sess.run(a) |
| |
| |
| class AttrScopeTest(test_util.TensorFlowTestCase): |
| |
| def _get_test_attrs(self): |
| x = control_flow_ops.no_op() |
| try: |
| a = compat.as_text(x.get_attr("_A")) |
| except ValueError: |
| a = None |
| try: |
| b = compat.as_text(x.get_attr("_B")) |
| except ValueError: |
| b = None |
| return (a, b) |
| |
| def testNoLabel(self): |
| with self.cached_session(): |
| self.assertAllEqual((None, None), self._get_test_attrs()) |
| |
| def testLabelMap(self): |
| with self.cached_session() as sess: |
| a1 = self._get_test_attrs() |
| with sess.graph._attr_scope({ |
| "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) |
| }): |
| a2 = self._get_test_attrs() |
| with sess.graph._attr_scope({ |
| "_A": None, |
| "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar")) |
| }): |
| a3 = self._get_test_attrs() |
| with sess.graph._attr_scope({ |
| "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz")) |
| }): |
| a4 = self._get_test_attrs() |
| a5 = self._get_test_attrs() |
| a6 = self._get_test_attrs() |
| a7 = self._get_test_attrs() |
| |
| self.assertAllEqual((None, None), a1) |
| self.assertAllEqual(("foo", None), a2) |
| self.assertAllEqual((None, "bar"), a3) |
| self.assertAllEqual(("baz", "bar"), a4) |
| self.assertAllEqual((None, "bar"), a5) |
| self.assertAllEqual(("foo", None), a6) |
| self.assertAllEqual((None, None), a7) |
| |
| |
| ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) |
| |
| |
| class KernelLabelTest(test_util.TensorFlowTestCase): |
| |
| def testNoLabel(self): |
| with self.cached_session(): |
| self.assertAllEqual(b"My label is: default", |
| test_ops.kernel_label().eval()) |
| |
| def testLabelMap(self): |
| with self.cached_session() as sess: |
| default_1 = test_ops.kernel_label() |
| # pylint: disable=protected-access |
| with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): |
| overload_1_1 = test_ops.kernel_label() |
| with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): |
| overload_2 = test_ops.kernel_label() |
| with sess.graph._kernel_label_map({"KernelLabel": ""}): |
| default_2 = test_ops.kernel_label() |
| overload_1_2 = test_ops.kernel_label() |
| # pylint: enable=protected-access |
| default_3 = test_ops.kernel_label() |
| |
| self.assertAllEqual(b"My label is: default", default_1.eval()) |
| self.assertAllEqual(b"My label is: default", default_2.eval()) |
| self.assertAllEqual(b"My label is: default", default_3.eval()) |
| self.assertAllEqual(b"My label is: overload_1", overload_1_1.eval()) |
| self.assertAllEqual(b"My label is: overload_1", overload_1_2.eval()) |
| self.assertAllEqual(b"My label is: overload_2", overload_2.eval()) |
| |
| |
| class AsGraphDefTest(test_util.TensorFlowTestCase): |
| |
| def testGraphDefVersion(self): |
| """Test that the graphdef version is plumbed through to kernels.""" |
| with ops.Graph().as_default() as g: |
| version = g.graph_def_versions.producer |
| with self.session(graph=g): |
| v = test_ops.graph_def_version().eval() |
| self.assertEqual(version, v) |
| |
| def testAddShapes(self): |
| with ops.Graph().as_default() as g: |
| t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [], |
| [dtypes.float32] * 5) |
| t1.set_shape(None) |
| t2.set_shape([]) |
| t3.set_shape([None]) |
| t4.set_shape([43, 37]) |
| t5.set_shape([43, None]) |
| |
| b = constant_op.constant(1.0) # pylint: disable=unused-variable |
| |
| gd = g.as_graph_def(add_shapes=True) |
| self.assertProtoEqualsVersion(""" |
| node { name: "FiveFloatOutputs" op: "FiveFloatOutputs" |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { unknown_rank: true } |
| shape { } |
| shape { dim { size: -1 } } |
| shape { dim { size: 43 } dim { size: 37 } } |
| shape { dim { size: 43 } dim { size: -1 } } |
| } |
| } |
| } |
| } |
| node { name: "Const" op: "Const" |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { } |
| } |
| } |
| } |
| attr { |
| key: "dtype" |
| value { type: DT_FLOAT } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_FLOAT |
| tensor_shape { } |
| float_val: 1.0 } } } } |
| """, gd) |
| |
| |
| @ops.RegisterStatistics("a", "flops") |
| def _calc_a_forward_flops(unused_graph, unused_node): |
| return ops.OpStats("flops", 20) |
| |
| |
| class StatisticsTest(test_util.TensorFlowTestCase): |
| |
| def testRegisteredNode(self): |
| graph = ops.Graph() |
| node = ops._NodeDef("a", "an_a") |
| flops = ops.get_stats_for_node_def(graph, node, "flops") |
| self.assertEqual(20, flops.value) |
| missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") |
| self.assertEqual(None, missing_stat.value) |
| |
| def testUnregisteredNode(self): |
| graph = ops.Graph() |
| node = ops._NodeDef("b", "a_b") |
| weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") |
| self.assertEqual(None, weight_params.value) |
| |
| def testAccumulateStatistics(self): |
| flops_total = ops.OpStats("flops") |
| self.assertEqual(None, flops_total.value) |
| second_flops = ops.OpStats("flops", 3) |
| flops_total += second_flops |
| self.assertEqual(3, flops_total.value) |
| |
| |
| class DeviceStackTest(test_util.TensorFlowTestCase): |
| |
| def testBasicDeviceAssignmentMetadata(self): |
| |
| def device_func(unused_op): |
| return "/cpu:*" |
| |
| const_zero = constant_op.constant([0.0], name="zero") |
| with ops.device("/cpu"): |
| const_one = constant_op.constant([1.0], name="one") |
| with ops.device("/cpu:0"): |
| const_two = constant_op.constant([2.0], name="two") |
| with ops.device(device_func): |
| const_three = constant_op.constant(3.0, name="three") |
| |
| self.assertEqual(0, len(const_zero.op._device_assignments)) |
| |
| one_list = const_one.op._device_assignments |
| self.assertEqual(1, len(one_list)) |
| self.assertEqual("/cpu", one_list[0].obj) |
| self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) |
| |
| two_list = const_two.op._device_assignments |
| self.assertEqual(2, len(two_list)) |
| devices = [t.obj for t in two_list] |
| self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) |
| |
| three_list = const_three.op._device_assignments |
| self.assertEqual(1, len(three_list)) |
| func_description = three_list[0].obj |
| expected_regex = r"device_func<.*ops_test.py, [0-9]+" |
| self.assertRegexpMatches(func_description, expected_regex) |
| |
| def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): |
| |
| with ops.device("/cpu"): |
| const_one = constant_op.constant([1.0], name="one") |
| with ops.get_default_graph().device("/cpu"): |
| const_two = constant_op.constant([2.0], name="two") |
| |
| one_metadata = const_one.op._device_assignments[0] |
| two_metadata = const_two.op._device_assignments[0] |
| |
| # Verify both types of device assignment return the right stack info. |
| self.assertRegexpMatches("ops_test.py", |
| os.path.basename(one_metadata.filename)) |
| self.assertEqual(one_metadata.filename, two_metadata.filename) |
| self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) |
| |
| |
| class ColocationGroupTest(test_util.TensorFlowTestCase): |
| |
| def testBasic(self): |
| a = constant_op.constant([2.0], name="a") |
| with ops.colocate_with(a.op): |
| b = constant_op.constant(3.0) |
| c = constant_op.constant(4.0) |
| self.assertEqual([b"loc:@a"], a.op.colocation_groups()) |
| self.assertEqual([b"loc:@a"], b.op.colocation_groups()) |
| with self.assertRaises(ValueError): |
| c.op.get_attr("_class") |
| |
| def testBasicColocationMetadata(self): |
| const_two = constant_op.constant([2.0], name="two") |
| with ops.colocate_with(const_two.op): |
| const_three = constant_op.constant(3.0, name="three") |
| locations_dict = const_three.op._colocation_dict |
| self.assertIn("two", locations_dict) |
| metadata = locations_dict["two"] |
| self.assertIsNone(metadata.obj) |
| # Check that this test's filename is recorded as the file containing the |
| # colocation statement. |
| self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) |
| |
| def testColocationDeviceInteraction(self): |
| with ops.device("/cpu:0"): |
| with ops.device("/device:GPU:0"): |
| a = constant_op.constant([2.0], name="a") |
| with ops.colocate_with(a.op): |
| # 'b' is created in the scope of /cpu:0, but it is |
| # colocated with 'a', which is on '/device:GPU:0'. colocate_with |
| # overrides devices because it is a stronger constraint. |
| b = constant_op.constant(3.0) |
| self.assertEqual([b"loc:@a"], b.op.colocation_groups()) |
| self.assertEqual(a.op.device, b.op.device) |
| |
| def testColocationCanonicalization(self): |
| with ops.device("/device:GPU:0"): |
| _ = constant_op.constant(2.0) |
| with ops.device(lambda op: "/device:GPU:0"): |
| b = constant_op.constant(3.0) |
| with ops.get_default_graph().colocate_with(b): |
| with ops.device("/device:GPU:0"): |
| c = constant_op.constant(4.0) |
| |
| # A's device will be /device:GPU:0 |
| # B's device will be /device:GPU:0 |
| # C's device will be /device:GPU:0 because it |
| # inherits B's device name, after canonicalizing the names. |
| self.assertEqual(b.op.device, c.op.device) |
| |
| def testLocationOverrides(self): |
| with ops.device("/cpu:0"): |
| with ops.device("/device:GPU:0"): |
| a = constant_op.constant([2.0], name="a") |
| # Note that this colocation is "redundant", since we are |
| # within the scope of "/device:GPU:0". However, we would like to |
| # preserve in the GraphDef that these two ops should be |
| # colocated in a portable way. |
| with ops.colocate_with(a.op): |
| b = constant_op.constant(3.0) |
| c = constant_op.constant(4.0) |
| d = constant_op.constant(5.0) |
| |
| self.assertEqual([b"loc:@a"], b.op.colocation_groups()) |
| self.assertEqual("/device:GPU:0", a.op.device) |
| self.assertEqual(a.op.device, b.op.device) |
| |
| # Test that device function stack is restored. |
| self.assertEqual("/device:GPU:0", c.op.device) |
| self.assertEqual("/device:CPU:0", d.op.device) |
| |
| def testNestedColocateWith(self): |
| a = constant_op.constant([2.0], name="a") |
| with ops.colocate_with(a.op): |
| b = constant_op.constant(3.0) |
| with ops.colocate_with(b.op): |
| c = constant_op.constant(4.0) |
| self.assertEqual([b"loc:@a"], b.op.colocation_groups()) |
| self.assertEqual([b"loc:@a"], c.op.colocation_groups()) |
| |
| def testMultiColocationGroups(self): |
| a = constant_op.constant([2.0], name="a") |
| b = constant_op.constant(3.0, name="b") |
| with ops.colocate_with(a.op): |
| with ops.colocate_with(b.op): |
| c = constant_op.constant(4.0) |
| self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) |
| |
| def testColocationIgnoreStack(self): |
| a = constant_op.constant([2.0], name="a") |
| b = constant_op.constant(3.0, name="b") |
| with ops.colocate_with(a.op): |
| with ops.colocate_with(b.op, ignore_existing=True): |
| c = constant_op.constant(4.0) |
| self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) |
| |
| def testColocateWithReset(self): |
| a = constant_op.constant([2.0], name="a") |
| with ops.colocate_with(a.op): |
| b = constant_op.constant(3.0, name="b") |
| with ops.colocate_with(None, ignore_existing=True): |
| c = constant_op.constant(4.0, name="c") |
| self.assertEqual([b"loc:@a"], b.op.colocation_groups()) |
| self.assertEqual([b"loc:@c"], c.op.colocation_groups()) |
| |
| def testColocateWithInitialNoneThenNested(self): |
| a = constant_op.constant([2.0], name="a") |
| with ops.colocate_with(a.op): |
| with ops.colocate_with(None, ignore_existing=True): |
| b = constant_op.constant(3.0, name="b") |
| with ops.colocate_with(b.op): |
| c = constant_op.constant(4.0, name="c") |
| self.assertEqual([b"loc:@b"], b.op.colocation_groups()) |
| self.assertEqual([b"loc:@b"], c.op.colocation_groups()) |
| |
| def testColocateVariables(self): |
| a = variables.Variable([2.0], name="a") |
| with ops.colocate_with(a.op): |
| b = variables.Variable([3.0], name="b") |
| self.assertEqual([b"loc:@a"], b.op.colocation_groups()) |
| |
| def testInconsistentDeviceWithinColocate(self): |
| with ops.device("/device:GPU:0"): |
| a = constant_op.constant([2.0], name="a") |
| with ops.colocate_with(a.op): |
| # This is allowed due to legacy but clearly wrong, since we |
| # should really be colocating with 'a'. We allow devices to |
| # override colocate_with, but we log warnings to suggest that |
| # this is probably unintentional or misguided. |
| with ops.device("/cpu:0"): |
| b = constant_op.constant([3.0], name="b") |
| |
| self.assertEqual("/device:CPU:0", b.device) |
| |
| def testMakeColocationConflictMessage(self): |
| """Test that provides an example of a complicated error message.""" |
| # We could test the message with any ops, but this test will be more |
| # instructive with a real colocation conflict. |
| with ops.device("/device:GPU:0"): |
| a = constant_op.constant([2.0], name="a") |
| with ops.colocate_with(a.op): |
| with ops.device("/cpu:0"): |
| b = constant_op.constant([3.0], name="b") |
| # The definition-location of the nodes will be wrong because of running |
| # from within a TF unittest. The rest of the info should be correct. |
| message = ops.get_default_graph()._make_colocation_conflict_message(a.op, |
| b.op) |
| self.assertRegexpMatches(message, |
| r"Tried to colocate op 'a' \(defined at.*\)") |
| self.assertRegexpMatches(message, "No node-device.*'a'") |
| self.assertRegexpMatches(message, "Device assignments active.*'a'") |
| self.assertRegexpMatches(message, "GPU:0") |
| self.assertRegexpMatches(message, "Node-device colocations active.*'b'") |
| self.assertRegexpMatches(message, "Device assignments active.*'b'") |
| self.assertRegexpMatches(message, "cpu:0") |
| |
| |
| class DeprecatedTest(test_util.TensorFlowTestCase): |
| |
| def testSuccess(self): |
| with ops.Graph().as_default() as g: |
| test_util.set_producer_version(g, 7) |
| old = test_ops.old() |
| with self.session(graph=g): |
| old.run() |
| |
| def _error(self): |
| return ((r"Op Old is not available in GraphDef version %d\. " |
| r"It has been removed in version 8\. For reasons\.") % |
| versions.GRAPH_DEF_VERSION) |
| |
| def testGraphConstructionFail(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegexp(NotImplementedError, self._error()): |
| test_ops.old() |
| |
| |
| class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase): |
| |
| def testSuccess(self): |
| op = ops.Operation( |
| ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) |
| t = op.outputs[0] |
| self.assertTrue(ops.is_dense_tensor_like(t)) |
| |
| v = variables.Variable([17]) |
| self.assertTrue(ops.is_dense_tensor_like(v)) |
| |
| class BadClassNoName(object): |
| pass |
| |
| class BadClassBadName(object): |
| |
| def name(self): |
| pass |
| |
| class BadClassNoDtype(object): |
| |
| @property |
| def name(self): |
| pass |
| |
| class BadClassBadDtype(object): |
| |
| @property |
| def name(self): |
| pass |
| |
| def dtype(self): |
| pass |
| |
| def testBadClass(self): |
| with self.assertRaisesRegexp(TypeError, "`name`"): |
| ops.register_dense_tensor_like_type( |
| DenseTensorLikeTypeTest.BadClassNoName) |
| with self.assertRaisesRegexp(TypeError, "`name`"): |
| ops.register_dense_tensor_like_type( |
| DenseTensorLikeTypeTest.BadClassBadName) |
| with self.assertRaisesRegexp(TypeError, "`dtype`"): |
| ops.register_dense_tensor_like_type( |
| DenseTensorLikeTypeTest.BadClassNoDtype) |
| with self.assertRaisesRegexp(TypeError, "`dtype`"): |
| ops.register_dense_tensor_like_type( |
| DenseTensorLikeTypeTest.BadClassBadDtype) |
| |
| |
| class NameScopeTest(test_util.TensorFlowTestCase): |
| |
| def testStripAndPrependScope(self): |
| strs = [ |
| "hidden1/hidden1/weights", # Same prefix. Should strip. |
| "hidden1///hidden1/weights", # Extra "/". Should strip. |
| "^hidden1/hidden1/weights", # Same prefix. Should strip. |
| "loc:@hidden1/hidden1/weights", # Same prefix. Should strip. |
| "hhidden1/hidden1/weights", # Different prefix. Should keep. |
| "hidden1" |
| ] # Not a prefix. Should keep. |
| expected_striped = [ |
| "hidden1/weights", "hidden1/weights", "^hidden1/weights", |
| "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1" |
| ] |
| expected_prepended = [ |
| "hidden2/hidden1/weights", "hidden2/hidden1/weights", |
| "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights", |
| "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1" |
| ] |
| name_scope_to_strip = "hidden1" |
| name_scope_to_add = "hidden2" |
| for es, ep, s in zip(expected_striped, expected_prepended, strs): |
| striped = ops.strip_name_scope(s, name_scope_to_strip) |
| self.assertEqual(es, striped) |
| self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add)) |
| |
| def testGetNameScope(self): |
| with ops.Graph().as_default() as g: |
| with ops.name_scope("scope1"): |
| with ops.name_scope("scope2"): |
| with ops.name_scope("scope3"): |
| self.assertEqual("scope1/scope2/scope3", g.get_name_scope()) |
| self.assertEqual("scope1/scope2", g.get_name_scope()) |
| self.assertEqual("scope1", g.get_name_scope()) |
| self.assertEqual("", g.get_name_scope()) |
| |
| def testTwoGraphs(self): |
| |
| def f(): |
| g1 = ops.Graph() |
| g2 = ops.Graph() |
| with g1.as_default(): |
| with g2.as_default(): |
| with ops.name_scope("_"): |
| pass |
| |
| self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f) |
| |
| |
| class TracebackTest(test_util.TensorFlowTestCase): |
| |
| def testTracebackWithStartLines(self): |
| with self.cached_session() as sess: |
| a = constant_op.constant(2.0) |
| sess.run( |
| a, |
| options=config_pb2.RunOptions( |
| trace_level=config_pb2.RunOptions.FULL_TRACE)) |
| self.assertTrue(sess.graph.get_operations()) |
| |
| # Tests that traceback_with_start_lines is the same as traceback |
| # but includes one more element at the end. |
| for op in sess.graph.get_operations(): |
| self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines)) |
| for frame, frame_with_start_line in zip( |
| op.traceback, op.traceback_with_start_lines): |
| self.assertEquals(5, len(frame_with_start_line)) |
| self.assertEquals(frame, frame_with_start_line[:-1]) |
| |
| |
| class EnableEagerExecutionTest(test_util.TensorFlowTestCase): |
| |
| def testBadArgumentsToEnableEagerExecution(self): |
| with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"): |
| ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) |
| with self.assertRaisesRegexp(ValueError, "device_policy must be one of"): |
| c = config_pb2.ConfigProto() |
| ops.enable_eager_execution(c, c) |
| with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"): |
| c = config_pb2.ConfigProto() |
| ops.enable_eager_execution(c, execution_mode=c) |
| |
| |
| if __name__ == "__main__": |
| googletest.main() |