Skip using temp files/folders when save tensor and new checkpoint.

Bug: 281758731
Test: mma
Change-Id: Idcb6f5b552acce757e4806236c6848aa6264a809
diff --git a/tensorflow/core/data/Android.bp b/tensorflow/core/data/Android.bp
index 08ea84e..c10e64f 100644
--- a/tensorflow/core/data/Android.bp
+++ b/tensorflow/core/data/Android.bp
@@ -1,4 +1,3 @@
-
 // Copyright (C) 2023 The Android Open Source Project
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 9961fff..93b9037 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -327,6 +327,12 @@
   return fs->HasAtomicMove(path, has_atomic_move);
 }
 
+Status Env::CanCreateTempFile(const string& fname, bool* can_create_temp_file) {
+  FileSystem* fs;
+  TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs));
+  return fs->CanCreateTempFile(fname, can_create_temp_file);
+}
+
 Status Env::DeleteRecursively(const string& dirname, int64_t* undeleted_files,
                               int64_t* undeleted_dirs) {
   FileSystem* fs;
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index 86b1077..67b45c6 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -333,6 +333,15 @@
   ///  TF
   Status HasAtomicMove(const std::string& path, bool* has_atomic_move);
 
+  /// Returns whether the give path is on a file system
+  /// that has ability to create a new temp file. This can be used
+  /// to determine if there needs to be a temp location to safely write objects.
+  /// If this returns false, TensorFlow will write directly to output files
+  /// instead of creating a temporary file and swapping it in. This may mean
+  /// that incomplete writes are visible to consumers.
+  Status CanCreateTempFile(const std::string& fname,
+                           bool* can_create_temp_file);
+
   /// Stores the size of `fname` in `*file_size`.
   Status GetFileSize(const std::string& fname, uint64* file_size);
 
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index e170b09..5847e76 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -87,6 +87,12 @@
   return OkStatus();
 }
 
+Status FileSystem::CanCreateTempFile(const std::string& fname,
+                                     bool* can_create_temp_file) {
+  *can_create_temp_file = true;
+  return OkStatus();
+}
+
 void FileSystem::FlushCaches(TransactionToken* token) {}
 
 bool FileSystem::FilesExist(const std::vector<string>& files,
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 3b74a47..be0266d 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -386,6 +386,15 @@
   ///  TF
   virtual Status HasAtomicMove(const std::string& path, bool* has_atomic_move);
 
+  /// Returns whether the give path is on a file system
+  /// that has ability to create a new temp file. This can be used
+  /// to determine if there needs to be a temp location to safely write objects.
+  /// If this returns false, TensorFlow will write directly to output files
+  /// instead of creating a temporary file and swapping it in. This may mean
+  /// that incomplete writes are visible to consumers.
+  virtual Status CanCreateTempFile(const std::string& fname,
+                                   bool* can_create_temp_file);
+
   /// \brief Flushes any cached filesystem objects from memory.
   virtual void FlushCaches() { FlushCaches(nullptr); }
 
diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc
index f5ca57b..d49a252 100644
--- a/tensorflow/core/platform/file_system_test.cc
+++ b/tensorflow/core/platform/file_system_test.cc
@@ -273,6 +273,14 @@
   EXPECT_EQ(has_atomic_move, true);
 }
 
+TEST(InterPlanetaryFileSystemTest, CanCreateTempFile) {
+  InterPlanetaryFileSystem ipfs;
+  const string dirname = io::JoinPath(kPrefix, "match-00/abc/00");
+  bool can_create_temp_file;
+  TF_EXPECT_OK(ipfs.CanCreateTempFile(dirname, &can_create_temp_file));
+  EXPECT_EQ(can_create_temp_file, true);
+}
+
 // A simple file system with a root directory and a single file underneath it.
 class TestFileSystem : public NullFileSystem {
  public:
diff --git a/tensorflow/core/platform/retrying_file_system.h b/tensorflow/core/platform/retrying_file_system.h
index 1543345..ec739fb 100644
--- a/tensorflow/core/platform/retrying_file_system.h
+++ b/tensorflow/core/platform/retrying_file_system.h
@@ -144,6 +144,12 @@
     return base_file_system_->HasAtomicMove(path, has_atomic_move);
   }
 
