[Profiler] Store Input shapes, dtypes, and metadata into flat AppendOnlyList (#74241)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74241

Adds the following changes:
- During collection, replaces the vector of vector of int shapes, and vector of string dtypes. Instead pack the IValue details into InputOutputEncoder as flat AppendOnlyLists.
- This will save each IValue with a enum tag, metadata holding its dim and dtype, and the shapes.
- During Post-Processing, re-construct the vectors that are originally expected (struct Inputs).

Reviewed By: chaekit

Differential Revision: D34823546

Pulled By: aaronenyeshi

fbshipit-source-id: 718fccaa8aab16128da986d665564a8fef5436c8
(cherry picked from commit 96a47c068e55220e7b7224c8a1935033859b5cd2)
diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp
index bdadabd..c9c99e4 100644
--- a/torch/csrc/profiler/collection.cpp
+++ b/torch/csrc/profiler/collection.cpp
@@ -3,12 +3,112 @@
 #include <algorithm>
 
 #include <ATen/record_function.h>
+#include <c10/core/ScalarTypeToTypeMeta.h>
 #include <c10/util/overloaded.h>
 #include <torch/csrc/jit/runtime/interpreter.h>
 
 namespace torch {
 namespace profiler {
 namespace impl {
+
+void InputOutputEncoder::push(const std::vector<c10::IValue>& values) {
+  for (const auto& value : values) {
+    if (value.isTensor()) {
+      push(value.toTensor());
+    } else if (value.isScalar()) {
+      tags_.emplace_back(Tag::Scalar);
+    } else if (value.isTensorList()) {
+      tags_.emplace_back(Tag::TensorListBegin);
+      // TODO: Skip TensorList for now.
+      tags_.emplace_back(Tag::TERMINATOR);
+    } else {
+      tags_.emplace_back(Tag::Other);
+    }
+  }
+  tags_.emplace_back(Tag::TERMINATOR);
+}
+
+void InputOutputEncoder::push(const at::Tensor& t) {
+  if (t.defined()) {
+    tags_.emplace_back(Tag::Tensor);
+    const auto& sizes = t.sizes();
+    const auto dim = sizes.size();
+    TORCH_CHECK(
+      dim <= std::numeric_limits<uint32_t>::max(),
+      "Cannot profile Tensors of size > uint32 max. Got dim: ", dim);
+
+    tensor_metadata_.emplace_back(
+      /*ptr_=*/(void*)t.unsafeGetTensorImpl(),
+      /*dtype_=*/t.scalar_type(),
+      /*dim_=*/(uint32_t)dim
+    );
+
+    for (const auto i : sizes) {
+      tensor_sizes_.emplace_back(i);
+    }
+  } else {
+    tags_.emplace_back(Tag::UndefinedTensor);
+  }
+}
+
+// This is a custom-iterator-like getter to obtain input shapes and dtypes.
+auto InputOutputEncoder::getNextShapesAndDtypes() {
+  return [this, tag_it = tags_.begin(),
+          tensor_metadata_it = tensor_metadata_.begin(),
+          tensor_size_it = tensor_sizes_.begin()]() mutable {
+    struct Inputs out;
+    bool terminate = false;
+    while (!terminate && tag_it != tags_.end()) {
+      out.shapes_.emplace_back();
+      switch (*tag_it) {
+        case Tag::Tensor:
+          {
+            const auto& md = *tensor_metadata_it++;
+            for (const auto _ : c10::irange(md.dim_)) {
+              (void)_; // Suppress unused variable warning
+              out.shapes_.back().push_back(*tensor_size_it++);
+            }
+            out.dtypes_.emplace_back(scalarTypeToTypeMeta(md.dtype_).name());
+          }
+          break;
+
+        case Tag::TensorListBegin:
+            while (*(++tag_it) != Tag::TERMINATOR) {
+              // TODO: Skip TensorLists for now.
+            }
+          out.dtypes_.emplace_back("TensorList");
+          break;
+
+        case Tag::Scalar:
+          out.dtypes_.emplace_back("Scalar");
+          break;
+
+        case Tag::UndefinedTensor:
+        case Tag::Other:
+          out.dtypes_.emplace_back();
+          break;
+
+        case Tag::TERMINATOR:
+          // This marks the end of this op.
+          out.shapes_.pop_back();
+          terminate = true;
+          break;
+
+        default:
+          break;
+      }
+      ++tag_it;
+    }
+    return out;
+  };
+}
+
+void InputOutputEncoder::clear() {
+  tags_.clear();
+  tensor_metadata_.clear();
+  tensor_sizes_.clear();
+}
+
 namespace {
 // See `RecordQueue::getSubqueue()` for an overview of this cache.
 struct SubQueueThreadCache {
@@ -56,9 +156,7 @@
       fn.debugHandle(),
       fn.name());
   if (config_.report_input_shapes) {
-    inputs_.emplace_back(
-        torch::profiler::impl::inputSizes(fn),
-        torch::profiler::impl::inputTypes(fn));
+    inputs_outputs_.push(fn.inputs());
   }
 
 #if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE
@@ -152,7 +250,7 @@
       out.push_back(std::move(r));
     }
 
-    auto input_it = queue.inputs_.begin();
+    auto input_getter = queue.inputs_outputs_.getNextShapesAndDtypes();
     auto jit_stack_it = queue.jit_stack_.begin();
     auto jit_module_it = queue.jit_modules_.begin();
     auto extra_args_it = queue.extra_args_.begin();
@@ -164,7 +262,7 @@
       r.start_tid_ = queue.tid();
       r.kineto_info_ = queue.kineto_info();
       r.event_ = std::move(i);
-      r.inputs_ = steal_or_default(input_it);
+      r.inputs_ = input_getter();
       r.jit_stack_ = steal_or_default(jit_stack_it);
       r.jit_modules_ = steal_or_default(jit_module_it);
       r.extra_args_ = steal_or_default(extra_args_it);
@@ -173,7 +271,7 @@
       out.push_back(std::move(r));
     }
     queue.op_events_.clear();
-    queue.inputs_.clear();
+    queue.inputs_outputs_.clear();
     queue.jit_stack_.clear();
     queue.jit_modules_.clear();
     queue.extra_args_.clear();
diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h
index 8402512..a82d947 100644
--- a/torch/csrc/profiler/collection.h
+++ b/torch/csrc/profiler/collection.h
@@ -4,6 +4,7 @@
 #include <mutex>
 #include <utility>
 
+#include <ATen/Context.h>
 #include <c10/macros/Macros.h>
 #include <c10/util/flat_hash_map.h>
 #include <c10/util/variant.h>
@@ -95,6 +96,45 @@
   FallbackPair* fallback_ {nullptr};
 };
 
+constexpr int IO_ENCODER_DEFAULT_BLOCK_SIZE = 1024;
+
+// InputOutputEncoder
+// Stores each op_events' shapes and dtypes into a contiguous AppendOnlyList
+// so that we no longer create vectors for shapes and dtypes on every op.
+// Those vectors can be created during post-processing.
+class InputOutputEncoder final {
+ public:
+  void push(const std::vector<c10::IValue>& values);
+
+  // Used during post-processing to create vectors for shapes and dtype.
+  auto getNextShapesAndDtypes();
+
+  void clear();
+
+ private:
+  enum class Tag {
+    Tensor = 0,
+    UndefinedTensor,
+    TensorListBegin, // TODO: generalize to other lists.
+    Scalar,
+    Other,
+    TERMINATOR
+  };
+
+  struct TensorMetadata {
+    void* ptr_;
+    c10::ScalarType dtype_;
+    uint32_t dim_;
+  };
+
+  void push(const at::Tensor& t);
+
+  AppendOnlyList<Tag, IO_ENCODER_DEFAULT_BLOCK_SIZE> tags_;
+  AppendOnlyList<TensorMetadata, IO_ENCODER_DEFAULT_BLOCK_SIZE> tensor_metadata_;
+  AppendOnlyList<int64_t, IO_ENCODER_DEFAULT_BLOCK_SIZE> tensor_sizes_;
+};
+
+
 class TORCH_API ThreadLocalSubqueue {
  public:
   ThreadLocalSubqueue(const uint64_t tid, const ProfilerConfig& config);
@@ -125,7 +165,7 @@
   AppendOnlyList<OpEvent, BlockSize> op_events_;
 
   // report_input_shapes
-  AppendOnlyList<Inputs, BlockSize> inputs_;
+  InputOutputEncoder inputs_outputs_;
 
   // with_stack
   AppendOnlyList<std::vector<std::string>, BlockSize> jit_stack_;