Fix pylint for backwards compatibility test
diff --git a/tensorflow/python/compiler/tensorrt/test/testdata/gen_tftrt_model.py b/tensorflow/python/compiler/tensorrt/test/testdata/gen_tftrt_model.py
index 6fdfde2..8f0acf4 100644
--- a/tensorflow/python/compiler/tensorrt/test/testdata/gen_tftrt_model.py
+++ b/tensorflow/python/compiler/tensorrt/test/testdata/gen_tftrt_model.py
@@ -28,8 +28,6 @@
from __future__ import division
from __future__ import print_function
-import numpy as np
-
import tensorflow as tf
from tensorflow.python.training.tracking import tracking
from tensorflow.python.eager import def_function
@@ -53,13 +51,14 @@
return out
-def generateModelV2(tf_saved_model_dir, tftrt_saved_model_dir):
+def GenerateModelV2(tf_saved_model_dir, tftrt_saved_model_dir):
+ """Generate and convert a model using TFv2 API."""
class SimpleModel(tracking.AutoTrackable):
"""Define model with a TF function."""
-
+
def __init__(self):
self.v = None
-
+
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=tf.dtypes.float32),
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=tf.dtypes.float32)
@@ -68,13 +67,14 @@
if self.v is None:
self.v = variables.Variable([[[1.0]]], dtype=tf.dtypes.float32)
return GetGraph(input1, input2, self.v)
-
+
root = SimpleModel()
-
+
# Saved TF model
- tf.saved_model.save(root, tf_saved_model_dir,
- {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run})
-
+ tf.saved_model.save(
+ root, tf_saved_model_dir,
+ {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run})
+
# Convert TF model to TensorRT
converter = trt_convert.TrtGraphConverterV2(
input_saved_model_dir=tf_saved_model_dir)
@@ -82,9 +82,8 @@
converter.save(tftrt_saved_model_dir)
-def generateModelV1(tf_saved_model_dir,
- tftrt_saved_model_dir,
- signature_key=None):
+def GenerateModelV1(tf_saved_model_dir, tftrt_saved_model_dir):
+ """Generate and convert a model using TFv1 API."""
def SimpleModel():
def GraphFn():
input1 = array_ops.placeholder(
@@ -94,11 +93,11 @@
var = variables.Variable([[[1.0]]], dtype=tf.dtypes.float32, name="v1")
out = GetGraph(input1, input2, var)
return g, var, input1, input2, out
-
+
g = ops.Graph()
with g.as_default():
return GraphFn()
-
+
g, var, input1, input2, out = SimpleModel()
signature_def = signature_def_utils.build_signature_def(
inputs={
@@ -107,12 +106,14 @@
},
outputs={"output": tf.saved_model.utils.build_tensor_info(out)},
method_name=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
- saved_model_builder = tf.saved_model.builder.SavedModelBuilder(tf_saved_model_dir)
+ saved_model_builder = tf.saved_model.builder.SavedModelBuilder(
+ tf_saved_model_dir)
with tf.Session(graph=g) as sess:
sess.run(var.initializer)
saved_model_builder.add_meta_graph_and_variables(
- sess, [tag_constants.SERVING],
- signature_def_map={signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def})
+ sess, [tag_constants.SERVING], signature_def_map={
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ signature_def})
saved_model_builder.save()
# Convert TF model to TensorRT
@@ -124,5 +125,5 @@
if __name__ == "__main__":
- generateModelV2(tf_saved_model_dir = "tf_saved_model",
- tftrt_saved_model_dir = "tftrt_saved_model")
+ GenerateModelV2(tf_saved_model_dir="tf_saved_model",
+ tftrt_saved_model_dir="tftrt_saved_model")
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
index 95d70bc..2ed4bea 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
@@ -21,9 +21,9 @@
import gc
import os
import tempfile
+import numpy as np
from absl.testing import parameterized
-import numpy as np
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
from tensorflow.core.framework import graph_pb2
@@ -844,7 +844,7 @@
return
model_dir = test.test_src_dir_path(
- 'python/compiler/tensorrt/test/testdata/tftrt_2.0_saved_model')
+ 'python/compiler/tensorrt/test/testdata/tftrt_2.0_saved_model')
saved_model_loaded = load.load(
model_dir, tags=[tag_constants.SERVING])
graph_func = saved_model_loaded.signatures[
@@ -855,7 +855,7 @@
np_input2 = ops.convert_to_tensor(
np.ones([4, 1, 1]).astype(np.float32))
output = graph_func(input1=np_input1, input2=np_input2)['output_0']
-
+
self.assertTrue(output.shape == (4, 1, 1))
self.assertAllClose(
np.asarray([5.0, 5.0, 5.0, 5.0]).reshape([4, 1, 1]),