Unsqueeze ops to reduce the number of reshapes in we use in LTC (#72011)
Summary:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72011
Reviewed By: gmagogsfm
Differential Revision: D33855760
Pulled By: Krovatkin
fbshipit-source-id: abe5572567c8f7746e7b06a552dfbe5566c3d3ce
(cherry picked from commit 8eac12685f17a145e1d5d78fcf0d65131248c5c3)
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 61545d5..88d6eeb 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -413,6 +413,7 @@
"torch/csrc/lazy/core/view_ops/resize.cpp",
"torch/csrc/lazy/core/view_ops/select.cpp",
"torch/csrc/lazy/core/view_ops/squeeze.cpp",
+ "torch/csrc/lazy/core/view_ops/unsqueeze.cpp",
"torch/csrc/lazy/core/view_ops/select_view_update.cpp",
"torch/csrc/lazy/core/view_ops/view.cpp",
"torch/csrc/lazy/ts_backend/config.cpp",
diff --git a/torch/csrc/lazy/core/view_ops/unsqueeze.cpp b/torch/csrc/lazy/core/view_ops/unsqueeze.cpp
new file mode 100644
index 0000000..4cea538
--- /dev/null
+++ b/torch/csrc/lazy/core/view_ops/unsqueeze.cpp
@@ -0,0 +1,38 @@
+#include <torch/csrc/lazy/core/view_ops/unsqueeze.h>
+#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
+
+namespace torch {
+namespace lazy {
+
+std::vector<int64_t> BuildUnsqueezedDimensions(
+ c10::ArrayRef<int64_t> dimensions,
+ int64_t squeeze_dim) {
+ std::vector<int64_t> output_dimensions(
+ dimensions.cbegin(), dimensions.cend());
+ output_dimensions.insert(output_dimensions.begin() + squeeze_dim, 1);
+ return output_dimensions;
+}
+
+Unsqueeze::Unsqueeze(const torch::lazy::Value& input, int dim)
+ : torch::lazy::TsNode(
+ torch::lazy::OpKind(at::aten::unsqueeze),
+ {input},
+ /*num_outputs=*/1,
+ torch::lazy::MHash(dim)),
+ dim_(dim) {
+ SetShapeDeferred([&]() {
+ const auto& input_shape = GetShapeFromTsValue(input);
+ return torch::lazy::Shape(
+ input_shape.scalar_type(),
+ BuildUnsqueezedDimensions(input_shape.sizes(), dim));
+ });
+}
+
+std::string Unsqueeze::ToString() const {
+ std::stringstream ss;
+ ss << torch::lazy::TsNode::ToString() << ", dim=" << dim_;
+ return ss.str();
+}
+
+} // namespace lazy
+} // namespace torch
diff --git a/torch/csrc/lazy/core/view_ops/unsqueeze.h b/torch/csrc/lazy/core/view_ops/unsqueeze.h
new file mode 100644
index 0000000..113e2da
--- /dev/null
+++ b/torch/csrc/lazy/core/view_ops/unsqueeze.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include <torch/csrc/lazy/ts_backend/ts_node.h>
+
+namespace torch {
+namespace lazy {
+
+TORCH_API std::vector<int64_t> BuildUnsqueezedDimensions(
+ c10::ArrayRef<int64_t> dimensions,
+ int64_t squeeze_dim);
+
+class TORCH_API Unsqueeze : public TsNode {
+ public:
+ Unsqueeze(const torch::lazy::Value& input, int dim);
+
+ std::string ToString() const override;
+
+ int dim() const {
+ return dim_;
+ }
+
+ private:
+ int dim_;
+};
+
+} // namespace lazy
+} // namespace torch