Update RNN conversion tflite g3doc

This uses the content from the blog post/dogfood announcement email

PiperOrigin-RevId: 317248288
Change-Id: I210c64bd54c70aa5b68742d59d6d36fa154e856c
diff --git a/tensorflow/lite/g3doc/convert/rnn.md b/tensorflow/lite/g3doc/convert/rnn.md
index 52bc287..734992c 100644
--- a/tensorflow/lite/g3doc/convert/rnn.md
+++ b/tensorflow/lite/g3doc/convert/rnn.md
@@ -1,99 +1,193 @@
-# Convert RNN models
+# TensorFlow RNN conversion to TensorFlow Lite
 
-The TensorFlow Lite interpreter currently implements a subset of TensorFlow
-operations, meaning some model architectures cannot immediately be converted due
-to missing operations.
+## Overview
 
-Some RNN-based architectures are affected by this. The following document
-outlines the current state of play and provides strategies for converting RNN
-models.
+TensorFlow Lite supports converting TensorFlow RNN models to TensorFlow Lite’s
+fused LSTM operators. Fused operators exist to maximize the performance of their
+underlying kernel implementations, as well as provide a higher level interface
+to define complex transformations like quantizatization.
 
-## Currently supported
+Since there are many variants of RNN APIs in TensorFlow, our approach has been
+two fold:
 
