Fix TensorFlow checkpoint and trackable imports.
PiperOrigin-RevId: 450079328
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
index afce76e..f32188d 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
@@ -137,7 +137,7 @@
"//tensorflow/python/saved_model:signature_constants",
"//tensorflow/python/saved_model:tag_constants",
"//tensorflow/python/saved_model:utils",
- "//tensorflow/python/training/tracking:trackable_utils",
+ "//tensorflow/python/trackable:trackable_utils",
],
)
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py
index 9ac20fb..c5a59bf 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py
@@ -28,7 +28,7 @@
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save as saved_model_save
from tensorflow.python.saved_model import tag_constants
-from tensorflow.python.training.tracking import tracking
+from tensorflow.python.trackable import autotrackable
class MultiThreadedTest(test.TestCase):
@@ -40,7 +40,7 @@
def _convert_with_calibration(self):
- class ModelWithAdd(tracking.AutoTrackable):
+ class ModelWithAdd(autotrackable.AutoTrackable):
"""Basic model with addition."""
@def_function.function(input_signature=[