+  Status CanCreateTempFile(const std::string& fname,
+                           bool* can_create_temp_file) override {
+    // this method does not need to be retried
+    return base_file_system_->CanCreateTempFile(fname, can_create_temp_file);
+  }
+
   Status DeleteRecursively(const string& dirname, TransactionToken* token,
                            int64_t* undeleted_files,
                            int64_t* undeleted_dirs) override {
diff --git a/tensorflow/core/platform/test.h b/tensorflow/core/platform/test.h
index b598b6e..e49f479 100644
--- a/tensorflow/core/platform/test.h
+++ b/tensorflow/core/platform/test.h
@@ -16,10 +16,11 @@
 #ifndef TENSORFLOW_CORE_PLATFORM_TEST_H_
 #define TENSORFLOW_CORE_PLATFORM_TEST_H_
 
+#include <gtest/gtest.h>  // IWYU pragma: export
+
 #include <memory>
 #include <vector>
 
-#include <gtest/gtest.h>  // IWYU pragma: export
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/platform/types.h"
@@ -39,7 +40,8 @@
 // The advantages of using gmock matchers instead of self defined matchers are
 // better error messages, more maintainable tests and more test coverage.
 #if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) && \
-    !defined(PLATFORM_CHROMIUMOS)
+    !defined(PLATFORM_CHROMIUMOS) &&                                  \
+    !defined(IS_MOBILE_PLATFORM)  // We have gmock.h in AOSP.
 #include <gmock/gmock-generated-matchers.h>
 #include <gmock/gmock-matchers.h>
 #include <gmock/gmock-more-matchers.h>
diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc
index 75197a5..0880b65 100644
--- a/tensorflow/core/util/tensor_slice_writer.cc
+++ b/tensorflow/core/util/tensor_slice_writer.cc
@@ -85,8 +85,17 @@
                                      CreateBuilderFunction create_builder)
     : filename_(filename),
       create_builder_(std::move(create_builder)),
-      tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
       slices_(0) {
+  Env* env = Env::Default();
+  Status status = env->CanCreateTempFile(filename_, &use_temp_file_);
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to get CanCreateTempFile attribute: " << filename_;
+    use_temp_file_ = true;
+  }
+  data_filename_ = filename_;
+  if (use_temp_file_) {
+    data_filename_ = strings::StrCat(filename_, ".tempstate", random::New64());
+  }
   VersionDef* versions = sts_.mutable_meta()->mutable_versions();
   versions->set_producer(TF_CHECKPOINT_VERSION);
   versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER);
@@ -94,7 +103,7 @@
 
 Status TensorSliceWriter::Finish() {
   Builder* b;
-  Status s = create_builder_(tmpname_, &b);
+  Status s = create_builder_(data_filename_, &b);
   if (!s.ok()) {
     delete b;
     return s;
@@ -113,18 +122,21 @@
 
   int64_t file_size;
   s = builder->Finish(&file_size);
-  // We need to rename the file to the proper name
-  if (s.ok()) {
-    s = Env::Default()->RenameFile(tmpname_, filename_);
+  // If use temp file, we need to rename the file to the proper name.
+  if (use_temp_file_) {
     if (s.ok()) {
-      VLOG(1) << "Written " << slices_ << " slices for "
-              << sts_.meta().tensor_size() << " tensors (" << file_size
-              << " bytes) to " << filename_;
+      s = Env::Default()->RenameFile(data_filename_, filename_);
+      if (s.ok()) {
+        VLOG(1) << "Written " << slices_ << " slices for "
+                << sts_.meta().tensor_size() << " tensors (" << file_size
+                << " bytes) to " << filename_;
+      } else {
+        LOG(ERROR) << "Failed to rename file " << data_filename_ << " to "
+                   << filename_;
+      }
     } else {
-      LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_;
+      Env::Default()->DeleteFile(data_filename_).IgnoreError();
     }
-  } else {
-    Env::Default()->DeleteFile(tmpname_).IgnoreError();
   }
   return s;
 }
diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h
index 34aa7c0..887db9d 100644
--- a/tensorflow/core/util/tensor_slice_writer.h
+++ b/tensorflow/core/util/tensor_slice_writer.h
@@ -83,7 +83,8 @@
 
   const string filename_;
   const CreateBuilderFunction create_builder_;
-  const string tmpname_;
+  string data_filename_;
+  bool use_temp_file_;
 
   // A mapping from the tensor names to their index in meta_.saved_slice_meta()
   std::unordered_map<string, int> name_to_index_;