-Currently, RNN models using
-[`tf.compat.v1.nn.static_rnn`](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
-can be converted successfully as long as no `sequence_length` is specified.
+1.  Provide **native support for standard TensorFlow RNN APIs** like Keras LSTM.
+    This is the recommended option.
+1.  Provide an **interface** **into the conversion infrastructure for**
+    **user-defined** **RNN implementations** to plug in and get converted to
+    TensorFlow Lite. We provide a couple of out of box examples of such
+    conversion using lingvo’s
+    [LSTMCellSimple](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L123)
+    and
+    [LayerNormalizedLSTMCellSimple](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc#L519)
+    RNN interfaces.
 
-The following `tf.compat.v1.nn.rnn_cell` operations work with
-`tf.compat.v1.nn.static_rnn`:
+## Converter API
 
-*   [tf.compat.v1.nn.rnn_cell.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell)
-*   [tf.compat.v1.nn.rnn_cell.RNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell)
-*   [tf.compat.v1.nn.rnn_cell.GRUCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell)
-*   [tf.compat.v1.nn.rnn_cell.BasicLSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicLSTMCell)
-*   [tf.compat.v1.nn.rnn_cell.BasicRNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicRNNCell)
+Currently this feature is available through the
+[tf-nightly](https://pypi.org/project/tf-nightly/) pip or from head. This will
+be available in the TensorFlow 2.3 release.
 
-In addition, TensorFlow Lite provides some experimental drop-in replacements for
-RNN operations that enable dynamic RNN architectures with TensorFlow Lite.
+This conversion functionality is available when converting to TensorFlow Lite
+via a SavedModel or from the Keras model directly. See example usages.
 
-Drop-in replacements are available for the following:
+### From saved model
 
-*   [tf.compat.v1.nn.dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
-*   [tf.compat.v1.nn.bidirectional_dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
-*   [tf.compat.v1.nn.rnn_cell.RNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell)
-*   [tf.compat.v1.nn.rnn_cell.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell)
+```
+# build a saved model. Here concrete_function is the exported function
+# corresponding to the TensorFlow model containing one or more
+# Keras LSTM layers.
+saved_model, saved_model_dir = build_saved_model_lstm(...)
+saved_model.save(saved_model_dir, save_format="tf", signatures=concrete_func)
 
-## Not currently supported
+# Convert the model.
+converter = TFLiteConverter.from_saved_model(saved_model_dir)
+tflite_model = converter.convert()
+```
 
-TensorFlow Lite does not currently support
-[Control Flow](https://www.tensorflow.org/api_docs/cc/group/control-flow-ops)
-operations. This means that, unless one of the conversion strategies discussed
-in the next section are employed, models built with the following TensorFlow
-functions will not convert successfully:
+### From Keras model
 
-*   [tf.compat.v1.nn.static_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
-    where a `sequence_length` is specified
-*   [tf.compat.v1.nn.dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
-*   [tf.compat.v1.nn.bidirectional_dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
+```
+# build a Keras model
+keras_model = build_keras_lstm(...)
 
-Note: TensorFlow Lite plans to implement all required Control Flow operations by
-the end of 2019. At this point, all RNN architectures will convert successfully.
+# Convert the model.
+converter = TFLiteConverter.from_keras_model(keras_model)
+tflite_model = converter.convert()
 
-## Conversion strategies
+```
 
-To convert an RNN model that uses the functions specified above, you will have
-to modify its architecture and retrain it. The following strategies can be used.
+## Example
 
-### 1. Refactoring
+Keras LSTM to TensorFlow Lite
+[Colab](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/experimental_new_converter/Keras_LSTM_fusion_Codelab.ipynb)
+illustrates the end to end usage with the TensorFlow Lite interpreter.
 
-The simplest approach, if possible, is to refactor the model architecture to use
-[tf.compat.v1.nn.static_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
-without `sequence_length`.
+## TensorFlow RNNs APIs supported
 
-### 2. Drop-in replacements that use op hints and fused ops
+### Keras LSTM conversion (recommended)
 
-TensorFlow Lite provides the some experimental drop-in replacements for RNN
-operations that enable dynamic RNN architectures with TensorFlow Lite. Using
-[OpHints](https://www.tensorflow.org/lite/guide/ops_custom#converting_tensorflow_models_to_convert_graphs),
-they run normally during training, but are substituted with special fused ops
-when run by the Lite interpreter.
+We support out-of-the-box conversion of Keras LSTM to TensorFlow Lite. For
+details on how this works please refer to the
+[Keras LSTM interface](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/experimental_new_converter/Keras_LSTM_fusion_Codelab.ipynb)<span style="text-decoration:space;">
+</span>and to the conversion logic
+[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc#L627).
 
-The following drop-in replacements are available:
+Also important is to highlight the TensorFlow Lite’s LSTM contract with respect
+to the Keras operation definition:
 
-*   [tf.compat.v1.lite.experimental.nn.dynamic_rnn](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn.py#L41)
-    *   replacement for tf.nn.dynamic_rnn
-*   [tf.compat.v1.lite.experimental.nn.bidirectional_dynamic_rnn](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn.py#L279)
-    *   replacement for tf.nn.bidirectional_dynamic_rnn
-*   [tf.compat.v1.lite.experimental.nn.TfLiteRNNCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn_cell.py#L39)
-    *   replacement for tf.nn.rnn_cell.RNNCell
-*   [tf.compat.v1.lite.experimental.nn.TfLiteLSTMCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn_cell.py#L159)
-    *   replacement for tf.nn.rnn_cell.LSTMCell
+1.  The dimension 0 of the input tensor is the batch size.
+1.  The dimension 0 of the recurrent\_weight tensor is the number of outputs.
+1.  The **weight** and **recurrent\_kernel** tensors are transposed.
+1.  The transposed weight, transposed recurrent\_kernel and bias tensors are
+    split into 4 equal sized tensors along the dimension 0. These correspond to
+    **input gate, forget gate, cell, and output gate**.
 
-Note: These replacements must be used together. For example, if you are using
-`tf.compat.v1.lite.experimental.nn.dynamic_rnn`, you must combine it with
-`tf.compat.v1.lite.experimental.nn.TfLiteRNNCell` instead of using
-`tf.compat.v1.nn.rnn_cell.RNNCell`.
+#### Keras LSTM Variants
 
-Instead of
-[tf.compat.v1.nn.rnn_cell.MultiRNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/MultiRNNCell),
-you should use
-[tf.compat.v1.keras.layers.StackedRNNCells](https://www.tensorflow.org/api_docs/python/tf/keras/layers/StackedRNNCells).
+##### Time major
 
-For a tutorial on using these replacements, see
-[TensorFlow Lite LSTM ops API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/g3doc/README.md).
+Users may choose time-major or no time-major. Keras LSTM adds a time-major
+attribute in the function def attributes. For Unidirectional sequence LSTM, we
+can simply map to unidirecional\_sequence\_lstm's
+[time major attribute](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L3508).
 
-For a Colab demonstrating these classes, refer to
-[TensorFlowLite_LSTM_Keras_Tutorial](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/TensorFlowLite_LSTM_Keras_Tutorial.ipynb).
+##### BiDirectional LSTM
 
-Note: There is no replacement available for
-[tf.compat.v1.nn.rnn_cell.GRUCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell).
+Bidirectional LSTM can be implemented with two Keras LSTM layers, one for
+forward and one for backward, see examples
+[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/wrappers.py#L381).
+Once we see the go\_backward attribute, we recognize it as backward LSTM, then
+we group forward & backward LSTM together. **This is future work.** Currently,
+this creates two UnidirectionalSequenceLSTM operators in the TensorFlow Lite
+model.
+
+### User-defined LSTM conversion examples
+
+TensorFlow Lite also provides a way to convert user defined LSTM
+implementations. Here we use Lingvo’s LSTM as an example of how that can be
+implemented. For details please refer to the
+[lingvo.LSTMCellSimple interface](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L230)
+and the conversion logic
+[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L123).
+We also provide an example for another of Lingvo’s LSTM definitions in
+[lingvo.LayerNormalizedLSTMCellSimple interface](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L1179)
+and its convertion logic
+[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L130).
+
+## “Bring your own TensorFlow RNN” to TensorFlow Lite
+
+If a user's RNN interface is different from the standard supported ones, there
+are a couple of options:
+
+**Option 1:** Write adapter code in TensorFlow python to adapt the RNN interface
+to the Keras RNN interface. This means a tf.function with
+[tf\_implements annotation](https://github.com/tensorflow/community/pull/113) on
+the generated RNN interface’s function that is identical to the one generated by
+the Keras LSTM layer. After this, the same conversion API used for Keras LSTM
+will work.
+
+**Option 2:** If the above is not possible (e.g. the Keras LSTM is missing some
+functionality that is currently exposed by TensorFlow Lite’s fused LSTM op like
+layer normalization), then extend the TensorFlow Lite converter by writing
+custom conversion code and plug it into the prepare-composite-functions
+MLIR-pass
+[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L108).
+The function’s interface should be treated like an API contract and should
+contain the arguments needed to convert to fused TensorFlow Lite LSTM
+operators - i.e. input, bias, weights, projection, layer normalization, etc. It
+is preferable for the tensors passed as arguments to this function to have known
+rank (i.e. RankedTensorType in MLIR). This makes it much easier to write
+conversion code that can assume these tensors as RankedTensorType and helps
+transform them to ranked tensors corresponding to the fused TensorFlow Lite
+operator’s operands.
+
+A complete example of such conversion flow is Lingvo’s LSTMCellSimple to
+TensorFlow Lite conversion.
+
+The LSTMCellSimple in Lingvo is defined
+[here](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L230).
+Models trained with this LSTM cell can be converted to TensorFlow Lite as
+follows:
+
+1.  Wrap all uses of LSTMCellSimple in a tf.function with a tf\_implements
+    annotation that is labelled as such (e.g. lingvo.LSTMCellSimple would be a
+    good annotation name here). Make sure the tf.function that is generated
+    matches the interface of the function expected in the conversion code. This
+    is a contract between the model author adding the annotation and the
+    conversion code.
+1.  Extend the prepare-composite-functions pass to plug in a custom composite op
+    to TensorFlow Lite fused LSTM op conversion. See
+    [LSTMCellSimple](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L123)
+    conversion code.
+
+    The conversion contract:
+
+1.  **Weight** and **projection** tensors are transposed.
+
+1.  The **{input, recurrent}** to **{cell, input gate, forget gate, output
+    gate}** are extracted by slicing the transposed weight tensor.
+
+1.  The **{bias}** to **{cell, input gate, forget gate, output gate}** are
+    extracted by slicing the bias tensor.
+
+1.  The **projection** is extracted by slicing the transposed projection tensor.
+
+1.  Similar conversion is written for
+    [LayerNormalizedLSTMCellSimple](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc#L519).
+
+1.  The rest of the TensorFlow Lite conversion infrastructure, including all the
+    [MLIR passes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc#L58)
+    defined as well as the final export to TensorFlow Lite flatbuffer can be
+    reused.
+
+## Known issues/limitations
+
+1.  Currently there is support only for converting stateless Keras LSTM (default
+    behavior in Keras). Stateful Keras LSTM conversion is future work.
+1.  It is still possible to model a stateful Keras LSTM layer using the
+    underlying stateless Keras LSTM layer and managing the state explicitly in
+    the user program. Such a TensorFlow program can still be converted to
+    TensorFlow Lite using the feature being described here.
+1.  Bidirectional LSTM is currently modelled as two UnidirectionalSequenceLSTM
+    operators in TensorFlow Lite. This will be replaced with a single
+    BidirectionalSequenceLSTM op.