| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| """Class to hold a library of OpDefs and use it to create Brain operations.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import six |
| |
| from google.protobuf import text_format |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.core.framework import tensor_pb2 |
| from tensorflow.core.framework import tensor_shape_pb2 |
| from tensorflow.core.framework import types_pb2 |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import op_callbacks |
| from tensorflow.python.framework import op_def_registry |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.util import _pywrap_utils |
| from tensorflow.python.util import compat |
| from tensorflow.python.util import tf_contextlib |
| |
| |
| def _Attr(op_def, name): |
| for attr in op_def.attr: |
| if attr.name == name: |
| return attr |
| raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" % |
| (op_def.name, name)) |
| |
| |
| def _AttrValue(attr_protos, name): |
| if name in attr_protos: |
| return attr_protos[name] |
| raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." % |
| (name, attr_protos)) |
| |
| |
| def _SatisfiesTypeConstraint(dtype, attr_def, param_name): |
| if attr_def.HasField("allowed_values"): |
| allowed_list = attr_def.allowed_values.list.type |
| if dtype not in allowed_list: |
| raise TypeError( |
| "Value passed to parameter '%s' has DataType %s not in list of " |
| "allowed values: %s" % |
| (param_name, dtypes.as_dtype(dtype).name, |
| ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) |
| |
| |
| def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name): |
| if attr_def.has_minimum and length < attr_def.minimum: |
| raise ValueError("Attr '%s' of '%s' Op passed list of length %d " |
| "less than minimum %d." % |
| (param_name, op_type_name, length, attr_def.minimum)) |
| |
| |
| def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name): |
| if value not in attr_def.allowed_values.list.s: |
| raise ValueError( |
| "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % |
| (arg_name, op_type_name, compat.as_text(value), '", "'.join( |
| map(compat.as_text, attr_def.allowed_values.list.s)))) |
| |
| |
| def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name): |
| if value < attr_def.minimum: |
| raise ValueError("Attr '%s' of '%s' Op passed %d less than minimum %d." % |
| (arg_name, op_type_name, value, attr_def.minimum)) |
| |
| |
| def _IsListParameter(arg): |
| if arg.number_attr: |
| return True |
| elif arg.type_list_attr: |
| return True |
| return False |
| |
| |
| def _NumTypeFields(arg): |
| num = 0 |
| if arg.type != types_pb2.DT_INVALID: num += 1 |
| if arg.type_attr: num += 1 |
| if arg.type_list_attr: num += 1 |
| return num |
| |
| |
| def _IsListValue(v): |
| return isinstance(v, (list, tuple)) |
| |
| |
| def _Flatten(l): |
| """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5].""" |
| # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]] |
| l_of_l = [x if _IsListValue(x) else [x] for x in l] |
| # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5] |
| return [item for sublist in l_of_l for item in sublist] |
| |
| |
| def _Restructure(l, structure): |
| """Returns the elements of list l structured according to the given structure. |
| |
| A structure is represented by a list whose elements are either |
| `None` or a non-negative integer. `None` corresponds to a single |
| element in the output list, and an integer N corresponds to a nested |
| list of length N. |
| |
| The function returns a data structure whose shape is given by |
| `structure`, and whose elements are taken from `l`. If `structure` |
| is a singleton, the function returns the single data structure |
| implied by the 0th element of `structure`. For example: |
| |
| _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None]) |
| -> ["foo", ["bar", "baz"], "qux"] |
| |
| _Restructure(["foo"], [None]) -> "foo" |
| |
| _Restructure(["foo"], [1]) -> ["foo"] |
| |
| _Restructure([], [0]) -> [] |
| |
| Args: |
| l: A list. |
| structure: A list whose elements are either `None` or a non-negative |
| integer. |
| |
| Returns: |
| The elements of `l`, restructured according to `structure`. If |
| `structure` is a list of length 1, this function returns the |
| single data structure implied by `structure[0]`. |
| |
| """ |
| result = [] |
| current_index = 0 |
| for element in structure: |
| if element is None: |
| result.append(l[current_index]) |
| current_index += 1 |
| else: |
| result.append(l[current_index:current_index+element]) |
| current_index += element |
| |
| if len(result) == 1: |
| return result[0] |
| else: |
| return tuple(result) |
| |
| |
| def _MakeFloat(v, arg_name): |
| if not isinstance(v, compat.real_types): |
| raise TypeError("Expected float for argument '%s' not %s." % |
| (arg_name, repr(v))) |
| return float(v) |
| |
| |
| def _MakeInt(v, arg_name): |
| if isinstance(v, six.string_types): |
| raise TypeError("Expected int for argument '%s' not %s." % |
| (arg_name, repr(v))) |
| try: |
| return int(v) |
| except (ValueError, TypeError): |
| raise TypeError("Expected int for argument '%s' not %s." % |
| (arg_name, repr(v))) |
| |
| |
| def _MakeStr(v, arg_name): |
| if not isinstance(v, compat.bytes_or_text_types): |
| raise TypeError("Expected string for argument '%s' not %s." % |
| (arg_name, repr(v))) |
| return compat.as_bytes(v) # Convert unicode strings to bytes. |
| |
| |
| def _MakeBool(v, arg_name): |
| if not isinstance(v, bool): |
| raise TypeError("Expected bool for argument '%s' not %s." % |
| (arg_name, repr(v))) |
| return v |
| |
| |
| def _MakeType(v, arg_name): |
| try: |
| v = dtypes.as_dtype(v).base_dtype |
| except TypeError: |
| raise TypeError("Expected DataType for argument '%s' not %s." % |
| (arg_name, repr(v))) |
| return v.as_datatype_enum |
| |
| |
| def _MakeShape(v, arg_name): |
| """Convert v into a TensorShapeProto.""" |
| # Args: |
| # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. |
| # arg_name: String, for error messages. |
| |
| # Returns: |
| # A TensorShapeProto. |
| if isinstance(v, tensor_shape_pb2.TensorShapeProto): |
| for d in v.dim: |
| if d.name: |
| logging.warning("Warning: TensorShapeProto with a named dimension: %s", |
| str(v)) |
| break |
| return v |
| try: |
| return tensor_shape.as_shape(v).as_proto() |
| except TypeError as e: |
| raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e)) |
| except ValueError as e: |
| raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e)) |
| |
| |
| def _MakeTensor(v, arg_name): |
| """Ensure v is a TensorProto.""" |
| if isinstance(v, tensor_pb2.TensorProto): |
| return v |
| raise TypeError( |
| "Don't know how to convert %s to a TensorProto for argument '%s'" % |
| (repr(v), arg_name)) |
| |
| |
| def _MakeFunc(v, arg_name): |
| """Ensure v is a func.""" |
| if isinstance(v, attr_value_pb2.NameAttrList): |
| return v |
| if isinstance(v, compat.bytes_or_text_types): |
| fn_attr = attr_value_pb2.NameAttrList(name=v) |
| elif hasattr(v, "add_to_graph"): |
| v.add_to_graph(ops.get_default_graph()) |
| if hasattr(v, "_as_name_attr_list"): |
| fn_attr = v._as_name_attr_list # pylint: disable=protected-access |
| else: |
| fn_attr = attr_value_pb2.NameAttrList(name=v.name) |
| else: |
| raise TypeError("Don't know how to convert {} to a func for " |
| "argument {}".format(v, arg_name)) |
| return fn_attr |
| |
| |
| # pylint: disable=g-doc-return-or-yield |
| @tf_contextlib.contextmanager |
| def _MaybeColocateWith(inputs): |
| """A context manager for (maybe) colocating with a list of input tensors. |
| |
| Args: |
| inputs: A list of `Tensor` or `Operation` objects. |
| |
| Returns: |
| A context manager. |
| """ |
| if not inputs: |
| yield |
| else: |
| # NOTE(mrry): The `ops.colocate_with()` function accepts only a single |
| # op or tensor, so we create one context manager per element in the list. |
| with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]): |
| yield |
| # pylint: enable=g-doc-return-or-yield |
| |
| |
| def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name |
| """Add a node invoking a registered Op to a graph. |
| |
| Example usage: |
| # input1 and input2 can be Tensors or anything ops.convert_to_tensor() |
| # will convert to a Tensor. |
| op_def_library.apply_op("op", input1=input1, input2=input2) |
| # Can specify a node name. |
| op_def_library.apply_op("op", input1=input1, name="node_name") |
| # Must use keyword arguments, with the names specified in the OpDef. |
| op_def_library.apply_op("op", input_name=input, attr_name=attr) |
| |
| All attrs must either be inferred from an input or specified. |
| (If inferred, the attr must not be specified.) If an attr has a default |
| value specified in the Op's OpDef, then you may pass None as the value |
| of that attr to get the default. |
| |
| Args: |
| op_type_name: string. Must match the name field of a registered Op. |
| name: string. Optional name of the created op. |
| **keywords: input Tensor and attr arguments specified by name, |
| and optional parameters to pass when constructing the Operation. |
| |
| Returns: |
| The Tensor(s) representing the output of the operation, or the Operation |
| itself if there are no outputs. |
| |
| Raises: |
| RuntimeError: On some errors. |
| TypeError: On some errors. |
| ValueError: On some errors. |
| """ |
| output_structure, is_stateful, op, outputs = _apply_op_helper( |
| op_type_name, name, **keywords) |
| if output_structure: |
| res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) |
| if isinstance(res, list) and not res and is_stateful: |
| return op |
| else: |
| return res |
| else: |
| return op |
| |
| |
| def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name |
| """Implementation of apply_op that returns output_structure, op.""" |
| op_def = op_def_registry.get(op_type_name) |
| if op_def is None: |
| raise RuntimeError("Unrecognized Op name " + op_type_name) |
| |
| # Determine the graph context. |
| try: |
| # Need to flatten all the arguments into a list. |
| # pylint: disable=protected-access |
| g = ops._get_graph_from_inputs(_Flatten(keywords.values())) |
| # pylint: enable=protected-access |
| except AssertionError as e: |
| raise RuntimeError( |
| "Cannot determine graph for Op '%s' due to: %s" |
| % (op_type_name, e.message)) |
| |
| # Default name if not specified. |
| if name is None: |
| name = op_type_name |
| |
| # Check for deprecation |
| deprecation_version = op_def.deprecation.version |
| if deprecation_version: |
| producer = g.graph_def_versions.producer |
| if producer >= deprecation_version: |
| raise NotImplementedError( |
| ("Op %s is not available in GraphDef version %d. " |
| "It has been removed in version %d. %s.") % |
| (op_type_name, producer, deprecation_version, |
| op_def.deprecation.explanation)) |
| |
| # Fill in the list of default types for all "type" attrs. This |
| # will be used to choose a preferred dtype to convert to in the |
| # absence of input type information. |
| # |
| # TODO(b/31302892): Currently the defaults don't work in the right |
| # way if you have two inputs, one of whose type resolution depends |
| # on the other. Handling this will require restructuring this code |
| # significantly. |
| default_type_attr_map = {} |
| allowed_list_attr_map = {} |
| for attr_def in op_def.attr: |
| if attr_def.type != "type": |
| continue |
| key = attr_def.name |
| if attr_def.HasField("default_value"): |
| default_type_attr_map[key] = dtypes.as_dtype( |
| attr_def.default_value.type) |
| if attr_def.HasField("allowed_values"): |
| allowed_list_attr_map[key] = attr_def.allowed_values.list.type |
| |
| # Requires that op_def has passed validation (using the C++ |
| # ValidateOpDef() from ../framework/op_def_util.h). |
| attrs = {} |
| inputs = [] |
| input_types = [] |
| with g.as_default(), ops.name_scope(name) as scope: |
| |
| # Perform input type inference |
| inferred_from = {} |
| for input_arg in op_def.input_arg: |
| input_name = input_arg.name |
| if input_name in keywords: |
| values = keywords.pop(input_name) |
| elif input_name + "_" in keywords: |
| # Handle the case where the name is a keyword or built-in |
| # for Python so we use the name + _ instead. |
| input_name += "_" |
| values = keywords.pop(input_name) |
| else: |
| raise TypeError("No argument for input " + input_name) |
| |
| # Goals: |
| # * Convert values to Tensors if it contains constants. |
| # * Verify that values is a list if that matches the input_arg's |
| # type. |
| # * If the input_arg's type is determined by attrs, either set |
| # those attrs and validate those attr values are legal (if |
| # they have not yet been set) or validate the input matches |
| # the type indicated by the attrs (if they have already been |
| # inferred via an earlier input). |
| # * If the input_arg has an explicit type, make sure the input |
| # conforms. |
| |
| if _IsListParameter(input_arg): |
| if not _IsListValue(values): |
| raise TypeError( |
| "Expected list for '%s' argument to '%s' Op, not %s." % |
| (input_name, op_type_name, values)) |
| # In cases where we expect all elements of the list to have the |
| # same dtype, try to cast non-Tensor elements to that type. |
| dtype = None |
| default_dtype = None |
| if input_arg.type != types_pb2.DT_INVALID: |
| dtype = input_arg.type |
| elif input_arg.number_attr: |
| if input_arg.type_attr in attrs: |
| dtype = attrs[input_arg.type_attr] |
| else: |
| for t in values: |
| if isinstance(t, ops.Tensor): |
| dtype = t.dtype |
| break |
| |
| # dtype still not found, prefer using the default dtype |
| # from the attr. |
| if dtype is None and input_arg.type_attr in default_type_attr_map: |
| default_dtype = default_type_attr_map[input_arg.type_attr] |
| |
| try: |
| if not input_arg.is_ref and dtype: |
| dtype = dtypes.as_dtype(dtype).base_dtype |
| values = ops.internal_convert_n_to_tensor( |
| values, |
| name=input_arg.name, |
| dtype=dtype if dtype else None, |
| preferred_dtype=default_dtype, |
| as_ref=input_arg.is_ref) |
| if input_arg.number_attr and len( |
| set(v.dtype.base_dtype for v in values)) > 1: |
| raise TypeError() # All types should match. |
| except (TypeError, ValueError): |
| # What types does the conversion function think values have? |
| observed_types = [] |
| for value in values: |
| try: |
| converted_value = ops.convert_to_tensor( |
| value, as_ref=input_arg.is_ref) |
| observed_types.append(converted_value.dtype.base_dtype.name) |
| except (TypeError, ValueError): |
| observed_types.append("<NOT CONVERTIBLE TO TENSOR>") |
| observed = ", ".join(observed_types) |
| |
| prefix = ( |
| "Tensors in list passed to '%s' of '%s' Op have types [%s]" % |
| (input_name, op_type_name, observed)) |
| if input_arg.number_attr: |
| if input_arg.type != types_pb2.DT_INVALID: |
| raise TypeError("%s that do not match expected type %s." % |
| (prefix, dtype.name)) |
| elif input_arg.type_attr in attrs: |
| raise TypeError("%s that do not match type %s inferred from " |
| "earlier arguments." % |
| (prefix, dtype.name)) |
| else: |
| raise TypeError("%s that don't all match." % prefix) |
| else: |
| raise TypeError( |
| "%s that are invalid. Tensors: %s" % (prefix, values)) |
| |
| types = [x.dtype for x in values] |
| inputs.extend(values) |
| else: |
| # In cases where we have an expected type, try to convert non-Tensor |
| # arguments to that type. |
| dtype = None |
| default_dtype = None |
| allowed_list = None |
| if input_arg.type != types_pb2.DT_INVALID: |
| dtype = input_arg.type |
| elif input_arg.type_attr in attrs: |
| dtype = attrs[input_arg.type_attr] |
| elif input_arg.type_attr in default_type_attr_map: |
| # The dtype could not be inferred solely from the inputs, |
| # so we prefer the attr's default, so code that adds a new attr |
| # with a default is backwards compatible. |
| default_dtype = default_type_attr_map[input_arg.type_attr] |
| allowed_list = allowed_list_attr_map.get(input_arg.type_attr) |
| |
| try: |
| # First see if we can get a valid dtype with the default conversion |
| # and see if it matches an allowed dtypes. Some ops like ConcatV2 may |
| # not list allowed dtypes, in which case we should skip this. |
| if dtype is None and allowed_list: |
| inferred = None |
| try: |
| inferred = ops.convert_to_tensor( |
| values, name=input_arg.name, as_ref=input_arg.is_ref) |
| except TypeError as err: |
| # When converting a python object such as a list of Dimensions, we |
| # need a dtype to be specified, thus tensor conversion may throw |
| # an exception which we will ignore and try again below. |
| pass |
| |
| # If we did not match an allowed dtype, try again with the default |
| # dtype. This could be because we have an empty tensor and thus we |
| # picked the wrong type. |
| if inferred is not None and inferred.dtype in allowed_list: |
| values = inferred |
| else: |
| values = ops.convert_to_tensor( |
| values, |
| name=input_arg.name, |
| as_ref=input_arg.is_ref, |
| preferred_dtype=default_dtype) |
| else: |
| values = ops.convert_to_tensor( |
| values, |
| name=input_arg.name, |
| dtype=dtype, |
| as_ref=input_arg.is_ref, |
| preferred_dtype=default_dtype) |
| except TypeError as err: |
| if dtype is None: |
| raise err |
| else: |
| raise TypeError( |
| "Expected %s passed to parameter '%s' of op '%s', got %s of " |
| "type '%s' instead. Error: %s" % |
| (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, |
| repr(values), type(values).__name__, err)) |
| except ValueError: |
| # What type does convert_to_tensor think it has? |
| try: |
| observed = ops.convert_to_tensor( |
| values, as_ref=input_arg.is_ref).dtype.name |
| except ValueError as err: |
| raise ValueError( |
| "Tried to convert '%s' to a tensor and failed. Error: %s" % |
| (input_name, err)) |
| prefix = ("Input '%s' of '%s' Op has type %s that does not match" % |
| (input_name, op_type_name, observed)) |
| if input_arg.type != types_pb2.DT_INVALID: |
| raise TypeError("%s expected type of %s." % |
| (prefix, dtypes.as_dtype(input_arg.type).name)) |
| else: |
| # Update the maps with the default, if needed. |
| k = input_arg.type_attr |
| if k in default_type_attr_map: |
| if k not in attrs: |
| attrs[k] = default_type_attr_map[k] |
| if k not in inferred_from: |
| inferred_from[k] = "Default in OpDef" |
| |
| raise TypeError( |
| "%s type %s of argument '%s'." % |
| (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, |
| inferred_from[input_arg.type_attr])) |
| |
| types = [values.dtype] |
| inputs.append(values) |
| base_types = [x.base_dtype for x in types] |
| |
| if input_arg.number_attr: |
| # <number-attr> * <type> or <number-attr> * <type-attr> |
| if input_arg.number_attr in attrs: |
| if len(values) != attrs[input_arg.number_attr]: |
| raise ValueError( |
| "List argument '%s' to '%s' Op with length %d must match " |
| "length %d of argument '%s'." % |
| (input_name, op_type_name, len(values), |
| attrs[input_arg.number_attr], |
| inferred_from[input_arg.number_attr])) |
| else: |
| attrs[input_arg.number_attr] = len(values) |
| inferred_from[input_arg.number_attr] = input_name |
| num_attr = _Attr(op_def, input_arg.number_attr) |
| if num_attr.has_minimum and len(values) < num_attr.minimum: |
| raise ValueError( |
| "List argument '%s' to '%s' Op with length %d shorter " |
| "than minimum length %d." % |
| (input_name, op_type_name, len(values), num_attr.minimum)) |
| # All tensors must have the same base type. |
| if any(bt != base_types[0] for bt in base_types): |
| raise TypeError( |
| "All tensors passed to '%s' of '%s' Op " |
| "must have the same type." % |
| (input_name, op_type_name)) |
| if input_arg.type != types_pb2.DT_INVALID: |
| # <number-attr> * <type> case |
| if base_types and base_types[0] != input_arg.type: |
| assert False, "Unreachable" |
| elif input_arg.type_attr in attrs: |
| # <number-attr> * <type-attr> case, where <type-attr> already |
| # has an inferred value. |
| if base_types and base_types[0] != attrs[input_arg.type_attr]: |
| assert False, "Unreachable" |
| else: |
| # <number-attr> * <type-attr> case, where we are now setting |
| # the <type-attr> based on this input |
| if not base_types: |
| # If it's in default_type_attr_map, then wait to set it |
| # (in "process remaining attrs", below). |
| if input_arg.type_attr not in default_type_attr_map: |
| raise TypeError( |
| "Don't know how to infer type variable from empty input " |
| "list passed to input '%s' of '%s' Op." % |
| (input_name, op_type_name)) |
| else: |
| attrs[input_arg.type_attr] = base_types[0] |
| inferred_from[input_arg.type_attr] = input_name |
| type_attr = _Attr(op_def, input_arg.type_attr) |
| _SatisfiesTypeConstraint(base_types[0], type_attr, |
| param_name=input_name) |
| elif input_arg.type_attr: |
| # <type-attr> |
| attr_value = base_types[0] |
| if input_arg.type_attr in attrs: |
| if attrs[input_arg.type_attr] != attr_value: |
| raise TypeError( |
| "Input '%s' of '%s' Op has type %s that does not " |
| "match type %s of argument '%s'." % |
| (input_name, op_type_name, dtypes.as_dtype(attr_value).name, |
| dtypes.as_dtype(attrs[input_arg.type_attr]).name, |
| inferred_from[input_arg.type_attr])) |
| else: |
| for base_type in base_types: |
| _SatisfiesTypeConstraint(base_type, |
| _Attr(op_def, input_arg.type_attr), |
| param_name=input_name) |
| attrs[input_arg.type_attr] = attr_value |
| inferred_from[input_arg.type_attr] = input_name |
| elif input_arg.type_list_attr: |
| # <type-list-attr> |
| attr_value = base_types |
| if input_arg.type_list_attr in attrs: |
| if attrs[input_arg.type_list_attr] != attr_value: |
| raise TypeError( |
| "Input '%s' of '%s' Op has type list of %s that does not " |
| "match type list %s of argument '%s'." % |
| (input_name, op_type_name, |
| ", ".join(dtypes.as_dtype(x).name for x in attr_value), |
| ", ".join(dtypes.as_dtype(x).name |
| for x in attrs[input_arg.type_list_attr]), |
| inferred_from[input_arg.type_list_attr])) |
| else: |
| for base_type in base_types: |
| _SatisfiesTypeConstraint(base_type, |
| _Attr(op_def, input_arg.type_list_attr), |
| param_name=input_name) |
| attrs[input_arg.type_list_attr] = attr_value |
| inferred_from[input_arg.type_list_attr] = input_name |
| else: |
| # single Tensor with specified type |
| if base_types[0] != input_arg.type: |
| assert False, "Unreachable" |
| |
| if input_arg.is_ref: |
| if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access |
| raise TypeError( |
| ("'%s' Op requires that input '%s' be a mutable tensor " |
| "(e.g.: a tf.Variable)") % (op_type_name, input_name)) |
| input_types.extend(types) |
| else: |
| input_types.extend(base_types) |
| |
| # Process remaining attrs |
| for attr in op_def.attr: |
| # Skip attrs that have already had their values inferred |
| if attr.name in attrs: |
| if attr.name in keywords: |
| raise TypeError( |
| "Should not specify value for inferred attr '%s'." % attr.name) |
| continue |
| if attr.name in keywords: |
| attrs[attr.name] = keywords.pop(attr.name) |
| elif attr.name + "_" in keywords: |
| # Attrs whose names match Python keywords have an extra '_' |
| # appended, so we must check for that as well. |
| attrs[attr.name] = keywords.pop(attr.name + "_") |
| elif attr.name in default_type_attr_map: |
| attrs[attr.name] = default_type_attr_map[attr.name] |
| inferred_from.setdefault(attr.name, "Default in OpDef") |
| else: |
| raise TypeError("No argument for attr " + attr.name) |
| |
| # Convert attr values to AttrValue protos. |
| attr_protos = {} |
| for attr_def in op_def.attr: |
| key = attr_def.name |
| value = attrs[key] |
| |
| if attr_def.HasField("default_value") and value is None: |
| attr_value = attr_value_pb2.AttrValue() |
| attr_value.CopyFrom(attr_def.default_value) |
| attr_protos[key] = attr_value |
| continue |
| |
| attr_value = value_to_attr_value(value, attr_def.type, key) |
| if attr_def.type.startswith("list("): |
| _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) |
| if attr_def.HasField("allowed_values"): |
| if attr_def.type == "string": |
| _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, |
| op_type_name) |
| elif attr_def.type == "list(string)": |
| for value in attr_value.list.s: |
| _SatisfiesAllowedStringsConstraint(value, attr_def, key, |
| op_type_name) |
| if attr_def.has_minimum and attr_def.type == "int": |
| _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, |
| op_type_name) |
| if attr_def.type == "type": |
| _SatisfiesTypeConstraint(attr_value.type, attr_def, key) |
| if attr_def.type == "list(type)": |
| for value in attr_value.list.type: |
| _SatisfiesTypeConstraint(value, attr_def, key) |
| |
| attr_protos[key] = attr_value |
| del attrs # attrs is no longer authoritative, use attr_protos instead |
| |
| # Determine output types (possibly using attrs) |
| output_structure = [] |
| for arg in op_def.output_arg: |
| if arg.number_attr: |
| n = _AttrValue(attr_protos, arg.number_attr).i |
| output_structure.append(n) |
| elif arg.type_attr: |
| t = _AttrValue(attr_protos, arg.type_attr) |
| output_structure.append(None) |
| elif arg.type_list_attr: |
| t = _AttrValue(attr_protos, arg.type_list_attr) |
| output_structure.append(len(t.list.type)) |
| else: |
| output_structure.append(None) |
| |
| if keywords: |
| raise TypeError("apply_op() got unexpected keyword arguments: " + |
| ", ".join(sorted(keywords.keys()))) |
| |
| # NOTE(mrry): We add an explicit colocation constraint between |
| # the newly created op and any of its reference-typed inputs. |
| must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) |
| if arg.is_ref] |
| with _MaybeColocateWith(must_colocate_inputs): |
| # Add Op to graph |
| # pylint: disable=protected-access |
| op = g._create_op_internal(op_type_name, inputs, dtypes=None, |
| name=scope, input_types=input_types, |
| attrs=attr_protos, op_def=op_def) |
| |
| # `outputs` is returned as a separate return value so that the output |
| # tensors can the `op` per se can be decoupled so that the |
| # `op_callbacks` can function properly. See framework/op_callbacks.py |
| # for more details. |
| outputs = op.outputs |
| # Conditionally invoke tfdbg v2's op callback(s). |
| if op_callbacks.should_invoke_op_callbacks(): |
| callback_outputs = op_callbacks.invoke_op_callbacks( |
| op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), |
| op_name=op.name, graph=g) |
| if callback_outputs is not None: |
| outputs = callback_outputs |
| |
| return output_structure, op_def.is_stateful, op, outputs |
| |
| |
| def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name |
| """Encodes a Python value as an `AttrValue` proto message. |
| |
| Args: |
| value: The value to convert. |
| attr_type: The value type (string) -- see the AttrValue proto definition for |
| valid strings. |
| arg_name: Argument name (for error messages). |
| |
| Returns: |
| An AttrValue proto message that encodes `value`. |
| """ |
| attr_value = attr_value_pb2.AttrValue() |
| |
| if attr_type.startswith("list("): |
| if not _IsListValue(value): |
| raise TypeError("Expected list for attr " + arg_name) |
| |
| if attr_type == "string": |
| attr_value.s = _MakeStr(value, arg_name) |
| elif attr_type == "list(string)": |
| attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value]) |
| elif attr_type == "int": |
| attr_value.i = _MakeInt(value, arg_name) |
| elif attr_type == "list(int)": |
| attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value]) |
| elif attr_type == "float": |
| attr_value.f = _MakeFloat(value, arg_name) |
| elif attr_type == "list(float)": |
| attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value]) |
| elif attr_type == "bool": |
| attr_value.b = _MakeBool(value, arg_name) |
| elif attr_type == "list(bool)": |
| attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value]) |
| elif attr_type == "type": |
| attr_value.type = _MakeType(value, arg_name) |
| elif attr_type == "list(type)": |
| attr_value.list.type.extend([_MakeType(x, arg_name) for x in value]) |
| elif attr_type == "shape": |
| attr_value.shape.CopyFrom(_MakeShape(value, arg_name)) |
| elif attr_type == "list(shape)": |
| attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value]) |
| elif attr_type == "tensor": |
| attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name)) |
| elif attr_type == "list(tensor)": |
| attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value]) |
| elif attr_type == "func": |
| attr_value.func.CopyFrom(_MakeFunc(value, arg_name)) |
| elif attr_type == "list(func)": |
| attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value]) |
| else: |
| raise TypeError("Unrecognized Attr type " + attr_type) |
| return attr_value |
| |
| |
| # The following symbols are used by op_def_util.cc. |
| _pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType) |
| _pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype) |
| _pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape) |
| _pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape) |
| _pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto) |
| _pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse) |
| _pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor) |