blob: d6b3290759312e2949579a12b36f579198f5ec9e [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains private utilities used mainly by the base Layer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import threading
from tensorflow.python import tf2
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras.utils import control_flow_util
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.training.tracking import base as tracking
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
_call_context = threading.local()
def create_mean_metric(value, name=None):
# import keras will import base_layer and then this module, and metric relies
# on base_layer, which result into a cyclic dependency.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
return metric_obj, metric_obj(value)
def make_variable(name,
shape=None,
dtype=dtypes.float32,
initializer=None,
trainable=None,
caching_device=None,
validate_shape=True,
constraint=None,
use_resource=None,
collections=None,
synchronization=tf_variables.VariableSynchronization.AUTO,
aggregation=tf_variables.VariableAggregation.NONE,
partitioner=None): # pylint: disable=unused-argument
"""Temporary util to create a variable (relies on `variable_scope.variable`).
Some reuse-related technicalities prevent us from using
`variable_scope.get_variable()` directly, so we use a subcomponent
that has fewer constraints (`variable_scope.variable()`).
In the longer term, it seems like a similar "default variable creator" method
should exist in `Trackable` instead. When this happens, we can get
rid of this temporary solution.
TODO(fchollet): remove this method when no longer needed.
Arguments:
name: Variable name.
shape: Variable shape.
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
initializer: Initializer instance (callable).
trainable: Whether the variable should be part of the layer's
"trainable_variables" (e.g. variables, biases)
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
marked as non-trainable. `trainable` defaults to `True` unless
`synchronization` is set to `ON_READ`.
caching_device: Passed to `tf.Variable`.
validate_shape: Passed to `tf.Variable`.
constraint: Constraint instance (callable).
use_resource: Whether to use a `ResourceVariable`.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses
when to synchronize. If `synchronization` is set to `ON_READ`,
`trainable` must not be set to `True`.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
`tf.VariableAggregation`.
partitioner: Not handled at this time.
Returns:
Variable instance.
"""
initializing_from_value = False
if initializer is not None and not callable(initializer):
initializing_from_value = True
if initializing_from_value:
init_val = initializer
variable_dtype = None
else:
# Instantiate initializer if provided initializer is a type object.
if tf_inspect.isclass(initializer):
initializer = initializer()
init_val = functools.partial(initializer, shape, dtype=dtype)
variable_dtype = dtype.base_dtype
if use_resource is None:
use_resource = True
# TODO(apassos,rohanj) figure out how to remove collections from here so we
# can remove the V1.
variable_shape = tensor_shape.TensorShape(shape)
return tf_variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
caching_device=caching_device,
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
use_resource=use_resource,
collections=collections,
synchronization=synchronization,
aggregation=aggregation,
shape=variable_shape if variable_shape else None)
def collect_previous_mask(input_tensors):
"""Retrieves the output mask(s) of the previous node.
Arguments:
input_tensors: An arbitrary structure of Tensors.
Returns:
A mask tensor or list of mask tensors.
"""
def _collect_previous_mask(x):
return getattr(x, '_keras_mask', None)
return nest.map_structure(_collect_previous_mask, input_tensors)
def have_all_keras_metadata(tensors):
return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
def generate_placeholders_from_shape(shape):
return array_ops.placeholder(shape=shape, dtype=backend.floatx())
def create_keras_history(tensors):
"""Wraps TensorFlow Operations for compatibility with the Functional API.
This method checks to see if a Tensor in `tensors` is missing Keras metadata
and has its origin in a Keras `Input` Layer. If so, this method will replace
the raw TensorFlow Operations that created this tensor with
`TensorFlowOpLayer` instances that create identical operations.
Any Tensors not originating from a Keras `Input` Layer will be treated as
constants when constructing `TensorFlowOpLayer` instances.
Arguments:
tensors: A structure of Tensors, some of which come from raw TensorFlow
operations and need to have Keras metadata assigned to them.
Returns:
created_layers: List. The `TensorFlowOpLayer` instances created to wrap
the raw Tensorflow operations.
"""
_, created_layers = _create_keras_history_helper(tensors, set(), [])
return created_layers
def _create_keras_history_helper(tensors, processed_ops, created_layers):
"""Helper method for `create_keras_history`.
Arguments:
tensors: A structure of Tensors for which to create Keras metadata.
processed_ops: Set. TensorFlow operations that have already been wrapped in
`TensorFlowOpLayer` instances.
created_layers: List. The `TensorFlowOpLayer` instances created.
Returns:
Tuple. First element is the updated set of TensorFlow Operations that
have been wrapped in `TensorFlowOpLayer` instances. Second element is
a list of the `TensorFlowOpLayer` instances created.
"""
# Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
# Cannot be imported at top because of circular dependencies.
# TODO(omalleyt): Resolve circular dependency.
from tensorflow.python.keras.engine import base_layer # pylint: disable=g-import-not-at-top
tensor_list = nest.flatten(tensors)
sparse_ops = []
ragged_tensors = []
for tensor in tensor_list:
if getattr(tensor, '_keras_history', None) is not None:
continue
if sparse_tensor.is_sparse(tensor):
sparse_ops.append(tensor.op)
continue
if tf_utils.is_ragged(tensor):
# Ragged tensors don't have an op property
ragged_tensors.append(tensor)
continue
op = tensor.op # The Op that created this Tensor.
if op not in processed_ops:
# Recursively set `_keras_history`.
op_inputs = list(op.inputs)
constants = {}
layer_inputs = []
for i, op_input in enumerate(op_inputs):
if uses_keras_history(op_input):
layer_inputs.append(op_input)
else:
# Treat any value not originating from a `keras.Input` as
# a constant. Variables cannot be supported.
ds_with_session = (
distribution_strategy_context.in_cross_replica_context() and
not ops.executing_eagerly_outside_functions())
using_xla = control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph())
if ds_with_session or using_xla:
# In Legacy Graph mode, evaluating here makes Session be
# configured improperly. The downside of this is that saving
# via `get_config` breaks, but SavedModel still works.
constants[i] = op_input
else:
with ops.init_scope():
if ops.executing_eagerly_outside_functions():
constants[i] = backend.eval_in_eager_or_function(op_input)
else:
constants[i] = backend.function([], op_input)([])
layer_inputs = unnest_if_single_tensor(layer_inputs)
processed_ops, created_layers = _create_keras_history_helper(
layer_inputs, processed_ops, created_layers)
name = op.name
node_def = op.node_def.SerializeToString()
op_layer = base_layer.TensorFlowOpLayer(
node_def, constants=constants, name=name)
created_layers.append(op_layer)
op_layer._set_connectivity_metadata( # pylint: disable=protected-access
args=(layer_inputs,),
kwargs={},
outputs=op.outputs)
processed_ops.update([op])
if sparse_ops or ragged_tensors:
lambda_example = """
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
output = tf.keras.layers.Lambda(weights_mult)(input)
"""
raise ValueError(
'Tensorflow ops that generate ragged or sparse tensor '
'outputs are currently not supported by Keras automatic '
'op wrapping. Please wrap these ops in a Lambda layer: '
'\n\n```\n{example}\n```\n'
'Sparse ops encountered: {sparse_ops}\n'
'Ragged tensors encountered: {ragged_tensors}\n'.format(
example=lambda_example,
sparse_ops=str(sparse_ops),
ragged_tensors=str(ragged_tensors)))
return processed_ops, created_layers
def unnest_if_single_tensor(input_tensors):
# Preserve compatibility with older configs
flat_input_tensors = nest.flatten(input_tensors)
# If this is a single element but not a dict, unwrap. If this is a dict,
# assume the first layer expects a dict (as is the case with a
# DenseFeatures layer); pass through.
if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
input_tensors = flat_input_tensors[0]
return input_tensors
def needs_keras_history(tensors, ignore_call_context=False):
"""Check if any Tensors need to be wrapped in TensorFlowOpLayers.
This will never return True inside a sublayer, because sublayers
do not need to create Keras History. Otherwise, this returns True
if one or more of `tensors` originates from a `keras.Input` and
does not have `_keras_history` set.
Arguments:
tensors: An arbitrary nested structure of Tensors.
ignore_call_context: Whether to ignore the check of if currently
outside of a `call` context. This is `True` when creating
KerasHistory inside `Node`, where we always know that Tensors
are being used with the Functional API.
Returns:
Bool, whether at least one Tensor needs to be wrapped.
"""
input_tensors = nest.flatten(tensors)
if call_context().in_call and not ignore_call_context:
return False
if all(
getattr(tensor, '_keras_history', None) is not None
for tensor in input_tensors):
# KerasHistory already set.
return False
return uses_keras_history(tensors)
def is_in_keras_graph():
"""Returns if currently executing inside of a Keras graph."""
return call_context().in_keras_graph
def is_in_eager_or_tf_function():
"""Returns if in eager mode or inside of a tf.function."""
return context.executing_eagerly() or is_in_tf_function()
def is_in_tf_function():
"""Returns if inside of a tf.function."""
# Check if running in V1 graph mode.
if not ops.executing_eagerly_outside_functions():
return False
if not ops.inside_function():
return False
# Check if inside Keras FuncGraph.
if is_in_keras_graph():
return False
# Check for a v1 `wrap_function` FuncGraph.
graph = ops.get_default_graph()
if (getattr(graph, 'name', False) and
graph.name.startswith('wrapped_function')):
return False
return True
def uses_keras_history(tensors):
"""Check if at least one Tensor originates from a `keras.Input`.
This is `True` if at least one Tensor has its origin in a `keras.Input`.
Any Tensor that originates from a `keras.Input` will have a dependency
Tensor with a `_keras_history` attribute attached. Tensors that have
already been checked to not originate from a `keras.Input`
are marked as `_keras_history_checked`.
Arguments:
tensors: An arbitrary nested structure of Tensors.
Returns:
Bool, whether at least one Tensor originates from a `keras.Input`.
"""
checked_tensors = set()
tensors_to_check = nest.flatten(tensors)
while tensors_to_check:
new_tensors_to_check = []
for tensor in tensors_to_check:
if id(tensor) in checked_tensors:
continue
checked_tensors.add(id(tensor))
if getattr(tensor, '_keras_history_checked', None) is not None:
continue
if getattr(tensor, '_keras_history', None) is not None:
return True
try:
new_tensors_to_check.extend(tensor.op.inputs)
except AttributeError:
# In case `tensor` is a Variable created in an Eager context.
pass
tensors_to_check = new_tensors_to_check
# Mark that these Tensors have been checked once for `_keras_history`,
# and should not be checked again for performance reasons.
mark_checked(tensors)
return False
def mark_checked(tensors):
"""Marks that these Tensors should not be tracked.
This prevents Layers from attempting to create TensorFlowOpLayers
for these Tensors.
Arguments:
tensors: An arbitrary structure of Tensors.
"""
def _mark_checked(tensor):
tensor._keras_history_checked = True # pylint: disable=protected-access
nest.map_structure(_mark_checked, tensors)
def call_context():
"""Returns currently active `CallContext`."""
call_ctx = getattr(_call_context, 'call_context', None)
if call_ctx is None:
call_ctx = CallContext()
_call_context.call_context = call_ctx
return call_ctx
control_flow_util_v2._register_keras_layer_context_function(call_context) # pylint: disable=protected-access
class CallContext(object):
"""Keeps track of properties currently inside a Layer/Model's `call`.
Attributes:
in_call: Whether currently inside the `call` of a Layer.
layer: The `Layer` whose `call` is currently active.
inputs: The inputs to the currently active `Layer`.
build_graph: Whether currently inside a Graph or FuncGraph.
training: Whether currently executing in training or inference mode.
saving: Whether currently saving to SavedModel.
frozen: Whether currently executing inside a `Layer` with `trainable` set to
`False`.
in_keras_graph: Whether executing inside the Keras Graph.
"""
def __init__(self):
# Handle `in_call` separately as it is the most-read attr and reading it is
# on the hot path.
self.in_call = False
self._state = {
'layer': None,
'inputs': None,
'build_graph': False,
'training': None,
'saving': None
}
# TODO(b/150169018): This logic can be replaced after the Functional API
# refactor.
self._in_keras_graph = False
def enter(self, layer, inputs, build_graph, training, saving=None):
"""Push a Layer and its inputs and state onto the current call context.
Arguments:
layer: The `Layer` whose `call` is currently active.
inputs: The inputs to the currently active `Layer`.
build_graph: Whether currently inside a Graph or FuncGraph.
training: Whether currently executing in training or inference mode.
saving: Whether currently saving to SavedModel.
Returns:
Context manager.
"""
state = {
'layer': layer,
'inputs': inputs,
'build_graph': build_graph,
'training': training,
'saving': saving
}
return CallContextManager(self, state)
@property
def layer(self):
return self._state['layer']
@property
def inputs(self):
return self._state['inputs']
@property
def build_graph(self):
return self._state['build_graph']
@property
def training(self):
return self._state['training']
@property
def saving(self):
return self._state['saving']
@property
def frozen(self):
layer = self._state['layer']
if not layer:
return False
return not layer.trainable
@property
def in_keras_graph(self):
# Returns True even if in a subgraph of the Keras graph, such as those
# created by control flow ops.
if context.executing_eagerly():
return False
return (self._in_keras_graph or
getattr(backend.get_graph(), 'name', None) == 'keras_graph')
class CallContextManager(object):
"""Context manager for `CallContext`."""
def __init__(self, call_ctx, state):
self._call_ctx = call_ctx
self._state = state
self._build_graph = state['build_graph']
def __enter__(self):
call_ctx = self._call_ctx
self._prev_in_call = call_ctx.in_call
self._prev_state = call_ctx._state
call_ctx.in_call = True
call_ctx._state = self._state
# TODO(b/150169018): This logic can be removed after the Functional API
# refactor.
if self._build_graph:
self._prev_in_keras_graph = call_ctx._in_keras_graph
call_ctx._in_keras_graph = (
call_ctx._in_keras_graph or
getattr(backend.get_graph(), 'name', None) == 'keras_graph')
def __exit__(self, *exc_info):
call_ctx = self._call_ctx
call_ctx.in_call = self._prev_in_call
call_ctx._state = self._prev_state
if self._build_graph:
call_ctx._in_keras_graph = self._prev_in_keras_graph
def training_arg_passed_to_call(argspec, args, kwargs):
"""Returns whether a user passed the `training` argument in `__call__`."""
# `argspec.args` starts with ['self', 'inputs']
full_args = dict(zip(argspec.args[2:], args))
full_args.update(kwargs)
return 'training' in full_args and full_args['training'] is not None
def is_subclassed(layer):
"""Returns True if the object is a subclassed layer or subclassed model."""
return (layer.__module__.find('keras.engine') == -1 and
layer.__module__.find('keras.layers') == -1)
def from_saved_model(layer):
"""Returns whether the layer is loaded from a SavedModel."""
return layer.__module__.find('keras.saving.saved_model') != -1
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
"""Checks that tensors passed to `add_*` method match the Keras graph.
When one of the `add_*` method is called inside a V2 conditional branch,
the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
We need to raise clear error messages in such cases.
Arguments:
tensor: Tensor to check, or `False` if it is known that an error
should be raised.
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
force_raise: If an error should be raised regardless of `tensor`.
Raises:
RuntimeError: In case of an out-of-graph tensor.
"""
if (force_raise or
(ops.executing_eagerly_outside_functions() and
hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)):
if method == 'activity_regularizer':
bad_example = """
class TestModel(tf.keras.Model):
def __init__(self):
super(TestModel, self).__init__(name='test_model')
self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
def call(self, x, training=None):
if training:
return self.dense(x)
else:
return self.dense(x)
"""
correct_example = """
class TestModel(tf.keras.Model):
def __init__(self):
super(TestModel, self).__init__(name='test_model')
self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
def call(self, x, training=None):
return self.dense(x)
"""
raise RuntimeError(
'You are using a layer with `activity_regularizer` in a control flow '
'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
'Please move your call to the layer with `activity_regularizer` out '
'of the control flow branch, e.g.:\n{correct_example}\n'
'You can also resolve this by marking your outer model/layer dynamic'
' (eager-only) by passing `dynamic=True` to the layer constructor. '
'Any kind of control flow is supported with dynamic layers. '
'Note that using `dynamic=True` requires you to implement static '
'shape inference in the `compute_output_shape(input_shape)` '
'method.'.format(
bad_example=bad_example, correct_example=correct_example))
if method == 'add_metric':
bad_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
else:
metric = 0.
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
elif method == 'add_loss':
bad_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
self.add_loss(loss)
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
else:
loss = 0.
self.add_loss(loss)
return inputs
"""
else:
bad_example = """
def call(self, inputs, training=None):
if training:
self.add_update(self.w.assign_add(1))
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
increment = 1
else:
increment = 0
self.add_update(self.w.assign_add(increment))
return inputs
"""
raise RuntimeError(
'You are using the method `{method}` in a control flow branch '
'in your layer, e.g.:\n{bad_example}\n'
'This is not currently supported. '
'Please move your call to {method} out of the control flow branch, '
'e.g.:\n{correct_example}\n'
'You can also resolve this by marking your layer '
'as dynamic (eager-only) by passing '
'`dynamic=True` to the layer constructor. '
'Any kind of control flow is supported with dynamic layers. '
'Note that using `dynamic=True` requires you '
'to implement static shape inference '
'in the `compute_output_shape(input_shape)` method.'.format(
method=method,
bad_example=bad_example,
correct_example=correct_example))
def mark_as_return(outputs, acd):
"""Marks `outputs` as the return values for automatic control deps."""
def _mark_as_return(tensor):
"""Marks `tensor` as the return value for automatic control deps."""
if not tensor_util.is_tensor(tensor):
return tensor
# pylint: disable=protected-access
return_tensor = acd.mark_as_return(tensor)
if getattr(tensor, '_keras_mask', None) is not None:
return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
else:
return_tensor._keras_mask = None
# Handle TensorFlow Probability attached metadata.
# TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
if getattr(tensor, '_tfp_distribution', None) is not None:
return_tensor._tfp_distribution = tensor._tfp_distribution
return return_tensor
# pylint: enable=protected-access
return nest.map_structure(_mark_as_return, outputs)
V2_DTYPE_BEHAVIOR = None
@keras_export(v1=['keras.layers.enable_v2_dtype_behavior'])
def enable_v2_dtype_behavior():
"""Enable the V2 dtype behavior for Keras layers.
By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function
is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since
mixed precision requires V2 dtype behavior to be enabled, this function allows
you to use mixed precision in Keras layers if `disable_v2_behavior` has been
called.
When enabled, the dtype of Keras layers defaults to floatx (which is typically
float32) instead of None. In addition, layers will automatically cast
floating-point inputs to the layer's dtype.
>>> x = tf.ones((4, 4, 4, 4), dtype='float64')
>>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
>>> print(layer.dtype) # float32 since V2 dtype behavior is enabled
float32
>>> y = layer(x) # Layer casts inputs since V2 dtype behavior is enabled
>>> print(y.dtype.name)
float32
A layer author can opt-out their layer from the automatic input casting by
passing `autocast=False` to the base Layer's constructor. This disables the
autocasting part of the V2 behavior for that layer, but not the defaulting to
floatx part of the V2 behavior.
When a global `tf.keras.mixed_precision.experimental.Policy` is set, a Keras
layer's dtype will default to the global policy instead of floatx. Layers
will automatically cast inputs to the policy's compute_dtype.
"""
global V2_DTYPE_BEHAVIOR
V2_DTYPE_BEHAVIOR = True
@keras_export(v1=['keras.layers.disable_v2_dtype_behavior'])
def disable_v2_dtype_behavior():
"""Disables the V2 dtype behavior for Keras layers.
See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`.
"""
global V2_DTYPE_BEHAVIOR
V2_DTYPE_BEHAVIOR = False
def v2_dtype_behavior_enabled():
"""Returns True if the V2 dtype behavior is enabled."""
if V2_DTYPE_BEHAVIOR is None:
return tf2.enabled()
return V2_DTYPE_BEHAVIOR
class TrackableWeightHandler(object):
"""Keras wrapper for handling tracking.Trackable object saving and restoring.
This class handles Trackables in both V1 and V2 modes, ensuring that they can
be saved and restored with the correct data and without adding additional ops
on every save.
Attributes:
trackable: The trackable to wrap.
num_tensors: The number of tensors that this trackable requires for saving.
"""
def __init__(self, trackable):
if not isinstance(trackable, tracking.Trackable):
raise ValueError('%s is not a Trackable object.' % (trackable,))
self._trackable = trackable
self._distribute_strategy = distribution_strategy_context.get_strategy()
# TODO(b/141682913): Figure out why this is private and fix it.
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
if len(saveables) != 1:
raise ValueError('Only Trackables with one Saveable are supported.')
saveable = list(saveables)[0]
if ops.executing_eagerly_outside_functions():
# If we're in eager mode, we need to defer calling the Trackable's
# saveable() callable until data export time.
# However, it is safe to call the saveable as many times as we want, so
# we will call it now to figure out how many tensors this Trackable will
# produce.
self._saveable = saveable
self._num_tensors = len(self._saveable().specs)
self._setter = lambda weights: self._saveable().restore(weights, None)
self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
else:
# If we're in Graph mode, we need to evaluate the Saveable only once and
# cache the resulting restore graph. Failing to do this will result in
# new assignment ops being added to the graph each time set_weights() is
# called.
self._placeholder_tensors = []
self._saveable = saveable()
self._num_tensors = len(self._saveable.specs)
for spec in self._saveable.specs:
tensor = spec.tensor
self._placeholder_tensors.append(
array_ops.placeholder(tensor.dtype, tensor.shape))
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
self._setter = self._set_weights_v1
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
@property
def num_tensors(self):
return self._num_tensors
def set_weights(self, weights):
if len(weights) != self._num_tensors:
raise ValueError(
('Weight handler for trackable %s received the wrong number of ' +
'weights: expected %s, got %s.') %
(self._trackable, self._num_tensors, len(weights)))
self._setter(weights)
def get_tensors(self):
return self._getter()
def _set_weights_v1(self, weights):
feed_dict = {}
for idx, tensor in enumerate(weights):
feed_dict[self._placeholder_tensors[idx]] = tensor
backend.get_session().run(self._assign_op, feed_dict)
def no_ragged_support(inputs, layer_name):
input_list = nest.flatten(inputs)
if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list):
raise ValueError('Layer %s does not support RaggedTensors as input. '
'Inputs received: %s. You can try converting your '
'input to an uniform tensor.' % (layer_name, inputs))
def is_split_variable(v):
"""Returns True if `v` is either a PartionedVariable or a ShardedVariable."""
return hasattr(v, '_variable_list') or hasattr(v, '_variables')
def has_weights(obj):
obj_type = type(obj)
return (hasattr(obj_type, 'trainable_weights') and
hasattr(obj_type, 'non_trainable_weights') and
not isinstance(obj, type))
# TODO(kathywu): This is a temporary hack. When a network of layers is revived
# from SavedModel, only the top-level layer will have losses. This causes issues
# in eager mode because the child layers may have graph losses
# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
# whenever eager losses are added to one layer, add eager losses to all
# child layers. This causes `.losses` to only return eager losses.
REVIVED_LOSS_PLACEHOLDER = (
'This layer\'s losses have been added to the parent layer.')