Add Handling of Cat in Shape Analysis (#65575)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65575
This is needed for lowering an NNC model to mobile. It is also the last class of unhandled ops which NNC fuses, and we need integration this for computing output symbolic shapes.
The graph of with two dynamic shape inputs produces:
```
graph(%x.1 : Tensor(SS(-2), 2, 3),
%y.1 : Tensor(SS(-3), 2, 3)):
%5 : int = prim::Constant[value=0]()
%4 : Tensor[] = prim::ListConstruct(%x.1, %y.1)
%6 : Tensor(SS(-4), 2, 3) = aten::cat(%4, %5) # /private/home/eellison/pytorch/test/jit/test_symbolic_shape_analysis.py:290:19
return (%6)
```
With a partial eval graph of
```
Done with partial evaluation
graph(%129 : int[],
%130 : int[],
%dim.14 : int):
%738 : int = prim::Constant[value=3]()
%737 : int = prim::Constant[value=2]()
%132 : int = prim::Constant[value=0]()
%392 : int = aten::__getitem__(%129, %132) # <string>:339:44
%417 : int = aten::__getitem__(%130, %132) # <string>:339:44
%cat_dim_size.48 : int = aten::add(%392, %417) # <string>:339:29
%result_size.5 : int[] = prim::ListConstruct(%cat_dim_size.48, %737, %738)
return (%result_size.5)
```
To handle cat, I essentially make the cat shape op variadic,
replacing
```
torch.cat([x, y]
...
def cat_shape_op(tensors: List[List[int]], dim: int):
...
op(tensors)
```
with
```
def cat_shape_op(x: List[int], y: List[int], dim: int):
tensors = [x, y]
op(tensors)
```
This reuses the existing input Tensor properties partial evaluation path and avoids having to add special handling to optimize out `len(tensors)` calls in the IR.
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D31732416
Pulled By: eellison
fbshipit-source-id: 6d93ddf62c34846ec238159f75229632515530b7
diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py
index 82216e9..2fb4edb 100644
--- a/test/jit/test_symbolic_shape_analysis.py
+++ b/test/jit/test_symbolic_shape_analysis.py
@@ -5,8 +5,10 @@
from torch.testing import FileCheck
from torch.testing._internal.common_utils import make_tensor
+from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
from torch import nn
+
from textwrap import dedent
if __name__ == '__main__':
@@ -257,6 +259,31 @@
self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True, constant_prop=False)
+ def test_shape_concat(self):
+ # TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR
+ sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False)
+
+ class CatMod(nn.Module):
+ __constants__ = ['dim']
+
+ def __init__(self, dim=0):
+ super(CatMod, self).__init__()
+ self.dim = dim
+
+ def forward(self, x, y):
+ return torch.cat([x, y], dim=self.dim)
+
+ for inp in sample_inputs:
+ mod = torch.jit.script(CatMod(**inp.kwargs).eval())
+
+ args = inp.input
+ self.assertTrue(len(args) == 2)
+ out_size = mod(*args).size()
+ inps = list(mod.graph.inputs())
+ inps[1].setType(inps[1].type().with_sizes(args[0].size()))
+ inps[2].setType(inps[2].type().with_sizes(args[1].size()))
+ self.checkShapeAnalysis(out_size, mod.graph, assert_propagation=True)
+
def test_partial_eval_graph_conv(self):
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
index e00c185..b9ce451 100644
--- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
@@ -23,6 +23,7 @@
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/utils/memory.h>
#include <memory>
+#include <numeric>
#include <unordered_map>
#include <vector>
@@ -136,6 +137,15 @@
namespace {
+IValue tensor_sizes_from_tensor_list(const IValue& iv) {
+ c10::List<c10::List<int64_t>> tensor_sizes;
+ auto tensor_list = iv.toTensorVector();
+ for (const auto& ten : tensor_list) {
+ tensor_sizes.push_back(c10::List<int64_t>(ten.sizes()));
+ }
+ return tensor_sizes;
+}
+
bool isListOfInts(const TypePtr& type) {
return type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<IntType>();
@@ -198,69 +208,90 @@
// NB: shape compute graphs may have less inputs than their node
// counterparts to allow e.g. sharing one single unary definition
// so iterate on # of shape inputs
- for (size_t i = 0; i < shape_compute_graph_->inputs().size(); i++) {
- auto type = node_->input(i)->type();
+ size_t shape_graph_initial_inputs = shape_compute_graph_->inputs().size();
+ // We make lists of Tensor inputs variadic, which results in
+ // offset between a node index and its corresponding graph index
+ size_t graph_index_offset = 0;
+ for (size_t node_index = 0; node_index < shape_graph_initial_inputs;
+ node_index++) {
+ auto type = node_->input(node_index)->type();
+ size_t graph_index = graph_index_offset + node_index;
if (auto opt_type = shape_compute_graph_->inputs()
- .at(i)
+ .at(graph_index)
->type()
->cast<OptionalType>()) {
// None will get handled with constant substitution later
if (!type->cast<OptionalType>() &&
!NoneType::get()->isSubtypeOf(*type)) {
- shape_compute_graph_->inputs().at(i)->setType(
- opt_type->getElementType());
+ shape_compute_graph_->inputs()
+ .at(graph_index)
+ ->setType(opt_type->getElementType());
}
} else if (shape_compute_graph_->inputs()
- .at(i)
+ .at(graph_index)
->type()
->cast<NumberType>()) {
- shape_compute_graph_->inputs().at(i)->setType(type);
+ shape_compute_graph_->inputs().at(graph_index)->setType(type);
}
if (auto tt = type->castRaw<TensorType>()) {
- // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
- c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
-
- // for testing, we don't insert complete tensor shapes and rely on our
- // partial evaluation pipeline to propagate information.
- // this is a good proxy for our ability to propagate non-complete shape
- // information.
-
- if (symbolic_shapes.isComplete() &&
- !symbolic_shape_analysis_test_mode) {
- replaceWithIValue(
- shape_compute_graph_->inputs().at(i),
- *tt->sizes().concrete_sizes());
- continue;
- }
- // TODO: remove, all constant tensors should have typed sizes
- if (toIValue(node_->input(i))) {
- auto size = constant_as<at::Tensor>(node_->input(i))->sizes();
- if (!symbolic_shape_analysis_test_mode) {
- replaceWithIValue(shape_compute_graph_->inputs().at(i), size);
- } else {
- node_symbolic_input_indices_.emplace_back(
- i, c10::SymbolicShape(size));
- }
- continue;
- }
-
- // we can't optimize a tensor without fixed rank
- if (symbolic_shapes.rank()) {
- node_symbolic_input_indices_.emplace_back(i, symbolic_shapes);
- }
+ addTensorInputMetaData(node_->input(node_index), graph_index);
} else if (
type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<TensorType>()) {
- TORCH_INTERNAL_ASSERT(false); // not handled yet
- } else if (auto ival = toIValue(node_->input(i))) {
- replaceWithIValue(shape_compute_graph_->inputs().at(i), *ival);
+ // When we have partially evaluate a list of Tensors like cat(tensor[])
+ // We have a few problems:
+ // - optimizing out calls to the length of the list: len(tensors)
+ // - resolving accesses of the list to the tensor symbolic sizes the
+ // corresponding list element We can solve both of these problems by
+ // replacing the partial evaluation of cat([x, y]) def cat(tensors:
+ // List[List[int]], dim: int)
+ // body
+ // with
+ // def cat(x, y, dim: int)
+ // tensors = [x, y]
+ // body
+ // This reuses the existing input Tensors partial evaluation and allows
+ // our existing optimizations to optimize out len(tensors) instead of
+ // requiring extra partial evaluation within this pass
+ if (node_->input(node_index)->node()->kind() == prim::Constant) {
+ replaceWithIValue(
+ shape_compute_graph_->inputs().at(graph_index),
+ tensor_sizes_from_tensor_list(
+ *toIValue(node_->input(node_index))));
+ } else if (
+ node_->input(node_index)->node()->kind() == prim::ListConstruct &&
+ !db.hasWriters(node_->input(node_index))) {
+ auto li_construct_node = node_->input(node_index)->node();
+ std::vector<Value*> li_inputs;
+ Value* graph_input = shape_compute_graph_->inputs().at(graph_index);
+ for (size_t j = 0; j < li_construct_node->inputs().size(); ++j) {
+ auto new_inp = shape_compute_graph_->insertInput(graph_index + j);
+ new_inp->setType(ListType::ofInts());
+ li_inputs.push_back(new_inp);
+ }
+ WithInsertPoint guard(
+ *shape_compute_graph_->block()->nodes().begin());
+ auto new_li = shape_compute_graph_->insertNode(
+ shape_compute_graph_->createList(ListType::ofInts(), li_inputs));
+ graph_input->replaceAllUsesWith(new_li->output());
+ for (size_t j = 0; j < li_construct_node->inputs().size(); ++j) {
+ addTensorInputMetaData(
+ li_construct_node->input(j), graph_index + j);
+ }
+ shape_compute_graph_->eraseInput(
+ node_index + li_construct_node->inputs().size());
+ graph_index_offset += li_construct_node->inputs().size() - 1;
+ }
+ } else if (auto ival = toIValue(node_->input(node_index))) {
+ replaceWithIValue(
+ shape_compute_graph_->inputs().at(graph_index), *ival);
} else if (
type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<IntType>()) {
- if (node_->input(i)->node()->kind() == prim::ListConstruct &&
- !db.hasWriters(node_->input(i))) {
+ if (node_->input(node_index)->node()->kind() == prim::ListConstruct &&
+ !db.hasWriters(node_->input(node_index))) {
// it is a very common in graphs to see patterns like:
// z = x.view(y.size())
// or:
@@ -269,7 +300,7 @@
// from y to z. To do this we try to associate symbolic dimensions
// or concrete sizes with the integer list inputs that have a
// constructor taken from constants or y.size() or y.size(0)
- auto list_construct = node_->input(i)->node();
+ auto list_construct = node_->input(node_index)->node();
std::vector<ShapeArg> shape;
for (Value* v : list_construct->inputs()) {
if (auto constant = constant_as<int64_t>(v)) {
@@ -295,18 +326,57 @@
shape.emplace_back(ShapeArg::unknownInteger());
}
}
- node_symbolic_input_indices_.emplace_back(i, std::move(shape));
+ node_symbolic_input_indices_.emplace_back(
+ graph_index, std::move(shape));
} else if (
- node_->input(i)->node()->kind() == aten::size &&
- !db.hasWriters(node_->input(i))) {
- auto ten_inp = node_->input(i)->node()->input();
+ node_->input(node_index)->node()->kind() == aten::size &&
+ !db.hasWriters(node_->input(node_index))) {
+ auto ten_inp = node_->input(node_index)->node()->input();
auto ss = ten_inp->type()->expect<TensorType>()->symbolic_sizes();
- node_symbolic_input_indices_.emplace_back(i, ss);
+ node_symbolic_input_indices_.emplace_back(graph_index, ss);
}
}
}
}
+ void addTensorInputMetaData(
+ Value* tensor_v,
+ size_t shape_compute_graph_index) {
+ auto tt = tensor_v->type()->expect<TensorType>();
+ // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+ c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
+
+ // for testing, we don't insert complete tensor shapes and rely on our
+ // partial evaluation pipeline to propagate information.
+ // this is a good proxy for our ability to propagate non-complete shape
+ // information.
+
+ if (symbolic_shapes.isComplete() && !symbolic_shape_analysis_test_mode) {
+ replaceWithIValue(
+ shape_compute_graph_->inputs().at(shape_compute_graph_index),
+ *tt->sizes().concrete_sizes());
+ return;
+ }
+ // TODO: remove, all constant tensors should have typed sizes
+ if (toIValue(tensor_v)) {
+ auto size = constant_as<at::Tensor>(tensor_v)->sizes();
+ if (!symbolic_shape_analysis_test_mode) {
+ replaceWithIValue(
+ shape_compute_graph_->inputs().at(shape_compute_graph_index), size);
+ } else {
+ node_symbolic_input_indices_.emplace_back(
+ shape_compute_graph_index, c10::SymbolicShape(size));
+ }
+ return;
+ }
+
+ // we can't optimize a tensor without fixed rank
+ if (symbolic_shapes.rank()) {
+ node_symbolic_input_indices_.emplace_back(
+ shape_compute_graph_index, symbolic_shapes);
+ }
+ }
+
// returns partially evaluated shape compute graph
std::shared_ptr<Graph> run() {
bool made_change = true;
@@ -527,11 +597,7 @@
// of TensorTypes with a fixed dimension but not a complete shape,
// because a complete shape we can completely replace with a constant
// and non-fixed dimensions we cannot reason about at all
- // TODO: might be cleaner to store as a pair of index -> symbolic shape
- // but there were weird lifetime issues
std::vector<std::pair<int64_t, ShapeArguments>> node_symbolic_input_indices_;
- std::vector<std::pair<int64_t, c10::SymbolicShape>>
- node_symbolic_input_indices;
std::shared_ptr<Graph> shape_compute_graph_;
Node* node_;
};
diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
index 911910e..3da0e4d 100644
--- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
+++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
@@ -101,11 +101,14 @@
out[infer_dim] = numel // newsize
return out
- def view(self: List[int], sizes: List[int]):
+ def numel(sizes: List[int]):
numel = 1
- for elem in self:
+ for elem in sizes:
numel *= elem
- return infer_size_impl(sizes, numel)
+ return numel
+
+ def view(self: List[int], sizes: List[int]):
+ return infer_size_impl(sizes, numel(self))
def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool=False):
return view(self, sizes)
@@ -309,6 +312,53 @@
out[dim] = (len + step - 1) // step
return out
+ def check_cat_no_zero_dim(tensors: List[List[int]]):
+ for tensor in tensors:
+ assert(len(tensor) > 0)
+
+ def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
+ out_dim : Optional[int] = None
+ for size in tensor_sizes:
+ if len(size) != 0 and size != [0] and out_dim is not None:
+ out_dim = maybe_wrap_dim(dim, len(size))
+ if out_dim is None:
+ out_dim = dim
+ return out_dim
+
+ def should_skip(tensor: List[int]):
+ return numel(tensor) == 0 and len(tensor) == 1
+
+ def check_cat_shape_except_dim(first: List[int], second: List[int], dimension: int, index: int):
+ first_dims = len(first)
+ second_dims = len(second)
+ assert first_dims == second_dims, "Tensors must have same number of dimensions"
+ for dim in range(0, first_dims):
+ if dim != dimension:
+ assert first[dim] == second[dim], "Sizes of tensors must match except in dimension"
+
+ def cat(tensors: List[List[int]], dim: int):
+ check_cat_no_zero_dim(tensors)
+ dim = legacy_cat_wrap_dim(dim, tensors)
+ assert len(tensors) > 0
+ not_skipped_tensor: Optional[List[int]] = None
+ for tensor in tensors:
+ if not should_skip(tensor):
+ not_skipped_tensor = tensor
+ if not_skipped_tensor is None:
+ return [0]
+
+ cat_dim_size = 0
+
+ for i in range(len(tensors)):
+ tensor = tensors[i]
+ if not should_skip(tensor):
+ check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
+ cat_dim_size = cat_dim_size + tensor[dim]
+
+ result_size = _copy(not_skipped_tensor)
+ result_size[dim] = cat_dim_size
+ return result_size
+
def select(self: List[int], dim: int, index: int):
ndim = len(self)
assert ndim != 0
@@ -639,6 +689,7 @@
{"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "batch_norm"},
{"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"},
{"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
+ {"aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "cat"},
{"aten::relu(Tensor self) -> Tensor", "unary"},
{"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"},
{"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"},