blob: a1767fb15f04fd95c3dc1daedaffd04393717c8e [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import weakref
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ContextTest(test.TestCase):
def testSetGlobalSeed(self):
c = context.Context()
c._set_global_seed(123)
for t in [np.int32, np.int64, np.uint32, np.uint64]:
c._set_global_seed(t(123))
c._set_global_seed(np.array(123, dtype=t))
c._set_global_seed(ops.convert_to_tensor(123, dtype=t))
def testContextIsDestroyedAfterTensors(self):
# Create a new context
new_context = context.Context()
weak_c = weakref.ref(new_context)
new_context.ensure_initialized()
# Create a tensor with the new context as default.
# Make sure to restore the original context.
original_context = context.context()
try:
context._set_context(new_context)
# Use a 2D tensor so that it is not cached.
tensor1 = constant_op.constant([[3.]])
# Produce a tensor as an operation output. This uses a different code path
# from tensors created from Python.
tensor2 = tensor1 * tensor1
context._set_context(original_context)
except:
context._set_context(original_context)
raise
# Deleting our context reference should not delete the underlying object.
del new_context
self.assertIsNot(weak_c(), None)
# Deleting the first tensor should not delete the context since there is
# another tensor.
del tensor1
self.assertIsNot(weak_c(), None)
# Deleting the last tensor should result in deleting its context.
del tensor2
self.assertIs(weak_c(), None)
def testSimpleGraphCollection(self):
@def_function.function
def f(x):
with ops.device('CPU:0'):
return x + constant_op.constant(1.)
with context.collect_graphs() as graphs:
with ops.device('CPU:0'):
x = constant_op.constant(1.)
f(x)
self.assertLen(graphs, 1)
graph, = graphs
self.assertIn('CPU:0', graph.node[1].device)
@test_util.disable_tfrt(
'b/171600738: tfrt does not support exporting post-optimization graph')
def testGraphCollectionAfterDevicePlacement(self):
@def_function.function
def f(x):
return x + constant_op.constant(1.)
with context.collect_graphs() as graphs:
with ops.device('CPU:0'):
f(constant_op.constant(1.))
self.assertLen(graphs, 1)
graph, = graphs
self.assertIn('CPU:0', graph.node[0].device)
def testGetFunctionDef(self):
@def_function.function
def f():
return constant_op.constant(1.)
concrete = f.get_concrete_function()
function_def = context.get_function_def(concrete.name)
self.assertIsNot(function_def, None)
found_const_node = False
for node_def in function_def.node_def:
if node_def.op == 'Const':
found_const_node = True
break
self.assertTrue(found_const_node)
with self.assertRaises(errors.NotFoundError):
_ = context.get_function_def('this_should_not_be_found')
@test_util.run_gpu_only
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryInfo(self):
array_ops.zeros([10]) # Allocate some memory on the GPU.
self.assertGreater(context.context().get_memory_info('GPU:0')['current'], 0)
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryInfoCPU(self):
# MklCPUAllocator (mkl_cpu_allocator.h) does not throw exception.
# So skip the test.
# TODO(gzmkl) work with Google team to address design issue in allocator.h
if not test_util.IsMklEnabled():
with self.assertRaisesRegex(ValueError, 'Allocator stats not available'):
context.context().get_memory_info('CPU:0')
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryInfoUnknownDevice(self):
with self.assertRaisesRegex(ValueError, 'No matching devices found'):
context.context().get_memory_info('unknown_device:0')
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryInfoUnparsableDevice(self):
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
context.context().get_memory_info('GPU')
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
context.context().get_memory_info('GPU:')
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
context.context().get_memory_info('GPU:CPU')
def testListFunctionNames(self):
@def_function.function
def f():
return constant_op.constant(1.)
concrete = f.get_concrete_function()
self.assertIn(concrete.name.decode(),
context.context().list_function_names())
def testSetLogicalDeviceAfterContextInitialization(self):
ctx = context.Context()
ctx.set_logical_cpu_devices(4)
self.assertIs(len(ctx.list_logical_devices('CPU')), 4)
# Cannot set logical device twice.
with self.assertRaisesRegex(RuntimeError, 'Virtual CPUs already set'):
ctx.set_logical_cpu_devices(8)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()