Add method for renaming custom ops
PiperOrigin-RevId: 370165436
Change-Id: I692573eeed53f372372d54e671553ed0cfb44f2f
diff --git a/tensorflow/lite/tools/flatbuffer_utils.py b/tensorflow/lite/tools/flatbuffer_utils.py
index 467dd75..b9d1d55 100644
--- a/tensorflow/lite/tools/flatbuffer_utils.py
+++ b/tensorflow/lite/tools/flatbuffer_utils.py
@@ -118,7 +118,6 @@
Args:
model: The model from which to remove nonessential strings.
-
"""
model.description = None
@@ -136,7 +135,6 @@
Args:
model: The model in which to randomize weights.
random_seed: The input to the random number generator (default value is 0).
-
"""
# The input to the random seed generator. The default value is 0.
@@ -158,6 +156,20 @@
buffer_i_data[j] = random.randint(0, 255)
+def rename_custom_ops(model, map_custom_op_renames):
+ """Rename custom ops so they use the same naming style as builtin ops.
+
+ Args:
+ model: The input tflite model.
+ map_custom_op_renames: A mapping from old to new custom op names.
+ """
+ for op_code in model.operatorCodes:
+ if op_code.customCode:
+ op_code_str = op_code.customCode.decode('ascii')
+ if op_code_str in map_custom_op_renames:
+ op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
+
+
def xxd_output_to_bytes(input_cc_file):
"""Converts xxd output C++ source file to bytes (immutable).