[xla] add better support for variables(num_variables, index lookup by name) in xla aot/jit.

PiperOrigin-RevId: 300243405
Change-Id: Iab455be5b0d3ec594b8482de2e61f5049bc4cb14
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 53150e9..4a4fec5 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -26,6 +26,7 @@
 #include "absl/strings/str_split.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
 #include "tensorflow/compiler/xla/service/compiler.h"
@@ -288,8 +289,8 @@
 }
 
 // Generates code implementing {Arg,Result}Names(), where T is one of
-// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
-// literal in the array, with nullptr terminating the array.
+// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
+// string literal in the array, with nullptr terminating the array.
 template <typename T>
 string GenNameToIndexCode(const T& entries, bool generate) {
   // No need for a static array if we're not supposed to generate the data.
@@ -419,6 +420,16 @@
   // Generate metadata.
   const string arg_names_code =
       GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
+
+  auto variable_copy = config.variable();
+  for (auto& var : variable_copy) {
+    if (var.name().empty()) {
+      var.set_name(var.node_name());
+    }
+  }
+  const string variable_names_code =
+      GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
+
   const string result_names_code =
       GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
   const string include_xla_data_proto =
@@ -507,6 +518,9 @@
   // Number of input arguments for the compiled computation.
   static constexpr size_t kNumArgs = {{ARG_NUM}};
 
+  // Number of variables for the compiled computation.
+  static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
+
   // Byte size of each argument buffer. There are kNumArgs entries.
   static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
     return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
@@ -522,8 +536,10 @@
       set_static_data_num_buffers(data, kNumBuffers);
       set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
       set_static_data_num_args(data, kNumArgs);
+      set_static_data_num_variables(data, kNumVariables);
       set_static_data_result_index(data, kResultIndex);
       set_static_data_arg_names(data, StaticArgNames());
+      set_static_data_variable_names(data, StaticVariableNames());
       set_static_data_result_names(data, StaticResultNames());
       set_static_data_program_shape(data, StaticProgramShape());
       set_static_data_hlo_profile_printer_data(
@@ -626,6 +642,9 @@
   // Array of names of each positional argument, terminated by nullptr.
   static const char** StaticArgNames() {{ARG_NAMES_CODE}}
 
+  // Array of names of each positional variable, terminated by nullptr.
+  static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
+
   // Array of names of each positional result, terminated by nullptr.
   static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
 
@@ -654,6 +673,7 @@
       {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
       {"{{ARG_NAMES_CODE}}", arg_names_code},
       {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
+      {"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
       {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
       {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
       {"{{CLASS}}", opts.class_name},
@@ -673,6 +693,7 @@
       {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
       {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
        metadata_result.program_shape_access_shim},
+      {"{{VARIABLE_NAMES_CODE}}", variable_names_code},
       {"{{RESULT_INDEX}}", absl::StrCat(result_index)},
       {"{{RESULT_NAMES_CODE}}", result_names_code},
       {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 6206f68..babbd7f 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -156,17 +156,14 @@
   // bazel test --test_strategy=local \
   //   third_party/tensorflow/compiler/aot:codegen_test
   const bool update_golden = false;
-  string golden_file_name;
+  string golden_file_name =
+      GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
 
   if (update_golden) {
-    golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
-                                    tensorflow_relative_golden_file_name);
     TF_EXPECT_OK(
         WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
   }
 
-  golden_file_name =
-      GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
   string golden_file_contents;
   TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
                                 &golden_file_contents));
@@ -220,10 +217,16 @@
       {},
       {BufferInfo::MakeTempBuffer(1),
        BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
-       BufferInfo::MakeTempBuffer(2),
+       BufferInfo::MakeTempBuffer(1),
        BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
-       BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
-      5, {}));
+       BufferInfo::MakeTempBuffer(1),
+       BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
+       BufferInfo::MakeTempBuffer(1),
+       BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
+       BufferInfo::MakeTempBuffer(1),
+       BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
+       BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
+      11, {}));
   compile_result.program_shape =
       xla::ShapeUtil::MakeProgramShape(
           {
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index 1669e72..af58ca2 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -55,14 +55,17 @@
 //   ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
 //
 // Memory stats:
-//   arg bytes total:    104
-//   arg bytes aligned:  192
+//   arg bytes total:    392
+//   arg bytes aligned:  576
 //   temp bytes total:   126
-//   temp bytes aligned: 320
+//   temp bytes aligned: 512
 class MyClass final : public tensorflow::XlaCompiledCpuFunction {
  public:
   // Number of input arguments for the compiled computation.
-  static constexpr size_t kNumArgs = 2;
+  static constexpr size_t kNumArgs = 5;
+
+  // Number of variables for the compiled computation.
+  static constexpr size_t kNumVariables = 3;
 
   // Byte size of each argument buffer. There are kNumArgs entries.
   static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
@@ -79,8 +82,10 @@
       set_static_data_num_buffers(data, kNumBuffers);
       set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
       set_static_data_num_args(data, kNumArgs);
+      set_static_data_num_variables(data, kNumVariables);
       set_static_data_result_index(data, kResultIndex);
       set_static_data_arg_names(data, StaticArgNames());
+      set_static_data_variable_names(data, StaticVariableNames());
       set_static_data_result_names(data, StaticResultNames());
       set_static_data_program_shape(data, StaticProgramShape());
       set_static_data_hlo_profile_printer_data(
@@ -295,16 +300,22 @@
 
  private:
   // Number of buffers for the compiled computation.
-  static constexpr size_t kNumBuffers = 6;
+  static constexpr size_t kNumBuffers = 12;
 
   static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
     static const ::xla::cpu_function_runtime::BufferInfo
       kBufferInfos[kNumBuffers] = {
 ::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
 ::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
-::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
+::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
 ::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
-::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
+::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
+::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
+::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
+::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
+::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
+::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
+::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
 ::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
       };
     return kBufferInfos;
@@ -312,13 +323,13 @@
 
   static const ::tensorflow::int32* ArgIndexToBufferIndex() {
     static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
-1, 3
+1, 3, 5, 7, 9
     };
     return kArgIndexToBufferIndex;
   }
 
   // The 0-based index of the result tuple in the temporary buffers.
-  static constexpr size_t kResultIndex = 5;
+  static constexpr size_t kResultIndex = 11;
 
   // Array of names of each positional argument, terminated by nullptr.
   static const char** StaticArgNames() {
@@ -326,6 +337,12 @@
     return kNames;
   }
 
+  // Array of names of each positional variable, terminated by nullptr.
+  static const char** StaticVariableNames() {
+    static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr};
+    return kNames;
+  }
+
   // Array of names of each positional result, terminated by nullptr.
   static const char** StaticResultNames() {
     static const char* kNames[] = {"myfetch", nullptr};
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
index 5420cf3..3870a67 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -28,7 +28,9 @@
       buffer_infos_(static_data.buffer_infos_),
       arg_index_table_(static_data.arg_index_table_),
       num_args_(static_data.num_args_),
+      num_variables_(static_data.num_variables_),
       arg_names_(static_data.arg_names_),
+      variable_names_(static_data.variable_names_),
       result_names_(static_data.result_names_),
       program_shape_(static_data.program_shape_),
       hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
@@ -63,6 +65,8 @@
 
 namespace {
 
+constexpr int kNotFound = -1;
+
 // Linear search through `names` looking for a match with `name`. Returns -1 if
 // the name isn't found, or is empty.
 //
@@ -72,7 +76,6 @@
   // for AOT try the setting the tfcompile --gen_name_to_index flag.
   assert(names != nullptr);
 
-  constexpr int kNotFound = -1;
   if (name.empty()) {
     return kNotFound;
   }
@@ -90,6 +93,14 @@
   return LookupNameIndex(name, arg_names_);
 }
 
+int XlaCompiledCpuFunction::LookupVariableIndex(const string& name) const {
+  int index = LookupNameIndex(name, variable_names_);
+  if (index == kNotFound) {
+    return kNotFound;
+  }
+  return num_args_ - num_variables_ + index;
+}
+
 int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const {
   return LookupNameIndex(name, result_names_);
 }
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index 5e452b5..04d9086 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -76,12 +76,16 @@
     // There are num_args entry parameters.
     int64 num_args_ = 0;
 
+    // There are num_variables variables.
+    int64 num_variables_ = 0;
+
     // The 0-based index of the result tuple, in the temp buffers.
     size_t result_index_ = 0;
 
     // [Optional] Arrays of arg and result names. These are arrays of C-style
     // strings, where the array is terminated by nullptr.
     const char** arg_names_ = nullptr;
+    const char** variable_names_ = nullptr;
     const char** result_names_ = nullptr;
 
     // [Optional] Arg and result shapes.
@@ -150,6 +154,8 @@
 
   int num_args() const { return num_args_; }
 
+  int num_variables() const { return num_variables_; }
+
   // Returns the size of entry parameter `idx`.
   //
   // There is a static version of this method on tfcompile generated subclasses
@@ -212,10 +218,11 @@
   // ------------------------------
   // Methods for extracting optional metadata.
 
-  // Returns true iff data is available for the Lookup{Arg,Result}Index methods.
-  // E.g. the data might not be compiled into the binary for AOT.
+  // Returns true iff data is available for the Lookup{Arg,Variable,Result}Index
+  // methods. E.g. the data might not be compiled into the binary for AOT.
   bool HasNameIndices() const {
-    return arg_names_ != nullptr && result_names_ != nullptr;
+    return arg_names_ != nullptr && variable_names_ != nullptr &&
+           result_names_ != nullptr;
   }
 
   // Returns the 0-based index for the argument with the given `name`.
@@ -226,6 +233,14 @@
   // Recommended usage is to capture this in a variable for re-use.
   int LookupArgIndex(const string& name) const;
 
+  // Returns the 0-based index for the variable with the given `name`.
+  // Returns -1 if the name wasn't found, or data isn't available.
+  //
+  // The index remains constant for every instance of XlaCompiledCpuFunction
+  // generated from the same static data, and might not be cheap to determine.
+  // Recommended usage is to capture this in a variable for re-use.
+  int LookupVariableIndex(const string& name) const;
+
   // Returns the 0-based index for the result with the given `name`.
   // Returns -1 if the name wasn't found, or data isn't available.
   //
@@ -280,6 +295,11 @@
     static_data->num_args_ = num_args;
   }
 
+  static void set_static_data_num_variables(StaticData* static_data,
+                                            int64 num_variables) {
+    static_data->num_variables_ = num_variables;
+  }
+
   static void set_static_data_result_index(StaticData* static_data,
                                            size_t result_index) {
     static_data->result_index_ = result_index;
@@ -290,6 +310,11 @@
     static_data->arg_names_ = arg_names;
   }
 
+  static void set_static_data_variable_names(StaticData* static_data,
+                                             const char** variable_names) {
+    static_data->variable_names_ = variable_names;
+  }
+
   static void set_static_data_result_names(StaticData* static_data,
                                            const char** result_names) {
     static_data->result_names_ = result_names;
@@ -334,6 +359,9 @@
   // The number of incoming arguments.
   const int32 num_args_;
 
+  // The number of incoming variables.
+  const int32 num_variables_;
+
   // Backing memory for buffer_table_ and args_, the latter depending on
   // AllocMode.
   void* alloc_buffer_table_ = nullptr;
@@ -346,6 +374,7 @@
 
   // Optional metadata.
   const char** arg_names_ = nullptr;
+  const char** variable_names_ = nullptr;
   const char** result_names_ = nullptr;
   const xla::ProgramShapeProto* program_shape_ = nullptr;
   const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 0392cc7..0deaa1e 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -49,9 +49,9 @@
   return result_slice.index();
 }
 
-// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold
-// the actual strings in nonempty_names, and hold arrays of pointers in
-// name_ptrs, terminated by a nullptr entry.
+// Collect names from `entries`, where T is one of
+// tf2xla::{Feed,Fetch,Variable}. We hold the actual strings in nonempty_names,
+// and hold arrays of pointers in name_ptrs, terminated by a nullptr entry.
 template <typename T>
 void CollectNames(const T& entries, std::vector<string>* nonempty_names,
                   std::vector<const char*>* name_ptrs) {
@@ -154,14 +154,28 @@
       &jit->static_data_, jit->arg_index_table_.data());
   XlaCompiledCpuFunction::set_static_data_num_args(
       &jit->static_data_, jit->arg_index_table_.size());
+  XlaCompiledCpuFunction::set_static_data_num_variables(&jit->static_data_,
+                                                        config.variable_size());
   XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_,
                                                        result_index);
   // Optional metadata is collected and set below.
   CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
+
+  auto variable_copy = config.variable();
+  for (auto& var : variable_copy) {
+    if (var.name().empty()) {
+      var.set_name(var.node_name());
+    }
+  }
+  CollectNames(variable_copy, &jit->nonempty_variable_names_,
+               &jit->variable_names_);
+
   CollectNames(config.fetch(), &jit->nonempty_result_names_,
                &jit->result_names_);
   XlaCompiledCpuFunction::set_static_data_arg_names(&jit->static_data_,
                                                     jit->arg_names_.data());
+  XlaCompiledCpuFunction::set_static_data_variable_names(
+      &jit->static_data_, jit->variable_names_.data());
   XlaCompiledCpuFunction::set_static_data_result_names(
       &jit->static_data_, jit->result_names_.data());
   XlaCompiledCpuFunction::set_static_data_program_shape(
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
index 11fc457..107968b 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
@@ -77,8 +77,10 @@
   // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static
   // data to refer to.
   std::vector<string> nonempty_arg_names_;
+  std::vector<string> nonempty_variable_names_;
   std::vector<string> nonempty_result_names_;
   std::vector<const char*> arg_names_;
+  std::vector<const char*> variable_names_;
   std::vector<const char*> result_names_;
 
   // The backing data for the program shape. The proto form of program shape is
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
index f5d6b52..880cb59 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
@@ -210,6 +210,9 @@
   EXPECT_EQ(function.LookupResultIndex("x_name"), -1);
   EXPECT_EQ(function.LookupResultIndex("y_name"), -1);
 
+  EXPECT_EQ(0, function.num_variables());
+  EXPECT_EQ(function.LookupVariableIndex("x"), -1);
+
   // Check program shape.
   using xla::ShapeUtil;
   const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
@@ -252,6 +255,14 @@
   EXPECT_EQ(*static_cast<int32*>(function.result_data(0)), 100);
   EXPECT_EQ(*static_cast<int32*>(function.result_data(1)), 420);
 
+  // Check name to index lookups.
+  EXPECT_TRUE(function.HasNameIndices());
+
+  EXPECT_EQ(2, function.num_args());
+
+  EXPECT_EQ(1, function.num_variables());
+  EXPECT_EQ(function.LookupVariableIndex("myvar"), 1);
+
   // Check program shape.
   using xla::ShapeUtil;
   const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});