| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================= |
| """xla is an experimental library that provides XLA support APIs.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import contextlib |
| |
| from six.moves import xrange # pylint: disable=redefined-builtin |
| |
| from tensorflow.compiler.jit.ops import xla_ops |
| from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.python.distribute import summary_op_util |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.util import compat |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import tf_inspect |
| from tensorflow.python.util.compat import collections_abc |
| from tensorflow.python.util.deprecation import deprecated |
| from tensorflow.python.util.tf_export import tf_export |
| |
| _XLA_COMPILE_ATTR = '_xla_compile_id' |
| _MAX_WARNING_LINES = 5 |
| |
| # Operations that indicate some error in the users graph. For example, XLA |
| # computation should not have any Placeholder op. |
| _DENYLISTED_OPS = set([ |
| 'Placeholder', |
| ]) |
| |
| # XLA doesn't currently support reading of intermediate tensors, thus some ops |
| # are not supported. |
| _UNSUPPORTED_OPS = set([ |
| 'AudioSummary', |
| 'AudioSummaryV2', |
| 'HistogramSummary', |
| 'ImageSummary', |
| 'MergeSummary', |
| 'Print', |
| 'ScalarSummary', |
| 'TensorSummary', |
| 'TensorSummaryV2', |
| ]) |
| |
| |
| @tf_export('xla.experimental.compile') |
| @deprecated( |
| None, 'xla.experimental.compile is deprecated. Consider using ' |
| 'tf.function(experimental_compile=True)', |
| warn_once=True) |
| def compile(computation, inputs=None): # pylint: disable=redefined-builtin |
| """Builds an operator that compiles and runs `computation` with XLA. |
| |
| NOTE: In eager mode, `computation` will have `@tf.function` semantics. |
| |
| Args: |
| computation: A Python function that builds a computation to apply to the |
| input. If the function takes n inputs, 'inputs' should be a list of n |
| tensors. |
| |
| `computation` may return a list of operations and tensors. Tensors must |
| come before operations in the returned list. The return value of |
| `compile` is a list of tensors corresponding to the tensors from the |
| output of `computation`. |
| |
| All `Operation`s returned from `computation` will be executed when |
| evaluating any of the returned output tensors. |
| inputs: A list of inputs or `None` (equivalent to an empty list). Each input |
| can be a nested structure containing values that are convertible to |
| tensors. Note that passing an N-dimension list of compatible values will |
| result in a N-dimension list of scalar tensors rather than a single Rank-N |
| tensors. If you need different behavior, convert part of inputs to tensors |
| with `tf.convert_to_tensor`. |
| |
| Returns: |
| Same data structure as if computation(*inputs) is called directly with some |
| exceptions for correctness. Exceptions include: |
| 1) None output: a NoOp would be returned which control-depends on |
| computation. |
| 2) Single value output: A tuple containing the value would be returned. |
| 3) Operation-only outputs: a NoOp would be returned which |
| control-depends on computation. |
| TODO(b/121383831): Investigate into removing these special cases. |
| |
| Raises: |
| RuntimeError: if called when eager execution is enabled. |
| |
| Known issues: |
| When a tf.random operation is built with XLA, the implementation doesn't |
| pass the user provided seed to the XLA compiler. As such, the XLA compiler |
| generates a random number and uses it as a seed when compiling the |
| operation. This implementation causes a violation of the Tensorflow |
| defined semantics in two aspects. First, changing the value of the user |
| defined seed doesn't change the numbers generated by the operation. |
| Second, when a seed is not specified, running the program multiple times |
| will generate the same numbers. |
| |
| """ |
| if context.executing_eagerly(): |
| @def_function.function |
| def xla_compile_wrapper(): |
| return _compile_internal(computation, inputs) |
| |
| return xla_compile_wrapper() |
| |
| return _compile_internal(computation, inputs) |
| |
| |
| class XLACompileContext(control_flow_ops.XLAControlFlowContext): |
| """A `ControlFlowContext` for nodes inside an XLA computation cluster. |
| |
| THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. |
| |
| The primary role of `XLACompileContext` is to mark operators inside a |
| xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is |
| a unique name. |
| |
| `ControlFlowContext` is used to perform the annotation since it integrates |
| with Tensorflow constructs like ResourceVariables. For example, if a |
| `ResourceVariable` is constructed inside a xla.compile() block, the |
| `ResourceVariable` implementation can use |
| `with ops.control_dependencies(None)` to build the variable's definition |
| outside the compiled computation. |
| """ |
| |
| def __init__(self, name, pivot): |
| """Builds a new XLACompileContext. |
| |
| Args: |
| name: a unique name for the context, used to populate the |
| `_xla_compile_id` attribute. |
| pivot: a pivot node. Nodes in the XLACompileContext that do not have any |
| inputs will have a control dependency on the pivot node. This ensures |
| that nodes are correctly included in any enclosing control flow |
| contexts. |
| """ |
| super(XLACompileContext, self).__init__() |
| self._name = name |
| self._name_as_bytes = compat.as_bytes(name) |
| self._unsupported_ops = [] |
| self._pivot = pivot |
| |
| def report_unsupported_operations(self): |
| if self._unsupported_ops: |
| op_str = '\n'.join([ |
| ' %s (%s)' % (op.type, op.name) |
| for op in self._unsupported_ops[:_MAX_WARNING_LINES] |
| ]) |
| logging.warning('%d unsupported operations found: \n%s', |
| len(self._unsupported_ops), op_str) |
| if len(self._unsupported_ops) > _MAX_WARNING_LINES: |
| logging.warning('... and %d more', |
| len(self._unsupported_ops) - _MAX_WARNING_LINES) |
| |
| def _RemoveExternalControlEdges(self, op): |
| """Remove any external control dependency on this op.""" |
| internal_control_inputs = [] |
| external_control_inputs = [] |
| for x in op.control_inputs: |
| # pylint: disable=protected-access |
| is_internal_op = False |
| ctxt = x._get_control_flow_context() |
| while ctxt is not None: |
| if ctxt == self: |
| is_internal_op = True |
| break |
| ctxt = ctxt._outer_context |
| if is_internal_op: |
| internal_control_inputs.append(x) |
| else: |
| external_control_inputs.append(x) |
| # pylint: enable=protected-access |
| # pylint: disable=protected-access |
| op._remove_all_control_inputs() |
| op._add_control_inputs(internal_control_inputs) |
| # pylint: enable=protected-access |
| return internal_control_inputs, external_control_inputs |
| |
| def AddOp(self, op): |
| """Create op in XLACompileContext and notifies outer context recursively.""" |
| # pylint: disable=protected-access |
| if op.type in _DENYLISTED_OPS: |
| logging.error( |
| 'Operation of type %s (%s) is not supported in XLA. Execution will ' |
| 'fail if this op is used in the graph. ', op.type, op.name) |
| |
| # TODO(ycao): Automatically disable summaries instead of reporting them. |
| if op.type in _UNSUPPORTED_OPS: |
| self._unsupported_ops.append(op) |
| |
| if any(x.dtype._is_ref_dtype for x in op.inputs): |
| raise NotImplementedError( |
| 'Non-resource Variables are not supported inside XLA computations ' |
| '(operator name: %s)' % op.name) |
| |
| if _XLA_COMPILE_ATTR in op.node_def.attr: |
| raise ValueError('XLA compiled computations cannot be nested, (operator ' |
| 'name: %s)' % op.name) |
| |
| op._set_attr( |
| _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) |
| |
| op.graph.prevent_feeding(op) |
| op.graph.prevent_fetching(op) |
| |
| # Remove any control edges from outer control flow contexts. These may cause |
| # mismatched frame errors. An example is when one of op's inputs is |
| # generated in a different While control flow context. |
| (internal_control_inputs, |
| external_control_inputs) = self._RemoveExternalControlEdges(op) |
| |
| if not op.inputs: |
| # Add a control edge from the control pivot to this op. |
| if not internal_control_inputs: |
| # pylint: disable=protected-access |
| op._add_control_input(self._pivot) |
| # pylint: enable=protected-access |
| else: |
| for index in xrange(len(op.inputs)): |
| x = op.inputs[index] |
| real_x = self.AddValue(x) |
| if real_x is not x: |
| op._update_input(index, real_x) # pylint: disable=protected-access |
| |
| if external_control_inputs: |
| # Use an identity to pull control inputs as data inputs. Note that we |
| # ignore ops which don't have outputs. TODO(phawkins): fix that. |
| with ops.control_dependencies(None): |
| self.Enter() |
| external_control_inputs = [ |
| array_ops.identity(x.outputs[0]).op |
| for x in external_control_inputs |
| if x.outputs |
| ] |
| self.Exit() |
| # pylint: disable=protected-access |
| op._add_control_inputs(external_control_inputs) |
| # pylint: enable=protected-access |
| |
| # Mark op's outputs as seen by this context and any outer contexts. |
| output_names = [x.name for x in op.outputs] |
| context = self |
| while context is not None: |
| # pylint: disable=protected-access |
| context._values.update(output_names) |
| context = context._outer_context |
| # pylint: enable=protected-access |
| |
| if self._outer_context: |
| self._outer_context.AddInnerOp(op) |
| |
| def AddValue(self, val): |
| """Add `val` to the current context and its outer context recursively.""" |
| if val.name in self._values: |
| # Use the real value if it comes from outer context. |
| result = self._external_values.get(val.name) |
| return val if result is None else result |
| |
| result = val |
| self._values.add(val.name) |
| if self._outer_context: |
| result = self._outer_context.AddValue(val) |
| self._values.add(result.name) |
| |
| self._external_values[val.name] = result |
| |
| return result |
| |
| def AddInnerOp(self, op): |
| self.AddOp(op) |
| if self._outer_context: |
| self._outer_context.AddInnerOp(op) |
| |
| @property |
| def grad_state(self): |
| # Define the gradient loop state associated with the XLACompileContext to |
| # be None as the XLACompileContext does not get nested nor does the |
| # grad_state outside the XLACompileContext affect the graph inside so the |
| # grad_state should be as if this is the top-level gradient state. |
| return None |
| |
| @property |
| def back_prop(self): |
| """Forwards to the enclosing while context, if any.""" |
| if self.GetWhileContext(): |
| return self.GetWhileContext().back_prop |
| return False |
| |
| |
| def _compile_internal(computation, inputs=None): |
| """Builds graph operators that compiles and symbolically executes computation. |
| |
| Args: |
| computation: A Python function that builds the computation to compile and |
| execute. |
| inputs: A list of inputs or `None` (equivalent to an empty list). Each input |
| can be a nested structure containing values that are convertible to |
| tensors. Note that passing an N-dimension list of compatible values will |
| result in a N-dimension list of scalar tensors rather than a single Rank-N |
| tensors. If you need different behavior, convert part of inputs to tensors |
| with `tf.convert_to_tensor`. |
| |
| Returns: |
| Same data structure as if computation(*inputs) is called directly with some |
| exceptions for correctness. Exceptions include: 1) None output 2) Single |
| value output 3) Operation-only outputs |
| Raises: |
| ValueError: If any element in computation outputs is neither an operations |
| or a value that can be converted to tensor. |
| ValueError: If computation outputs is non-flat and contains any Operations. |
| TypeError: If `inputs` is not a list or tuple. |
| """ |
| if inputs is None: |
| inputs = [] |
| |
| if not isinstance(inputs, collections_abc.Sequence): |
| raise TypeError('inputs must be a list') |
| |
| # Flatten inputs. |
| flat_inputs = nest.flatten(inputs) |
| # Converts inputs to Tensors. |
| flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] |
| |
| cluster_name = ops.get_default_graph().unique_name('cluster') |
| pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') |
| context = XLACompileContext(name=cluster_name, pivot=pivot) |
| try: |
| context.Enter() |
| |
| # Add identity ops so even unused inputs are 'consumed' by the |
| # computation. |
| flat_inputs = [ |
| array_ops.identity(x, name='input_{}'.format(i)) |
| for i, x in enumerate(flat_inputs) |
| ] |
| |
| # Re-pack flat_inputs in same structure as 'inputs'. |
| computation_inputs = nest.pack_sequence_as( |
| structure=inputs, flat_sequence=flat_inputs) |
| |
| # Only resource variables work inside an XLA computation, so turn on |
| # resource variables for the computation. |
| vscope = variable_scope.get_variable_scope() |
| saved_use_resource = vscope.use_resource |
| vscope.set_use_resource(True) |
| |
| with _disable_summary_context(): |
| outputs = computation(*computation_inputs) |
| |
| # Restore variable scope after computation. |
| vscope.set_use_resource(saved_use_resource) |
| |
| outputs_is_flat = is_flat(outputs) |
| if outputs_is_flat: |
| output_tensors, control_deps = _postprocess_flat_outputs(outputs) |
| else: |
| output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) |
| |
| context.ExitResult(output_tensors) |
| finally: |
| context.report_unsupported_operations() |
| context.Exit() |
| |
| # When XLA computation returns only operations and no tensors, a NoOp |
| # dependent on the operations in outputs is returned. Otherwise final |
| # outputs would be empty and there is no way to trigger returned |
| # operations. |
| if not output_tensors: |
| return control_flow_ops.group(control_deps, name='output_0') |
| |
| output_tensors = [ |
| xla_ops.xla_cluster_output(o, name='output{}'.format(i)) |
| for i, o in enumerate(output_tensors) |
| ] |
| |
| with ops.control_dependencies(control_deps): |
| # Wraps the outputs in identity operators that carries control |
| # dependencies. |
| output_tensors = [ |
| array_ops.identity(o, name='output_%d' % i) |
| for i, o in enumerate(output_tensors) |
| ] |
| |
| # If `computation` returned non-flat output structure, pack output tensors |
| # back into same structure. |
| if not outputs_is_flat: |
| output_tensors = nest.pack_sequence_as( |
| structure=outputs, flat_sequence=output_tensors) |
| |
| return output_tensors |
| |
| |
| def is_flat(outputs): |
| """Checks if outputs is a flat structure. |
| |
| Following structures and values are considered flat: |
| 1) None |
| 2) A single object |
| 3) A list or tuple of Tensors/Operations |
| |
| The only structures that this function understands are sequences, |
| dictionaries and types defined using the attrs library. E.g. this means |
| that if outputs contains a single user-defined Object, it is considered to |
| be flat. Errors are raised later on if that Object cannot be converted to a |
| Tensor. |
| |
| Args: |
| outputs: Output from `computation` inside `xla.compile`. |
| |
| Returns: |
| A boolean indicates whether outputs is flat. |
| """ |
| # If outputs is a list or tuple, check if it has any nested structure. If |
| # there is, then outputs is non-flat. |
| if isinstance(outputs, collections_abc.Sequence): |
| for o in outputs: |
| if (isinstance(o, collections_abc.Sequence) or |
| isinstance(o, collections_abc.Mapping) or |
| hasattr(o.__class__, '__attrs_attrs__')): |
| return False |
| |
| # If outputs is a dict, it is non-flat. |
| if isinstance(outputs, collections_abc.Mapping): |
| return False |
| |
| # If outputs is from the attrs library, it is non-flat. |
| if hasattr(outputs.__class__, '__attrs_attrs__'): |
| return False |
| |
| # Getting here means either outputs itself is a single non-structured value |
| # or it is a flat list of single non-structured values. |
| return True |
| |
| |
| def _postprocess_flat_outputs(outputs): |
| """Validates flat outputs and adds back device assignments. |
| |
| Args: |
| outputs: Output from `computation` inside `xla.compile`. |
| |
| Returns: |
| Tensors and Operations extracted from outputs. |
| """ |
| # Following code segment is to preserve legacy behavior. Previously we only |
| # supported flat outputs and thus for consistency it was nice to convert even |
| # single element into a tuple. But now that we support arbitrary output |
| # structure, this is no longer necessary. |
| # TODO(b/121383831): Migrate all legacy use cases and delete this special |
| # case. |
| # If the computation returns `None`, make it an empty tuple. |
| if outputs is None: |
| outputs = tuple() |
| # If the computation only returned one value, make it a tuple. |
| if not isinstance(outputs, collections_abc.Sequence): |
| outputs = (outputs,) |
| |
| # Append `no_op` here so that return value of this function always contains |
| # at least one op that can trigger XlaLaunch node. |
| outputs += (control_flow_ops.no_op(),) |
| try: |
| outputs = [ |
| o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) |
| for o in outputs |
| ] |
| except Exception as e: |
| raise ValueError( |
| 'XLA computation function return values must all either be Operations' |
| ' or convertible to Tensors. Got error: "%s"' % str(e)) |
| |
| # Separates the returned Operations and Tensors. |
| output_operations = [o for o in outputs if isinstance(o, ops.Operation)] |
| output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] |
| |
| if outputs != output_tensors + output_operations: |
| raise ValueError( |
| 'XLA computation function must return zero or more Tensor values ' |
| 'followed by zero or more Operations.') |
| |
| new_output_tensors = [] |
| for t in output_tensors: |
| with ops.device(t.device if t.device else ''): |
| new_output_tensors.append(array_ops.identity(t)) |
| |
| return new_output_tensors, output_operations |
| |
| |
| def _postprocess_non_flat_outputs(outputs): |
| """Validates non-flat outputs and adds back device assignments. |
| |
| Args: |
| outputs: Output from `computation` inside `xla.compile`. |
| |
| Returns: |
| Tensors extracted from outputs and an empty list because Operations are not |
| allowed in non-flat outputs.. |
| """ |
| # Convert all non-Operation outputs to Tensors. |
| new_output_tensors = [] |
| for o in nest.flatten(outputs): |
| if isinstance(o, ops.Operation): |
| raise ValueError( |
| 'xla.compile does not support Operation as return value in non-flat ' |
| 'output structure. You can set returned Operations as control ' |
| 'dependencies of returned Tensors so Operations are triggered when ' |
| 'Tensors are evaluated. Operation found: "%s"' % o.name) |
| |
| try: |
| o = ops.convert_to_tensor(o) |
| except Exception as e: |
| raise ValueError( |
| 'XLA computation function return values must all either be ' |
| 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) |
| |
| # Makes sure even pass-through inputs/outputs are touched in compile |
| # context by creating an Identity node inside compile context. |
| with ops.device(o.device if o.device else ''): |
| new_output_tensors.append(array_ops.identity(o)) |
| |
| return new_output_tensors, [] |
| |
| |
| @contextlib.contextmanager |
| def _disable_summary_context(): |
| """Enters a context where all summary ops are skipped. |
| |
| Summaries are not yet supported in xla.compile(). So we provide this context |
| manager that can skip creating summary ops. This is a temporary workaround due |
| to XLA not supporting summary ops. |
| |
| Yields: |
| None. |
| """ |
| original_skip_summary_func = summary_op_util.skip_summary |
| summary_op_util.skip_summary = lambda: True |
| |
| try: |
| yield |
| finally: |
| summary_op_util.skip_summary = original_skip_summary_func |
| |
| |
| class _CapturedObject(object): |
| """A placeholder to capture an object.""" |
| |
| def __init__(self): |
| self._object = None |
| |
| def capture(self, o): |
| if self._object: |
| raise RuntimeError( |
| 'InternalError: _CapturedObject can capture only once. Please file ' |
| 'bug.') |
| |
| self._object = o |
| |
| def get(self): |
| return self._object |
| |
| |
| def _get_scaffold(captured_scaffold_fn): |
| """Retrieves the Scaffold from `captured_scaffold_fn`.""" |
| scaffold_fn = captured_scaffold_fn.get() |
| |
| if not scaffold_fn: |
| return None |
| |
| scaffold = scaffold_fn() |
| if scaffold is None: |
| raise ValueError( |
| 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') |
| |
| return scaffold |
| |
| |
| def check_function_argument_count(func, input_arity, infeed_queue): |
| """Validate the number of input arguments to an XLA function. |
| |
| Args: |
| func: the Python function that will be called to generate the body of an XLA |
| computation graph. |
| input_arity: the number of explicit arguments supplied by the caller. |
| infeed_queue: if not None, the infeed queue that will supply |
| additional arguments to the function. |
| |
| Returns: |
| None if function can be called with the supplied number of |
| arguments, or an error string if it cannot. |
| """ |
| def format_error(complaint, quantity): |
| return '%s %d argument%s' % (complaint, quantity, '' |
| if quantity == 1 else 's') |
| |
| num_args_supplied = input_arity |
| if infeed_queue is not None: |
| num_args_supplied += infeed_queue.number_of_tuple_elements |
| arg_spec = tf_inspect.getargspec(func) |
| num_func_args = len(arg_spec.args) |
| if arg_spec.defaults is None: |
| num_func_defaults = 0 |
| else: |
| num_func_defaults = len(arg_spec.defaults) |
| min_func_args = num_func_args - num_func_defaults |
| if num_args_supplied < min_func_args: |
| # The required number of arguments is not enough to call the function. |
| if num_func_defaults == 0 and arg_spec.varargs is None: |
| return format_error('exactly', num_func_args) |
| else: |
| return format_error('at least', min_func_args) |
| if arg_spec.varargs is None and num_args_supplied > num_func_args: |
| # The required number of arguments is too many to call the function. |
| if num_func_defaults == 0: |
| return format_error('exactly', num_func_args) |
| else: |
| return format_error('at most', num_func_args) |
| # Reaching here means either |
| # 1) There are varargs, func can accept any number of arguments greater than |
| # the minimum. |
| # 2) Number of supplied arguments falls in range of acceptable argument count |
| # of func. |
| return None |