| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Ops to use variables as resources.""" |
| |
| # pylint: disable=g-bad-name |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import contextlib |
| import functools |
| import weakref |
| |
| import numpy as np |
| |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.core.framework import variable_pb2 |
| from tensorflow.python.client import pywrap_tf_session |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import tape |
| from tensorflow.python.framework import auto_control_deps_utils as acd |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import cpp_shape_inference_pb2 |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_array_ops |
| from tensorflow.python.ops import gen_resource_variable_ops |
| from tensorflow.python.ops import gen_state_ops |
| from tensorflow.python.ops import handle_data_util |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import variables |
| # go/tf-wildcard-import |
| # pylint: disable=wildcard-import |
| from tensorflow.python.ops.gen_resource_variable_ops import * |
| # pylint: enable=wildcard-import |
| from tensorflow.python.training.tracking import base as trackable |
| from tensorflow.python.types import core |
| from tensorflow.python.util import _pywrap_utils |
| from tensorflow.python.util import compat |
| from tensorflow.python.util.deprecation import deprecated |
| |
| acd.register_read_only_resource_op("ReadVariableOp") |
| acd.register_read_only_resource_op("VariableShape") |
| acd.register_read_only_resource_op("ResourceGather") |
| acd.register_read_only_resource_op("ResourceGatherNd") |
| acd.register_read_only_resource_op("_ReadVariablesOp") |
| |
| |
| # TODO(allenl): Remove this alias and migrate callers. |
| get_resource_handle_data = handle_data_util.get_resource_handle_data |
| |
| |
| def get_eager_safe_handle_data(handle): |
| """Get the data handle from the Tensor `handle`.""" |
| assert isinstance(handle, ops.Tensor) |
| |
| if isinstance(handle, ops.EagerTensor): |
| return handle._handle_data # pylint: disable=protected-access |
| else: |
| return get_resource_handle_data(handle) |
| |
| |
| def _set_handle_shapes_and_types(tensor, handle_data, graph_mode): |
| """Sets the shape inference result HandleData on tensor. |
| |
| Args: |
| tensor: A `Tensor` or `EagerTensor`. |
| handle_data: A `CppShapeInferenceResult.HandleData`. |
| graph_mode: A python bool. |
| """ |
| tensor._handle_data = handle_data # pylint: disable=protected-access |
| if not graph_mode: |
| return |
| |
| # Not an EagerTensor, so a graph tensor. |
| shapes, types = zip(*[(pair.shape, pair.dtype) |
| for pair in handle_data.shape_and_type]) |
| ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] |
| shapes = [ |
| [d.size for d in s.dim] # pylint: disable=g-complex-comprehension |
| if not s.unknown_rank else None for s in shapes |
| ] |
| pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( |
| tensor._op._graph._c_graph, # pylint: disable=protected-access |
| tensor._as_tf_output(), # pylint: disable=protected-access |
| shapes, |
| ranks, |
| types) |
| |
| |
| def _combine_handle_data(handle, initial_value): |
| """Concats HandleData from tensors `handle` and `initial_value`. |
| |
| Args: |
| handle: A `Tensor` of dtype `resource`. |
| initial_value: A `Tensor`. |
| |
| Returns: |
| A `CppShapeInferenceResult.HandleData`. If `initial_value` has dtype |
| `variant`, the `HandleData` contains the concatenation of the shape_and_type |
| from both `handle` and `initial_value`. |
| |
| Raises: |
| RuntimeError: If handle, which was returned by VarHandleOp, either has |
| no handle data, or its len(handle_data.shape_and_type) != 1. |
| """ |
| assert handle.dtype == dtypes.resource |
| |
| variable_handle_data = get_eager_safe_handle_data(handle) |
| |
| if initial_value.dtype != dtypes.variant: |
| return variable_handle_data |
| |
| extra_handle_data = get_eager_safe_handle_data(initial_value) |
| if extra_handle_data is not None and extra_handle_data.is_set: |
| if (variable_handle_data is None or not variable_handle_data.is_set or |
| len(variable_handle_data.shape_and_type) != 1): |
| raise RuntimeError( |
| "Expected VarHandleOp to return a length==1 shape_and_type, " |
| "but saw: '%s'" % (variable_handle_data,)) |
| variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) |
| return variable_handle_data |
| |
| |
| def _variable_handle_from_shape_and_dtype(shape, |
| dtype, |
| shared_name, |
| name, |
| graph_mode, |
| initial_value=None): |
| """Create a variable handle, copying in handle data from `initial_value`.""" |
| container = ops.get_default_graph()._container # pylint: disable=protected-access |
| if container is None: |
| container = "" |
| shape = tensor_shape.as_shape(shape) |
| dtype = dtypes.as_dtype(dtype) |
| if not graph_mode: |
| if shared_name is not None: |
| raise errors.InternalError( |
| "Using an explicit shared_name is not supported executing eagerly.") |
| shared_name = context.shared_name() |
| |
| handle = gen_resource_variable_ops.var_handle_op( |
| shape=shape, |
| dtype=dtype, |
| shared_name=shared_name, |
| name=name, |
| container=container) |
| if initial_value is None: |
| initial_value = handle |
| if graph_mode: |
| full_handle_data = _combine_handle_data(handle, initial_value) |
| _set_handle_shapes_and_types(handle, full_handle_data, graph_mode) |
| return handle |
| else: |
| handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() |
| handle_data.is_set = True |
| handle_data.shape_and_type.append( |
| cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( |
| shape=shape.as_proto(), dtype=dtype.as_datatype_enum)) |
| |
| if initial_value is not None and initial_value.dtype == dtypes.variant: |
| extra_handle_data = get_eager_safe_handle_data(initial_value) |
| if extra_handle_data is not None and extra_handle_data.is_set: |
| if (not handle_data.is_set or len(handle_data.shape_and_type) != 1): |
| raise RuntimeError( |
| "Expected VarHandleOp to return a length==1 shape_and_type, " |
| "but saw: '%s'" % (handle_data,)) |
| handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) |
| |
| _set_handle_shapes_and_types(handle, handle_data, graph_mode) |
| return handle |
| |
| |
| def eager_safe_variable_handle(initial_value, shape, shared_name, name, |
| graph_mode): |
| """Creates a variable handle with information to do shape inference. |
| |
| The dtype is read from `initial_value` and stored in the returned |
| resource tensor's handle data. |
| |
| If `initial_value.dtype == tf.variant`, we additionally extract the handle |
| data (if any) from `initial_value` and append it to the `handle_data`. |
| In this case, the returned tensor's handle data is in the form |
| |
| ``` |
| is_set: true |
| shape_and_type { |
| shape { |
| // initial_value.shape |
| } |
| dtype: DT_VARIANT |
| } |
| shape_and_type { |
| // handle_data(initial_value).shape_and_type[0] |
| } |
| shape_and_type { |
| // handle_data(initial_value).shape_and_type[1] |
| } |
| ... |
| ``` |
| |
| Ops that read from this tensor, such as `ReadVariableOp` and |
| `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]` |
| correspond to the handle data of the variant(s) stored in the Variable. |
| |
| Args: |
| initial_value: A `Tensor`. |
| shape: The shape of the handle data. Can be `TensorShape(None)` (i.e. |
| unknown shape). |
| shared_name: A string. |
| name: A string. |
| graph_mode: A python bool. |
| |
| Returns: |
| The handle, a `Tensor` of type `resource`. |
| """ |
| dtype = initial_value.dtype.base_dtype |
| return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name, |
| graph_mode, initial_value) |
| |
| |
| @contextlib.contextmanager |
| def _handle_graph(handle): |
| # Note: might have an eager tensor but not be executing eagerly when building |
| # functions. |
| if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or |
| ops.has_default_graph()): |
| yield |
| else: |
| with handle.graph.as_default(): |
| yield |
| |
| |
| class EagerResourceDeleter(object): |
| """An object which cleans up a resource handle. |
| |
| An alternative to defining a __del__ method on an object. The intended use is |
| that ResourceVariables or other objects with resource handles will maintain a |
| single reference to this object. When the parent object is collected, this |
| object will be too. Even if the parent object is part of a reference cycle, |
| the cycle will be collectable. |
| """ |
| |
| __slots__ = ["_handle", "_handle_device", "_context"] |
| |
| def __init__(self, handle, handle_device): |
| if not isinstance(handle, ops.Tensor): |
| raise ValueError( |
| ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle " |
| "Tensor." % (handle,))) |
| self._handle = handle |
| self._handle_device = handle_device |
| # This is held since the __del__ function runs an op, and if the context() |
| # is collected before this object, there will be a segfault when running the |
| # op. |
| self._context = context.context() |
| |
| def __del__(self): |
| # Resources follow object-identity when executing eagerly, so it is safe to |
| # delete the resource we have a handle to. |
| try: |
| # A packed EagerTensor doesn't own any resource. |
| if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed: |
| return |
| # This resource was created in eager mode. However, this destructor may be |
| # running in graph mode (especially during unit tests). To clean up |
| # successfully, we switch back into eager mode temporarily. |
| with context.eager_mode(): |
| with ops.device(self._handle_device): |
| gen_resource_variable_ops.destroy_resource_op( |
| self._handle, ignore_lookup_error=True) |
| except TypeError: |
| # Suppress some exceptions, mainly for the case when we're running on |
| # module deletion. Things that can go wrong include the context module |
| # already being unloaded, self._handle._handle_data no longer being |
| # valid, and so on. Printing warnings in these cases is silly |
| # (exceptions raised from __del__ are printed as warnings to stderr). |
| pass # 'NoneType' object is not callable when the handle has been |
| # partially unloaded. |
| except AttributeError: |
| pass # 'NoneType' object has no attribute 'eager_mode' when context has |
| # been unloaded. Will catch other module unloads as well. |
| |
| |
| def shape_safe_assign_variable_handle(handle, shape, value, name=None): |
| """Helper that checks shape compatibility and assigns variable.""" |
| with _handle_graph(handle): |
| value_tensor = ops.convert_to_tensor(value) |
| shape.assert_is_compatible_with(value_tensor.shape) |
| return gen_resource_variable_ops.assign_variable_op( |
| handle, value_tensor, name=name) |
| |
| |
| def _maybe_set_handle_data(dtype, handle, tensor): |
| if dtype == dtypes.variant: |
| # For DT_VARIANT types, the handle's shape_and_type[1:] stores the |
| # variant's handle data. Extract it. |
| handle_data = get_eager_safe_handle_data(handle) |
| if handle_data.is_set and len(handle_data.shape_and_type) > 1: |
| tensor._handle_data = ( # pylint: disable=protected-access |
| cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( |
| is_set=True, shape_and_type=handle_data.shape_and_type[1:])) |
| |
| |
| def variable_accessed(variable): |
| """Records that `variable` was accessed for the tape and FuncGraph.""" |
| if hasattr(ops.get_default_graph(), "watch_variable"): |
| ops.get_default_graph().watch_variable(variable) |
| if variable.trainable: |
| tape.variable_accessed(variable) |
| |
| |
| class BaseResourceVariable(variables.VariableV1, core.Tensor): |
| """A python variable from an existing handle.""" |
| |
| # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in. |
| def __init__( # pylint: disable=super-init-not-called |
| self, |
| trainable=None, |
| shape=None, |
| dtype=None, |
| handle=None, |
| constraint=None, |
| synchronization=None, |
| aggregation=None, |
| distribute_strategy=None, |
| name=None, |
| unique_id=None, |
| handle_name=None, |
| graph_element=None, |
| initial_value=None, |
| initializer_op=None, |
| is_initialized_op=None, |
| cached_value=None, |
| save_slice_info=None, |
| handle_deleter=None, |
| caching_device=None, |
| **unused_kwargs): |
| """Creates a variable from a handle. |
| |
| Args: |
| trainable: If `True`, GradientTapes automatically watch uses of this |
| Variable. |
| shape: The variable's shape. |
| dtype: The variable's dtype. |
| handle: The variable's handle |
| constraint: An optional projection function to be applied to the variable |
| after being updated by an `Optimizer` (e.g. used to implement norm |
| constraints or value constraints for layer weights). The function must |
| take as input the unprojected Tensor representing the value of the |
| variable and return the Tensor for the projected value (which must have |
| the same shape). Constraints are not safe to use when doing asynchronous |
| distributed training. |
| synchronization: Indicates when a distributed a variable will be |
| aggregated. Accepted values are constants defined in the class |
| `tf.VariableSynchronization`. By default the synchronization is set to |
| `AUTO` and the current `DistributionStrategy` chooses when to |
| synchronize. |
| aggregation: Indicates how a distributed variable will be aggregated. |
| Accepted values are constants defined in the class |
| `tf.VariableAggregation`. |
| distribute_strategy: The distribution strategy this variable was created |
| under. |
| name: The name for this variable. |
| unique_id: Internal. Unique ID for this variable's handle. |
| handle_name: The name for the variable's handle. |
| graph_element: Optional, required only in session.run-mode. Pre-created |
| tensor which reads this variable's value. |
| initial_value: Optional. Variable's initial value. |
| initializer_op: Operation which assigns the variable's initial value. |
| is_initialized_op: Pre-created operation to check whether this variable is |
| initialized. |
| cached_value: Pre-created operation to read this variable in a specific |
| device. |
| save_slice_info: Metadata for variable partitioning. |
| handle_deleter: EagerResourceDeleter responsible for cleaning up the |
| handle. |
| caching_device: Optional device string or function describing where the |
| Variable should be cached for reading. Defaults to the Variable's |
| device. If not `None`, caches on another device. Typical use is to |
| cache on the device where the Ops using the Variable reside, to |
| deduplicate copying through `Switch` and other conditional statements. |
| """ |
| with ops.init_scope(): |
| self._in_graph_mode = not context.executing_eagerly() |
| synchronization, aggregation, trainable = ( |
| variables.validate_synchronization_aggregation_trainable( |
| synchronization, aggregation, trainable, name)) |
| self._trainable = trainable |
| self._synchronization = synchronization |
| self._aggregation = aggregation |
| self._save_slice_info = save_slice_info |
| self._initial_value = initial_value |
| self._initializer_op = initializer_op |
| self._is_initialized_op = is_initialized_op |
| self._graph_element = graph_element |
| self._caching_device = caching_device |
| self._cached_value = cached_value |
| self._distribute_strategy = distribute_strategy |
| # Store the graph key so optimizers know how to only retrieve variables from |
| # this graph. Guaranteed to be the same as the eager graph_key. |
| self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access |
| self._shape = tensor_shape.as_shape(shape) |
| self._dtype = dtypes.as_dtype(dtype) |
| self._handle = handle |
| self._unique_id = unique_id |
| self._handle_name = handle_name + ":0" |
| self._constraint = constraint |
| # After the handle has been created, set up a way to clean it up when |
| # executing eagerly. We'll hold the only reference to the deleter, so that |
| # when this object is garbage collected the deleter will be too. This |
| # means ResourceVariables can be part of reference cycles without those |
| # cycles being uncollectable. |
| if not self._in_graph_mode: |
| if handle_deleter is None: |
| handle_deleter = EagerResourceDeleter( |
| handle=self._handle, handle_device=self._handle.device) |
| self._handle_deleter = handle_deleter |
| self._cached_shape_as_list = None |
| |
| def __repr__(self): |
| if context.executing_eagerly() and not self._in_graph_mode: |
| # If we cannot read the value for any reason, still produce a __repr__. |
| try: |
| value_text = ops.numpy_text(self.read_value(), is_repr=True) |
| except: # pylint: disable=bare-except |
| value_text = "<unavailable>" |
| |
| return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( |
| self.name, self.get_shape(), self.dtype.name, value_text) |
| else: |
| return "<tf.Variable '%s' shape=%s dtype=%s>" % ( |
| self.name, self.get_shape(), self.dtype.name) |
| |
| @contextlib.contextmanager |
| def _assign_dependencies(self): |
| """Makes assignments depend on the cached value, if any. |
| |
| This prevents undefined behavior with reads not ordered wrt writes. |
| |
| Yields: |
| None. |
| """ |
| if self._cached_value is not None: |
| with ops.control_dependencies([self._cached_value]): |
| yield |
| else: |
| yield |
| |
| def __array__(self): |
| """Allows direct conversion to a numpy array. |
| |
| >>> np.array(tf.Variable([1.0])) |
| array([1.], dtype=float32) |
| |
| Returns: |
| The variable value as a numpy array. |
| """ |
| # You can't return `self.numpy()` here because for scalars |
| # that raises: |
| # ValueError: object __array__ method not producing an array |
| # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give |
| # the same error. The `EagerTensor` class must be doing something behind the |
| # scenes to make `np.array(tf.constant(1))` work. |
| return np.asarray(self.numpy()) |
| |
| def __nonzero__(self): |
| return self.__bool__() |
| |
| def __bool__(self): |
| return bool(self.read_value()) |
| |
| def __copy__(self): |
| return self |
| |
| def __deepcopy__(self, memo): |
| if not context.executing_eagerly(): |
| raise NotImplementedError( |
| "__deepcopy__() is only available when eager execution is enabled.") |
| copied_variable = ResourceVariable( |
| initial_value=self.read_value(), |
| trainable=self._trainable, |
| constraint=self._constraint, |
| dtype=self._dtype, |
| name=self._shared_name, |
| distribute_strategy=self._distribute_strategy) |
| memo[self._unique_id] = copied_variable |
| return copied_variable |
| |
| @property |
| def dtype(self): |
| """The dtype of this variable.""" |
| return self._dtype |
| |
| @property |
| def device(self): |
| """The device this variable is on.""" |
| return self._handle.device |
| |
| @property |
| def graph(self): |
| """The `Graph` of this variable.""" |
| return self._handle.graph |
| |
| @property |
| def name(self): |
| """The name of the handle for this variable.""" |
| return self._handle_name |
| |
| @property |
| def shape(self): |
| """The shape of this variable.""" |
| return self._shape |
| |
| def set_shape(self, shape): |
| self._shape = self._shape.merge_with(shape) |
| |
| def _shape_as_list(self): |
| if self.shape.ndims is None: |
| return None |
| return [dim.value for dim in self.shape.dims] |
| |
| def _shape_tuple(self): |
| shape = self._shape_as_list() |
| if shape is None: |
| return None |
| return tuple(shape) |
| |
| @property |
| def create(self): |
| """The op responsible for initializing this variable.""" |
| if not self._in_graph_mode: |
| raise RuntimeError("Calling create is not supported when eager execution" |
| " is enabled.") |
| return self._initializer_op |
| |
| @property |
| def handle(self): |
| """The handle by which this variable can be accessed.""" |
| return self._handle |
| |
| def value(self): |
| """A cached operation which reads the value of this variable.""" |
| if self._cached_value is not None: |
| return self._cached_value |
| with ops.colocate_with(None, ignore_existing=True): |
| return self._read_variable_op() |
| |
| def _as_graph_element(self): |
| """Conversion function for Graph.as_graph_element().""" |
| return self._graph_element |
| |
| @property |
| def initializer(self): |
| """The op responsible for initializing this variable.""" |
| return self._initializer_op |
| |
| @property |
| def initial_value(self): |
| """Returns the Tensor used as the initial value for the variable.""" |
| if context.executing_eagerly(): |
| raise RuntimeError("initial_value not supported in EAGER mode.") |
| return self._initial_value |
| |
| @property |
| def constraint(self): |
| """Returns the constraint function associated with this variable. |
| |
| Returns: |
| The constraint function that was passed to the variable constructor. |
| Can be `None` if no constraint was passed. |
| """ |
| return self._constraint |
| |
| @property |
| def op(self): |
| """The op for this variable.""" |
| return self._handle.op |
| |
| @property |
| def trainable(self): |
| return self._trainable |
| |
| @property |
| def synchronization(self): |
| return self._synchronization |
| |
| @property |
| def aggregation(self): |
| return self._aggregation |
| |
| def eval(self, session=None): |
| """Evaluates and returns the value of this variable.""" |
| if context.executing_eagerly(): |
| raise RuntimeError("Trying to eval in EAGER mode") |
| return self._graph_element.eval(session=session) |
| |
| def numpy(self): |
| if context.executing_eagerly(): |
| return self.read_value().numpy() |
| raise NotImplementedError( |
| "numpy() is only available when eager execution is enabled.") |
| |
| @deprecated(None, "Prefer Dataset.range instead.") |
| def count_up_to(self, limit): |
| """Increments this variable until it reaches `limit`. |
| |
| When that Op is run it tries to increment the variable by `1`. If |
| incrementing the variable would bring it above `limit` then the Op raises |
| the exception `OutOfRangeError`. |
| |
| If no error is raised, the Op outputs the value of the variable before |
| the increment. |
| |
| This is essentially a shortcut for `count_up_to(self, limit)`. |
| |
| Args: |
| limit: value at which incrementing the variable raises an error. |
| |
| Returns: |
| A `Tensor` that will hold the variable value before the increment. If no |
| other Op modifies this variable, the values produced will all be |
| distinct. |
| """ |
| return gen_state_ops.resource_count_up_to( |
| self.handle, limit=limit, T=self.dtype) |
| |
| def _map_resources(self, save_options): |
| """For implementing `Trackable`.""" |
| new_variable = None |
| if save_options.experimental_variable_policy._save_variable_devices(): # pylint:disable=protected-access |
| with ops.device(self.device): |
| new_variable = copy_to_graph_uninitialized(self) |
| else: |
| new_variable = copy_to_graph_uninitialized(self) |
| obj_map = {self: new_variable} |
| resource_map = {self._handle: new_variable.handle} |
| return obj_map, resource_map |
| |
| def _read_variable_op(self): |
| variable_accessed(self) |
| |
| def read_and_set_handle(): |
| result = gen_resource_variable_ops.read_variable_op( |
| self._handle, self._dtype) |
| _maybe_set_handle_data(self._dtype, self._handle, result) |
| return result |
| |
| if getattr(self, "_caching_device", None) is not None: |
| with ops.colocate_with(None, ignore_existing=True): |
| with ops.device(self._caching_device): |
| result = read_and_set_handle() |
| else: |
| result = read_and_set_handle() |
| |
| if not context.executing_eagerly(): |
| # Note that if a control flow context is active the input of the read op |
| # might not actually be the handle. This line bypasses it. |
| tape.record_operation( |
| "ReadVariableOp", [result], [self._handle], |
| backward_function=lambda x: [x], |
| forward_function=lambda x: [x]) |
| return result |
| |
| def read_value(self): |
| """Constructs an op which reads the value of this variable. |
| |
| Should be used when there are multiple reads, or when it is desirable to |
| read the value only after some condition is true. |
| |
| Returns: |
| the read operation. |
| """ |
| with ops.name_scope("Read"): |
| value = self._read_variable_op() |
| # Return an identity so it can get placed on whatever device the context |
| # specifies instead of the device where the variable is. |
| return array_ops.identity(value) |
| |
| def sparse_read(self, indices, name=None): |
| """Reads the value of this variable sparsely, using `gather`.""" |
| with ops.name_scope("Gather" if name is None else name) as name: |
| variable_accessed(self) |
| value = gen_resource_variable_ops.resource_gather( |
| self._handle, indices, dtype=self._dtype, name=name) |
| |
| if self._dtype == dtypes.variant: |
| # For DT_VARIANT types, the handle's shape_and_type[1:] stores the |
| # variant's handle data. Extract it. |
| handle_data = get_eager_safe_handle_data(self._handle) |
| if handle_data.is_set and len(handle_data.shape_and_type) > 1: |
| value._handle_data = ( # pylint: disable=protected-access |
| cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( |
| is_set=True, shape_and_type=handle_data.shape_and_type[1:])) |
| |
| return array_ops.identity(value) |
| |
| def gather_nd(self, indices, name=None): |
| """Reads the value of this variable sparsely, using `gather_nd`.""" |
| with ops.name_scope("GatherNd" if name is None else name) as name: |
| if self.trainable: |
| variable_accessed(self) |
| value = gen_resource_variable_ops.resource_gather_nd( |
| self._handle, indices, dtype=self._dtype, name=name) |
| |
| return array_ops.identity(value) |
| |
| def to_proto(self, export_scope=None): |
| """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. |
| |
| Args: |
| export_scope: Optional `string`. Name scope to remove. |
| |
| Raises: |
| RuntimeError: If run in EAGER mode. |
| |
| Returns: |
| A `VariableDef` protocol buffer, or `None` if the `Variable` is not |
| in the specified name scope. |
| """ |
| if context.executing_eagerly(): |
| raise RuntimeError("to_proto not supported in EAGER mode.") |
| if export_scope is None or self.handle.name.startswith(export_scope): |
| var_def = variable_pb2.VariableDef() |
| var_def.variable_name = ops.strip_name_scope(self.handle.name, |
| export_scope) |
| if self._initial_value is not None: |
| # This is inside an if-statement for backwards compatibility, since |
| # self._initial_value might be None for variables constructed from old |
| # protos. |
| var_def.initial_value_name = ops.strip_name_scope( |
| self._initial_value.name, export_scope) |
| var_def.initializer_name = ops.strip_name_scope(self.initializer.name, |
| export_scope) |
| if self._cached_value is not None: |
| var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, |
| export_scope) |
| else: |
| # Store the graph_element here |
| var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, |
| export_scope) |
| var_def.is_resource = True |
| var_def.trainable = self.trainable |
| var_def.synchronization = self.synchronization.value |
| var_def.aggregation = self.aggregation.value |
| if self._save_slice_info: |
| var_def.save_slice_info_def.MergeFrom( |
| self._save_slice_info.to_proto(export_scope=export_scope)) |
| return var_def |
| else: |
| return None |
| |
| @staticmethod |
| def from_proto(variable_def, import_scope=None): |
| if context.executing_eagerly(): |
| raise RuntimeError("from_proto not supported in EAGER mode.") |
| return ResourceVariable( |
| variable_def=variable_def, import_scope=import_scope) |
| |
| __array_priority__ = 100 |
| |
| def is_initialized(self, name=None): |
| """Checks whether a resource variable has been initialized. |
| |
| Outputs boolean scalar indicating whether the tensor has been initialized. |
| |
| Args: |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` of type `bool`. |
| """ |
| # TODO(b/169792703): The current device placement logic never overrides an |
| # explicit placement with a custom device, causing `v.is_initalized()` to |
| # fail under a non-custom device context if `v` is in a custom device. The |
| # explicit placement below makes this work, but should not be necessary once |
| # the logic is updated to handle cases like this. |
| with ops.device(self.device): |
| return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) |
| |
| def assign_sub(self, delta, use_locking=None, name=None, read_value=True): |
| """Subtracts a value from this variable. |
| |
| Args: |
| delta: A `Tensor`. The value to subtract from this variable. |
| use_locking: If `True`, use locking during the operation. |
| name: The name to use for the operation. |
| read_value: A `bool`. Whether to read and return the new value of the |
| variable or not. |
| |
| Returns: |
| If `read_value` is `True`, this method will return the new value of the |
| variable after the assignment has completed. Otherwise, when in graph mode |
| it will return the `Operation` that does the assignment, and when in eager |
| mode it will return `None`. |
| """ |
| # TODO(apassos): this here and below is not atomic. Consider making it |
| # atomic if there's a way to do so without a performance cost for those who |
| # don't need it. |
| with _handle_graph(self.handle), self._assign_dependencies(): |
| assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( |
| self.handle, |
| ops.convert_to_tensor(delta, dtype=self.dtype), |
| name=name) |
| if read_value: |
| return self._lazy_read(assign_sub_op) |
| return assign_sub_op |
| |
| def assign_add(self, delta, use_locking=None, name=None, read_value=True): |
| """Adds a value to this variable. |
| |
| Args: |
| delta: A `Tensor`. The value to add to this variable. |
| use_locking: If `True`, use locking during the operation. |
| name: The name to use for the operation. |
| read_value: A `bool`. Whether to read and return the new value of the |
| variable or not. |
| |
| Returns: |
| If `read_value` is `True`, this method will return the new value of the |
| variable after the assignment has completed. Otherwise, when in graph mode |
| it will return the `Operation` that does the assignment, and when in eager |
| mode it will return `None`. |
| """ |
| with _handle_graph(self.handle), self._assign_dependencies(): |
| assign_add_op = gen_resource_variable_ops.assign_add_variable_op( |
| self.handle, |
| ops.convert_to_tensor(delta, dtype=self.dtype), |
| name=name) |
| if read_value: |
| return self._lazy_read(assign_add_op) |
| return assign_add_op |
| |
| def _lazy_read(self, op): |
| variable_accessed(self) |
| return _UnreadVariable( |
| handle=self._handle, |
| dtype=self.dtype, |
| shape=self._shape, |
| in_graph_mode=self._in_graph_mode, |
| deleter=self._handle_deleter if not self._in_graph_mode else None, |
| parent_op=op, |
| unique_id=self._unique_id) |
| |
| def assign(self, value, use_locking=None, name=None, read_value=True): |
| """Assigns a new value to this variable. |
| |
| Args: |
| value: A `Tensor`. The new value for this variable. |
| use_locking: If `True`, use locking during the assignment. |
| name: The name to use for the assignment. |
| read_value: A `bool`. Whether to read and return the new value of the |
| variable or not. |
| |
| Returns: |
| If `read_value` is `True`, this method will return the new value of the |
| variable after the assignment has completed. Otherwise, when in graph mode |
| it will return the `Operation` that does the assignment, and when in eager |
| mode it will return `None`. |
| """ |
| # Note: not depending on the cached value here since this can be used to |
| # initialize the variable. |
| with _handle_graph(self.handle): |
| value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) |
| if not self._shape.is_compatible_with(value_tensor.shape): |
| if self.name is None: |
| tensor_name = "" |
| else: |
| tensor_name = " " + str(self.name) |
| raise ValueError( |
| ("Cannot assign to variable%s due to variable shape %s and value " |
| "shape %s are incompatible") % |
| (tensor_name, self._shape, value_tensor.shape)) |
| assign_op = gen_resource_variable_ops.assign_variable_op( |
| self.handle, value_tensor, name=name) |
| if read_value: |
| return self._lazy_read(assign_op) |
| return assign_op |
| |
| def __reduce__(self): |
| # The implementation mirrors that of __deepcopy__. |
| return functools.partial( |
| ResourceVariable, |
| initial_value=self.numpy(), |
| trainable=self.trainable, |
| name=self._shared_name, |
| dtype=self.dtype, |
| constraint=self.constraint, |
| distribute_strategy=self._distribute_strategy), () |
| |
| def scatter_sub(self, sparse_delta, use_locking=False, name=None): |
| """Subtracts `tf.IndexedSlices` from this variable. |
| |
| Args: |
| sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. |
| use_locking: If `True`, use locking during the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| |
| Raises: |
| TypeError: if `sparse_delta` is not an `IndexedSlices`. |
| """ |
| if not isinstance(sparse_delta, ops.IndexedSlices): |
| raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) |
| return self._lazy_read( |
| gen_resource_variable_ops.resource_scatter_sub( |
| self.handle, |
| sparse_delta.indices, |
| ops.convert_to_tensor(sparse_delta.values, self.dtype), |
| name=name)) |
| |
| def scatter_add(self, sparse_delta, use_locking=False, name=None): |
| """Adds `tf.IndexedSlices` to this variable. |
| |
| Args: |
| sparse_delta: `tf.IndexedSlices` to be added to this variable. |
| use_locking: If `True`, use locking during the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| |
| Raises: |
| TypeError: if `sparse_delta` is not an `IndexedSlices`. |
| """ |
| if not isinstance(sparse_delta, ops.IndexedSlices): |
| raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) |
| return self._lazy_read( |
| gen_resource_variable_ops.resource_scatter_add( |
| self.handle, |
| sparse_delta.indices, |
| ops.convert_to_tensor(sparse_delta.values, self.dtype), |
| name=name)) |
| |
| def scatter_max(self, sparse_delta, use_locking=False, name=None): |
| """Updates this variable with the max of `tf.IndexedSlices` and itself. |
| |
| Args: |
| sparse_delta: `tf.IndexedSlices` to use as an argument of max with this |
| variable. |
| use_locking: If `True`, use locking during the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| |
| Raises: |
| TypeError: if `sparse_delta` is not an `IndexedSlices`. |
| """ |
| if not isinstance(sparse_delta, ops.IndexedSlices): |
| raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) |
| return self._lazy_read( |
| gen_resource_variable_ops.resource_scatter_max( |
| self.handle, |
| sparse_delta.indices, |
| ops.convert_to_tensor(sparse_delta.values, self.dtype), |
| name=name)) |
| |
| def scatter_min(self, sparse_delta, use_locking=False, name=None): |
| """Updates this variable with the min of `tf.IndexedSlices` and itself. |
| |
| Args: |
| sparse_delta: `tf.IndexedSlices` to use as an argument of min with this |
| variable. |
| use_locking: If `True`, use locking during the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| |
| Raises: |
| TypeError: if `sparse_delta` is not an `IndexedSlices`. |
| """ |
| if not isinstance(sparse_delta, ops.IndexedSlices): |
| raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) |
| return self._lazy_read( |
| gen_resource_variable_ops.resource_scatter_min( |
| self.handle, |
| sparse_delta.indices, |
| ops.convert_to_tensor(sparse_delta.values, self.dtype), |
| name=name)) |
| |
| def scatter_mul(self, sparse_delta, use_locking=False, name=None): |
| """Multiply this variable by `tf.IndexedSlices`. |
| |
| Args: |
| sparse_delta: `tf.IndexedSlices` to multiply this variable by. |
| use_locking: If `True`, use locking during the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| |
| Raises: |
| TypeError: if `sparse_delta` is not an `IndexedSlices`. |
| """ |
| if not isinstance(sparse_delta, ops.IndexedSlices): |
| raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) |
| return self._lazy_read( |
| gen_resource_variable_ops.resource_scatter_mul( |
| self.handle, |
| sparse_delta.indices, |
| ops.convert_to_tensor(sparse_delta.values, self.dtype), |
| name=name)) |
| |
| def scatter_div(self, sparse_delta, use_locking=False, name=None): |
| """Divide this variable by `tf.IndexedSlices`. |
| |
| Args: |
| sparse_delta: `tf.IndexedSlices` to divide this variable by. |
| use_locking: If `True`, use locking during the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| |
| Raises: |
| TypeError: if `sparse_delta` is not an `IndexedSlices`. |
| """ |
| if not isinstance(sparse_delta, ops.IndexedSlices): |
| raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) |
| return self._lazy_read( |
| gen_resource_variable_ops.resource_scatter_div( |
| self.handle, |
| sparse_delta.indices, |
| ops.convert_to_tensor(sparse_delta.values, self.dtype), |
| name=name)) |
| |
| def scatter_update(self, sparse_delta, use_locking=False, name=None): |
| """Assigns `tf.IndexedSlices` to this variable. |
| |
| Args: |
| sparse_delta: `tf.IndexedSlices` to be assigned to this variable. |
| use_locking: If `True`, use locking during the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| |
| Raises: |
| TypeError: if `sparse_delta` is not an `IndexedSlices`. |
| """ |
| if not isinstance(sparse_delta, ops.IndexedSlices): |
| raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) |
| return self._lazy_read( |
| gen_resource_variable_ops.resource_scatter_update( |
| self.handle, |
| sparse_delta.indices, |
| ops.convert_to_tensor(sparse_delta.values, self.dtype), |
| name=name)) |
| |
| def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): |
| """Assigns `tf.IndexedSlices` to this variable batch-wise. |
| |
| Analogous to `batch_gather`. This assumes that this variable and the |
| sparse_delta IndexedSlices have a series of leading dimensions that are the |
| same for all of them, and the updates are performed on the last dimension of |
| indices. In other words, the dimensions should be the following: |
| |
| `num_prefix_dims = sparse_delta.indices.ndims - 1` |
| `batch_dim = num_prefix_dims + 1` |
| `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ |
| batch_dim:]` |
| |
| where |
| |
| `sparse_delta.updates.shape[:num_prefix_dims]` |
| `== sparse_delta.indices.shape[:num_prefix_dims]` |
| `== var.shape[:num_prefix_dims]` |
| |
| And the operation performed can be expressed as: |
| |
| `var[i_1, ..., i_n, |
| sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ |
| i_1, ..., i_n, j]` |
| |
| When sparse_delta.indices is a 1D tensor, this operation is equivalent to |
| `scatter_update`. |
| |
| To avoid this operation one can looping over the first `ndims` of the |
| variable and using `scatter_update` on the subtensors that result of slicing |
| the first dimension. This is a valid option for `ndims = 1`, but less |
| efficient than this implementation. |
| |
| Args: |
| sparse_delta: `tf.IndexedSlices` to be assigned to this variable. |
| use_locking: If `True`, use locking during the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| |
| Raises: |
| TypeError: if `sparse_delta` is not an `IndexedSlices`. |
| """ |
| if not isinstance(sparse_delta, ops.IndexedSlices): |
| raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) |
| return self._lazy_read( |
| state_ops.batch_scatter_update( |
| self, |
| sparse_delta.indices, |
| sparse_delta.values, |
| use_locking=use_locking, |
| name=name)) |
| |
| def scatter_nd_sub(self, indices, updates, name=None): |
| """Applies sparse subtraction to individual values or slices in a Variable. |
| |
| `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. |
| |
| `indices` must be integer tensor, containing indices into `ref`. |
| It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. |
| |
| The innermost dimension of `indices` (with length `K`) corresponds to |
| indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th |
| dimension of `ref`. |
| |
| `updates` is `Tensor` of rank `Q-1+P-K` with shape: |
| |
| ``` |
| [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. |
| ``` |
| |
| For example, say we want to add 4 scattered elements to a rank-1 tensor to |
| 8 elements. In Python, that update would look like this: |
| |
| ```python |
| ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) |
| indices = tf.constant([[4], [3], [1] ,[7]]) |
| updates = tf.constant([9, 10, 11, 12]) |
| op = ref.scatter_nd_sub(indices, updates) |
| with tf.compat.v1.Session() as sess: |
| print sess.run(op) |
| ``` |
| |
| The resulting update to ref would look like this: |
| |
| [1, -9, 3, -6, -6, 6, 7, -4] |
| |
| See `tf.scatter_nd` for more details about how to make updates to |
| slices. |
| |
| Args: |
| indices: The indices to be used in the operation. |
| updates: The values to be used in the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| """ |
| return self._lazy_read( |
| gen_state_ops.resource_scatter_nd_sub( |
| self.handle, |
| indices, |
| ops.convert_to_tensor(updates, self.dtype), |
| name=name)) |
| |
| def scatter_nd_add(self, indices, updates, name=None): |
| """Applies sparse addition to individual values or slices in a Variable. |
| |
| `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. |
| |
| `indices` must be integer tensor, containing indices into `ref`. |
| It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. |
| |
| The innermost dimension of `indices` (with length `K`) corresponds to |
| indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th |
| dimension of `ref`. |
| |
| `updates` is `Tensor` of rank `Q-1+P-K` with shape: |
| |
| ``` |
| [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. |
| ``` |
| |
| For example, say we want to add 4 scattered elements to a rank-1 tensor to |
| 8 elements. In Python, that update would look like this: |
| |
| ```python |
| ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) |
| indices = tf.constant([[4], [3], [1] ,[7]]) |
| updates = tf.constant([9, 10, 11, 12]) |
| add = ref.scatter_nd_add(indices, updates) |
| with tf.compat.v1.Session() as sess: |
| print sess.run(add) |
| ``` |
| |
| The resulting update to ref would look like this: |
| |
| [1, 13, 3, 14, 14, 6, 7, 20] |
| |
| See `tf.scatter_nd` for more details about how to make updates to |
| slices. |
| |
| Args: |
| indices: The indices to be used in the operation. |
| updates: The values to be used in the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| """ |
| return self._lazy_read( |
| gen_state_ops.resource_scatter_nd_add( |
| self.handle, |
| indices, |
| ops.convert_to_tensor(updates, self.dtype), |
| name=name)) |
| |
| def scatter_nd_update(self, indices, updates, name=None): |
| """Applies sparse assignment to individual values or slices in a Variable. |
| |
| `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. |
| |
| `indices` must be integer tensor, containing indices into `ref`. |
| It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. |
| |
| The innermost dimension of `indices` (with length `K`) corresponds to |
| indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th |
| dimension of `ref`. |
| |
| `updates` is `Tensor` of rank `Q-1+P-K` with shape: |
| |
| ``` |
| [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. |
| ``` |
| |
| For example, say we want to add 4 scattered elements to a rank-1 tensor to |
| 8 elements. In Python, that update would look like this: |
| |
| ```python |
| ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) |
| indices = tf.constant([[4], [3], [1] ,[7]]) |
| updates = tf.constant([9, 10, 11, 12]) |
| op = ref.scatter_nd_update(indices, updates) |
| with tf.compat.v1.Session() as sess: |
| print sess.run(op) |
| ``` |
| |
| The resulting update to ref would look like this: |
| |
| [1, 11, 3, 10, 9, 6, 7, 12] |
| |
| See `tf.scatter_nd` for more details about how to make updates to |
| slices. |
| |
| Args: |
| indices: The indices to be used in the operation. |
| updates: The values to be used in the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| """ |
| return self._lazy_read( |
| gen_state_ops.resource_scatter_nd_update( |
| self.handle, |
| indices, |
| ops.convert_to_tensor(updates, self.dtype), |
| name=name)) |
| |
| def scatter_nd_max(self, indices, updates, name=None): |
| """Updates this variable with the max of `tf.IndexedSlices` and itself. |
| |
| `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. |
| |
| `indices` must be integer tensor, containing indices into `ref`. |
| It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. |
| |
| The innermost dimension of `indices` (with length `K`) corresponds to |
| indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th |
| dimension of `ref`. |
| |
| `updates` is `Tensor` of rank `Q-1+P-K` with shape: |
| |
| ``` |
| [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. |
| ``` |
| |
| See `tf.scatter_nd` for more details about how to make updates to |
| slices. |
| |
| Args: |
| indices: The indices to be used in the operation. |
| updates: The values to be used in the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| """ |
| return self._lazy_read( |
| gen_state_ops.resource_scatter_nd_max( |
| self.handle, |
| indices, |
| ops.convert_to_tensor(updates, self.dtype), |
| name=name)) |
| |
| def scatter_nd_min(self, indices, updates, name=None): |
| """Updates this variable with the min of `tf.IndexedSlices` and itself. |
| |
| `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. |
| |
| `indices` must be integer tensor, containing indices into `ref`. |
| It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. |
| |
| The innermost dimension of `indices` (with length `K`) corresponds to |
| indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th |
| dimension of `ref`. |
| |
| `updates` is `Tensor` of rank `Q-1+P-K` with shape: |
| |
| ``` |
| [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. |
| ``` |
| |
| See `tf.scatter_nd` for more details about how to make updates to |
| slices. |
| |
| Args: |
| indices: The indices to be used in the operation. |
| updates: The values to be used in the operation. |
| name: the name of the operation. |
| |
| Returns: |
| The updated variable. |
| """ |
| return self._lazy_read( |
| gen_state_ops.resource_scatter_nd_min( |
| self.handle, |
| indices, |
| ops.convert_to_tensor(updates, self.dtype), |
| name=name)) |
| |
| def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, |
| end_mask, ellipsis_mask, new_axis_mask, |
| shrink_axis_mask): |
| with _handle_graph(self.handle), self._assign_dependencies(): |
| return self._lazy_read( |
| gen_array_ops.resource_strided_slice_assign( |
| ref=self.handle, |
| begin=begin, |
| end=end, |
| strides=strides, |
| value=ops.convert_to_tensor(value, dtype=self.dtype), |
| name=name, |
| begin_mask=begin_mask, |
| end_mask=end_mask, |
| ellipsis_mask=ellipsis_mask, |
| new_axis_mask=new_axis_mask, |
| shrink_axis_mask=shrink_axis_mask)) |
| |
| def __complex__(self): |
| return complex(self.value().numpy()) |
| |
| def __int__(self): |
| return int(self.value().numpy()) |
| |
| def __long__(self): |
| return long(self.value().numpy()) |
| |
| def __float__(self): |
| return float(self.value().numpy()) |
| |
| def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): |
| del name |
| if dtype is not None and not dtype.is_compatible_with(self.dtype): |
| raise ValueError( |
| "Incompatible type conversion requested to type {!r} for variable " |
| "of type {!r}".format(dtype.name, self.dtype.name)) |
| if as_ref: |
| return self.read_value().op.inputs[0] |
| else: |
| return self.value() |
| |
| def __iadd__(self, unused_other): |
| raise RuntimeError("Variable += value not supported. Use " |
| "variable.assign_add(value) to modify the variable " |
| "value and variable = variable + value to get a new " |
| "Tensor object.") |
| |
| def __isub__(self, unused_other): |
| raise RuntimeError("Variable -= value not supported. Use " |
| "variable.assign_sub(value) to modify the variable " |
| "value and variable = variable - value to get a new " |
| "Tensor object.") |
| |
| def __imul__(self, unused_other): |
| raise RuntimeError("Variable *= value not supported. Use " |
| "`var.assign(var * value)` to modify the variable or " |
| "`var = var * value` to get a new Tensor object.") |
| |
| def __idiv__(self, unused_other): |
| raise RuntimeError("Variable /= value not supported. Use " |
| "`var.assign(var / value)` to modify the variable or " |
| "`var = var / value` to get a new Tensor object.") |
| |
| def __itruediv__(self, unused_other): |
| raise RuntimeError("Variable /= value not supported. Use " |
| "`var.assign(var / value)` to modify the variable or " |
| "`var = var / value` to get a new Tensor object.") |
| |
| def __irealdiv__(self, unused_other): |
| raise RuntimeError("Variable /= value not supported. Use " |
| "`var.assign(var / value)` to modify the variable or " |
| "`var = var / value` to get a new Tensor object.") |
| |
| def __ipow__(self, unused_other): |
| raise RuntimeError("Variable **= value not supported. Use " |
| "`var.assign(var ** value)` to modify the variable or " |
| "`var = var ** value` to get a new Tensor object.") |
| |
| |
| class ResourceVariable(BaseResourceVariable): |
| """Variable based on resource handles. |
| |
| See the [Variables How To](https://tensorflow.org/guide/variables) |
| for a high level overview. |
| |
| A `ResourceVariable` allows you to maintain state across subsequent calls to |
| session.run. |
| |
| The `ResourceVariable` constructor requires an initial value for the variable, |
| which can be a `Tensor` of any type and shape. The initial value defines the |
| type and shape of the variable. After construction, the type and shape of |
| the variable are fixed. The value can be changed using one of the assign |
| methods. |
| |
| Just like any `Tensor`, variables created with |
| `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the |
| graph. Additionally, all the operators overloaded for the `Tensor` class are |
| carried over to variables, so you can also add nodes to the graph by just |
| doing arithmetic on variables. |
| |
| Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each |
| usage of a ResourceVariable in a TensorFlow graph adds a read_value operation |
| to the graph. The Tensors returned by a read_value operation are guaranteed to |
| see all modifications to the value of the variable which happen in any |
| operation on which the read_value depends on (either directly, indirectly, or |
| via a control dependency) and guaranteed to not see any modification to the |
| value of the variable from operations that depend on the read_value operation. |
| Updates from operations that have no dependency relationship to the read_value |
| operation might or might not be visible to read_value. |
| |
| For example, if there is more than one assignment to a ResourceVariable in |
| a single session.run call there is a well-defined value for each operation |
| which uses the variable's value if the assignments and the read are connected |
| by edges in the graph. Consider the following example, in which two writes |
| can cause tf.Variable and tf.ResourceVariable to behave differently: |
| |
| ```python |
| a = tf.Variable(1.0, use_resource=True) |
| a.initializer.run() |
| |
| assign = a.assign(2.0) |
| with tf.control_dependencies([assign]): |
| b = a.read_value() |
| with tf.control_dependencies([b]): |
| other_assign = a.assign(3.0) |
| with tf.control_dependencies([other_assign]): |
| # Will print 2.0 because the value was read before other_assign ran. If |
| # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. |
| tf.compat.v1.Print(b, [b]).eval() |
| ``` |
| """ |
| |
| def __init__( |
| self, # pylint: disable=super-init-not-called |
| initial_value=None, |
| trainable=None, |
| collections=None, |
| validate_shape=True, # pylint: disable=unused-argument |
| caching_device=None, |
| name=None, |
| dtype=None, |
| variable_def=None, |
| import_scope=None, |
| constraint=None, |
| distribute_strategy=None, |
| synchronization=None, |
| aggregation=None, |
| shape=None): |
| """Creates a variable. |
| |
| Args: |
| initial_value: A `Tensor`, or Python object convertible to a `Tensor`, |
| which is the initial value for the Variable. Can also be a callable with |
| no argument that returns the initial value when called. (Note that |
| initializer functions from init_ops.py must first be bound to a shape |
| before being used here.) |
| trainable: If `True`, the default, also adds the variable to the graph |
| collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as |
| the default list of variables to use by the `Optimizer` classes. |
| Defaults to `True`, unless `synchronization` is set to `ON_READ`, in |
| which case it defaults to `False`. |
| collections: List of graph collections keys. The new variable is added to |
| these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. |
| validate_shape: Ignored. Provided for compatibility with tf.Variable. |
| caching_device: Optional device string or function describing where the |
| Variable should be cached for reading. Defaults to the Variable's |
| device. If not `None`, caches on another device. Typical use is to |
| cache on the device where the Ops using the Variable reside, to |
| deduplicate copying through `Switch` and other conditional statements. |
| name: Optional name for the variable. Defaults to `'Variable'` and gets |
| uniquified automatically. |
| dtype: If set, initial_value will be converted to the given type. If None, |
| either the datatype will be kept (if initial_value is a Tensor) or |
| float32 will be used (if it is a Python object convertible to a Tensor). |
| variable_def: `VariableDef` protocol buffer. If not None, recreates the |
| `ResourceVariable` object with its contents. `variable_def` and other |
| arguments (except for import_scope) are mutually exclusive. |
| import_scope: Optional `string`. Name scope to add to the |
| ResourceVariable. Only used when `variable_def` is provided. |
| constraint: An optional projection function to be applied to the variable |
| after being updated by an `Optimizer` (e.g. used to implement norm |
| constraints or value constraints for layer weights). The function must |
| take as input the unprojected Tensor representing the value of the |
| variable and return the Tensor for the projected value (which must have |
| the same shape). Constraints are not safe to use when doing asynchronous |
| distributed training. |
| distribute_strategy: The tf.distribute.Strategy this variable is being |
| created inside of. |
| synchronization: Indicates when a distributed a variable will be |
| aggregated. Accepted values are constants defined in the class |
| `tf.VariableSynchronization`. By default the synchronization is set to |
| `AUTO` and the current `DistributionStrategy` chooses when to |
| synchronize. |
| aggregation: Indicates how a distributed variable will be aggregated. |
| Accepted values are constants defined in the class |
| `tf.VariableAggregation`. |
| shape: (optional) The shape of this variable. If None, the shape of |
| `initial_value` will be used. When setting this argument to |
| `tf.TensorShape(None)` (representing an unspecified shape), the variable |
| can be assigned with values of different shapes. |
| |
| Raises: |
| ValueError: If the initial value is not specified, or does not have a |
| shape and `validate_shape` is `True`. |
| |
| @compatibility(eager) |
| When Eager Execution is enabled, the default for the `collections` argument |
| is `None`, which signifies that this `Variable` will not be added to any |
| collections. |
| @end_compatibility |
| """ |
| if variable_def: |
| if initial_value is not None: |
| raise ValueError("variable_def and initial_value are mutually " |
| "exclusive.") |
| if context.executing_eagerly(): |
| raise ValueError("Creating ResourceVariable from variable_def is " |
| "not supported when eager execution is enabled.") |
| self._init_from_proto(variable_def, import_scope=import_scope) |
| else: |
| self._init_from_args( |
| initial_value=initial_value, |
| trainable=trainable, |
| collections=collections, |
| caching_device=caching_device, |
| name=name, |
| dtype=dtype, |
| constraint=constraint, |
| synchronization=synchronization, |
| aggregation=aggregation, |
| shape=shape, |
| distribute_strategy=distribute_strategy) |
| |
| def _init_from_args(self, |
| initial_value=None, |
| trainable=None, |
| collections=None, |
| caching_device=None, |
| name=None, |
| dtype=None, |
| constraint=None, |
| synchronization=None, |
| aggregation=None, |
| distribute_strategy=None, |
| shape=None): |
| """Creates a variable. |
| |
| Args: |
| initial_value: A `Tensor`, or Python object convertible to a `Tensor`, |
| which is the initial value for the Variable. The initial value must have |
| a shape specified unless `validate_shape` is set to False. Can also be a |
| callable with no argument that returns the initial value when called. |
| (Note that initializer functions from init_ops.py must first be bound to |
| a shape before being used here.) |
| trainable: If `True`, the default, also adds the variable to the graph |
| collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as |
| the default list of variables to use by the `Optimizer` classes. |
| Defaults to `True`, unless `synchronization` is set to `ON_READ`, in |
| which case it defaults to `False`. |
| collections: List of graph collections keys. The new variable is added to |
| these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. |
| caching_device: Optional device string or function describing where the |
| Variable should be cached for reading. Defaults to the Variable's |
| device. If not `None`, caches on another device. Typical use is to |
| cache on the device where the Ops using the Variable reside, to |
| deduplicate copying through `Switch` and other conditional statements. |
| name: Optional name for the variable. Defaults to `'Variable'` and gets |
| uniquified automatically. |
| dtype: If set, initial_value will be converted to the given type. If None, |
| either the datatype will be kept (if initial_value is a Tensor) or |
| float32 will be used (if it is a Python object convertible to a Tensor). |
| constraint: An optional projection function to be applied to the variable |
| after being updated by an `Optimizer` (e.g. used to implement norm |
| constraints or value constraints for layer weights). The function must |
| take as input the unprojected Tensor representing the value of the |
| variable and return the Tensor for the projected value (which must have |
| the same shape). Constraints are not safe to use when doing asynchronous |
| distributed training. |
| synchronization: Indicates when a distributed a variable will be |
| aggregated. Accepted values are constants defined in the class |
| `tf.VariableSynchronization`. By default the synchronization is set to |
| `AUTO` and the current `DistributionStrategy` chooses when to |
| synchronize. |
| aggregation: Indicates how a distributed variable will be aggregated. |
| Accepted values are constants defined in the class |
| `tf.VariableAggregation`. |
| distribute_strategy: DistributionStrategy under which this variable was |
| created. |
| shape: (optional) The shape of this variable. If None, the shape of |
| `initial_value` will be used. When setting this argument to |
| `tf.TensorShape(None)` (representing an unspecified shape), the variable |
| can be assigned with values of different shapes. |
| |
| Raises: |
| ValueError: If the initial value is not specified, or does not have a |
| shape and `validate_shape` is `True`. |
| |
| @compatibility(eager) |
| When Eager Execution is enabled, variables are never added to collections. |
| It is not implicitly added to the `GLOBAL_VARIABLES` or |
| `TRAINABLE_VARIABLES` collections, and the `collections` argument is |
| ignored. |
| @end_compatibility |
| """ |
| synchronization, aggregation, trainable = ( |
| variables.validate_synchronization_aggregation_trainable( |
| synchronization, aggregation, trainable, name)) |
| if initial_value is None: |
| raise ValueError("initial_value must be specified.") |
| init_from_fn = callable(initial_value) |
| |
| if isinstance(initial_value, ops.Tensor) and hasattr( |
| initial_value, "graph") and initial_value.graph.building_function: |
| raise ValueError("Tensor-typed variable initializers must either be " |
| "wrapped in an init_scope or callable " |
| "(e.g., `tf.Variable(lambda : " |
| "tf.truncated_normal([10, 40]))`) when building " |
| "functions. Please file a feature request if this " |
| "restriction inconveniences you.") |
| |
| if collections is None: |
| collections = [ops.GraphKeys.GLOBAL_VARIABLES] |
| if not isinstance(collections, (list, tuple, set)): |
| raise ValueError( |
| "collections argument to Variable constructor must be a list, tuple, " |
| "or set. Got %s of type %s" % (collections, type(collections))) |
| if constraint is not None and not callable(constraint): |
| raise ValueError("The `constraint` argument must be a callable.") |
| |
| if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: |
| collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] |
| with ops.init_scope(): |
| self._in_graph_mode = not context.executing_eagerly() |
| with ops.name_scope( |
| name, |
| "Variable", [] if init_from_fn else [initial_value], |
| skip_on_eager=False) as name: |
| # pylint: disable=protected-access |
| handle_name = ops.name_from_scope_name(name) |
| if self._in_graph_mode: |
| shared_name = handle_name |
| unique_id = shared_name |
| else: |
| # When in eager mode use a uid for the shared_name, to prevent |
| # accidental sharing. |
| unique_id = "%s_%d" % (handle_name, ops.uid()) |
| shared_name = None # Never shared |
| # Use attr_scope and device(None) to simulate the behavior of |
| # colocate_with when the variable we want to colocate with doesn't |
| # yet exist. |
| device_context_manager = ( |
| ops.device if self._in_graph_mode else ops.NullContextmanager) |
| attr = attr_value_pb2.AttrValue( |
| list=attr_value_pb2.AttrValue.ListValue( |
| s=[compat.as_bytes("loc:@%s" % handle_name)])) |
| with ops.get_default_graph()._attr_scope({"_class": attr}): |
| with ops.name_scope("Initializer"), device_context_manager(None): |
| if init_from_fn: |
| initial_value = initial_value() |
| if isinstance(initial_value, trackable.CheckpointInitialValue): |
| self._maybe_initialize_trackable() |
| self._update_uid = initial_value.checkpoint_position.restore_uid |
| initial_value = initial_value.wrapped_value |
| initial_value = ops.convert_to_tensor(initial_value, |
| name="initial_value", |
| dtype=dtype) |
| if shape is not None: |
| if not initial_value.shape.is_compatible_with(shape): |
| raise ValueError( |
| "The initial value's shape (%s) is not compatible with " |
| "the explicitly supplied `shape` argument (%s)." % |
| (initial_value.shape, shape)) |
| else: |
| shape = initial_value.shape |
| handle = eager_safe_variable_handle( |
| initial_value=initial_value, |
| shape=shape, |
| shared_name=shared_name, |
| name=name, |
| graph_mode=self._in_graph_mode) |
| # pylint: disable=protected-access |
| if (self._in_graph_mode and initial_value is not None and |
| initial_value.op._get_control_flow_context() is not None): |
| raise ValueError( |
| "Initializer for variable %s is from inside a control-flow " |
| "construct, such as a loop or conditional. When creating a " |
| "variable inside a loop or conditional, use a lambda as the " |
| "initializer." % name) |
| # pylint: enable=protected-access |
| dtype = initial_value.dtype.base_dtype |
| |
| if self._in_graph_mode: |
| with ops.name_scope("IsInitialized"): |
| is_initialized_op = ( |
| gen_resource_variable_ops.var_is_initialized_op(handle)) |
| if initial_value is not None: |
| # pylint: disable=g-backslash-continuation |
| with ops.name_scope("Assign") as n, \ |
| ops.colocate_with(None, ignore_existing=True), \ |
| ops.device(handle.device): |
| # pylint: disable=protected-access |
| initializer_op = ( |
| gen_resource_variable_ops.assign_variable_op( |
| handle, |
| variables._try_guard_against_uninitialized_dependencies( |
| name, initial_value), |
| name=n)) |
| # pylint: enable=protected-access |
| # pylint: enable=g-backslash-continuation |
| with ops.name_scope("Read"): |
| # Manually assign reads to the handle's device to avoid log |
| # messages. |
| with ops.device(handle.device): |
| value = gen_resource_variable_ops.read_variable_op(handle, dtype) |
| _maybe_set_handle_data(dtype, handle, value) |
| graph_element = value |
| if caching_device is not None: |
| # Variables may be created in a tf.device() or ops.colocate_with() |
| # context. At the same time, users would expect caching device to |
| # be independent of this context, and/or would not expect the |
| # current device context to be merged with the caching device |
| # spec. Therefore we reset the colocation stack before creating |
| # the cached value. Note that resetting the colocation stack will |
| # also reset the device stack. |
| with ops.colocate_with(None, ignore_existing=True): |
| with ops.device(caching_device): |
| cached_value = array_ops.identity(value) |
| else: |
| cached_value = None |
| else: |
| gen_resource_variable_ops.assign_variable_op(handle, initial_value) |
| is_initialized_op = None |
| initializer_op = None |
| graph_element = None |
| if caching_device: |
| with ops.device(caching_device): |
| cached_value = gen_resource_variable_ops.read_variable_op( |
| handle, dtype) |
| _maybe_set_handle_data(dtype, handle, cached_value) |
| else: |
| cached_value = None |
| |
| if cached_value is not None: |
| # Store the variable object so that the original variable can be |
| # accessed to generate functions that are compatible with SavedModel. |
| cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access |
| |
| if not context.executing_eagerly(): |
| # Eager variables are only added to collections if they are part of an |
| # eager variable store (otherwise in an interactive session they would |
| # hog memory and cause OOM). This is done in ops/variable_scope.py. |
| ops.add_to_collections(collections, self) |
| elif ops.GraphKeys.GLOBAL_STEP in collections: |
| ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) |
| initial_value = initial_value if self._in_graph_mode else None |
| super(ResourceVariable, self).__init__( |
| trainable=trainable, |
| shape=shape, |
| dtype=dtype, |
| handle=handle, |
| synchronization=synchronization, |
| constraint=constraint, |
| aggregation=aggregation, |
| distribute_strategy=distribute_strategy, |
| name=name, |
| unique_id=unique_id, |
| handle_name=handle_name, |
| graph_element=graph_element, |
| initial_value=initial_value, |
| initializer_op=initializer_op, |
| is_initialized_op=is_initialized_op, |
| cached_value=cached_value, |
| caching_device=caching_device) |
| |
| def _init_from_proto(self, variable_def, import_scope=None): |
| """Initializes from `VariableDef` proto.""" |
| # Note that init_from_proto is currently not supported in Eager mode. |
| assert not context.executing_eagerly() |
| self._in_graph_mode = True |
| assert isinstance(variable_def, variable_pb2.VariableDef) |
| if not variable_def.is_resource: |
| raise ValueError("Trying to restore Variable as ResourceVariable.") |
| |
| # Create from variable_def. |
| g = ops.get_default_graph() |
| self._handle = g.as_graph_element( |
| ops.prepend_name_scope( |
| variable_def.variable_name, import_scope=import_scope)) |
| self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape")) |
| self._handle_name = self._handle.name |
| self._unique_id = self._handle_name |
| self._initializer_op = g.as_graph_element( |
| ops.prepend_name_scope( |
| variable_def.initializer_name, import_scope=import_scope)) |
| # Check whether initial_value_name exists for backwards compatibility. |
| if (hasattr(variable_def, "initial_value_name") and |
| variable_def.initial_value_name): |
| self._initial_value = g.as_graph_element( |
| ops.prepend_name_scope( |
| variable_def.initial_value_name, import_scope=import_scope)) |
| else: |
| self._initial_value = None |
| synchronization, aggregation, trainable = ( |
| variables.validate_synchronization_aggregation_trainable( |
| variable_def.synchronization, variable_def.aggregation, |
| variable_def.trainable, variable_def.variable_name)) |
| self._synchronization = synchronization |
| self._aggregation = aggregation |
| self._trainable = trainable |
| if variable_def.snapshot_name: |
| snapshot = g.as_graph_element( |
| ops.prepend_name_scope( |
| variable_def.snapshot_name, import_scope=import_scope)) |
| if snapshot.op.type != "ReadVariableOp": |
| self._cached_value = snapshot |
| else: |
| self._cached_value = None |
| while snapshot.op.type != "ReadVariableOp": |
| snapshot = snapshot.op.inputs[0] |
| self._graph_element = snapshot |
| else: |
| self._cached_value = None |
| # Legacy case for protos without the snapshot name; assume it's the |
| # following. |
| self._graph_element = g.get_tensor_by_name(self._handle.op.name + |
| "/Read/ReadVariableOp:0") |
| if variable_def.HasField("save_slice_info_def"): |
| self._save_slice_info = variables.Variable.SaveSliceInfo( |
| save_slice_info_def=variable_def.save_slice_info_def, |
| import_scope=import_scope) |
| else: |
| self._save_slice_info = None |
| self._caching_device = None |
| self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) |
| self._constraint = None |
| |
| |
| class UninitializedVariable(BaseResourceVariable): |
| """A variable with no initializer.""" |
| |
| def __init__( # pylint: disable=super-init-not-called |
| self, |
| trainable=None, |
| caching_device=None, |
| name=None, |
| shape=None, |
| dtype=None, |
| constraint=None, |
| synchronization=None, |
| aggregation=None, |
| extra_handle_data=None, |
| distribute_strategy=None, |
| **unused_kwargs): |
| """Creates the variable handle. |
| |
| Args: |
| trainable: If `True`, GradientTapes automatically watch uses of this |
| Variable. |
| caching_device: Optional device string or function describing where the |
| Variable should be cached for reading. Defaults to the Variable's |
| device. If not `None`, caches on another device. Typical use is to |
| cache on the device where the Ops using the Variable reside, to |
| deduplicate copying through `Switch` and other conditional statements. |
| name: Optional name for the variable. Defaults to `'Variable'` and gets |
| uniquified automatically. |
| shape: The variable's shape. |
| dtype: The variable's dtype. |
| constraint: An optional projection function to be applied to the variable |
| after being updated by an `Optimizer` (e.g. used to implement norm |
| constraints or value constraints for layer weights). The function must |
| take as input the unprojected Tensor representing the value of the |
| variable and return the Tensor for the projected value (which must have |
| the same shape). Constraints are not safe to use when doing asynchronous |
| distributed training. |
| synchronization: Indicates when a distributed a variable will be |
| aggregated. Accepted values are constants defined in the class |
| `tf.VariableSynchronization`. By default the synchronization is set to |
| `AUTO` and the current `DistributionStrategy` chooses when to |
| synchronize. |
| aggregation: Indicates how a distributed variable will be aggregated. |
| Accepted values are constants defined in the class |
| `tf.VariableAggregation`. |
| extra_handle_data: Optional, another resource handle or Tensor with handle |
| data to merge with `shape` and `dtype`. |
| distribute_strategy: The tf.distribute.Strategy this variable is being |
| created inside of. |
| """ |
| with ops.init_scope(): |
| self._in_graph_mode = not context.executing_eagerly() |
| with ops.init_scope(): |
| with ops.name_scope(name, "Variable", skip_on_eager=False) as name: |
| handle_name = ops.name_from_scope_name(name) |
| if self._in_graph_mode: |
| shared_name = handle_name |
| unique_id = shared_name |
| else: |
| unique_id = "%s_%d" % (handle_name, ops.uid()) |
| shared_name = None # Never shared |
| handle = _variable_handle_from_shape_and_dtype( |
| shape=shape, |
| dtype=dtype, |
| shared_name=shared_name, |
| name=name, |
| graph_mode=self._in_graph_mode, |
| initial_value=extra_handle_data) |
| if not context.executing_eagerly(): |
| with ops.name_scope("Read"): |
| # Manually assign reads to the handle's device to avoid log |
| # messages. |
| with ops.device(handle.device): |
| value = gen_resource_variable_ops.read_variable_op(handle, dtype) |
| _maybe_set_handle_data(dtype, handle, value) |
| graph_element = value |
| ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) |
| # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable, |
| # because retraining or frozen use of imported SavedModels is |
| # controlled at higher levels of model building. |
| else: |
| graph_element = None |
| super(UninitializedVariable, self).__init__( |
| distribute_strategy=distribute_strategy, |
| shape=shape, |
| dtype=dtype, |
| unique_id=unique_id, |
| handle_name=handle_name, |
| constraint=constraint, |
| handle=handle, |
| graph_element=graph_element, |
| trainable=trainable, |
| synchronization=synchronization, |
| aggregation=aggregation) |
| |
| |
| _pywrap_utils.RegisterType("ResourceVariable", ResourceVariable) |
| math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access |
| |
| |
| def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): |
| return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access |
| |
| |
| # Register a conversion function which reads the value of the variable, |
| # allowing instances of the class to be used as tensors. |
| ops.register_tensor_conversion_function(BaseResourceVariable, |
| _dense_var_to_tensor) |
| |
| |
| class _UnreadVariable(BaseResourceVariable): |
| """Represents a future for a read of a variable. |
| |
| Pretends to be the tensor if anyone looks. |
| """ |
| |
| def __init__(self, handle, dtype, shape, in_graph_mode, deleter, parent_op, |
| unique_id): |
| if isinstance(handle, ops.EagerTensor): |
| handle_name = "" |
| else: |
| handle_name = handle.name |
| # Only create a graph_element if we're in session.run-land as only |
| # session.run requires a preexisting tensor to evaluate. Otherwise we can |
| # avoid accidentally reading the variable. |
| if context.executing_eagerly() or ops.inside_function(): |
| graph_element = None |
| else: |
| with ops.control_dependencies([parent_op]): |
| graph_element = gen_resource_variable_ops.read_variable_op( |
| handle, dtype) |
| _maybe_set_handle_data(dtype, handle, graph_element) |
| super(_UnreadVariable, self).__init__( |
| handle=handle, |
| shape=shape, |
| handle_name=handle_name, |
| unique_id=unique_id, |
| dtype=dtype, |
| handle_deleter=deleter, |
| graph_element=graph_element) |
| self._parent_op = parent_op |
| |
| @property |
| def name(self): |
| if self._in_graph_mode: |
| return self._parent_op.name |
| else: |
| return "UnreadVariable" |
| |
| def value(self): |
| return self._read_variable_op() |
| |
| def read_value(self): |
| return self._read_variable_op() |
| |
| def _read_variable_op(self): |
| with ops.control_dependencies([self._parent_op]): |
| result = gen_resource_variable_ops.read_variable_op( |
| self._handle, self._dtype) |
| _maybe_set_handle_data(self._dtype, self._handle, result) |
| return result |
| |
| def assign_sub(self, delta, use_locking=None, name=None, read_value=True): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).assign_sub(delta, use_locking, name, |
| read_value) |
| |
| def assign_add(self, delta, use_locking=None, name=None, read_value=True): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).assign_add(delta, use_locking, name, |
| read_value) |
| |
| def assign(self, value, use_locking=None, name=None, read_value=True): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).assign(value, use_locking, name, |
| read_value) |
| |
| def scatter_sub(self, sparse_delta, use_locking=False, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking, |
| name) |
| |
| def scatter_add(self, sparse_delta, use_locking=False, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking, |
| name) |
| |
| def scatter_max(self, sparse_delta, use_locking=False, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking, |
| name) |
| |
| def scatter_min(self, sparse_delta, use_locking=False, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking, |
| name) |
| |
| def scatter_mul(self, sparse_delta, use_locking=False, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking, |
| name) |
| |
| def scatter_div(self, sparse_delta, use_locking=False, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking, |
| name) |
| |
| def scatter_update(self, sparse_delta, use_locking=False, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, |
| self).scatter_update(sparse_delta, use_locking, name) |
| |
| def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, |
| self).batch_scatter_update(sparse_delta, use_locking, name) |
| |
| def scatter_nd_sub(self, indices, updates, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name) |
| |
| def scatter_nd_add(self, indices, updates, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name) |
| |
| def scatter_nd_update(self, indices, updates, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, |
| self).scatter_nd_update(indices, updates, name) |
| |
| def scatter_nd_max(self, indices, updates, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name) |
| |
| def scatter_nd_min(self, indices, updates, name=None): |
| with ops.control_dependencies([self._parent_op]): |
| return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name) |
| |
| @property |
| def op(self): |
| """The op for this variable.""" |
| return self._parent_op |
| |
| |
| @ops.RegisterGradient("ReadVariableOp") |
| def _ReadGrad(_, grad): |
| """Gradient for read op.""" |
| return grad |
| |
| |
| def variable_shape(handle, out_type=dtypes.int32): |
| if getattr(handle, "_handle_data", |
| None) is None or not handle._handle_data.is_set: # pylint: disable=protected-access |
| return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) |
| shape_proto = handle._handle_data.shape_and_type[0].shape # pylint: disable=protected-access |
| if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim): |
| return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) |
| return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type) |
| |
| |
| @ops.RegisterGradient("ResourceGather") |
| def _GatherGrad(op, grad): |
| """Gradient for gather op.""" |
| # Build appropriately shaped IndexedSlices |
| handle = op.inputs[0] |
| indices = op.inputs[1] |
| params_shape = variable_shape(handle) |
| size = array_ops.expand_dims(array_ops.size(indices), 0) |
| values_shape = array_ops.concat([size, params_shape[1:]], 0) |
| values = array_ops.reshape(grad, values_shape) |
| indices = array_ops.reshape(indices, size) |
| return (ops.IndexedSlices(values, indices, params_shape), None) |
| |
| |
| def _to_proto_fn(v, export_scope=None): |
| """Converts Variable and ResourceVariable to VariableDef for collections.""" |
| return v.to_proto(export_scope=export_scope) |
| |
| |
| def _from_proto_fn(v, import_scope=None): |
| """Creates Variable or ResourceVariable from VariableDef as needed.""" |
| if v.is_resource: |
| return ResourceVariable.from_proto(v, import_scope=import_scope) |
| return variables.Variable.from_proto(v, import_scope=import_scope) |
| |
| |
| ops.register_proto_function( |
| ops.GraphKeys.GLOBAL_VARIABLES, |
| proto_type=variable_pb2.VariableDef, |
| to_proto=_to_proto_fn, |
| from_proto=_from_proto_fn) |
| ops.register_proto_function( |
| ops.GraphKeys.TRAINABLE_VARIABLES, |
| proto_type=variable_pb2.VariableDef, |
| to_proto=_to_proto_fn, |
| from_proto=_from_proto_fn) |
| ops.register_proto_function( |
| ops.GraphKeys.MOVING_AVERAGE_VARIABLES, |
| proto_type=variable_pb2.VariableDef, |
| to_proto=_to_proto_fn, |
| from_proto=_from_proto_fn) |
| ops.register_proto_function( |
| ops.GraphKeys.LOCAL_VARIABLES, |
| proto_type=variable_pb2.VariableDef, |
| to_proto=_to_proto_fn, |
| from_proto=_from_proto_fn) |
| ops.register_proto_function( |
| ops.GraphKeys.MODEL_VARIABLES, |
| proto_type=variable_pb2.VariableDef, |
| to_proto=_to_proto_fn, |
| from_proto=_from_proto_fn) |
| ops.register_proto_function( |
| ops.GraphKeys.GLOBAL_STEP, |
| proto_type=variable_pb2.VariableDef, |
| to_proto=_to_proto_fn, |
| from_proto=_from_proto_fn) |
| ops.register_proto_function( |
| ops.GraphKeys.METRIC_VARIABLES, |
| proto_type=variable_pb2.VariableDef, |
| to_proto=_to_proto_fn, |
| from_proto=_from_proto_fn) |
| |
| |
| def is_resource_variable(var): |
| """"Returns True if `var` is to be considered a ResourceVariable.""" |
| return isinstance(var, BaseResourceVariable) or hasattr( |
| var, "_should_act_as_resource_variable") |
| |
| |
| def copy_to_graph_uninitialized(var): |
| """Copies an existing variable to a new graph, with no initializer.""" |
| # Like ResourceVariable.__deepcopy__, but does not set an initializer on the |
| # new variable. |
| # pylint: disable=protected-access |
| new_variable = UninitializedVariable( |
| trainable=var.trainable, |
| constraint=var._constraint, |
| shape=var.shape, |
| dtype=var.dtype, |
| name=var._shared_name, |
| synchronization=var.synchronization, |
| aggregation=var.aggregation, |
| extra_handle_data=var.handle) |
| new_variable._maybe_initialize_trackable() |
| # pylint: enable=protected-access |
| return new_variable |
| |
| |
| ops.NotDifferentiable("Assert") |
| ops.NotDifferentiable("VarIsInitializedOp") |
| ops.NotDifferentiable("VariableShape") |
| |
| |
| class VariableSpec(tensor_spec.DenseSpec): |
| """Describes a tf.Variable.""" |
| |
| __slots__ = [] |
| |
| value_type = property(lambda self: BaseResourceVariable) |
| |
| def _to_components(self, value): |
| raise NotImplementedError |
| |
| def _from_components(self, components): |
| raise NotImplementedError |
| |
| def _from_compatible_tensor_list(self, tensor_list): |
| assert len(tensor_list) == 1 |
| return tensor_list[0] |
| |
| |
| _pywrap_utils.RegisterType("VariableSpec", VariableSpec) |