Optimize alignBatchDimsAtFront (#41941)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41941
If we know that the tensor already has the desired aligned size, we
don't need to put in the effort to align it.
Test Plan: - `./build/bin/vmap_test`, `pytest test/test_vmap.py -v`
Reviewed By: albanD
Differential Revision: D22764101
Pulled By: zou3519
fbshipit-source-id: a2ab7ce7b98d405ae905f7fd98db097210bfad65
diff --git a/aten/src/ATen/VmapTransforms.cpp b/aten/src/ATen/VmapTransforms.cpp
index c8bcae5..f9f3517 100644
--- a/aten/src/ATen/VmapTransforms.cpp
+++ b/aten/src/ATen/VmapTransforms.cpp
@@ -157,7 +157,13 @@
auto tensor_example_dim = physical_sizes.size() - /*num_batch_dims*/tensor_levels.count();
TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
- std::vector<int64_t> aligned_sizes(requested_levels.count() + requested_example_dim, 1);
+ if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
+ // Optimization: no need to do another view if the physical tensor is
+ // already the correct shape
+ return physical_tensor;
+ }
+
+ VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
// align the example dims (non-bdims dims) first
// aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]