[TF-TRT] Adding `tearDown()` to unittest base function.
diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
index dc1ab08..3270724 100644
--- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
@@ -155,6 +155,15 @@
if not is_tensorrt_enabled():
self.skipTest("Test requires TensorRT")
+ def tearDown(self):
+ """Making sure to clean artifact."""
+ idx = 0
+ while gc.garbage:
+ gc.collect() # Force GC to destroy the TRT engine cache.
+ idx += 1
+ if idx >= 10: # After 10 iterations, break to avoid infinite collect.
+ break
+
def _GetTensorSpec(self, shape, mask, dtype, name):
# Set dimension i to None if mask[i] == False
assert len(shape) == len(mask)